diff --git a/dm_env_rpc/v1/dm_env_adaptor_test.py b/dm_env_rpc/v1/dm_env_adaptor_test.py index 10e8688..3ef151e 100644 --- a/dm_env_rpc/v1/dm_env_adaptor_test.py +++ b/dm_env_rpc/v1/dm_env_adaptor_test.py @@ -152,8 +152,8 @@ def test_first_running_step(self): self._connection.send.assert_called_once_with(_SAMPLE_STEP_REQUEST) self.assertEqual(dm_env.StepType.FIRST, timestep.step_type) - self.assertEqual(None, timestep.reward) - self.assertEqual(None, timestep.discount) + self.assertIsNone(timestep.reward) + self.assertIsNone(timestep.discount) self.assertEqual({'foo': 5, 'bar': 'goodbye'}, timestep.observation) def test_mid_running_step(self): @@ -195,8 +195,8 @@ def test_reset(self): timestep = self._env.reset() self.assertEqual(dm_env.StepType.FIRST, timestep.step_type) - self.assertEqual(None, timestep.reward) - self.assertEqual(None, timestep.discount) + self.assertIsNone(timestep.reward) + self.assertIsNone(timestep.discount) self.assertEqual({'foo': 5, 'bar': 'goodbye'}, timestep.observation) def test_reset_changes_spec_raises_error(self):