|
53 | 53 | from parlai.core.params import ParlaiParser
|
54 | 54 | from parlai.core.teachers import Teacher, create_task_agent_from_taskname
|
55 | 55 | from parlai.utils.data import DatatypeHelper
|
56 |
| -from parlai.utils.misc import Timer, display_messages |
| 56 | +from parlai.utils.misc import Timer, display_messages, warn_once |
57 | 57 | from parlai.tasks.tasks import ids_to_tasks
|
58 | 58 | from parlai.utils.misc import error_once
|
59 | 59 |
|
@@ -562,10 +562,17 @@ def __init__(self, opt: Opt, agents=None, shared=None, default_world=None):
|
562 | 562 | self.parleys = -1
|
563 | 563 | # Check to see if we are training
|
564 | 564 | self.is_training = DatatypeHelper.is_training(opt.get('datatype'))
|
| 565 | + # Check to see if we should shuffle |
| 566 | + self.should_shuffle = DatatypeHelper.should_shuffle(opt.get('datatype')) |
565 | 567 | # Make multi-task task probabilities.
|
566 | 568 | self.cum_task_weights = [1] * len(self.worlds)
|
567 | 569 | self.task_choices = range(len(self.worlds))
|
568 | 570 | weights = self.opt.get('multitask_weights', [1])
|
| 571 | + # Warn about multi-task weights being ignored if we are in a datatype that doesn't involve shuffling |
| 572 | + if weights != [1] and not self.should_shuffle: |
| 573 | + warn_once( |
| 574 | + f"WARNING: multitask weights are ignored for datatype {opt.get('datatype')} as we iterate through tasks in a round robin" |
| 575 | + ) |
569 | 576 | if weights == 'stochastic':
|
570 | 577 | weights = [w.num_episodes() for w in self.worlds]
|
571 | 578 | sum = 0
|
@@ -672,7 +679,7 @@ def parley_init(self):
|
672 | 679 | if self.new_world:
|
673 | 680 | self.new_world = False
|
674 | 681 | self.parleys = 0
|
675 |
| - if self.is_training: |
| 682 | + if self.should_shuffle: |
676 | 683 | # select random world
|
677 | 684 | self.world_idx = random.choices(
|
678 | 685 | self.task_choices, cum_weights=self.cum_task_weights
|
|
0 commit comments