@@ -271,6 +271,62 @@ def test_save_multiple_world_logs(self):
271
271
json_lines = f .readlines ()
272
272
assert len (json_lines ) == 5
273
273
274
+ def test_save_multiple_world_logs_evaltask (self ):
275
+ """
276
+ Test that we can save multiple world_logs from train model on multiple tasks
277
+ where there are more evaltasks than tasks.
278
+ """
279
+ with testing_utils .tempdir () as tmpdir :
280
+ log_report = os .path .join (tmpdir , 'world_logs.jsonl' )
281
+ multitask = 'integration_tests,integration_tests:ReverseTeacher'
282
+ evaltask = 'integration_tests,integration_tests:mutators=flatten,integration_tests:ReverseTeacher:mutator=reverse'
283
+ valid , test = testing_utils .train_model (
284
+ {
285
+ 'task' : multitask ,
286
+ 'evaltask' : evaltask ,
287
+ 'validation_max_exs' : 10 ,
288
+ 'model' : 'repeat_label' ,
289
+ 'short_final_eval' : True ,
290
+ 'num_epochs' : 1.0 ,
291
+ 'world_logs' : log_report ,
292
+ }
293
+ )
294
+
295
+ for task in evaltask .split (',' ):
296
+ task_log_report = get_task_world_logs (
297
+ task , log_report , is_multitask = True
298
+ )
299
+ with PathManager .open (task_log_report ) as f :
300
+ json_lines = f .readlines ()
301
+ assert len (json_lines ) == 4
302
+
303
+ def test_save_multiple_world_logs_mutator (self ):
304
+ """
305
+ Test that we can save multiple world_logs from train model on multiple tasks
306
+ with mutators present.
307
+ """
308
+ with testing_utils .tempdir () as tmpdir :
309
+ log_report = os .path .join (tmpdir , 'world_logs.jsonl' )
310
+ multitask = 'integration_tests:mutators=flatten,integration_tests:ReverseTeacher:mutator=reverse'
311
+ valid , test = testing_utils .train_model (
312
+ {
313
+ 'task' : multitask ,
314
+ 'validation_max_exs' : 10 ,
315
+ 'model' : 'repeat_label' ,
316
+ 'short_final_eval' : True ,
317
+ 'num_epochs' : 1.0 ,
318
+ 'world_logs' : log_report ,
319
+ }
320
+ )
321
+
322
+ for task in multitask .split (',' ):
323
+ task_log_report = get_task_world_logs (
324
+ task , log_report , is_multitask = True
325
+ )
326
+ with PathManager .open (task_log_report ) as f :
327
+ json_lines = f .readlines ()
328
+ assert len (json_lines ) == 5
329
+
274
330
275
331
@register_agent ("fake_report" )
276
332
class FakeReportAgent (Agent ):
0 commit comments