diff --git a/tf_agents/bandits/agents/exp3_mixture_agent_test.py b/tf_agents/bandits/agents/exp3_mixture_agent_test.py index ca1766378..6feb4941c 100644 --- a/tf_agents/bandits/agents/exp3_mixture_agent_test.py +++ b/tf_agents/bandits/agents/exp3_mixture_agent_test.py @@ -195,7 +195,9 @@ def testMixtureUpdate( reward_aggregates = self.evaluate( mixed_agent._variable_collection.reward_aggregates ) - self.assertAllInSet(reward_aggregates[: num_agents - 1], [0.999]) + self.assertAllClose( + reward_aggregates[: num_agents - 1], [0.999] * (num_agents - 1) + ) agent_prob = 1 / num_agents est_rewards = 0.5 / agent_prob per_step_update = est_rewards