Skip to content

Commit

Permalink
No longer test text file in favor of yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
Philip Colangelo committed Dec 10, 2024
1 parent 76102d0 commit 112403b
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 25 deletions.
6 changes: 2 additions & 4 deletions src/digest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,9 @@ def load_onnx(self, filepath: str):
basename = os.path.splitext(os.path.basename(filepath))
model_name = basename[0]

# Save the model proto so we can use the Freeze Inputs feature
digest_model = DigestOnnxModel(
onnx_model=model, model_name=model_name, save_proto=False
onnx_model=opt_model, model_name=model_name, save_proto=True
)
model_id = digest_model.unique_id

Expand All @@ -484,9 +485,6 @@ def load_onnx(self, filepath: str):

self.digest_models[model_id] = digest_model

# We must set the proto for the model_summary freeze_inputs
digest_model.model_proto = opt_model

model_summary = modelSummary(digest_model)
if model_summary.freeze_inputs:
model_summary.freeze_inputs.complete_signal.connect(self.load_onnx)
Expand Down
14 changes: 9 additions & 5 deletions src/digest/model_class/digest_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from typing import List, Dict, Optional, Tuple, Union, cast
from datetime import datetime
from collections import OrderedDict
import yaml
import numpy as np
import onnx
Expand Down Expand Up @@ -38,7 +39,7 @@ def __init__(
self.producer_version: Optional[str] = None
self.ir_version: Optional[int] = None
self.opset: Optional[int] = None
self.imports: Dict[str, int] = {}
self.imports: OrderedDict[str, int] = OrderedDict()

# Private members not intended to be exposed
self.input_tensors_: Dict[str, onnx.ValueInfoProto] = {}
Expand All @@ -55,9 +56,12 @@ def update_state(self, model_proto: onnx.ModelProto) -> None:
self.producer_version = model_proto.producer_version
self.ir_version = model_proto.ir_version
self.opset = onnx_utils.get_opset(model_proto)
self.imports = {
import_.domain: import_.version for import_ in model_proto.opset_import
}
self.imports = OrderedDict(
sorted(
(import_.domain, import_.version)
for import_ in model_proto.opset_import
)
)

self.model_inputs = onnx_utils.get_model_input_shapes_types(model_proto)
self.model_outputs = onnx_utils.get_model_output_shapes_types(model_proto)
Expand Down Expand Up @@ -527,7 +531,7 @@ def save_yaml_report(self, filepath: str) -> None:
"producer_version": self.producer_version,
"ir_version": self.ir_version,
"opset": self.opset,
"import_list": self.imports,
"import_list": dict(self.imports),
"graph_nodes": sum(self.node_type_counts.values()),
"model_parameters": self.model_parameters,
"model_flops": self.model_flops,
Expand Down
4 changes: 2 additions & 2 deletions test/resnet18_reports/resnet18_report.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ Opset: 17

Import list
: 17
com.microsoft.nchwc: 1
ai.onnx.ml: 5
ai.onnx.training: 1
ai.onnx.preview.training: 1
ai.onnx.training: 1
com.microsoft: 1
com.microsoft.experimental: 1
com.microsoft.nchwc: 1
org.pytorch.aten: 1

Total graph nodes: 49
Expand Down
4 changes: 2 additions & 2 deletions test/resnet18_reports/resnet18_report.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ opset: 17
import_list:
? ''
: 17
com.microsoft.nchwc: 1
ai.onnx.ml: 5
ai.onnx.training: 1
ai.onnx.preview.training: 1
ai.onnx.training: 1
com.microsoft: 1
com.microsoft.experimental: 1
com.microsoft.nchwc: 1
org.pytorch.aten: 1
graph_nodes: 49
model_parameters: 11684712
Expand Down
81 changes: 69 additions & 12 deletions test/test_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import unittest
import tempfile
import csv
from typing import List, Optional, Dict, Any
import yaml
import utils.onnx_utils as onnx_utils
from digest.model_class.digest_onnx_model import DigestOnnxModel

Expand Down Expand Up @@ -49,6 +51,67 @@ def compare_csv_files(self, file1, file2, skip_lines=0):
for row1, row2 in zip(reader1, reader2):
self.assertEqual(row1, row2, msg=f"Difference in row: {row1} vs {row2}")

def compare_yaml_files(
self, file1: str, file2: str, skip_keys: Optional[List[str]] = None
) -> bool:
"""
Compare two YAML files, ignoring specified keys.
:param file1: Path to the first YAML file
:param file2: Path to the second YAML file
:param skip_keys: List of keys to ignore in the comparison
:return: True if the files are equal (ignoring specified keys), False otherwise
"""

def load_yaml(file_path: str) -> Dict[str, Any]:
with open(file_path, "r", encoding="utf-8") as file:
return yaml.safe_load(file)

def compare_dicts(
dict1: Dict[str, Any], dict2: Dict[str, Any], path: str = ""
) -> List[str]:
differences = []
all_keys = set(dict1.keys()) | set(dict2.keys())

for key in all_keys:
if skip_keys and key in skip_keys:
continue

current_path = f"{path}.{key}" if path else key

if key not in dict1:
differences.append(
f"Key '{current_path}' is missing in the first file"
)
elif key not in dict2:
differences.append(
f"Key '{current_path}' is missing in the second file"
)
elif isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
differences.extend(
compare_dicts(dict1[key], dict2[key], current_path)
)
elif dict1[key] != dict2[key]:
differences.append(
f"Value mismatch for key '{current_path}': {dict1[key]} != {dict2[key]}"
)

return differences

yaml1 = load_yaml(file1)
yaml2 = load_yaml(file2)

differences = compare_dicts(yaml1, yaml2)

if differences:
print("Differences found:")
for diff in differences:
print(f"- {diff}")
return False
else:
print("No differences found.")
return True

def test_against_example_reports(self):
model_proto = onnx_utils.load_onnx(TEST_ONNX, load_external_data=False)
model_name = os.path.splitext(os.path.basename(TEST_ONNX))[0]
Expand All @@ -61,22 +124,16 @@ def test_against_example_reports(self):
)

with tempfile.TemporaryDirectory() as tmpdir:
# Model text report
text_report_filepath = os.path.join(tmpdir, f"{model_name}_report.txt")
digest_model.save_text_report(text_report_filepath)
with self.subTest("Testing report text file"):
self.compare_files_line_by_line(
TEST_SUMMARY_TEXT_REPORT,
text_report_filepath,
skip_lines=2,
)

# Model yaml report
yaml_report_filepath = os.path.join(tmpdir, f"{model_name}_report.yaml")
digest_model.save_yaml_report(yaml_report_filepath)
with self.subTest("Testing report yaml file"):
self.compare_files_line_by_line(
TEST_SUMMARY_YAML_REPORT, yaml_report_filepath, skip_lines=2
self.assertTrue(
self.compare_yaml_files(
TEST_SUMMARY_YAML_REPORT,
yaml_report_filepath,
skip_keys=["report_date", "onnx_file"],
)
)

# Save CSV containing node-level information
Expand Down

0 comments on commit 112403b

Please sign in to comment.