diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 0a2f237908a..8796de7a6c5 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -2158,6 +2158,13 @@ def act(self): self.self_observe(response) return response + def _init_batch_reply(self, num_observations): + batch_reply = [ + Message({'id': self.getID(), 'episode_done': False}) + for _ in range(num_observations) + ] + return batch_reply + def batch_act(self, observations): """ Process a batch of observations (batchsize list of message dicts). @@ -2182,10 +2189,7 @@ def batch_act(self, observations): num_observations = len(observations) # initialize a list of replies with this agent's id - batch_reply = [ - Message({'id': self.getID(), 'episode_done': False}) - for _ in range(num_observations) - ] + batch_reply = self._init_batch_reply(num_observations) self.is_training = batch.is_training