Skip to content

Commit 65bde68

Browse files
committed
Support nested create/join world settings in dm_env_adaptor helper functions.
PiperOrigin-RevId: 499439311 Change-Id: I5beb9e33b6ad7fc5f5dfa41fd094777fe12f4c60
1 parent fb4fb53 commit 65bde68

File tree

2 files changed

+45
-6
lines changed

2 files changed

+45
-6
lines changed

dm_env_rpc/v1/dm_env_adaptor.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ class DmEnvAdaptor(dm_env.Environment):
4949
Users can also optionally provide a mapping of objects to DmEnvAdaptor
5050
attributes. This is to accommodate user-created protocol extensions that
5151
compliment the core protocol.
52-
5352
"""
5453

5554
# Disable pytype attribute checking for dynamically created extension attrs.
@@ -285,12 +284,16 @@ def create_world(connection: dm_env_rpc_connection.ConnectionType,
285284
Args:
286285
connection: An instance of Connection already connected to a dm_env_rpc
287286
server.
288-
create_world_settings: Settings used to create the world. Values must be
289-
packable into a Tensor proto or already packed.
287+
create_world_settings: Settings used to create the world. Nested settings
288+
will be automatically flattened before sending to the server. Values must
289+
be packable into a Tensor proto or already packed.
290290
291291
Returns:
292292
Created world name.
293293
"""
294+
create_world_settings = dm_env_flatten_utils.flatten_dict(
295+
create_world_settings, DEFAULT_KEY_SEPARATOR
296+
)
294297

295298
create_world_settings = {
296299
key: (value if isinstance(value, dm_env_rpc_pb2.Tensor) else
@@ -314,14 +317,18 @@ def join_world(
314317
connection: An instance of Connection already connected to a dm_env_rpc
315318
server.
316319
world_name: Name of the world to join.
317-
join_world_settings: Settings used to join the world. Values must be
320+
join_world_settings: Settings used to join the world. Nested settings will
321+
be automatically flattened before sending to the server. Values must be
318322
packable into a Tensor message or already packed.
319323
**adaptor_kwargs: Additional keyword args used to create the DmEnvAdaptor
320324
instance.
321325
322326
Returns:
323327
Instance of DmEnvAdaptor.
324328
"""
329+
join_world_settings = dm_env_flatten_utils.flatten_dict(
330+
join_world_settings, DEFAULT_KEY_SEPARATOR
331+
)
325332

326333
join_world_settings = {
327334
key: (value if isinstance(value, dm_env_rpc_pb2.Tensor) else
@@ -356,7 +363,8 @@ def create_and_join_world(connection: dm_env_rpc_connection.ConnectionType,
356363
server.
357364
create_world_settings: Settings used to create the world. Values must be
358365
packable into a Tensor proto or already packed.
359-
join_world_settings: Settings used to join the world. Values must be
366+
join_world_settings: Settings used to join the world. Nested settings will
367+
be automatically flattened before sending to the server. Values must be
360368
packable into a Tensor message.
361369
**adaptor_kwargs: Additional keyword args used to create the DmEnvAdaptor
362370
instance.
@@ -372,4 +380,3 @@ def create_and_join_world(connection: dm_env_rpc_connection.ConnectionType,
372380
except (error.DmEnvRpcError, ValueError):
373381
connection.send(dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name))
374382
raise
375-

dm_env_rpc/v1/dm_env_adaptor_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,38 @@ def test_create_join_world(self):
636636
}""", dm_env_rpc_pb2.JoinWorldRequest())),
637637
])
638638

639+
def test_flatten_create_join_world_settings(self):
640+
connection = mock.MagicMock()
641+
connection.send = mock.MagicMock(side_effect=[
642+
dm_env_rpc_pb2.CreateWorldResponse(world_name='Damogran_01'),
643+
dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC)
644+
])
645+
env, world_name = dm_env_adaptor.create_and_join_world(
646+
connection,
647+
create_world_settings={'nested': {'planet': 'Damogran'}},
648+
join_world_settings={'nested': {'ship_type': 1, 'player': 'zaphod'}})
649+
self.assertIsNotNone(env)
650+
self.assertEqual('Damogran_01', world_name)
651+
652+
connection.send.assert_has_calls([
653+
mock.call(
654+
text_format.Parse(
655+
"""settings: {
656+
key: 'nested.planet', value: { strings: { array: 'Damogran' } }
657+
}""", dm_env_rpc_pb2.CreateWorldRequest())),
658+
mock.call(
659+
text_format.Parse(
660+
"""world_name: 'Damogran_01'
661+
settings: {
662+
key: 'nested.ship_type', value: { int64s: { array: 1 } }
663+
}
664+
settings: {
665+
key: 'nested.player', value: {
666+
strings: { array: 'zaphod' }
667+
}
668+
}""", dm_env_rpc_pb2.JoinWorldRequest())),
669+
])
670+
639671
def test_create_join_world_with_packed_settings(self):
640672
connection = mock.MagicMock()
641673
connection.send = mock.MagicMock(side_effect=[

0 commit comments

Comments
 (0)