diff --git a/src/digest/main.py b/src/digest/main.py index 7e71b58..fe46bf0 100644 --- a/src/digest/main.py +++ b/src/digest/main.py @@ -474,8 +474,9 @@ def load_onnx(self, filepath: str): basename = os.path.splitext(os.path.basename(filepath)) model_name = basename[0] + # Save the model proto so we can use the Freeze Inputs feature digest_model = DigestOnnxModel( - onnx_model=model, model_name=model_name, save_proto=False + onnx_model=opt_model, model_name=model_name, save_proto=True ) model_id = digest_model.unique_id @@ -484,9 +485,6 @@ def load_onnx(self, filepath: str): self.digest_models[model_id] = digest_model - # We must set the proto for the model_summary freeze_inputs - digest_model.model_proto = opt_model - model_summary = modelSummary(digest_model) if model_summary.freeze_inputs: model_summary.freeze_inputs.complete_signal.connect(self.load_onnx) diff --git a/src/digest/model_class/digest_onnx_model.py b/src/digest/model_class/digest_onnx_model.py index 2ee4583..c8b5af3 100644 --- a/src/digest/model_class/digest_onnx_model.py +++ b/src/digest/model_class/digest_onnx_model.py @@ -3,6 +3,7 @@ import os from typing import List, Dict, Optional, Tuple, Union, cast from datetime import datetime +from collections import OrderedDict import yaml import numpy as np import onnx @@ -38,7 +39,7 @@ def __init__( self.producer_version: Optional[str] = None self.ir_version: Optional[int] = None self.opset: Optional[int] = None - self.imports: Dict[str, int] = {} + self.imports: OrderedDict[str, int] = OrderedDict() # Private members not intended to be exposed self.input_tensors_: Dict[str, onnx.ValueInfoProto] = {} @@ -55,9 +56,12 @@ def update_state(self, model_proto: onnx.ModelProto) -> None: self.producer_version = model_proto.producer_version self.ir_version = model_proto.ir_version self.opset = onnx_utils.get_opset(model_proto) - self.imports = { - import_.domain: import_.version for import_ in model_proto.opset_import - } + self.imports = OrderedDict( + sorted( + (import_.domain, import_.version) + for import_ in model_proto.opset_import + ) + ) self.model_inputs = onnx_utils.get_model_input_shapes_types(model_proto) self.model_outputs = onnx_utils.get_model_output_shapes_types(model_proto) @@ -527,7 +531,7 @@ def save_yaml_report(self, filepath: str) -> None: "producer_version": self.producer_version, "ir_version": self.ir_version, "opset": self.opset, - "import_list": self.imports, + "import_list": dict(self.imports), "graph_nodes": sum(self.node_type_counts.values()), "model_parameters": self.model_parameters, "model_flops": self.model_flops, diff --git a/test/resnet18_reports/resnet18_report.txt b/test/resnet18_reports/resnet18_report.txt index a72ae03..fdda0bf 100644 --- a/test/resnet18_reports/resnet18_report.txt +++ b/test/resnet18_reports/resnet18_report.txt @@ -9,12 +9,12 @@ Opset: 17 Import list : 17 - com.microsoft.nchwc: 1 ai.onnx.ml: 5 - ai.onnx.training: 1 ai.onnx.preview.training: 1 + ai.onnx.training: 1 com.microsoft: 1 com.microsoft.experimental: 1 + com.microsoft.nchwc: 1 org.pytorch.aten: 1 Total graph nodes: 49 diff --git a/test/resnet18_reports/resnet18_report.yaml b/test/resnet18_reports/resnet18_report.yaml index 531f840..8fe8eea 100644 --- a/test/resnet18_reports/resnet18_report.yaml +++ b/test/resnet18_reports/resnet18_report.yaml @@ -10,12 +10,12 @@ opset: 17 import_list: ? '' : 17 - com.microsoft.nchwc: 1 ai.onnx.ml: 5 - ai.onnx.training: 1 ai.onnx.preview.training: 1 + ai.onnx.training: 1 com.microsoft: 1 com.microsoft.experimental: 1 + com.microsoft.nchwc: 1 org.pytorch.aten: 1 graph_nodes: 49 model_parameters: 11684712 diff --git a/test/test_reports.py b/test/test_reports.py index 9740121..cc99063 100644 --- a/test/test_reports.py +++ b/test/test_reports.py @@ -4,6 +4,8 @@ import unittest import tempfile import csv +from typing import List, Optional, Dict, Any +import yaml import utils.onnx_utils as onnx_utils from digest.model_class.digest_onnx_model import DigestOnnxModel @@ -49,6 +51,67 @@ def compare_csv_files(self, file1, file2, skip_lines=0): for row1, row2 in zip(reader1, reader2): self.assertEqual(row1, row2, msg=f"Difference in row: {row1} vs {row2}") + def compare_yaml_files( + self, file1: str, file2: str, skip_keys: Optional[List[str]] = None + ) -> bool: + """ + Compare two YAML files, ignoring specified keys. + + :param file1: Path to the first YAML file + :param file2: Path to the second YAML file + :param skip_keys: List of keys to ignore in the comparison + :return: True if the files are equal (ignoring specified keys), False otherwise + """ + + def load_yaml(file_path: str) -> Dict[str, Any]: + with open(file_path, "r", encoding="utf-8") as file: + return yaml.safe_load(file) + + def compare_dicts( + dict1: Dict[str, Any], dict2: Dict[str, Any], path: str = "" + ) -> List[str]: + differences = [] + all_keys = set(dict1.keys()) | set(dict2.keys()) + + for key in all_keys: + if skip_keys and key in skip_keys: + continue + + current_path = f"{path}.{key}" if path else key + + if key not in dict1: + differences.append( + f"Key '{current_path}' is missing in the first file" + ) + elif key not in dict2: + differences.append( + f"Key '{current_path}' is missing in the second file" + ) + elif isinstance(dict1[key], dict) and isinstance(dict2[key], dict): + differences.extend( + compare_dicts(dict1[key], dict2[key], current_path) + ) + elif dict1[key] != dict2[key]: + differences.append( + f"Value mismatch for key '{current_path}': {dict1[key]} != {dict2[key]}" + ) + + return differences + + yaml1 = load_yaml(file1) + yaml2 = load_yaml(file2) + + differences = compare_dicts(yaml1, yaml2) + + if differences: + print("Differences found:") + for diff in differences: + print(f"- {diff}") + return False + else: + print("No differences found.") + return True + def test_against_example_reports(self): model_proto = onnx_utils.load_onnx(TEST_ONNX, load_external_data=False) model_name = os.path.splitext(os.path.basename(TEST_ONNX))[0] @@ -61,22 +124,16 @@ def test_against_example_reports(self): ) with tempfile.TemporaryDirectory() as tmpdir: - # Model text report - text_report_filepath = os.path.join(tmpdir, f"{model_name}_report.txt") - digest_model.save_text_report(text_report_filepath) - with self.subTest("Testing report text file"): - self.compare_files_line_by_line( - TEST_SUMMARY_TEXT_REPORT, - text_report_filepath, - skip_lines=2, - ) - # Model yaml report yaml_report_filepath = os.path.join(tmpdir, f"{model_name}_report.yaml") digest_model.save_yaml_report(yaml_report_filepath) with self.subTest("Testing report yaml file"): - self.compare_files_line_by_line( - TEST_SUMMARY_YAML_REPORT, yaml_report_filepath, skip_lines=2 + self.assertTrue( + self.compare_yaml_files( + TEST_SUMMARY_YAML_REPORT, + yaml_report_filepath, + skip_keys=["report_date", "onnx_file"], + ) ) # Save CSV containing node-level information