From 03af35c914111e341bda61e9782400ecbf767e41 Mon Sep 17 00:00:00 2001 From: Philip Colangelo Date: Tue, 31 Dec 2024 13:48:27 -0500 Subject: [PATCH] handle dupes with reports in multimodel analysis --- src/digest/model_class/digest_report_model.py | 142 +++++++++++++----- src/digest/multi_model_analysis.py | 36 +++-- src/digest/multi_model_selection_page.py | 49 ++++-- test/test_reports.py | 66 +------- 4 files changed, 157 insertions(+), 136 deletions(-) diff --git a/src/digest/model_class/digest_report_model.py b/src/digest/model_class/digest_report_model.py index 4478285..50e76de 100644 --- a/src/digest/model_class/digest_report_model.py +++ b/src/digest/model_class/digest_report_model.py @@ -3,7 +3,7 @@ import csv import ast import re -from typing import Tuple, Optional +from typing import Tuple, Optional, List, Dict, Any, Union import yaml from digest.model_class.digest_model import ( DigestModel, @@ -15,7 +15,9 @@ ) -def parse_tensor_info(csv_tensor_cell_value) -> Tuple[str, list, str, float]: +def parse_tensor_info( + csv_tensor_cell_value, +) -> Tuple[str, list, str, Union[str, float]]: """This is a helper function that expects the input to come from parsing the nodes csv and extracting either an input or output tensor.""" @@ -38,7 +40,10 @@ def parse_tensor_info(csv_tensor_cell_value) -> Tuple[str, list, str, float]: if not isinstance(shape, list): shape = list(shape) - return name.strip(), shape, dtype.strip(), float(size.split()[0]) + if size != "None": + size = float(size.split()[0]) + + return name.strip(), shape, dtype.strip(), size class DigestReportModel(DigestModel): @@ -49,7 +54,7 @@ def __init__( self.model_type = SupportedModelTypes.REPORT - self.is_valid = self.validate_yaml(report_filepath) + self.is_valid = validate_yaml(report_filepath) if not self.is_valid: print(f"The yaml file {report_filepath} is not a valid digest report.") @@ -131,41 +136,6 @@ def __init__( } ) - def validate_yaml(self, report_file_path: str) -> bool: - """Check that the provided yaml file is indeed a Digest Report file.""" - expected_keys = [ - "report_date", - "model_file", - "model_type", - "model_name", - "flops", - "node_type_flops", - "node_type_parameters", - "node_type_counts", - "input_tensors", - "output_tensors", - ] - try: - with open(report_file_path, "r", encoding="utf-8") as file: - yaml_content = yaml.safe_load(file) - - if not isinstance(yaml_content, dict): - print("Error: YAML content is not a dictionary") - return False - - for key in expected_keys: - if key not in yaml_content: - # print(f"Error: Missing required key '{key}'") - return False - - return True - except yaml.YAMLError as _: - # print(f"Error parsing YAML file: {e}") - return False - except IOError as _: - # print(f"Error reading file: {e}") - return False - def parse_model_nodes(self) -> None: """There are no model nodes to parse""" @@ -174,3 +144,97 @@ def save_yaml_report(self, filepath: str) -> None: def save_text_report(self, filepath: str) -> None: """Report models are not intended to be saved""" + + +def validate_yaml(report_file_path: str) -> bool: + """Check that the provided yaml file is indeed a Digest Report file.""" + expected_keys = [ + "report_date", + "model_file", + "model_type", + "model_name", + "flops", + "node_type_flops", + "node_type_parameters", + "node_type_counts", + "input_tensors", + "output_tensors", + ] + try: + with open(report_file_path, "r", encoding="utf-8") as file: + yaml_content = yaml.safe_load(file) + + if not isinstance(yaml_content, dict): + print("Error: YAML content is not a dictionary") + return False + + for key in expected_keys: + if key not in yaml_content: + # print(f"Error: Missing required key '{key}'") + return False + + return True + except yaml.YAMLError as _: + # print(f"Error parsing YAML file: {e}") + return False + except IOError as _: + # print(f"Error reading file: {e}") + return False + + +def compare_yaml_files( + file1: str, file2: str, skip_keys: Optional[List[str]] = None +) -> bool: + """ + Compare two YAML files, ignoring specified keys. + + :param file1: Path to the first YAML file + :param file2: Path to the second YAML file + :param skip_keys: List of keys to ignore in the comparison + :return: True if the files are equal (ignoring specified keys), False otherwise + """ + + def load_yaml(file_path: str) -> Dict[str, Any]: + with open(file_path, "r", encoding="utf-8") as file: + return yaml.safe_load(file) + + def compare_dicts( + dict1: Dict[str, Any], dict2: Dict[str, Any], path: str = "" + ) -> List[str]: + differences = [] + all_keys = set(dict1.keys()) | set(dict2.keys()) + + for key in all_keys: + if skip_keys and key in skip_keys: + continue + + current_path = f"{path}.{key}" if path else key + + if key not in dict1: + differences.append(f"Key '{current_path}' is missing in the first file") + elif key not in dict2: + differences.append( + f"Key '{current_path}' is missing in the second file" + ) + elif isinstance(dict1[key], dict) and isinstance(dict2[key], dict): + differences.extend(compare_dicts(dict1[key], dict2[key], current_path)) + elif dict1[key] != dict2[key]: + differences.append( + f"Value mismatch for key '{current_path}': {dict1[key]} != {dict2[key]}" + ) + + return differences + + yaml1 = load_yaml(file1) + yaml2 = load_yaml(file2) + + differences = compare_dicts(yaml1, yaml2) + + if differences: + # print("Differences found:") + # for diff in differences: + # print(f"- {diff}") + return False + else: + # print("No differences found.") + return True diff --git a/src/digest/multi_model_analysis.py b/src/digest/multi_model_analysis.py index d5937bc..a19d40f 100644 --- a/src/digest/multi_model_analysis.py +++ b/src/digest/multi_model_analysis.py @@ -101,7 +101,13 @@ def __init__( self.ui.dataTable.resizeColumnsToContents() self.ui.dataTable.resizeRowsToContents() - node_type_counter = {} + # Until we use the unique_id to represent the model contents we store + # the entire model as the key so that we can store models that happen to have + # the same name. There is a guarantee that the models will not be duplicates. + node_type_counter: Dict[ + Union[DigestOnnxModel, DigestReportModel], NodeTypeCounts + ] = {} + for i, digest_model in enumerate(model_list): progress.step() progress.setLabelText(f"Analyzing model {digest_model.model_name}") @@ -143,26 +149,18 @@ def __init__( "flops": digest_model.flops, } - # Here we are creating a name that is a combination of the model name - # and the model type. - node_type_counter_key = ( - f"{digest_model.model_name}-{digest_model.model_type.value}" - ) - - if node_type_counter_key in node_type_counter: + if digest_model in node_type_counter: print( f"Warning! {digest_model.model_name} with model type " - f"{digest_model.model_type.value} has already been added to " - "to the stacked histogram, skipping." + f"{digest_model.model_type.value} and id {digest_model.unique_id} " + "has already been added to the stacked histogram, skipping." ) continue - node_type_counter[node_type_counter_key] = digest_model.node_type_counts + node_type_counter[digest_model] = digest_model.node_type_counts # Update global data structure for node type counter - self.global_node_type_counter.update( - node_type_counter[node_type_counter_key] - ) + self.global_node_type_counter.update(node_type_counter[digest_model]) node_shape_counts = digest_model.get_node_shape_counts() @@ -180,20 +178,20 @@ def __init__( # Create stacked op histograms max_count = 0 top_ops = [key for key, _ in self.global_node_type_counter.most_common(20)] - for model_name, _ in node_type_counter.items(): - max_local = Counter(node_type_counter[model_name]).most_common()[0][1] + for model, _ in node_type_counter.items(): + max_local = Counter(node_type_counter[model]).most_common()[0][1] if max_local > max_count: max_count = max_local - for idx, model_name in enumerate(node_type_counter): + for idx, model in enumerate(node_type_counter): stacked_histogram_widget = StackedHistogramWidget() ordered_dict = OrderedDict() - model_counter = Counter(node_type_counter[model_name]) + model_counter = Counter(node_type_counter[model]) for key in top_ops: ordered_dict[key] = model_counter.get(key, 0) title = "Stacked Op Histogram" if idx == 0 else "" stacked_histogram_widget.set_data( ordered_dict, - model_name=model_name, + model_name=model.model_name, y_max=max_count, title=title, set_ticks=False, diff --git a/src/digest/multi_model_selection_page.py b/src/digest/multi_model_selection_page.py index 0d0e5cc..e9d5c2b 100644 --- a/src/digest/multi_model_selection_page.py +++ b/src/digest/multi_model_selection_page.py @@ -23,7 +23,7 @@ from digest.multi_model_analysis import MultiModelAnalysis from digest.qt_utils import apply_dark_style_sheet, prompt_user_ram_limit from digest.model_class.digest_onnx_model import DigestOnnxModel -from digest.model_class.digest_report_model import DigestReportModel +from digest.model_class.digest_report_model import DigestReportModel, compare_yaml_files from utils import onnx_utils @@ -203,7 +203,7 @@ def set_directory(self, directory: str): else: return - progress = ProgressDialog("Searching Directory for ONNX Files", 0, self) + progress = ProgressDialog("Searching directory for model files", 0, self) onnx_file_list = list( glob.glob(os.path.join(directory, "**/*.onnx"), recursive=True) @@ -227,11 +227,11 @@ def set_directory(self, directory: str): serialized_models_paths: defaultdict[bytes, List[str]] = defaultdict(list) progress.close() - progress = ProgressDialog("Loading Models", total_num_models, self) + progress = ProgressDialog("Loading models", total_num_models, self) memory_limit_percentage = 90 models_loaded = 0 - for filepath in onnx_file_list: + for filepath in onnx_file_list + report_file_list: progress.step() if progress.user_canceled: break @@ -284,17 +284,38 @@ def set_directory(self, directory: str): self.ui.duplicateListWidget.addItem(paths[0]) for dupe in paths[1:]: self.ui.duplicateListWidget.addItem(f"- Duplicate: {dupe}") - item = QStandardItem(paths[0]) - item.setCheckable(True) - item.setCheckState(Qt.CheckState.Checked) - self.item_model.appendRow(item) - else: - item = QStandardItem(paths[0]) - item.setCheckable(True) - item.setCheckState(Qt.CheckState.Checked) - self.item_model.appendRow(item) + item = QStandardItem(paths[0]) + item.setCheckable(True) + item.setCheckState(Qt.CheckState.Checked) + self.item_model.appendRow(item) - for path in report_file_list: + # Use a standard nested loop to detect duplicate reports + duplicate_reports: Dict[str, List[str]] = {} + processed_files = set() + for i in range(len(report_file_list)): + progress.step() + if progress.user_canceled: + break + path1 = report_file_list[i] + if path1 in processed_files: + continue # Skip already processed files + + # We will use path1 as the unique model and save a list of duplicates + duplicate_reports[path1] = [] + for j in range(i + 1, len(report_file_list)): + path2 = report_file_list[j] + if compare_yaml_files( + path1, path2, ["report_date", "model_files", "digest_version"] + ): + num_duplicates += 1 + duplicate_reports[path1].append(path2) + processed_files.add(path2) + + for path, dupes in duplicate_reports.items(): + if dupes: + self.ui.duplicateListWidget.addItem(path) + for dupe in dupes: + self.ui.duplicateListWidget.addItem(f"- Duplicate: {dupe}") item = QStandardItem(path) item.setCheckable(True) item.setCheckState(Qt.CheckState.Checked) diff --git a/test/test_reports.py b/test/test_reports.py index 4653464..ae99ab9 100644 --- a/test/test_reports.py +++ b/test/test_reports.py @@ -4,10 +4,9 @@ import unittest import tempfile import csv -from typing import List, Optional, Dict, Any -import yaml import utils.onnx_utils as onnx_utils from digest.model_class.digest_onnx_model import DigestOnnxModel +from digest.model_class.digest_report_model import compare_yaml_files TEST_DIR = os.path.dirname(os.path.abspath(__file__)) TEST_ONNX = os.path.join(TEST_DIR, "resnet18.onnx") @@ -51,67 +50,6 @@ def compare_csv_files(self, file1, file2, skip_lines=0): for row1, row2 in zip(reader1, reader2): self.assertEqual(row1, row2, msg=f"Difference in row: {row1} vs {row2}") - def compare_yaml_files( - self, file1: str, file2: str, skip_keys: Optional[List[str]] = None - ) -> bool: - """ - Compare two YAML files, ignoring specified keys. - - :param file1: Path to the first YAML file - :param file2: Path to the second YAML file - :param skip_keys: List of keys to ignore in the comparison - :return: True if the files are equal (ignoring specified keys), False otherwise - """ - - def load_yaml(file_path: str) -> Dict[str, Any]: - with open(file_path, "r", encoding="utf-8") as file: - return yaml.safe_load(file) - - def compare_dicts( - dict1: Dict[str, Any], dict2: Dict[str, Any], path: str = "" - ) -> List[str]: - differences = [] - all_keys = set(dict1.keys()) | set(dict2.keys()) - - for key in all_keys: - if skip_keys and key in skip_keys: - continue - - current_path = f"{path}.{key}" if path else key - - if key not in dict1: - differences.append( - f"Key '{current_path}' is missing in the first file" - ) - elif key not in dict2: - differences.append( - f"Key '{current_path}' is missing in the second file" - ) - elif isinstance(dict1[key], dict) and isinstance(dict2[key], dict): - differences.extend( - compare_dicts(dict1[key], dict2[key], current_path) - ) - elif dict1[key] != dict2[key]: - differences.append( - f"Value mismatch for key '{current_path}': {dict1[key]} != {dict2[key]}" - ) - - return differences - - yaml1 = load_yaml(file1) - yaml2 = load_yaml(file2) - - differences = compare_dicts(yaml1, yaml2) - - if differences: - print("Differences found:") - for diff in differences: - print(f"- {diff}") - return False - else: - print("No differences found.") - return True - def test_against_example_reports(self): model_proto = onnx_utils.load_onnx(TEST_ONNX, load_external_data=False) model_name = os.path.splitext(os.path.basename(TEST_ONNX))[0] @@ -129,7 +67,7 @@ def test_against_example_reports(self): digest_model.save_yaml_report(yaml_report_filepath) with self.subTest("Testing report yaml file"): self.assertTrue( - self.compare_yaml_files( + compare_yaml_files( TEST_SUMMARY_YAML_REPORT, yaml_report_filepath, skip_keys=["report_date", "model_file", "digest_version"],