@@ -473,20 +473,21 @@ def decay_batch_exp(batch_size, factor=0.5, divisor=16):
473473def _try_run (model_name , bench_fn , initial_batch_size , bench_kwargs ):
474474 batch_size = initial_batch_size
475475 results = dict ()
476+ error_str = 'Unknown'
476477 while batch_size >= 1 :
477478 torch .cuda .empty_cache ()
478479 try :
479480 bench = bench_fn (model_name = model_name , batch_size = batch_size , ** bench_kwargs )
480481 results = bench .run ()
481482 return results
482483 except RuntimeError as e :
483- e_str = str (e )
484- print (e_str )
485- if 'channels_last' in e_str :
486- print (f'Error: { model_name } not supported in channels_last, skipping.' )
484+ error_str = str (e )
485+ if 'channels_last' in error_str :
486+ _logger .error (f'{ model_name } not supported in channels_last, skipping.' )
487487 break
488- print (f'Error: " { e_str } " while running benchmark. Reducing batch size to { batch_size } for retry.' )
488+ _logger . warning (f'" { error_str } " while running benchmark. Reducing batch size to { batch_size } for retry.' )
489489 batch_size = decay_batch_exp (batch_size )
490+ results ['error' ] = error_str
490491 return results
491492
492493
@@ -528,13 +529,14 @@ def benchmark(args):
528529 model_results = OrderedDict (model = model )
529530 for prefix , bench_fn in zip (prefixes , bench_fns ):
530531 run_results = _try_run (model , bench_fn , initial_batch_size = batch_size , bench_kwargs = bench_kwargs )
531- if prefix :
532+ if prefix and 'error' not in run_results :
532533 run_results = {'_' .join ([prefix , k ]): v for k , v in run_results .items ()}
533534 model_results .update (run_results )
534- param_count = model_results .pop ('infer_param_count' , model_results .pop ('train_param_count' , 0 ))
535- model_results .setdefault ('param_count' , param_count )
536- model_results .pop ('train_param_count' , 0 )
537- return model_results if model_results ['param_count' ] else dict ()
535+ if 'error' not in model_results :
536+ param_count = model_results .pop ('infer_param_count' , model_results .pop ('train_param_count' , 0 ))
537+ model_results .setdefault ('param_count' , param_count )
538+ model_results .pop ('train_param_count' , 0 )
539+ return model_results
538540
539541
540542def main ():
@@ -578,13 +580,15 @@ def main():
578580 sort_key = 'train_samples_per_sec'
579581 elif 'profile' in args .bench :
580582 sort_key = 'infer_gmacs'
583+ results = filter (lambda x : sort_key in x , results )
581584 results = sorted (results , key = lambda x : x [sort_key ], reverse = True )
582585 if len (results ):
583586 write_results (results_file , results )
584587 else :
585588 results = benchmark (args )
586- json_str = json .dumps (results , indent = 4 )
587- print (json_str )
589+
590+ # output results in JSON to stdout w/ delimiter for runner script
591+ print (f'--result\n { json .dumps (results , indent = 4 )} ' )
588592
589593
590594def write_results (results_file , results ):
0 commit comments