From 92d40ebec61fad03d3aa9ffc273ef9c18cdaeda4 Mon Sep 17 00:00:00 2001 From: Tom Ward Date: Thu, 3 Dec 2020 06:30:32 -0800 Subject: [PATCH] Add support for overriding the catch environment seed on create/reset requests. PiperOrigin-RevId: 345438095 Change-Id: I6632df499f2289fbb437e30202105ec3be574e47 --- examples/catch_environment.py | 26 ++++++++++++++++--- examples/catch_test.py | 49 +++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/examples/catch_environment.py b/examples/catch_environment.py index 3579b1f..cc9f784 100644 --- a/examples/catch_environment.py +++ b/examples/catch_environment.py @@ -21,6 +21,7 @@ from dm_env_rpc.v1 import dm_env_rpc_pb2_grpc from dm_env_rpc.v1 import spec_manager from dm_env_rpc.v1 import tensor_spec_utils +from dm_env_rpc.v1 import tensor_utils _ACTION_PADDLE = 'paddle' _DEFAULT_ACTION = 0 @@ -31,6 +32,7 @@ _OBSERVATION_BOARD = 'board' _WORLD_NAME = 'catch' _VALID_ACTIONS = [-1, 0, 1] +_VALID_CREATE_AND_RESET_SETTINGS = ['seed'] class CatchGame(object): @@ -161,6 +163,9 @@ def new_game(self): self._seed += 1 return env + def reset_seed(self, seed): + self._seed = seed + class CatchEnvironmentService(dm_env_rpc_pb2_grpc.EnvironmentServicer): """Runs the Catch game as a gRPC EnvironmentServicer.""" @@ -199,7 +204,12 @@ def Process(self, request_iterator, context): _check_message_type(env, is_joined, message_type) if message_type == 'create_world': - _validate_settings(request.create_world.settings, valid_settings=[]) + _validate_settings( + request.create_world.settings, + valid_settings=_VALID_CREATE_AND_RESET_SETTINGS) + seed = request.create_world.settings.get('seed', None) + if seed is not None: + env_factory.reset_seed(tensor_utils.unpack_tensor(seed)) env = env_factory.new_game() skip_next_frame = True response = dm_env_rpc_pb2.CreateWorldResponse(world_name=_WORLD_NAME) @@ -251,7 +261,12 @@ def Process(self, request_iterator, context): env = env_factory.new_game() skip_next_frame = True elif message_type == 'reset': - _validate_settings(request.reset.settings, valid_settings=[]) + _validate_settings( + request.reset.settings, + valid_settings=_VALID_CREATE_AND_RESET_SETTINGS) + seed = request.reset.settings.get('seed', None) + if seed is not None: + env_factory.reset_seed(tensor_utils.unpack_tensor(seed)) env = env_factory.new_game() skip_next_frame = True response = dm_env_rpc_pb2.ResetResponse() @@ -260,7 +275,12 @@ def Process(self, request_iterator, context): for uid, observation in _observation_spec().items(): response.specs.observations[uid].CopyFrom(observation) elif message_type == 'reset_world': - _validate_settings(request.reset_world.settings, valid_settings=[]) + _validate_settings( + request.reset_world.settings, + valid_settings=_VALID_CREATE_AND_RESET_SETTINGS) + seed = request.reset_world.settings.get('seed', None) + if seed is not None: + env_factory.reset_seed(tensor_utils.unpack_tensor(seed)) env = env_factory.new_game() skip_next_frame = True response = dm_env_rpc_pb2.ResetWorldResponse() diff --git a/examples/catch_test.py b/examples/catch_test.py index 5c225b4..7e32ef5 100644 --- a/examples/catch_test.py +++ b/examples/catch_test.py @@ -205,6 +205,55 @@ def make_object_under_test(self): return self._dm_env +class CatchTestSettings(absltest.TestCase): + + def setUp(self): + super().setUp() + + self._server_connection = ServerConnection() + self._connection = self._server_connection.connection + self._world_name = None + + def tearDown(self): + try: + if self._world_name: + self._connection.send(dm_env_rpc_pb2.LeaveWorldRequest()) + self._connection.send( + dm_env_rpc_pb2.DestroyWorldRequest(world_name=self._world_name)) + finally: + self._server_connection.close() + super().tearDown() + + def test_reset_world_seed_setting(self): + self._world_name = self._connection.send( + dm_env_rpc_pb2.CreateWorldRequest( + settings={'seed': tensor_utils.pack_tensor(1234)})).world_name + self._connection.send( + dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name)) + + step_response = self._connection.send(dm_env_rpc_pb2.StepRequest()) + self._connection.send( + dm_env_rpc_pb2.ResetWorldRequest( + world_name=self._world_name, + settings={'seed': tensor_utils.pack_tensor(1234)})) + self.assertEqual(step_response, + self._connection.send(dm_env_rpc_pb2.StepRequest())) + + def test_reset_seed_setting(self): + self._world_name = self._connection.send( + dm_env_rpc_pb2.CreateWorldRequest( + settings={'seed': tensor_utils.pack_tensor(1234)})).world_name + self._connection.send( + dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name)) + + step_response = self._connection.send(dm_env_rpc_pb2.StepRequest()) + self._connection.send( + dm_env_rpc_pb2.ResetRequest( + settings={'seed': tensor_utils.pack_tensor(1234)})) + self.assertEqual(step_response, + self._connection.send(dm_env_rpc_pb2.StepRequest())) + + class CatchTest(absltest.TestCase): def setUp(self):