Skip to content

Commit

Permalink
Fix JoinLeaveWorld compliance tests for invalid join_settings.
Browse files Browse the repository at this point in the history
Add invalid setting tests in catch example for validation.

PiperOrigin-RevId: 345224776
Change-Id: Ib0682193abb8fdd976977e05901d6757286a2228
  • Loading branch information
tomwardio authored and copybara-github committed Dec 2, 2020
1 parent 9c0d805 commit 3974d8e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
11 changes: 7 additions & 4 deletions dm_env_rpc/v1/compliance/join_leave_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def required_join_settings(self):
@property
def invalid_join_settings(self):
"""A list of dicts of Join World settings which are invalid in some way."""
return []
return {}

@abc.abstractproperty
def world_name(self):
Expand Down Expand Up @@ -105,10 +105,13 @@ def test_cannot_join_with_wrong_world_name(self):

def test_cannot_join_world_with_invalid_settings(self):
settings = self.required_join_settings
for invalid_settings in self.invalid_join_settings:
for name, tensor in self.invalid_join_settings.items():
with self.assertRaises(error.DmEnvRpcError):
self.join_world(world_name=self.world_name,
settings={**settings, **invalid_settings})
self.join_world(
world_name=self.world_name, settings={
name: tensor,
**settings
})

def test_cannot_join_world_twice(self):
self.join_world(
Expand Down
15 changes: 15 additions & 0 deletions examples/catch_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ def _action_spec():
return {1: paddle_action_spec}


def _validate_settings(settings, valid_settings):
""""Validate the provided settings with list of valid setting keys."""
unrecognized_settings = [
setting for setting in settings if setting not in valid_settings
]

if unrecognized_settings:
raise ValueError('Unrecognized settings provided! Invalid settings:'
f' {unrecognized_settings}')


class CatchGameFactory(object):
"""Factory for creating new CatchGame instances."""

Expand Down Expand Up @@ -188,10 +199,12 @@ def Process(self, request_iterator, context):
_check_message_type(env, is_joined, message_type)

if message_type == 'create_world':
_validate_settings(request.create_world.settings, valid_settings=[])
env = env_factory.new_game()
skip_next_frame = True
response = dm_env_rpc_pb2.CreateWorldResponse(world_name=_WORLD_NAME)
elif message_type == 'join_world':
_validate_settings(request.join_world.settings, valid_settings=[])
if is_joined:
raise RuntimeError(
f'Tried to join world "{internal_request.world_name}" but '
Expand Down Expand Up @@ -238,6 +251,7 @@ def Process(self, request_iterator, context):
env = env_factory.new_game()
skip_next_frame = True
elif message_type == 'reset':
_validate_settings(request.reset.settings, valid_settings=[])
env = env_factory.new_game()
skip_next_frame = True
response = dm_env_rpc_pb2.ResetResponse()
Expand All @@ -246,6 +260,7 @@ def Process(self, request_iterator, context):
for uid, observation in _observation_spec().items():
response.specs.observations[uid].CopyFrom(observation)
elif message_type == 'reset_world':
_validate_settings(request.reset_world.settings, valid_settings=[])
env = env_factory.new_game()
skip_next_frame = True
response = dm_env_rpc_pb2.ResetWorldResponse()
Expand Down
7 changes: 6 additions & 1 deletion examples/catch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from dm_env_rpc.v1 import dm_env_rpc_pb2
from dm_env_rpc.v1 import dm_env_rpc_pb2_grpc
from dm_env_rpc.v1 import error
from dm_env_rpc.v1 import tensor_utils


class ServerConnection:
Expand Down Expand Up @@ -108,7 +109,7 @@ def required_world_settings(self):
@property
def invalid_world_settings(self):
"""World creation settings which are invalid in some way."""
return {}
return {'invalid_setting': tensor_utils.pack_tensor(123)}

@property
def has_multiple_world_support(self):
Expand All @@ -134,6 +135,10 @@ def connection(self):
def world_name(self):
return self._world_name

@property
def invalid_join_settings(self):
return {'invalid_setting': tensor_utils.pack_tensor(123)}

def setUp(self):
self._server_connection = ServerConnection()
response = self.connection.send(dm_env_rpc_pb2.CreateWorldRequest())
Expand Down

0 comments on commit 3974d8e

Please sign in to comment.