k-Armed Bandit 1.0.0
A collection of k-armed bandits and assoicated agents for reinforcement learning
Loading...
Searching...
No Matches
test_base_agent.py
Go to the documentation of this file.
1from agent import BaseAgent
2import unittest
3
4
6 """
7 A fake child class to allow testing of BaseAgent.
8
9 BaseAgent is an abstract class, so can't be instantiated directly. This FakeChild class implements the bare minimum
10 to allow testing of the elements of the base class that can be tested.
11 """
12
13 def __init__(self, k: int, start_value: float = 0.0) -> None:
14 super().__init__(k, start_value=start_value)
15
16 def act(self) -> int:
17 return 0
18
19 def update(self, action: int, reward: float) -> None:
20 pass
21
22
23class TestBaseAgent(unittest.TestCase):
24 """
25 Test the BaseAgent class.
26
27 This tests the creating of the Q-table, as well as the exploration/exploitation helper methods.
28 """
29
31 """
32 Verify that the table initializes with the right values.
33 """
34 # Positive integers are valid and should produce a Q-table of the same size.
35 for k in (1, 10, 100):
36 agent = FakeAgent(k=k, start_value=0.0)
37 self.assertEqual(agent.table.size, k,
38 msg='BaseAgent did not make a correct table size.')
39 # Anything else should fail.
40 for k in (0, -1, 0.5, '1', 'the', None):
41 with self.assertRaises(Exception, msg='Static bandit did not reject invalid k input.'):
42 FakeAgent(k) # type:ignore
43
45 """
46 Test that the class picks the best action based on the table.
47
48 This tests that the class will reliable return the index associated with the best action. It should also
49 arbitrarily break ties in the case of multiple entries with an equivalent score.
50 """
51 agent = FakeAgent(k=4, start_value=0.0)
52 # Set one clearly better than the others to make sure it is returned.
53 ACTUAL_BEST = 2
54 BEST_REWARD = 100.0
55 agent._table[ACTUAL_BEST] = BEST_REWARD
56 expected_best = agent.exploit()
57 self.assertEqual(ACTUAL_BEST, expected_best,
58 'Exploitation picked an incorrect index.')
59 # Set another equal to force a tie.
60 agent._table[ACTUAL_BEST + 1] = BEST_REWARD
61 # Sample several times to make sure it never picks anything else.
62 for _ in range(100):
63 expected_best = agent.exploit()
64 result = (expected_best == ACTUAL_BEST) or (
65 expected_best == ACTUAL_BEST + 1)
66 self.assertTrue(
67 result, msg='Exploitation picked an incorrect index when breaking a tie.')
68
70 """
71 Test that the class picks a random valid action from the table.
72 """
73 K = 4
74 agent = FakeAgent(k=K, start_value=0.0)
75 # Sample several times and make sure the result is always a valid index
76 possible_actions = list(range(K))
77 for _ in range(100):
78 action = agent.explore()
79 self.assertTrue(action in possible_actions,
80 msg='Exploration produced an invalid index.')
A base class used to create a variety of bandit solving agents.
Definition base_agent.py:5
A fake child class to allow testing of BaseAgent.
None __init__(self, int k, float start_value=0.0)
Construct the agent.
int act(self)
Use a specific algorithm to determine which action to take.
None update(self, int action, float reward)
Update the Q-Table.
test_q_table_creation(self)
Verify that the table initializes with the right values.
test_exploitation(self)
Test that the class picks the best action based on the table.
test_exploration(self)
Test that the class picks a random valid action from the table.