Skip to content

Commit 36493d9

Browse files
committed
Cleanup dm_env_adaptor extension type hint.
No longer support None for extensions, as an empty extension mapping is effectively the same. PiperOrigin-RevId: 499442345 Change-Id: Ic09de4134de4e75b363083a53b415b7257d56c9a
1 parent 65bde68 commit 36493d9

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

dm_env_rpc/v1/dm_env_adaptor.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060
specs: dm_env_rpc_pb2.ActionObservationSpecs,
6161
requested_observations: Optional[Sequence[str]] = None,
6262
nested_tensors: bool = True,
63-
extensions: Optional[Mapping[str, Any]] = immutabledict.immutabledict()):
63+
extensions: Mapping[str, Any] = immutabledict.immutabledict()):
6464
"""Initializes the environment with the provided dm_env_rpc connection.
6565
6666
Args:
@@ -71,8 +71,8 @@ def __init__(
7171
environment when step is called. If None is specified then all
7272
observations will be requested.
7373
nested_tensors: Boolean to determine whether to flatten/unflatten tensors.
74-
extensions: Optional mapping of extension instances to DmEnvAdaptor
75-
attributes. Raises ValueError if attribute already exists.
74+
extensions: Mapping of extension instances to DmEnvAdaptor attributes.
75+
Raises ValueError if attribute already exists.
7676
"""
7777
self._dm_env_rpc_specs = specs
7878
self._action_specs = spec_manager.SpecManager(specs.actions)
@@ -117,15 +117,12 @@ def __init__(
117117
# Not strictly necessary but it makes the unit tests deterministic.
118118
self._requested_observation_uids.sort()
119119

120-
if extensions is not None:
121-
self._extension_names = extensions.keys()
122-
for extension_name, extension in extensions.items():
123-
if hasattr(self, extension_name):
124-
raise ValueError(
125-
f'DmEnvAdaptor already has attribute "{extension_name}"!')
126-
setattr(self, extension_name, extension)
127-
else:
128-
self._extension_names = None
120+
self._extension_names = extensions.keys()
121+
for extension_name, extension in extensions.items():
122+
if hasattr(self, extension_name):
123+
raise ValueError(
124+
f'DmEnvAdaptor already has attribute "{extension_name}"!')
125+
setattr(self, extension_name, extension)
129126

130127
def reset(self):
131128
"""Implements dm_env.Environment.reset."""
@@ -269,7 +266,7 @@ def discount_spec(self):
269266
def close(self):
270267
"""Implements dm_env.Environment.close."""
271268
# Release any extensions associated with this EnvAdaptor:
272-
for extension_name in (self._extension_names or []):
269+
for extension_name in self._extension_names:
273270
setattr(self, extension_name, None)
274271
# Leaves the world if we were joined. If not, this will be a no-op anyway.
275272
if self._connection is not None:

0 commit comments

Comments
 (0)