33import sys
44import ast
55import math
6+ import json
67import importlib
78import inspect
89import subprocess
@@ -60,6 +61,36 @@ def get_number_of_returns(file_path, class_name, func_name):
6061 return 0
6162
6263
64+ def read_graph_source_and_tag (model_path ):
65+ try :
66+ with open (os .path .join (model_path , "graph_net.json" ), "r" ) as f :
67+ data = json .load (f )
68+ return data ["source" ], data ["heuristic_tag" ]
69+ except Exception :
70+ if "cosyvoice" in model_path :
71+ return "cosyvoice" , "audio"
72+ elif "torchaudio" in model_path :
73+ return "torchaudio" , "audio"
74+ elif "ultralytics" in model_path :
75+ return "ultralytics" , "computer_vision"
76+ elif "torchvision" in model_path :
77+ return "torchvision" , "computer_vision"
78+ elif "timm" in model_path :
79+ return "timm" , "computer_vision"
80+ elif "mmseg" in model_path :
81+ return "mmseg" , "computer_vision"
82+ elif "mmpose" in model_path :
83+ return "mmpose" , "computer_vision"
84+ elif "torchgeometric" in model_path :
85+ return "torchgeometric" , "other"
86+ elif "transformers-auto-model" in model_path :
87+ return "huggingface_hub" , "unknown"
88+ elif "nemo" in model_path :
89+ return "nemo" , "unknown"
90+ else :
91+ return "unknown" , "unknown"
92+
93+
6394def get_input_dict (model_path , device ):
6495 inputs_params = utils .load_converted_from_text (f"{ model_path } " )
6596 params = inputs_params ["weight_info" ]
@@ -456,6 +487,8 @@ def collect_model_stats(model_path, device, log_prompt):
456487 model_size_in_billion = model_size / 1e9
457488 num_inputs = len (argument_name2types ) - num_params
458489
490+ source , heuristic_tag = read_graph_source_and_tag (model_path )
491+
459492 def dict_to_string (d ):
460493 kv_list = [f"{ k } :{ v } " for k , v in d .items ()]
461494 return " " .join (kv_list )
@@ -475,6 +508,8 @@ def print_with_log_prompt(key, value):
475508 print_with_log_prompt ("param_dtypes" , dict_to_string (param_dtypes ))
476509 print_with_log_prompt ("op_dtypes" , dict_to_string (op_dtypes ))
477510 print_with_log_prompt ("ops" , dict_to_string (ops_count_dict ))
511+ print_with_log_prompt ("source" , source )
512+ print_with_log_prompt ("heuristic_tag" , heuristic_tag )
478513 print_with_log_prompt ("method" , method )
479514 print_with_log_prompt ("is_complete" , is_complete )
480515
@@ -505,7 +540,10 @@ def main(args):
505540
506541 i = 0
507542 for root , dirs , files in os .walk (graph_net_samples_path ):
508- if is_single_model_dir (root ) and root in previous_failed_model_pathes :
543+ if is_single_model_dir (root ) and (
544+ args .previous_collect_result_path is None
545+ or root in previous_failed_model_pathes
546+ ):
509547 print (f"[{ i } ] Collect information for { root } " )
510548 cmd = [
511549 "python" ,
0 commit comments