1313
1414class ContextExchanger :
1515
16- def __init__ (
17- self ,
18- skip_n_iter : int = 1 ,
19- storage_loader : Optional [StorageLoader ] = None ,
20- ) -> None :
16+ def __init__ (self , skip_n_iter : int = 1 , storage_loader : Optional [StorageLoader ] = None ) -> None :
2117 """
2218 Overview:
2319 Exchange context between processes,
@@ -41,9 +37,8 @@ def __init__(
4137 self ._storage_loader = storage_loader
4238
4339 # Both nng and torchrpc use background threads to trigger the receiver's recv action,
44- # there is a race condition between sender and sender, and between senders and receiver .
40+ # there is a race condition between the listen thread and the polling thread .
4541 self ._put_lock = LockContext (LockContextType .THREAD_LOCK )
46- self ._recv_ready = False
4742 self ._bypass_eventloop = task .router .mq_type == MQType .RPC
4843
4944 for role in task .role : # Only subscribe to other roles
@@ -101,7 +96,6 @@ def callback(payload: Dict):
10196 getattr (self , fn_name )(item )
10297 else :
10398 logging .warning ("Receive unexpected key ({}) in context exchanger" .format (key ))
104- self ._recv_ready = True
10599
106100 if isinstance (payload , Storage ):
107101 assert self ._storage_loader is not None , "Storage loader is not defined when data is a storage object."
@@ -126,19 +120,27 @@ def fetch(self, ctx: "Context") -> Dict[str, Any]:
126120 return payload
127121
128122 def merge (self , ctx : "Context" ):
129-
123+ # Dict's assignment is not an atomic operation, even if len(self._state)
124+ # is not 0, the value corresponding to the key maybe empty.
125+ ready = 0
130126 if task .has_role (task .role .LEARNER ):
131127 # Learner should always wait for trajs.
132128 # TODO: Automaticlly wait based on properties, not roles.
133- while self ._recv_ready is False :
134- sleep (0.01 )
129+ while ready == 0 :
130+ with self ._put_lock :
131+ ready = len (self ._state )
132+ if ready == 0 :
133+ sleep (0.01 )
135134 elif ctx .total_step >= self ._skip_n_iter :
136135 start = time ()
137- while self ._recv_ready is False :
138- if time () - start > 60 :
139- logging .warning ("Timeout when waiting for new context! Node id: {}" .format (task .router .node_id ))
140- break
141- sleep (0.01 )
136+ while ready == 0 :
137+ with self ._put_lock :
138+ ready = len (self ._state )
139+ if ready == 0 :
140+ if time () - start > 60 :
141+ logging .warning ("Timeout when waiting for new context! Node id: {}" .format (task .router .node_id ))
142+ break
143+ sleep (0.01 )
142144
143145 with self ._put_lock :
144146 for k , v in self ._state .items ():
@@ -148,7 +150,6 @@ def merge(self, ctx: "Context"):
148150 else :
149151 setattr (ctx , k , v )
150152 self ._state = {}
151- self ._recv_ready = False
152153
153154 # Handle each attibute of context
154155 def _put_trajectories (self , traj : List [Any ]):
@@ -173,14 +174,14 @@ def _fetch_episodes(self, episodes: List[Any]):
173174 if task .has_role (task .role .COLLECTOR ):
174175 return episodes
175176
176- def _put_trajectory_end_idx (self , trajectory_end_idx : List [int ]):
177+ def _put_trajectory_end_idx (self , trajectory_end_idx : List [str ]):
177178 if not task .has_role (task .role .LEARNER ):
178179 return
179180 if "trajectory_end_idx" not in self ._state :
180181 self ._state ["trajectory_end_idx" ] = []
181182 self ._state ["trajectory_end_idx" ].extend (trajectory_end_idx )
182183
183- def _fetch_trajectory_end_idx (self , trajectory_end_idx : List [int ]):
184+ def _fetch_trajectory_end_idx (self , trajectory_end_idx : List [str ]):
184185 if task .has_role (task .role .COLLECTOR ):
185186 return trajectory_end_idx
186187
@@ -202,6 +203,12 @@ def _put_env_episode(self, increment_env_episode: int):
202203 self ._state ['increment_env_episode' ] = 0
203204 self ._state ["increment_env_episode" ] += increment_env_episode
204205
206+ def _fetch_env_episode (self , env_episode : int ):
207+ if task .has_role (task .role .COLLECTOR ):
208+ increment_env_episode = env_episode - self ._local_state ['env_episode' ]
209+ self ._local_state ['env_episode' ] = env_episode
210+ return increment_env_episode
211+
205212 def _put_train_iter (self , train_iter : int ):
206213 if not task .has_role (task .role .LEARNER ):
207214 self ._state ["train_iter" ] = train_iter
0 commit comments