@@ -60,7 +60,7 @@ def __init__(
60
60
specs : dm_env_rpc_pb2 .ActionObservationSpecs ,
61
61
requested_observations : Optional [Sequence [str ]] = None ,
62
62
nested_tensors : bool = True ,
63
- extensions : Optional [ Mapping [str , Any ] ] = immutabledict .immutabledict ()):
63
+ extensions : Mapping [str , Any ] = immutabledict .immutabledict ()):
64
64
"""Initializes the environment with the provided dm_env_rpc connection.
65
65
66
66
Args:
@@ -71,8 +71,8 @@ def __init__(
71
71
environment when step is called. If None is specified then all
72
72
observations will be requested.
73
73
nested_tensors: Boolean to determine whether to flatten/unflatten tensors.
74
- extensions: Optional mapping of extension instances to DmEnvAdaptor
75
- attributes. Raises ValueError if attribute already exists.
74
+ extensions: Mapping of extension instances to DmEnvAdaptor attributes.
75
+ Raises ValueError if attribute already exists.
76
76
"""
77
77
self ._dm_env_rpc_specs = specs
78
78
self ._action_specs = spec_manager .SpecManager (specs .actions )
@@ -117,15 +117,12 @@ def __init__(
117
117
# Not strictly necessary but it makes the unit tests deterministic.
118
118
self ._requested_observation_uids .sort ()
119
119
120
- if extensions is not None :
121
- self ._extension_names = extensions .keys ()
122
- for extension_name , extension in extensions .items ():
123
- if hasattr (self , extension_name ):
124
- raise ValueError (
125
- f'DmEnvAdaptor already has attribute "{ extension_name } "!' )
126
- setattr (self , extension_name , extension )
127
- else :
128
- self ._extension_names = None
120
+ self ._extension_names = extensions .keys ()
121
+ for extension_name , extension in extensions .items ():
122
+ if hasattr (self , extension_name ):
123
+ raise ValueError (
124
+ f'DmEnvAdaptor already has attribute "{ extension_name } "!' )
125
+ setattr (self , extension_name , extension )
129
126
130
127
def reset (self ):
131
128
"""Implements dm_env.Environment.reset."""
@@ -269,7 +266,7 @@ def discount_spec(self):
269
266
def close (self ):
270
267
"""Implements dm_env.Environment.close."""
271
268
# Release any extensions associated with this EnvAdaptor:
272
- for extension_name in ( self ._extension_names or []) :
269
+ for extension_name in self ._extension_names :
273
270
setattr (self , extension_name , None )
274
271
# Leaves the world if we were joined. If not, this will be a no-op anyway.
275
272
if self ._connection is not None :
0 commit comments