Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit 573c76c

Browse files
authored
use should_shuffle instead of is_training to determine whether to ran… (#4425)
* use should_shuffle instead of is_training to determine whether to randomly sample from worlds * adde unit test showing that multiworld acts deterministically with -dt train:ordered * added warning about using datatype that doesn't shuffle with multitask weights
1 parent 4f7b38e commit 573c76c

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

parlai/core/worlds.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from parlai.core.params import ParlaiParser
5454
from parlai.core.teachers import Teacher, create_task_agent_from_taskname
5555
from parlai.utils.data import DatatypeHelper
56-
from parlai.utils.misc import Timer, display_messages
56+
from parlai.utils.misc import Timer, display_messages, warn_once
5757
from parlai.tasks.tasks import ids_to_tasks
5858
from parlai.utils.misc import error_once
5959

@@ -562,10 +562,17 @@ def __init__(self, opt: Opt, agents=None, shared=None, default_world=None):
562562
self.parleys = -1
563563
# Check to see if we are training
564564
self.is_training = DatatypeHelper.is_training(opt.get('datatype'))
565+
# Check to see if we should shuffle
566+
self.should_shuffle = DatatypeHelper.should_shuffle(opt.get('datatype'))
565567
# Make multi-task task probabilities.
566568
self.cum_task_weights = [1] * len(self.worlds)
567569
self.task_choices = range(len(self.worlds))
568570
weights = self.opt.get('multitask_weights', [1])
571+
# Warn about multi-task weights being ignored if we are in a datatype that doesn't involve shuffling
572+
if weights != [1] and not self.should_shuffle:
573+
warn_once(
574+
f"WARNING: multitask weights are ignored for datatype {opt.get('datatype')} as we iterate through tasks in a round robin"
575+
)
569576
if weights == 'stochastic':
570577
weights = [w.num_episodes() for w in self.worlds]
571578
sum = 0
@@ -672,7 +679,7 @@ def parley_init(self):
672679
if self.new_world:
673680
self.new_world = False
674681
self.parleys = 0
675-
if self.is_training:
682+
if self.should_shuffle:
676683
# select random world
677684
self.world_idx = random.choices(
678685
self.task_choices, cum_weights=self.cum_task_weights

tests/test_multiworld.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,31 @@ def test_with_stream(self):
111111
exs = report[f'{task}/exs'].value()
112112
assert exs > 0, err
113113
world.reset_metrics()
114+
115+
def test_with_ordered(self):
116+
"""
117+
Test that multi-tasking works deterministically with datatype train:ordered.
118+
"""
119+
120+
opt = ParlaiParser(True, True).parse_kwargs(
121+
task='teacher1,teacher2',
122+
multitask_weights='1,1',
123+
model='fixed_response',
124+
fixed_response='None',
125+
datatype='train:ordered',
126+
batchsize=1,
127+
)
128+
multiworld1 = create_task(opt, create_agent(opt))
129+
multiworld2 = create_task(opt, create_agent(opt))
130+
131+
while not (multiworld1.epoch_done() or multiworld2.epoch_done()):
132+
multiworld1.parley()
133+
acts1 = multiworld1.get_acts()
134+
135+
multiworld2.parley()
136+
acts2 = multiworld2.get_acts()
137+
138+
self.assertEqual(len(acts1), len(acts2))
139+
assert all([act1 == act2 for act1, act2 in zip(acts1, acts2)])
140+
141+
assert multiworld1.epoch_done() and multiworld2.epoch_done()

0 commit comments

Comments
 (0)