diff --git a/dm_env_rpc/v1/compliance/join_leave_world.py b/dm_env_rpc/v1/compliance/join_leave_world.py index faf3ed3..a6b3e56 100644 --- a/dm_env_rpc/v1/compliance/join_leave_world.py +++ b/dm_env_rpc/v1/compliance/join_leave_world.py @@ -59,7 +59,7 @@ def required_join_settings(self): @property def invalid_join_settings(self): """A list of dicts of Join World settings which are invalid in some way.""" - return [] + return {} @abc.abstractproperty def world_name(self): @@ -105,10 +105,13 @@ def test_cannot_join_with_wrong_world_name(self): def test_cannot_join_world_with_invalid_settings(self): settings = self.required_join_settings - for invalid_settings in self.invalid_join_settings: + for name, tensor in self.invalid_join_settings.items(): with self.assertRaises(error.DmEnvRpcError): - self.join_world(world_name=self.world_name, - settings={**settings, **invalid_settings}) + self.join_world( + world_name=self.world_name, settings={ + name: tensor, + **settings + }) def test_cannot_join_world_twice(self): self.join_world( diff --git a/examples/catch_environment.py b/examples/catch_environment.py index 632d58e..3579b1f 100644 --- a/examples/catch_environment.py +++ b/examples/catch_environment.py @@ -139,6 +139,17 @@ def _action_spec(): return {1: paddle_action_spec} +def _validate_settings(settings, valid_settings): + """"Validate the provided settings with list of valid setting keys.""" + unrecognized_settings = [ + setting for setting in settings if setting not in valid_settings + ] + + if unrecognized_settings: + raise ValueError('Unrecognized settings provided! Invalid settings:' + f' {unrecognized_settings}') + + class CatchGameFactory(object): """Factory for creating new CatchGame instances.""" @@ -188,10 +199,12 @@ def Process(self, request_iterator, context): _check_message_type(env, is_joined, message_type) if message_type == 'create_world': + _validate_settings(request.create_world.settings, valid_settings=[]) env = env_factory.new_game() skip_next_frame = True response = dm_env_rpc_pb2.CreateWorldResponse(world_name=_WORLD_NAME) elif message_type == 'join_world': + _validate_settings(request.join_world.settings, valid_settings=[]) if is_joined: raise RuntimeError( f'Tried to join world "{internal_request.world_name}" but ' @@ -238,6 +251,7 @@ def Process(self, request_iterator, context): env = env_factory.new_game() skip_next_frame = True elif message_type == 'reset': + _validate_settings(request.reset.settings, valid_settings=[]) env = env_factory.new_game() skip_next_frame = True response = dm_env_rpc_pb2.ResetResponse() @@ -246,6 +260,7 @@ def Process(self, request_iterator, context): for uid, observation in _observation_spec().items(): response.specs.observations[uid].CopyFrom(observation) elif message_type == 'reset_world': + _validate_settings(request.reset_world.settings, valid_settings=[]) env = env_factory.new_game() skip_next_frame = True response = dm_env_rpc_pb2.ResetWorldResponse() diff --git a/examples/catch_test.py b/examples/catch_test.py index 7392f3b..5c225b4 100644 --- a/examples/catch_test.py +++ b/examples/catch_test.py @@ -29,6 +29,7 @@ from dm_env_rpc.v1 import dm_env_rpc_pb2 from dm_env_rpc.v1 import dm_env_rpc_pb2_grpc from dm_env_rpc.v1 import error +from dm_env_rpc.v1 import tensor_utils class ServerConnection: @@ -108,7 +109,7 @@ def required_world_settings(self): @property def invalid_world_settings(self): """World creation settings which are invalid in some way.""" - return {} + return {'invalid_setting': tensor_utils.pack_tensor(123)} @property def has_multiple_world_support(self): @@ -134,6 +135,10 @@ def connection(self): def world_name(self): return self._world_name + @property + def invalid_join_settings(self): + return {'invalid_setting': tensor_utils.pack_tensor(123)} + def setUp(self): self._server_connection = ServerConnection() response = self.connection.send(dm_env_rpc_pb2.CreateWorldRequest())