k-Armed Bandit 1.0.0
A collection of k-armed bandits and assoicated agents for reinforcement learning
Loading...
Searching...
No Matches
test_static.py
Go to the documentation of this file.
1from bandit import Static
2import numpy
3import unittest
4
5
6class TestStaticBandit(unittest.TestCase):
7 """
8 Test case to test that the static bandit works correctly.
9 """
10
12 """
13 Test that the class can handle reward values. Acceptable values should
14 be some sort of iterable with a number of elements equal to k. Each
15 element should be numeric. Alternatively, None can be used to have the
16 class randomly select values.
17 """
18 k = 4
19 # Iterables can include lists, arrays, or numpy arrays, for starters.
20 for values in ((1, 2, 3, 4), [1, 2, 3, 4], numpy.array([1, 2, 3, 4])):
21 with self.subTest(values=values):
22 bandit = Static(k, values)
23 # Make sure that the stored rewards are what was provided.
24 rewards = bandit.trueValues()
25 for i, value in enumerate(values):
26 self.assertEqual(
27 value, rewards[i], 'Stored reward does not match provided.')
28 self.assertIsNotNone(bandit)
29 # The user can pass in None to randomly select rewards between [0, 1).
30 # Since this is technically random, try several times.
31 for i in range(50):
32 bandit = Static(k, None)
33 values = bandit.trueValues()
34 for value in values:
35 self.assertGreaterEqual(value, 0)
36 self.assertLess(value, 1)
37 # An iterable must have the same length as k, otherwise it should fail.
38 with self.assertRaises(ValueError, msg='Static bandit did not reject incorrect length rewards.'):
39 Static(k=3, rewards=(1, 2))
40 # Other, non-numeric iterables and data types should produce some form
41 # of error. This includes multi-dimensional numpy arrays.
42 for values in ((1, 2), ('the', 'it', 4, 5), 4, '1 2 3 4', numpy.array([[1, 2], [3, 4]])):
43 with self.assertRaises(Exception, msg='Static bandit did not reject non-numeric rewards'):
44 Static(k, values)
45
47 """
48 Test that the right reward is returned when each arm is selected. This
49 should allow any indexing inputs that you could use for numpy arrays
50 and reject everything else. The correct inputs should return the
51 appropriate rewards that match the reward values set by the class.
52 """
53 k = 10
54 true_rewards = numpy.random.uniform(low=0, high=1, size=10)
55 bandit = Static(k, true_rewards)
56 # This method can accept a single integer.
57 for arm in (-3, 0, 2, 9):
58 expected_reward = bandit.select(arm)
59 self.assertEqual(
60 expected_reward, true_rewards[arm], 'Static bandit did not provide a correct reward.')
61 # It can also accept an iterable type.
62 arms = range(k)
63 expected_rewards = bandit.select(arms)
64 self.assertTrue(numpy.array_equal(
65 true_rewards[arms], expected_rewards))
66 # Additionally, incorrect inputs should be rejected, including indices
67 # out of range.
68 for i in (0.5, 10, '1', 'the'):
69 rewards = bandit.trueValues()
70 with self.subTest(i=i):
71 with self.assertRaises(Exception, msg='Incorrect indices not rejected.'):
72 reward = bandit.select(i)
73 # None is a special case and will return None
74 self.assertIsNone(bandit.select(None))
75
76
77if __name__ == '__main__':
78 unittest.main()
This class implements a bandit with a constant reward value each time an arm is chosen.
Definition static.py:5
Test case to test that the static bandit works correctly.
Definition test_static.py:6
test_correct_rewards(self)
Test that the right reward is returned when each arm is selected.
test_instantiate_rewards(self)
Test that the class can handle reward values.