Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit d6773a0

Browse files
authored
Fixes train_model worldlogging for multitask with mutators. (#4414)
* Fixes train_model worldlogging for multitask with mutators. * Fix bug in train_model when evaltask doesn't match task.
1 parent 573c76c commit d6773a0

File tree

2 files changed

+64
-4
lines changed

2 files changed

+64
-4
lines changed

parlai/scripts/train_model.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ def validate(self):
619619
return True
620620
return False
621621

622-
def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask):
622+
def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask, task):
623623

624624
# run evaluation on a single world
625625
valid_world.reset()
@@ -629,7 +629,7 @@ def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask):
629629
# set up world logger for the "test" fold
630630
if opt['world_logs'] and datatype == 'test':
631631
task_opt['world_logs'] = get_task_world_logs(
632-
valid_world.getID(), opt['world_logs'], is_multitask
632+
task, opt['world_logs'], is_multitask
633633
)
634634
world_logger = WorldLogger(task_opt)
635635

@@ -691,9 +691,13 @@ def _run_eval(
691691

692692
max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers())
693693
is_multitask = len(valid_worlds) > 1
694-
for v_world in valid_worlds:
694+
for index, v_world in enumerate(valid_worlds):
695+
if opt.get('evaltask'):
696+
task = opt['evaltask'].split(',')[index]
697+
else:
698+
task = opt['task'].split(',')[index]
695699
task_report = self._run_single_eval(
696-
opt, v_world, max_exs_per_worker, datatype, is_multitask
700+
opt, v_world, max_exs_per_worker, datatype, is_multitask, task
697701
)
698702
reports.append(task_report)
699703

tests/test_train_model.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,62 @@ def test_save_multiple_world_logs(self):
271271
json_lines = f.readlines()
272272
assert len(json_lines) == 5
273273

274+
def test_save_multiple_world_logs_evaltask(self):
275+
"""
276+
Test that we can save multiple world_logs from train model on multiple tasks
277+
where there are more evaltasks than tasks.
278+
"""
279+
with testing_utils.tempdir() as tmpdir:
280+
log_report = os.path.join(tmpdir, 'world_logs.jsonl')
281+
multitask = 'integration_tests,integration_tests:ReverseTeacher'
282+
evaltask = 'integration_tests,integration_tests:mutators=flatten,integration_tests:ReverseTeacher:mutator=reverse'
283+
valid, test = testing_utils.train_model(
284+
{
285+
'task': multitask,
286+
'evaltask': evaltask,
287+
'validation_max_exs': 10,
288+
'model': 'repeat_label',
289+
'short_final_eval': True,
290+
'num_epochs': 1.0,
291+
'world_logs': log_report,
292+
}
293+
)
294+
295+
for task in evaltask.split(','):
296+
task_log_report = get_task_world_logs(
297+
task, log_report, is_multitask=True
298+
)
299+
with PathManager.open(task_log_report) as f:
300+
json_lines = f.readlines()
301+
assert len(json_lines) == 4
302+
303+
def test_save_multiple_world_logs_mutator(self):
304+
"""
305+
Test that we can save multiple world_logs from train model on multiple tasks
306+
with mutators present.
307+
"""
308+
with testing_utils.tempdir() as tmpdir:
309+
log_report = os.path.join(tmpdir, 'world_logs.jsonl')
310+
multitask = 'integration_tests:mutators=flatten,integration_tests:ReverseTeacher:mutator=reverse'
311+
valid, test = testing_utils.train_model(
312+
{
313+
'task': multitask,
314+
'validation_max_exs': 10,
315+
'model': 'repeat_label',
316+
'short_final_eval': True,
317+
'num_epochs': 1.0,
318+
'world_logs': log_report,
319+
}
320+
)
321+
322+
for task in multitask.split(','):
323+
task_log_report = get_task_world_logs(
324+
task, log_report, is_multitask=True
325+
)
326+
with PathManager.open(task_log_report) as f:
327+
json_lines = f.readlines()
328+
assert len(json_lines) == 5
329+
274330

275331
@register_agent("fake_report")
276332
class FakeReportAgent(Agent):

0 commit comments

Comments
 (0)