Skip to content

Commit

Permalink
Add support for overriding the catch environment seed on create/reset…
Browse files Browse the repository at this point in the history
… requests.

PiperOrigin-RevId: 345438095
Change-Id: I6632df499f2289fbb437e30202105ec3be574e47
  • Loading branch information
tomwardio authored and copybara-github committed Dec 3, 2020
1 parent 3974d8e commit 92d40eb
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 3 deletions.
26 changes: 23 additions & 3 deletions examples/catch_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from dm_env_rpc.v1 import dm_env_rpc_pb2_grpc
from dm_env_rpc.v1 import spec_manager
from dm_env_rpc.v1 import tensor_spec_utils
from dm_env_rpc.v1 import tensor_utils

_ACTION_PADDLE = 'paddle'
_DEFAULT_ACTION = 0
Expand All @@ -31,6 +32,7 @@
_OBSERVATION_BOARD = 'board'
_WORLD_NAME = 'catch'
_VALID_ACTIONS = [-1, 0, 1]
_VALID_CREATE_AND_RESET_SETTINGS = ['seed']


class CatchGame(object):
Expand Down Expand Up @@ -161,6 +163,9 @@ def new_game(self):
self._seed += 1
return env

def reset_seed(self, seed):
self._seed = seed


class CatchEnvironmentService(dm_env_rpc_pb2_grpc.EnvironmentServicer):
"""Runs the Catch game as a gRPC EnvironmentServicer."""
Expand Down Expand Up @@ -199,7 +204,12 @@ def Process(self, request_iterator, context):
_check_message_type(env, is_joined, message_type)

if message_type == 'create_world':
_validate_settings(request.create_world.settings, valid_settings=[])
_validate_settings(
request.create_world.settings,
valid_settings=_VALID_CREATE_AND_RESET_SETTINGS)
seed = request.create_world.settings.get('seed', None)
if seed is not None:
env_factory.reset_seed(tensor_utils.unpack_tensor(seed))
env = env_factory.new_game()
skip_next_frame = True
response = dm_env_rpc_pb2.CreateWorldResponse(world_name=_WORLD_NAME)
Expand Down Expand Up @@ -251,7 +261,12 @@ def Process(self, request_iterator, context):
env = env_factory.new_game()
skip_next_frame = True
elif message_type == 'reset':
_validate_settings(request.reset.settings, valid_settings=[])
_validate_settings(
request.reset.settings,
valid_settings=_VALID_CREATE_AND_RESET_SETTINGS)
seed = request.reset.settings.get('seed', None)
if seed is not None:
env_factory.reset_seed(tensor_utils.unpack_tensor(seed))
env = env_factory.new_game()
skip_next_frame = True
response = dm_env_rpc_pb2.ResetResponse()
Expand All @@ -260,7 +275,12 @@ def Process(self, request_iterator, context):
for uid, observation in _observation_spec().items():
response.specs.observations[uid].CopyFrom(observation)
elif message_type == 'reset_world':
_validate_settings(request.reset_world.settings, valid_settings=[])
_validate_settings(
request.reset_world.settings,
valid_settings=_VALID_CREATE_AND_RESET_SETTINGS)
seed = request.reset_world.settings.get('seed', None)
if seed is not None:
env_factory.reset_seed(tensor_utils.unpack_tensor(seed))
env = env_factory.new_game()
skip_next_frame = True
response = dm_env_rpc_pb2.ResetWorldResponse()
Expand Down
49 changes: 49 additions & 0 deletions examples/catch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,55 @@ def make_object_under_test(self):
return self._dm_env


class CatchTestSettings(absltest.TestCase):

def setUp(self):
super().setUp()

self._server_connection = ServerConnection()
self._connection = self._server_connection.connection
self._world_name = None

def tearDown(self):
try:
if self._world_name:
self._connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
self._connection.send(
dm_env_rpc_pb2.DestroyWorldRequest(world_name=self._world_name))
finally:
self._server_connection.close()
super().tearDown()

def test_reset_world_seed_setting(self):
self._world_name = self._connection.send(
dm_env_rpc_pb2.CreateWorldRequest(
settings={'seed': tensor_utils.pack_tensor(1234)})).world_name
self._connection.send(
dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name))

step_response = self._connection.send(dm_env_rpc_pb2.StepRequest())
self._connection.send(
dm_env_rpc_pb2.ResetWorldRequest(
world_name=self._world_name,
settings={'seed': tensor_utils.pack_tensor(1234)}))
self.assertEqual(step_response,
self._connection.send(dm_env_rpc_pb2.StepRequest()))

def test_reset_seed_setting(self):
self._world_name = self._connection.send(
dm_env_rpc_pb2.CreateWorldRequest(
settings={'seed': tensor_utils.pack_tensor(1234)})).world_name
self._connection.send(
dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name))

step_response = self._connection.send(dm_env_rpc_pb2.StepRequest())
self._connection.send(
dm_env_rpc_pb2.ResetRequest(
settings={'seed': tensor_utils.pack_tensor(1234)}))
self.assertEqual(step_response,
self._connection.send(dm_env_rpc_pb2.StepRequest()))


class CatchTest(absltest.TestCase):

def setUp(self):
Expand Down

0 comments on commit 92d40eb

Please sign in to comment.