diff --git a/.pylintrc b/.pylintrc
index b831756..efb14e7 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -81,7 +81,6 @@ enable =
expression-not-assigned,
confusing-with-statement,
unnecessary-lambda,
- assign-to-new-keyword,
redeclared-assigned-name,
pointless-statement,
pointless-string-statement,
@@ -123,7 +122,6 @@ enable =
invalid-length-returned,
protected-access,
attribute-defined-outside-init,
- no-init,
abstract-method,
invalid-overridden-method,
arguments-differ,
@@ -165,9 +163,7 @@ enable =
### format
# Line length, indentation, whitespace:
bad-indentation,
- mixed-indentation,
unnecessary-semicolon,
- bad-whitespace,
missing-final-newline,
line-too-long,
mixed-line-endings,
@@ -187,7 +183,6 @@ enable =
import-self,
preferred-module,
reimported,
- relative-import,
deprecated-module,
wildcard-import,
misplaced-future,
@@ -282,12 +277,6 @@ indent-string = ' '
# black doesn't always obey its own limit. See pyproject.toml.
max-line-length = 100
-# List of optional constructs for which whitespace checking is disabled. `dict-
-# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
-# `trailing-comma` allows a space between comma and closing bracket: (a, ).
-# `empty-line` allows space-only lines.
-no-space-check =
-
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt = no
diff --git a/README.md b/README.md
index 7338b35..63f8aa7 100644
--- a/README.md
+++ b/README.md
@@ -8,6 +8,7 @@ SPDX-License-Identifier: Apache-2.0
DigestAI
===========================
+
DigestAI is a powerful model analysis tool that extracts insights from your models, enabling optimization and direct modification.
-**Get started quickly!** Download the DigestAI installer directly from [coming soon!].
+
+**Get started quickly!** Download the DigestAI executable directly from [link coming soon].
+
**Developers: Contribute to DigestAI** Follow the installation instruction below to get started.
@@ -65,19 +69,18 @@ The following steps are recommended because they are reproducible, however, ther
**Workflow**
1. **Open Qt Designer:**
- - **Activate Conda Environment:** Ensure your `digest` Conda environment is activated.
- - **Launch:** Run `pyside6-designer.exe` from your terminal.
+ * **Activate Conda Environment:** Ensure your `digest` Conda environment is activated.
+ * **Launch:** Run `pyside6-designer.exe` from your terminal.
2. **Work with UI Files:**
- - Open any existing UI file (`.ui`) from `src/digest/ui`.
- - Design your interface using the drag-and-drop tools and property editor.
- - Resource Files (Optional): If your UI uses custom icons, images, or stylesheets, please leverage the Qt resource file (`.qrc`). This makes it easier to manage and package resources with the application.
- - Please add any new `.ui` files to the `.pylintrc` file.
+ * Open any existing UI file (`.ui`) from `src/digest/ui`.
+ * Design your interface using the drag-and-drop tools and property editor.
+ * Resource Files (Optional): If your UI uses custom icons, images, or stylesheets, please leverage the Qt resource file (`.qrc`). This makes it easier to manage and package resources with the application.
+ * Please add any new `.ui` files to the `.pylintrc` file.
3. **Recompile UI Files (After Making Changes):**
- - From your terminal, navigate to the project's root directory.
- - Run: `python src/digest/compile_digest_gui.py`
-
+ * From your terminal, navigate to the project's root directory.
+ * Run: `python src/digest/compile_digest_gui.py`
## Building EXE for Windows Deployment
@@ -114,8 +117,9 @@ pytest test/test_gui.py
```
## License
-This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE.txt) file for details.
+This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE.txt) file for details.
## Copyright
+
Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.
diff --git a/examples/analysis.py b/examples/analysis.py
index e9b9c63..da89068 100644
--- a/examples/analysis.py
+++ b/examples/analysis.py
@@ -6,14 +6,16 @@
import csv
from collections import Counter, defaultdict
from tqdm import tqdm
+from digest.model_class.digest_model import (
+ NodeShapeCounts,
+ NodeTypeCounts,
+ save_node_shape_counts_csv_report,
+ save_node_type_counts_csv_report,
+)
+from digest.model_class.digest_onnx_model import DigestOnnxModel
from utils.onnx_utils import (
get_dynamic_input_dims,
load_onnx,
- DigestOnnxModel,
- save_node_shape_counts_csv_report,
- save_node_type_counts_csv_report,
- NodeTypeCounts,
- NodeShapeCounts,
)
GLOBAL_MODEL_HEADERS = [
@@ -82,46 +84,46 @@ def main(onnx_files: str, output_dir: str):
global_model_data[model_name] = {
"opset": digest_model.opset,
- "parameters": digest_model.model_parameters,
- "flops": digest_model.model_flops,
+ "parameters": digest_model.parameters,
+ "flops": digest_model.flops,
}
# Model summary text report
summary_filepath = os.path.join(output_dir, f"{model_name}_summary.txt")
- digest_model.save_txt_report(summary_filepath)
+ digest_model.save_text_report(summary_filepath)
+
+ # Model summary yaml report
+ summary_filepath = os.path.join(output_dir, f"{model_name}_summary.yaml")
+ digest_model.save_yaml_report(summary_filepath)
# Save csv containing node-level information
nodes_filepath = os.path.join(output_dir, f"{model_name}_nodes.csv")
digest_model.save_nodes_csv_report(nodes_filepath)
# Save csv containing node type counter
- node_type_counter = digest_model.get_node_type_counts()
node_type_filepath = os.path.join(
output_dir, f"{model_name}_node_type_counts.csv"
)
- if node_type_counter:
- save_node_type_counts_csv_report(node_type_counter, node_type_filepath)
+
+ digest_model.save_node_type_counts_csv_report(node_type_filepath)
# Update global data structure for node type counter
- global_node_type_counter.update(node_type_counter)
+ global_node_type_counter.update(digest_model.node_type_counts)
# Save csv containing node shape counts per op_type
- node_shape_counts = digest_model.get_node_shape_counts()
node_shape_filepath = os.path.join(
output_dir, f"{model_name}_node_shape_counts.csv"
)
- save_node_shape_counts_csv_report(node_shape_counts, node_shape_filepath)
+ digest_model.save_node_shape_counts_csv_report(node_shape_filepath)
# Update global data structure for node shape counter
- for node_type, shape_counts in node_shape_counts.items():
+ for node_type, shape_counts in digest_model.get_node_shape_counts().items():
global_node_shape_counter[node_type].update(shape_counts)
if len(onnx_file_list) > 1:
global_filepath = os.path.join(output_dir, "global_node_type_counts.csv")
- global_node_type_counter = NodeTypeCounts(
- global_node_type_counter.most_common()
- )
- save_node_type_counts_csv_report(global_node_type_counter, global_filepath)
+ global_node_type_counts = NodeTypeCounts(global_node_type_counter.most_common())
+ save_node_type_counts_csv_report(global_node_type_counts, global_filepath)
global_filepath = os.path.join(output_dir, "global_node_shape_counts.csv")
save_node_shape_counts_csv_report(global_node_shape_counter, global_filepath)
diff --git a/setup.py b/setup.py
index ca21f4a..b2ad16d 100644
--- a/setup.py
+++ b/setup.py
@@ -4,7 +4,7 @@
setup(
name="digestai",
- version="1.0.0",
+ version="1.1.0",
description="Model analysis toolkit",
author="Philip Colangelo, Daniel Holanda",
packages=find_packages(where="src"),
diff --git a/src/digest/dialog.py b/src/digest/dialog.py
index d2f834e..ae9986d 100644
--- a/src/digest/dialog.py
+++ b/src/digest/dialog.py
@@ -125,13 +125,23 @@ class WarnDialog(QDialog):
def __init__(self, warning_message: str, parent=None):
super().__init__(parent)
- self.setWindowTitle("Warning Message")
+
self.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg"))
+
+ self.setWindowTitle("Warning Message")
+ self.setWindowFlags(Qt.WindowType.Dialog)
self.setMinimumWidth(300)
+ self.setWindowModality(Qt.WindowModality.WindowModal)
+
layout = QVBoxLayout()
# Application Version
- layout.addWidget(QLabel("
Something went wrong"))
+ layout.addWidget(QLabel("
Warning"))
layout.addWidget(QLabel(warning_message))
+
+ ok_button = QPushButton("OK")
+ ok_button.clicked.connect(self.accept) # Close dialog when clicked
+ layout.addWidget(ok_button)
+
self.setLayout(layout)
diff --git a/src/digest/histogramchartwidget.py b/src/digest/histogramchartwidget.py
index 97d5f16..f72befb 100644
--- a/src/digest/histogramchartwidget.py
+++ b/src/digest/histogramchartwidget.py
@@ -140,7 +140,7 @@ def __init__(self, *args, **kwargs):
super(StackedHistogramWidget, self).__init__(*args, **kwargs)
self.plot_widget = pg.PlotWidget()
- self.plot_widget.setMaximumHeight(150)
+ self.plot_widget.setMaximumHeight(200)
plot_item = self.plot_widget.getPlotItem()
if plot_item:
plot_item.setContentsMargins(0, 0, 0, 0)
@@ -157,7 +157,6 @@ def __init__(self, *args, **kwargs):
self.bar_spacing = 25
def set_data(self, data: OrderedDict, model_name, y_max, title="", set_ticks=False):
-
title_color = "rgb(0,0,0)" if set_ticks else "rgb(200,200,200)"
self.plot_widget.setLabel(
"left",
@@ -173,7 +172,8 @@ def set_data(self, data: OrderedDict, model_name, y_max, title="", set_ticks=Fal
x_positions = list(range(len(op_count)))
total_count = sum(op_count)
width = 0.6
- self.plot_widget.setFixedWidth(len(op_names) * self.bar_spacing)
+ self.plot_widget.setFixedWidth(500)
+
for count, x_pos, tick in zip(op_count, x_positions, op_names):
x0 = x_pos - width / 2
y0 = 0
diff --git a/src/digest/main.py b/src/digest/main.py
index 08c401a..db4c685 100644
--- a/src/digest/main.py
+++ b/src/digest/main.py
@@ -3,11 +3,13 @@
import os
import sys
+import shutil
import argparse
from datetime import datetime
-from typing import Dict, Tuple, Optional
+from typing import Dict, Tuple, Optional, Union
import tempfile
from enum import IntEnum
+import pandas as pd
import yaml
# This is a temporary workaround since the Qt designer generated files
@@ -33,10 +35,10 @@
QMenu,
)
from PySide6.QtGui import QDragEnterEvent, QDropEvent, QPixmap, QMovie, QIcon, QFont
-from PySide6.QtCore import Qt, QDir
+from PySide6.QtCore import Qt, QSize
from digest.dialog import StatusDialog, InfoDialog, WarnDialog, ProgressDialog
-from digest.thread import StatsThread, SimilarityThread
+from digest.thread import StatsThread, SimilarityThread, post_process
from digest.popup_window import PopupWindow
from digest.huggingface_page import HuggingfacePage
from digest.multi_model_selection_page import MultiModelSelectionPage
@@ -44,6 +46,9 @@
from digest.modelsummary import modelSummary
from digest.node_summary import NodeSummary
from digest.qt_utils import apply_dark_style_sheet
+from digest.model_class.digest_model import DigestModel
+from digest.model_class.digest_onnx_model import DigestOnnxModel
+from digest.model_class.digest_report_model import DigestReportModel
from utils import onnx_utils
GUI_CONFIG = os.path.join(os.path.dirname(__file__), "gui_config.yaml")
@@ -161,11 +166,12 @@ def __init__(self, model_file: Optional[str] = None):
self.status_dialog = None
self.err_open_dialog = None
self.temp_dir = tempfile.TemporaryDirectory()
- self.digest_models: Dict[str, onnx_utils.DigestOnnxModel] = {}
+ self.digest_models: Dict[str, Union[DigestOnnxModel, DigestReportModel]] = {}
# QThread containers
self.model_nodes_stats_thread: Dict[str, StatsThread] = {}
self.model_similarity_thread: Dict[str, SimilarityThread] = {}
+
self.model_similarity_report: Dict[str, SimilarityAnalysisReport] = {}
self.ui.singleModelWidget.hide()
@@ -209,7 +215,7 @@ def __init__(self, model_file: Optional[str] = None):
# Set up the HUGGINGFACE Page
huggingface_page = HuggingfacePage()
- huggingface_page.model_signal.connect(self.load_onnx)
+ huggingface_page.model_signal.connect(self.load_model)
self.ui.stackedWidget.insertWidget(self.Page.HUGGINGFACE, huggingface_page)
# Set up the multi model page and relevant button
@@ -217,15 +223,16 @@ def __init__(self, model_file: Optional[str] = None):
self.ui.stackedWidget.insertWidget(
self.Page.MULTIMODEL, self.multimodelselection_page
)
- self.multimodelselection_page.model_signal.connect(self.load_onnx)
+ self.multimodelselection_page.model_signal.connect(self.load_model)
# Load model file if given as input to the executable
if model_file:
- if (
- os.path.exists(model_file)
- and os.path.splitext(model_file)[-1] == ".onnx"
- ):
+ exists = os.path.exists(model_file)
+ ext = os.path.splitext(model_file)[-1]
+ if exists and ext == ".onnx":
self.load_onnx(model_file)
+ elif exists and ext == ".yaml":
+ self.load_report(model_file)
else:
self.err_open_dialog = StatusDialog(
f"Could not open {model_file}", parent=self
@@ -243,10 +250,11 @@ def uncheck_ingest_buttons(self):
def tab_focused(self, index):
widget = self.ui.tabWidget.widget(index)
if isinstance(widget, modelSummary):
- model_id = widget.digest_model.unique_id
+ unique_id = widget.digest_model.unique_id
if (
- self.stats_save_button_flag[model_id]
- and self.similarity_save_button_flag[model_id]
+ self.stats_save_button_flag[unique_id]
+ and self.similarity_save_button_flag[unique_id]
+ and not isinstance(widget.digest_model, DigestReportModel)
):
self.ui.saveBtn.setEnabled(True)
else:
@@ -257,11 +265,17 @@ def closeTab(self, index):
if isinstance(summary_widget, modelSummary):
unique_id = summary_widget.digest_model.unique_id
summary_widget.deleteLater()
- tab_thread = self.model_nodes_stats_thread[unique_id]
+
+ tab_thread = self.model_nodes_stats_thread.get(unique_id)
if tab_thread:
tab_thread.exit()
+ tab_thread.wait(5000)
+
if not tab_thread.isRunning():
del self.model_nodes_stats_thread[unique_id]
+ else:
+ print(f"Warning: Thread for {unique_id} did not finish in time")
+
# delete the digest model to free up used memory
if unique_id in self.digest_models:
del self.digest_models[unique_id]
@@ -272,40 +286,41 @@ def closeTab(self, index):
self.ui.singleModelWidget.hide()
def openFile(self):
- filename, _ = QFileDialog.getOpenFileName(
- self, "Open File", "", "ONNX Files (*.onnx)"
+ file_name, _ = QFileDialog.getOpenFileName(
+ self, "Open File", "", "ONNX and Report Files (*.onnx *.yaml)"
)
- if (
- filename and os.path.splitext(filename)[-1] == ".onnx"
- ): # Only if user selects a file and clicks OK
- self.load_onnx(filename)
+ if not file_name:
+ return
- def update_flops_label(
+ self.load_model(file_name)
+
+ def update_cards(
self,
- digest_model: onnx_utils.DigestOnnxModel,
+ digest_model: DigestModel,
unique_id: str,
):
- self.digest_models[unique_id].model_flops = digest_model.model_flops
+ self.digest_models[unique_id].flops = digest_model.flops
self.digest_models[unique_id].node_type_flops = digest_model.node_type_flops
- self.digest_models[unique_id].model_parameters = digest_model.model_parameters
+ self.digest_models[unique_id].parameters = digest_model.parameters
self.digest_models[unique_id].node_type_parameters = (
digest_model.node_type_parameters
)
- self.digest_models[unique_id].per_node_info = digest_model.per_node_info
+ self.digest_models[unique_id].node_data = digest_model.node_data
# We must iterate over the tabWidget and match to the tab_name because the user
# may have switched the currentTab during the threads execution.
+ curr_index = -1
for index in range(self.ui.tabWidget.count()):
widget = self.ui.tabWidget.widget(index)
if (
isinstance(widget, modelSummary)
and widget.digest_model.unique_id == unique_id
):
- if digest_model.model_flops is None:
+ if digest_model.flops is None:
flops_str = "--"
else:
- flops_str = format(digest_model.model_flops, ",")
+ flops_str = format(digest_model.flops, ",")
# Set up the pie chart
pie_chart_labels, pie_chart_data = zip(
@@ -328,11 +343,14 @@ def update_flops_label(
pie_chart_labels,
pie_chart_data,
)
+ curr_index = index
break
self.stats_save_button_flag[unique_id] = True
- if self.ui.tabWidget.currentIndex() == index:
- if self.similarity_save_button_flag[unique_id]:
+ if self.ui.tabWidget.currentIndex() == curr_index:
+ if self.similarity_save_button_flag[unique_id] and not isinstance(
+ digest_model, DigestReportModel
+ ):
self.ui.saveBtn.setEnabled(True)
def open_similarity_report(self, model_id: str, image_path, most_similar_models):
@@ -346,10 +364,12 @@ def update_similarity_widget(
completed_successfully: bool,
model_id: str,
most_similar: str,
- png_filepath: str,
+ png_filepath: Optional[str] = None,
+ df_sorted: Optional[pd.DataFrame] = None,
):
-
widget = None
+ digest_model = None
+ curr_index = -1
for index in range(self.ui.tabWidget.count()):
tab_widget = self.ui.tabWidget.widget(index)
if (
@@ -357,49 +377,90 @@ def update_similarity_widget(
and tab_widget.digest_model.unique_id == model_id
):
widget = tab_widget
+ digest_model = tab_widget.digest_model
+ curr_index = index
break
- if completed_successfully and isinstance(widget, modelSummary):
+ # convert back to a List[str]
+ most_similar_list = most_similar.split(",")
+
+ if (
+ completed_successfully
+ and isinstance(widget, modelSummary)
+ and digest_model
+ and png_filepath
+ ):
+
+ if df_sorted is not None:
+ post_process(
+ digest_model.model_name, most_similar_list, df_sorted, png_filepath
+ )
+
+ widget.load_gif.stop()
+ widget.ui.similarityImg.clear()
+ # We give the image a 10% haircut to fit it more aesthetically
widget_width = widget.ui.similarityImg.width()
- widget.ui.similarityImg.setPixmap(
- QPixmap(png_filepath).scaledToWidth(widget_width)
+
+ pixmap = QPixmap(png_filepath)
+ aspect_ratio = pixmap.width() / pixmap.height()
+ target_height = int(widget_width / aspect_ratio)
+ pixmap_scaled = pixmap.scaled(
+ QSize(widget_width, target_height),
+ Qt.AspectRatioMode.KeepAspectRatio,
+ Qt.TransformationMode.SmoothTransformation,
)
+
+ widget.ui.similarityImg.setPixmap(pixmap_scaled)
widget.ui.similarityImg.setText("")
widget.ui.similarityImg.setCursor(Qt.CursorShape.PointingHandCursor)
# Show most correlated models
widget.ui.similarityCorrelation.show()
widget.ui.similarityCorrelationStatic.show()
- most_similar_models = most_similar.split(",")
- text = (
- "\n
"
- f"{most_similar_models[0]}, {most_similar_models[1]}, and {most_similar_models[2]}."
- ""
- )
+
+ most_similar_list = most_similar_list[1:4]
+ if most_similar:
+ text = (
+ "\n
"
+ f"{most_similar_list[0]}, {most_similar_list[1]}, "
+ f"and {most_similar_list[2]}. "
+ ""
+ )
+ else:
+ # currently the similarity widget expects the most_similar_models
+ # to allows contains 3 models. For now we will just send three empty
+ # strings but at some point we should handle an arbitrary case.
+ most_similar_list = ["", "", ""]
+ text = "NTD"
# Create option to click to enlarge image
widget.ui.similarityImg.mousePressEvent = (
lambda event: self.open_similarity_report(
- model_id, png_filepath, most_similar_models
+ model_id, png_filepath, most_similar_list
)
)
# Create option to click to enlarge image
self.model_similarity_report[model_id] = SimilarityAnalysisReport(
- png_filepath, most_similar_models
+ png_filepath, most_similar_list
)
widget.ui.similarityCorrelation.setText(text)
elif isinstance(widget, modelSummary):
# Remove animation and set text to failing message
- widget.ui.similarityImg.setMovie(QMovie(None))
+ widget.load_gif.stop()
+ widget.ui.similarityImg.clear()
widget.ui.similarityImg.setText("Failed to perform similarity analysis")
else:
- print("Tab widget is not of type modelSummary which is unexpected.")
+ print(
+ f"Tab widget is of type {type(widget)} and not of type modelSummary "
+ "which is unexpected."
+ )
- #
self.similarity_save_button_flag[model_id] = True
- if self.ui.tabWidget.currentIndex() == index:
- if self.stats_save_button_flag[model_id]:
+ if self.ui.tabWidget.currentIndex() == curr_index:
+ if self.stats_save_button_flag[model_id] and not isinstance(
+ digest_model, DigestReportModel
+ ):
self.ui.saveBtn.setEnabled(True)
def load_onnx(self, filepath: str):
@@ -432,8 +493,9 @@ def load_onnx(self, filepath: str):
basename = os.path.splitext(os.path.basename(filepath))
model_name = basename[0]
- digest_model = onnx_utils.DigestOnnxModel(
- onnx_model=model, model_name=model_name, save_proto=False
+ # Save the model proto so we can use the Freeze Inputs feature
+ digest_model = DigestOnnxModel(
+ onnx_model=opt_model, model_name=model_name, save_proto=True
)
model_id = digest_model.unique_id
@@ -442,11 +504,9 @@ def load_onnx(self, filepath: str):
self.digest_models[model_id] = digest_model
- # We must set the proto for the model_summary freeze_inputs
- self.digest_models[model_id].model_proto = opt_model
-
- model_summary = modelSummary(self.digest_models[model_id])
- model_summary.freeze_inputs.complete_signal.connect(self.load_onnx)
+ model_summary = modelSummary(digest_model)
+ if model_summary.freeze_inputs:
+ model_summary.freeze_inputs.complete_signal.connect(self.load_onnx)
dynamic_input_dims = onnx_utils.get_dynamic_input_dims(opt_model)
if dynamic_input_dims:
@@ -480,14 +540,13 @@ def load_onnx(self, filepath: str):
model_summary.ui.modelFilename.setText(filepath)
model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y"))
- self.digest_models[model_id].model_name = model_name
- self.digest_models[model_id].filepath = filepath
-
- self.digest_models[model_id].model_inputs = (
- onnx_utils.get_model_input_shapes_types(opt_model)
+ digest_model.model_name = model_name
+ digest_model.filepath = filepath
+ digest_model.model_inputs = onnx_utils.get_model_input_shapes_types(
+ opt_model
)
- self.digest_models[model_id].model_outputs = (
- onnx_utils.get_model_output_shapes_types(opt_model)
+ digest_model.model_outputs = onnx_utils.get_model_output_shapes_types(
+ opt_model
)
progress.step()
@@ -498,9 +557,7 @@ def load_onnx(self, filepath: str):
# Kick off model stats thread
self.model_nodes_stats_thread[model_id] = StatsThread()
- self.model_nodes_stats_thread[model_id].completed.connect(
- self.update_flops_label
- )
+ self.model_nodes_stats_thread[model_id].completed.connect(self.update_cards)
self.model_nodes_stats_thread[model_id].model = opt_model
self.model_nodes_stats_thread[model_id].tab_name = model_name
@@ -518,7 +575,7 @@ def load_onnx(self, filepath: str):
model_summary.ui.opHistogramChart.bar_spacing = bar_spacing
model_summary.ui.opHistogramChart.set_data(node_type_counts)
model_summary.ui.nodes.setText(str(sum(node_type_counts.values())))
- self.digest_models[model_id].node_type_counts = node_type_counts
+ digest_model.node_type_counts = node_type_counts
progress.step()
progress.setLabelText("Gathering Model Inputs and Outputs")
@@ -577,24 +634,24 @@ def load_onnx(self, filepath: str):
model_summary.ui.modelProtoTable.setItem(
0, 1, QTableWidgetItem(str(opt_model.model_version))
)
- self.digest_models[model_id].model_version = opt_model.model_version
+ digest_model.model_version = opt_model.model_version
model_summary.ui.modelProtoTable.setItem(
1, 1, QTableWidgetItem(str(opt_model.graph.name))
)
- self.digest_models[model_id].graph_name = opt_model.graph.name
+ digest_model.graph_name = opt_model.graph.name
producer_txt = f"{opt_model.producer_name} {opt_model.producer_version}"
model_summary.ui.modelProtoTable.setItem(
2, 1, QTableWidgetItem(producer_txt)
)
- self.digest_models[model_id].producer_name = opt_model.producer_name
- self.digest_models[model_id].producer_version = opt_model.producer_version
+ digest_model.producer_name = opt_model.producer_name
+ digest_model.producer_version = opt_model.producer_version
model_summary.ui.modelProtoTable.setItem(
3, 1, QTableWidgetItem(str(opt_model.ir_version))
)
- self.digest_models[model_id].ir_version = opt_model.ir_version
+ digest_model.ir_version = opt_model.ir_version
for imp in opt_model.opset_import:
row_idx = model_summary.ui.importsTable.rowCount()
@@ -602,7 +659,7 @@ def load_onnx(self, filepath: str):
if imp.domain == "" or imp.domain == "ai.onnx":
model_summary.ui.opsetVersion.setText(str(imp.version))
domain = "ai.onnx"
- self.digest_models[model_id].opset = imp.version
+ digest_model.opset = imp.version
else:
domain = imp.domain
model_summary.ui.importsTable.setItem(
@@ -613,7 +670,7 @@ def load_onnx(self, filepath: str):
)
row_idx += 1
- self.digest_models[model_id].imports[imp.domain] = imp.version
+ digest_model.imports[imp.domain] = imp.version
progress.step()
progress.setLabelText("Wrapping Up Model Analysis")
@@ -628,14 +685,11 @@ def load_onnx(self, filepath: str):
self.ui.singleModelWidget.show()
progress.step()
- movie = QMovie(":/assets/gifs/load.gif")
- model_summary.ui.similarityImg.setMovie(movie)
- movie.start()
-
# Start similarity Analysis
# Note: Should only be started after the model tab has been created
png_tmp_path = os.path.join(self.temp_dir.name, model_id)
os.makedirs(png_tmp_path, exist_ok=True)
+ assert os.path.exists(png_tmp_path), f"Error with creating {png_tmp_path}"
self.model_similarity_thread[model_id] = SimilarityThread()
self.model_similarity_thread[model_id].completed_successfully.connect(
self.update_similarity_widget
@@ -652,6 +706,217 @@ def load_onnx(self, filepath: str):
except FileNotFoundError as e:
print(f"File not found: {e.filename}")
+ def load_report(self, filepath: str):
+
+ # Ensure the filepath follows a standard formatting:
+ filepath = os.path.normpath(filepath)
+
+ if not os.path.exists(filepath):
+ return
+
+ # Every time a report is loaded we should emulate a model summary button click
+ self.summary_clicked()
+
+ # Before opening the file, check to see if it is already opened.
+ for index in range(self.ui.tabWidget.count()):
+ widget = self.ui.tabWidget.widget(index)
+ if isinstance(widget, modelSummary) and filepath == widget.file:
+ self.ui.tabWidget.setCurrentIndex(index)
+ return
+
+ try:
+
+ progress = ProgressDialog("Loading Digest Report File...", 2, self)
+ QApplication.processEvents() # Process pending events
+
+ digest_model = DigestReportModel(filepath)
+
+ if not digest_model.is_valid:
+ progress.close()
+ invalid_yaml_dialog = StatusDialog(
+ title="Warning",
+ status_message=f"YAML file {filepath} is not a valid digest report",
+ )
+ invalid_yaml_dialog.show()
+
+ return
+
+ model_id = digest_model.unique_id
+
+ # There is no sense in offering to save the report
+ self.stats_save_button_flag[model_id] = False
+ self.similarity_save_button_flag[model_id] = False
+
+ self.digest_models[model_id] = digest_model
+
+ model_summary = modelSummary(digest_model)
+
+ self.ui.tabWidget.addTab(model_summary, "")
+ model_summary.ui.flops.setText("Loading...")
+
+ # Hide some of the components
+ model_summary.ui.similarityCorrelation.hide()
+ model_summary.ui.similarityCorrelationStatic.hide()
+
+ model_summary.file = filepath
+ model_summary.setObjectName(digest_model.model_name)
+ model_summary.ui.modelName.setText(digest_model.model_name)
+ model_summary.ui.modelFilename.setText(filepath)
+ model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y"))
+
+ model_summary.ui.parameters.setText(format(digest_model.parameters, ","))
+
+ node_type_counts = digest_model.node_type_counts
+ if len(node_type_counts) < 15:
+ bar_spacing = 40
+ else:
+ bar_spacing = 20
+
+ model_summary.ui.opHistogramChart.bar_spacing = bar_spacing
+ model_summary.ui.opHistogramChart.set_data(node_type_counts)
+ model_summary.ui.nodes.setText(str(sum(node_type_counts.values())))
+
+ progress.step()
+ progress.setLabelText("Gathering Model Inputs and Outputs")
+
+ # Inputs Table
+ model_summary.ui.inputsTable.setRowCount(
+ len(self.digest_models[model_id].model_inputs)
+ )
+
+ for row_idx, (input_name, input_info) in enumerate(
+ self.digest_models[model_id].model_inputs.items()
+ ):
+ model_summary.ui.inputsTable.setItem(
+ row_idx, 0, QTableWidgetItem(input_name)
+ )
+ model_summary.ui.inputsTable.setItem(
+ row_idx, 1, QTableWidgetItem(str(input_info.shape))
+ )
+ model_summary.ui.inputsTable.setItem(
+ row_idx, 2, QTableWidgetItem(str(input_info.dtype))
+ )
+ model_summary.ui.inputsTable.setItem(
+ row_idx, 3, QTableWidgetItem(str(input_info.size_kbytes))
+ )
+
+ model_summary.ui.inputsTable.resizeColumnsToContents()
+ model_summary.ui.inputsTable.resizeRowsToContents()
+
+ # Outputs Table
+ model_summary.ui.outputsTable.setRowCount(
+ len(self.digest_models[model_id].model_outputs)
+ )
+ for row_idx, (output_name, output_info) in enumerate(
+ self.digest_models[model_id].model_outputs.items()
+ ):
+ model_summary.ui.outputsTable.setItem(
+ row_idx, 0, QTableWidgetItem(output_name)
+ )
+ model_summary.ui.outputsTable.setItem(
+ row_idx, 1, QTableWidgetItem(str(output_info.shape))
+ )
+ model_summary.ui.outputsTable.setItem(
+ row_idx, 2, QTableWidgetItem(str(output_info.dtype))
+ )
+ model_summary.ui.outputsTable.setItem(
+ row_idx, 3, QTableWidgetItem(str(output_info.size_kbytes))
+ )
+
+ model_summary.ui.outputsTable.resizeColumnsToContents()
+ model_summary.ui.outputsTable.resizeRowsToContents()
+
+ progress.step()
+ progress.setLabelText("Gathering Model Proto Data")
+
+ # ModelProto Info
+ model_summary.ui.modelProtoTable.setItem(
+ 0, 1, QTableWidgetItem(str(digest_model.model_data["model_version"]))
+ )
+
+ model_summary.ui.modelProtoTable.setItem(
+ 1, 1, QTableWidgetItem(str(digest_model.model_data["graph_name"]))
+ )
+
+ producer_txt = (
+ f"{digest_model.model_data['producer_name']} "
+ f"{digest_model.model_data['producer_version']}"
+ )
+ model_summary.ui.modelProtoTable.setItem(
+ 2, 1, QTableWidgetItem(producer_txt)
+ )
+
+ model_summary.ui.modelProtoTable.setItem(
+ 3, 1, QTableWidgetItem(str(digest_model.model_data["ir_version"]))
+ )
+
+ for domain, version in digest_model.model_data["import_list"].items():
+ row_idx = model_summary.ui.importsTable.rowCount()
+ model_summary.ui.importsTable.insertRow(row_idx)
+ if domain == "" or domain == "ai.onnx":
+ model_summary.ui.opsetVersion.setText(str(version))
+ domain = "ai.onnx"
+
+ model_summary.ui.importsTable.setItem(
+ row_idx, 0, QTableWidgetItem(str(domain))
+ )
+ model_summary.ui.importsTable.setItem(
+ row_idx, 1, QTableWidgetItem(str(version))
+ )
+ row_idx += 1
+
+ progress.step()
+ progress.setLabelText("Wrapping Up Model Analysis")
+
+ model_summary.ui.importsTable.resizeColumnsToContents()
+ model_summary.ui.modelProtoTable.resizeColumnsToContents()
+ model_summary.setObjectName(digest_model.model_name)
+ new_tab_idx = self.ui.tabWidget.count() - 1
+ self.ui.tabWidget.setTabText(new_tab_idx, "".join(digest_model.model_name))
+ self.ui.tabWidget.setCurrentIndex(new_tab_idx)
+ self.ui.stackedWidget.setCurrentIndex(self.Page.SUMMARY)
+ self.ui.singleModelWidget.show()
+ progress.step()
+
+ self.update_cards(digest_model, digest_model.unique_id)
+
+ movie = QMovie(":/assets/gifs/load.gif")
+ model_summary.ui.similarityImg.setMovie(movie)
+ movie.start()
+
+ self.update_similarity_widget(
+ completed_successfully=bool(digest_model.similarity_heatmap_path),
+ model_id=digest_model.unique_id,
+ most_similar="",
+ png_filepath=digest_model.similarity_heatmap_path,
+ )
+
+ progress.close()
+
+ except FileNotFoundError as e:
+ print(f"File not found: {e.filename}")
+
+ def load_model(self, file_path: str):
+
+ # Ensure the filepath follows a standard formatting:
+ file_path = os.path.normpath(file_path)
+
+ if not os.path.exists(file_path):
+ return
+
+ file_ext = os.path.splitext(file_path)[-1]
+
+ if file_ext == ".onnx":
+ self.load_onnx(file_path)
+ elif file_ext == ".yaml":
+ self.load_report(file_path)
+ else:
+ bad_ext_dialog = StatusDialog(
+ f"Digest does not support files with the extension {file_ext}",
+ parent=self,
+ )
+ bad_ext_dialog.show()
+
def dragEnterEvent(self, event: QDragEnterEvent):
if event.mimeData().hasUrls():
event.acceptProposedAction()
@@ -660,9 +925,7 @@ def dropEvent(self, event: QDropEvent):
if event.mimeData().hasUrls():
for url in event.mimeData().urls():
file_path = url.toLocalFile()
- if file_path.endswith(".onnx"):
- self.load_onnx(file_path)
- break
+ self.load_model(file_path)
## functions for changing menu page
def logo_clicked(self):
@@ -710,14 +973,8 @@ def save_reports(self):
self, "Select Directory"
)
- if not save_directory:
- return
-
- # Create a QDir object
- directory = QDir(save_directory)
-
# Check if the directory exists and is writable
- if not directory.exists() and directory.isWritable(): # type: ignore
+ if not os.path.exists(save_directory) or not os.access(save_directory, os.W_OK):
self.show_warning_dialog(
f"The directory {save_directory} is not valid or writable."
)
@@ -726,43 +983,56 @@ def save_reports(self):
save_directory, str(digest_model.model_name) + "_reports"
)
- os.makedirs(save_directory, exist_ok=True)
-
- # Save the node histogram image
- node_histogram = current_tab.ui.opHistogramChart.grab()
- node_histogram.save(
- os.path.join(save_directory, f"{model_name}_histogram.png"), "PNG"
- )
+ try:
+ os.makedirs(save_directory, exist_ok=True)
- # Save csv of node type counts
- node_type_filepath = os.path.join(
- save_directory, f"{model_name}_node_type_counts.csv"
- )
- node_counter = digest_model.get_node_type_counts()
- if node_counter:
- onnx_utils.save_node_type_counts_csv_report(
- node_counter, node_type_filepath
+ # Save the node histogram image
+ node_histogram = current_tab.ui.opHistogramChart.grab()
+ node_histogram.save(
+ os.path.join(save_directory, f"{model_name}_histogram.png"), "PNG"
)
- # Save the similarity image
- similarity_png = self.model_similarity_report[digest_model.unique_id].grab()
- similarity_png.save(
- os.path.join(save_directory, f"{model_name}_heatmap.png"), "PNG"
- )
+ # Save csv of node type counts
+ node_type_filepath = os.path.join(
+ save_directory, f"{model_name}_node_type_counts.csv"
+ )
+ digest_model.save_node_type_counts_csv_report(node_type_filepath)
+
+ # Save (copy) the similarity image
+ png_file_path = self.model_similarity_thread[
+ digest_model.unique_id
+ ].png_filepath
+ png_save_path = os.path.join(save_directory, f"{model_name}_heatmap.png")
+ if png_file_path and os.path.exists(png_file_path):
+ shutil.copy(png_file_path, png_save_path)
+
+ # Save the text report
+ txt_report_filepath = os.path.join(
+ save_directory, f"{model_name}_report.txt"
+ )
+ digest_model.save_text_report(txt_report_filepath)
- # Save the text report
- txt_report_filepath = os.path.join(save_directory, f"{model_name}_report.txt")
- digest_model.save_txt_report(txt_report_filepath)
+ # Save the yaml report
+ yaml_report_filepath = os.path.join(
+ save_directory, f"{model_name}_report.yaml"
+ )
+ digest_model.save_yaml_report(yaml_report_filepath)
- # Save the node list
- nodes_report_filepath = os.path.join(save_directory, f"{model_name}_nodes.csv")
- self.save_nodes_csv(nodes_report_filepath, False)
+ # Save the node list
+ nodes_report_filepath = os.path.join(
+ save_directory, f"{model_name}_nodes.csv"
+ )
- self.status_dialog = StatusDialog(
- f"Saved reports to: \n{os.path.abspath(save_directory)}",
- "Successfully saved reports!",
- )
- self.status_dialog.show()
+ self.save_nodes_csv(nodes_report_filepath, False)
+ except Exception as exception: # pylint: disable=broad-exception-caught
+ self.status_dialog = StatusDialog(f"{exception}")
+ self.status_dialog.show()
+ else:
+ self.status_dialog = StatusDialog(
+ f"Saved reports to: \n{os.path.abspath(save_directory)}",
+ "Successfully saved reports!",
+ )
+ self.status_dialog.show()
def on_dialog_closed(self):
self.infoDialog = None
@@ -829,7 +1099,7 @@ def open_node_summary(self):
digest_models = self.digest_models[model_id]
node_summary = NodeSummary(
- model_name=model_name, node_data=digest_models.per_node_info
+ model_name=model_name, node_data=digest_models.node_data
)
self.nodes_window[model_id] = PopupWindow(
diff --git a/src/digest/model_class/digest_model.py b/src/digest/model_class/digest_model.py
new file mode 100644
index 0000000..87f399d
--- /dev/null
+++ b/src/digest/model_class/digest_model.py
@@ -0,0 +1,232 @@
+# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.
+
+import os
+import csv
+from enum import Enum
+from dataclasses import dataclass, field
+from uuid import uuid4
+from abc import ABC, abstractmethod
+from collections import Counter, OrderedDict, defaultdict
+from typing import List, Dict, Optional, Any, Union
+
+
+class SupportedModelTypes(Enum):
+ ONNX = "onnx"
+ REPORT = "report"
+
+
+class NodeParsingException(Exception):
+ pass
+
+
+# The classes are for type aliasing. Once python 3.10 is the minimum we can switch to TypeAlias
+class NodeShapeCounts(defaultdict[str, Counter]):
+ def __init__(self):
+ super().__init__(Counter) # Initialize with the Counter factory
+
+
+class NodeTypeCounts(Dict[str, int]):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+
+@dataclass
+class TensorInfo:
+ "Used to store node input and output tensor information"
+ dtype: Optional[str] = None
+ dtype_bytes: Optional[int] = None
+ size_kbytes: Optional[float] = None
+ shape: List[Union[int, str]] = field(default_factory=list)
+
+
+class TensorData(OrderedDict[str, TensorInfo]):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+
+class NodeInfo:
+ def __init__(self) -> None:
+ self.flops: Optional[int] = None
+ self.parameters: int = 0 # TODO: should we make this Optional[int] = None?
+ self.node_type: Optional[str] = None
+ self.attributes: OrderedDict[str, Any] = OrderedDict()
+ # We use an ordered dictionary because the order in which
+ # the inputs and outputs are listed in the node matter.
+ self.inputs = TensorData()
+ self.outputs = TensorData()
+
+ def get_input(self, index: int) -> TensorInfo:
+ return list(self.inputs.values())[index]
+
+ def get_output(self, index: int) -> TensorInfo:
+ return list(self.outputs.values())[index]
+
+ def __str__(self):
+ """Provides a human-readable string representation of NodeInfo."""
+ output = [
+ f"Node Type: {self.node_type}",
+ f"FLOPs: {self.flops if self.flops is not None else 'N/A'}",
+ f"Parameters: {self.parameters}",
+ ]
+
+ if self.attributes:
+ output.append("Attributes:")
+ for key, value in self.attributes.items():
+ output.append(f" - {key}: {value}")
+
+ if self.inputs:
+ output.append("Inputs:")
+ for name, tensor in self.inputs.items():
+ output.append(f" - {name}: {tensor}")
+
+ if self.outputs:
+ output.append("Outputs:")
+ for name, tensor in self.outputs.items():
+ output.append(f" - {name}: {tensor}")
+
+ return "\n".join(output)
+
+
+# The classes are for type aliasing. Once python 3.10 is the minimum we can switch to TypeAlias
+class NodeData(OrderedDict[str, NodeInfo]):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+
+class DigestModel(ABC):
+ def __init__(self, filepath: str, model_name: str, model_type: SupportedModelTypes):
+ # Public members exposed to the API
+ self.unique_id: str = str(uuid4())
+ self.filepath: Optional[str] = filepath
+ self.model_name: str = model_name
+ self.model_type: SupportedModelTypes = model_type
+ self.node_type_counts: NodeTypeCounts = NodeTypeCounts()
+ self.flops: Optional[int] = None
+ self.parameters: int = 0
+ self.node_type_flops: Dict[str, int] = {}
+ self.node_type_parameters: Dict[str, int] = {}
+ self.node_data = NodeData()
+ self.model_inputs = TensorData()
+ self.model_outputs = TensorData()
+
+ def get_node_shape_counts(self) -> NodeShapeCounts:
+ tensor_shape_counter = NodeShapeCounts()
+ for _, info in self.node_data.items():
+ shape_hash = tuple([tuple(v.shape) for _, v in info.inputs.items()])
+ if info.node_type:
+ tensor_shape_counter[info.node_type][shape_hash] += 1
+ return tensor_shape_counter
+
+ @abstractmethod
+ def parse_model_nodes(self, *args, **kwargs) -> None:
+ pass
+
+ @abstractmethod
+ def save_yaml_report(self, filepath: str) -> None:
+ pass
+
+ @abstractmethod
+ def save_text_report(self, filepath: str) -> None:
+ pass
+
+ def save_nodes_csv_report(self, filepath: str) -> None:
+ save_nodes_csv_report(self.node_data, filepath)
+
+ def save_node_type_counts_csv_report(self, filepath: str) -> None:
+ if self.node_type_counts:
+ save_node_type_counts_csv_report(self.node_type_counts, filepath)
+
+ def save_node_shape_counts_csv_report(self, filepath: str) -> None:
+ save_node_shape_counts_csv_report(self.get_node_shape_counts(), filepath)
+
+
+def save_nodes_csv_report(node_data: NodeData, filepath: str) -> None:
+
+ parent_dir = os.path.dirname(os.path.abspath(filepath))
+ if not os.path.exists(parent_dir):
+ raise FileNotFoundError(f"Directory {parent_dir} does not exist.")
+
+ flattened_data = []
+ fieldnames = ["Node Name", "Node Type", "Parameters", "FLOPs", "Attributes"]
+ input_fieldnames = []
+ output_fieldnames = []
+ for name, node_info in node_data.items():
+ row = OrderedDict()
+ row["Node Name"] = name
+ row["Node Type"] = str(node_info.node_type)
+ row["Parameters"] = str(node_info.parameters)
+ row["FLOPs"] = str(node_info.flops)
+ if node_info.attributes:
+ row["Attributes"] = str({k: v for k, v in node_info.attributes.items()})
+ else:
+ row["Attributes"] = ""
+
+ for i, (input_name, input_info) in enumerate(node_info.inputs.items()):
+ column_name = f"Input{i+1} (Shape, Dtype, Size (kB))"
+ row[column_name] = (
+ f"{input_name} ({input_info.shape}, {input_info.dtype}, {input_info.size_kbytes})"
+ )
+
+ # Dynamically add input column names to fieldnames if not already present
+ if column_name not in input_fieldnames:
+ input_fieldnames.append(column_name)
+
+ for i, (output_name, output_info) in enumerate(node_info.outputs.items()):
+ column_name = f"Output{i+1} (Shape, Dtype, Size (kB))"
+ row[column_name] = (
+ f"{output_name} ({output_info.shape}, "
+ f"{output_info.dtype}, {output_info.size_kbytes})"
+ )
+
+ # Dynamically add input column names to fieldnames if not already present
+ if column_name not in output_fieldnames:
+ output_fieldnames.append(column_name)
+
+ flattened_data.append(row)
+
+ fieldnames = fieldnames + input_fieldnames + output_fieldnames
+ try:
+ with open(filepath, "w", encoding="utf-8", newline="") as csvfile:
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator="\n")
+ writer.writeheader()
+ writer.writerows(flattened_data)
+ except PermissionError as exception:
+ raise PermissionError(
+ f"Saving reports to {filepath} failed with error {exception}"
+ )
+
+
+def save_node_type_counts_csv_report(
+ node_type_counts: NodeTypeCounts, filepath: str
+) -> None:
+
+ parent_dir = os.path.dirname(os.path.abspath(filepath))
+ if not os.path.exists(parent_dir):
+ raise FileNotFoundError(f"Directory {parent_dir} does not exist.")
+
+ header = ["Node Type", "Count"]
+
+ with open(filepath, "w", encoding="utf-8", newline="") as csvfile:
+ writer = csv.writer(csvfile, lineterminator="\n")
+ writer.writerow(header)
+ for node_type, node_count in node_type_counts.items():
+ writer.writerow([node_type, node_count])
+
+
+def save_node_shape_counts_csv_report(
+ node_shape_counts: NodeShapeCounts, filepath: str
+) -> None:
+
+ parent_dir = os.path.dirname(os.path.abspath(filepath))
+ if not os.path.exists(parent_dir):
+ raise FileNotFoundError(f"Directory {parent_dir} does not exist.")
+
+ header = ["Node Type", "Input Tensors Shapes", "Count"]
+
+ with open(filepath, "w", encoding="utf-8", newline="") as csvfile:
+ writer = csv.writer(csvfile, dialect="excel", lineterminator="\n")
+ writer.writerow(header)
+ for node_type, node_info in node_shape_counts.items():
+ info_iter = iter(node_info.items())
+ for shape, count in info_iter:
+ writer.writerow([node_type, shape, count])
diff --git a/src/digest/model_class/digest_onnx_model.py b/src/digest/model_class/digest_onnx_model.py
new file mode 100644
index 0000000..8c8dd7f
--- /dev/null
+++ b/src/digest/model_class/digest_onnx_model.py
@@ -0,0 +1,656 @@
+# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.
+
+import os
+from typing import List, Dict, Optional, Tuple, cast
+from datetime import datetime
+import importlib.metadata
+from collections import OrderedDict
+import yaml
+import numpy as np
+import onnx
+from prettytable import PrettyTable
+from digest.model_class.digest_model import (
+ DigestModel,
+ SupportedModelTypes,
+ NodeInfo,
+ TensorData,
+ TensorInfo,
+)
+import utils.onnx_utils as onnx_utils
+
+
+class DigestOnnxModel(DigestModel):
+ def __init__(
+ self,
+ onnx_model: onnx.ModelProto,
+ onnx_filepath: str = "",
+ model_name: str = "",
+ save_proto: bool = True,
+ ) -> None:
+ super().__init__(onnx_filepath, model_name, SupportedModelTypes.ONNX)
+
+ self.model_type = SupportedModelTypes.ONNX
+
+ # Public members exposed to the API
+ self.model_proto: Optional[onnx.ModelProto] = onnx_model if save_proto else None
+ self.model_version: Optional[int] = None
+ self.graph_name: Optional[str] = None
+ self.producer_name: Optional[str] = None
+ self.producer_version: Optional[str] = None
+ self.ir_version: Optional[int] = None
+ self.opset: Optional[int] = None
+ self.imports: OrderedDict[str, int] = OrderedDict()
+
+ # Private members not intended to be exposed
+ self.input_tensors_: Dict[str, onnx.ValueInfoProto] = {}
+ self.output_tensors_: Dict[str, onnx.ValueInfoProto] = {}
+ self.value_tensors_: Dict[str, onnx.ValueInfoProto] = {}
+ self.init_tensors_: Dict[str, onnx.TensorProto] = {}
+
+ self.update_state(onnx_model)
+
+ def update_state(self, model_proto: onnx.ModelProto) -> None:
+ self.model_version = model_proto.model_version
+ self.graph_name = model_proto.graph.name
+ self.producer_name = model_proto.producer_name
+ self.producer_version = model_proto.producer_version
+ self.ir_version = model_proto.ir_version
+ self.opset = onnx_utils.get_opset(model_proto)
+ self.imports = OrderedDict(
+ sorted(
+ (import_.domain, import_.version)
+ for import_ in model_proto.opset_import
+ )
+ )
+
+ self.model_inputs = onnx_utils.get_model_input_shapes_types(model_proto)
+ self.model_outputs = onnx_utils.get_model_output_shapes_types(model_proto)
+
+ self.node_type_counts = onnx_utils.get_node_type_counts(model_proto)
+ self.parse_model_nodes(model_proto)
+
+ def get_node_tensor_info_(
+ self, onnx_node: onnx.NodeProto
+ ) -> Tuple[TensorData, TensorData]:
+ """
+ This function is set to private because it is not intended to be used
+ outside of the DigestOnnxModel class.
+ """
+
+ input_tensor_info = TensorData()
+ for node_input in onnx_node.input:
+ input_tensor_info[node_input] = TensorInfo()
+ if (
+ node_input in self.input_tensors_
+ or node_input in self.value_tensors_
+ or node_input in self.output_tensors_
+ ):
+ tensor = (
+ self.input_tensors_.get(node_input)
+ or self.value_tensors_.get(node_input)
+ or self.output_tensors_.get(node_input)
+ )
+ if tensor:
+ for dim in tensor.type.tensor_type.shape.dim:
+ if dim.HasField("dim_value"):
+ input_tensor_info[node_input].shape.append(dim.dim_value)
+ elif dim.HasField("dim_param"):
+ input_tensor_info[node_input].shape.append(dim.dim_param)
+
+ dtype_str, dtype_bytes = onnx_utils.tensor_type_to_str_and_size(
+ tensor.type.tensor_type.elem_type
+ )
+ elif node_input in self.init_tensors_:
+ input_tensor_info[node_input].shape.extend(
+ [dim for dim in self.init_tensors_[node_input].dims]
+ )
+ dtype_str, dtype_bytes = onnx_utils.tensor_type_to_str_and_size(
+ self.init_tensors_[node_input].data_type
+ )
+ else:
+ dtype_str = None
+ dtype_bytes = None
+
+ input_tensor_info[node_input].dtype = dtype_str
+ input_tensor_info[node_input].dtype_bytes = dtype_bytes
+
+ if (
+ all(isinstance(s, int) for s in input_tensor_info[node_input].shape)
+ and dtype_bytes
+ ):
+ tensor_size = float(
+ np.prod(np.array(input_tensor_info[node_input].shape))
+ )
+ input_tensor_info[node_input].size_kbytes = (
+ tensor_size * float(dtype_bytes) / 1024.0
+ )
+
+ output_tensor_info = TensorData()
+ for node_output in onnx_node.output:
+ output_tensor_info[node_output] = TensorInfo()
+ if (
+ node_output in self.input_tensors_
+ or node_output in self.value_tensors_
+ or node_output in self.output_tensors_
+ ):
+ tensor = (
+ self.input_tensors_.get(node_output)
+ or self.value_tensors_.get(node_output)
+ or self.output_tensors_.get(node_output)
+ )
+ if tensor:
+ output_tensor_info[node_output].shape.extend(
+ [
+ int(dim.dim_value)
+ for dim in tensor.type.tensor_type.shape.dim
+ ]
+ )
+ dtype_str, dtype_bytes = onnx_utils.tensor_type_to_str_and_size(
+ tensor.type.tensor_type.elem_type
+ )
+ elif node_output in self.init_tensors_:
+ output_tensor_info[node_output].shape.extend(
+ [dim for dim in self.init_tensors_[node_output].dims]
+ )
+ dtype_str, dtype_bytes = onnx_utils.tensor_type_to_str_and_size(
+ self.init_tensors_[node_output].data_type
+ )
+
+ else:
+ dtype_str = None
+ dtype_bytes = None
+
+ output_tensor_info[node_output].dtype = dtype_str
+ output_tensor_info[node_output].dtype_bytes = dtype_bytes
+
+ if (
+ all(isinstance(s, int) for s in output_tensor_info[node_output].shape)
+ and dtype_bytes
+ ):
+ tensor_size = float(
+ np.prod(np.array(output_tensor_info[node_output].shape))
+ )
+ output_tensor_info[node_output].size_kbytes = (
+ tensor_size * float(dtype_bytes) / 1024.0
+ )
+
+ return input_tensor_info, output_tensor_info
+
+ def parse_model_nodes(self, onnx_model: onnx.ModelProto) -> None:
+ """
+ Calculate total number of FLOPs found in the onnx model.
+ FLOP is defined as one floating-point operation. This distinguishes
+ from multiply-accumulates (MACs) where FLOPs == 2 * MACs.
+ """
+
+ # Initialze to zero so we can accumulate. Set to None during the
+ # model FLOPs calculation if it errors out.
+ self.flops = 0
+
+ # Check to see if the model inputs have any dynamic shapes
+ if onnx_utils.get_dynamic_input_dims(onnx_model):
+ self.flops = None
+
+ try:
+ onnx_model, _ = onnx_utils.optimize_onnx_model(onnx_model)
+
+ onnx_model = onnx.shape_inference.infer_shapes(
+ onnx_model, strict_mode=True, data_prop=True
+ )
+ except Exception as e: # pylint: disable=broad-except
+ print(f"ONNX utils: {str(e)}")
+ self.flops = None
+
+ # If the ONNX model contains one of the following unsupported ops, then this
+ # function will return None since the FLOP total is expected to be incorrect
+ unsupported_ops = [
+ "Einsum",
+ "RNN",
+ "GRU",
+ "DeformConv",
+ ]
+
+ if not self.input_tensors_:
+ self.input_tensors_ = {
+ tensor.name: tensor for tensor in onnx_model.graph.input
+ }
+
+ if not self.output_tensors_:
+ self.output_tensors_ = {
+ tensor.name: tensor for tensor in onnx_model.graph.output
+ }
+
+ if not self.value_tensors_:
+ self.value_tensors_ = {
+ tensor.name: tensor for tensor in onnx_model.graph.value_info
+ }
+
+ if not self.init_tensors_:
+ self.init_tensors_ = {
+ tensor.name: tensor for tensor in onnx_model.graph.initializer
+ }
+
+ for node in onnx_model.graph.node: # pylint: disable=E1101
+
+ node_info = NodeInfo()
+
+ # TODO: I have encountered models containing nodes with no name. It would be a good idea
+ # to have this type of model info fed back to the user through a warnings section.
+ if not node.name:
+ node.name = f"{node.op_type}_{len(self.node_data)}"
+
+ node_info.node_type = node.op_type
+ input_tensor_info, output_tensor_info = self.get_node_tensor_info_(node)
+ node_info.inputs = input_tensor_info
+ node_info.outputs = output_tensor_info
+
+ # Check if this node has parameters through the init tensors
+ for input_name, input_tensor in node_info.inputs.items():
+ if input_name in self.init_tensors_:
+ if all(isinstance(dim, int) for dim in input_tensor.shape):
+ input_parameters = int(np.prod(np.array(input_tensor.shape)))
+ node_info.parameters += input_parameters
+ self.parameters += input_parameters
+ self.node_type_parameters[node.op_type] = (
+ self.node_type_parameters.get(node.op_type, 0)
+ + input_parameters
+ )
+ else:
+ print(f"Tensor with params has unknown shape: {input_name}")
+
+ for attribute in node.attribute:
+ node_info.attributes.update(onnx_utils.attribute_to_dict(attribute))
+
+ # if node.name in self.node_data:
+ # print(f"Node name {node.name} is a duplicate.")
+
+ self.node_data[node.name] = node_info
+
+ if node.op_type in unsupported_ops:
+ self.flops = None
+ node_info.flops = None
+
+ try:
+
+ if (
+ node.op_type == "MatMul"
+ or node.op_type == "MatMulInteger"
+ or node.op_type == "QLinearMatMul"
+ ):
+
+ input_a = node_info.get_input(0).shape
+ if node.op_type == "QLinearMatMul":
+ input_b = node_info.get_input(3).shape
+ else:
+ input_b = node_info.get_input(1).shape
+
+ if not all(
+ isinstance(dim, int) for dim in input_a
+ ) or not isinstance(input_b[-1], int):
+ node_info.flops = None
+ self.flops = None
+ continue
+
+ node_info.flops = int(
+ 2 * np.prod(np.array(input_a), dtype=np.int64) * input_b[-1]
+ )
+
+ elif (
+ node.op_type == "Mul"
+ or node.op_type == "Div"
+ or node.op_type == "Add"
+ ):
+ input_a = node_info.get_input(0).shape
+ input_b = node_info.get_input(1).shape
+
+ if not all(isinstance(dim, int) for dim in input_a) or not all(
+ isinstance(dim, int) for dim in input_b
+ ):
+ node_info.flops = None
+ self.flops = None
+ continue
+
+ node_info.flops = int(
+ np.prod(np.array(input_a), dtype=np.int64)
+ ) + int(np.prod(np.array(input_b), dtype=np.int64))
+
+ elif node.op_type == "Gemm" or node.op_type == "QGemm":
+ x_shape = node_info.get_input(0).shape
+ if node.op_type == "Gemm":
+ w_shape = node_info.get_input(1).shape
+ else:
+ w_shape = node_info.get_input(3).shape
+
+ if not all(isinstance(dim, int) for dim in x_shape) or not all(
+ isinstance(dim, int) for dim in w_shape
+ ):
+ node_info.flops = None
+ self.flops = None
+ continue
+
+ mm_dims = [
+ (
+ x_shape[0]
+ if not node_info.attributes.get("transA", 0)
+ else x_shape[1]
+ ),
+ (
+ x_shape[1]
+ if not node_info.attributes.get("transA", 0)
+ else x_shape[0]
+ ),
+ (
+ w_shape[1]
+ if not node_info.attributes.get("transB", 0)
+ else w_shape[0]
+ ),
+ ]
+
+ node_info.flops = int(
+ 2 * np.prod(np.array(mm_dims), dtype=np.int64)
+ )
+
+ if len(mm_dims) == 3: # if there is a bias input
+ bias_shape = node_info.get_input(2).shape
+ node_info.flops += int(np.prod(np.array(bias_shape)))
+
+ elif (
+ node.op_type == "Conv"
+ or node.op_type == "ConvInteger"
+ or node.op_type == "QLinearConv"
+ or node.op_type == "ConvTranspose"
+ ):
+ # N, C, d1, ..., dn
+ x_shape = node_info.get_input(0).shape
+
+ # M, C/group, k1, ..., kn. Note C and M are swapped for ConvTranspose
+ if node.op_type == "QLinearConv":
+ w_shape = node_info.get_input(3).shape
+ else:
+ w_shape = node_info.get_input(1).shape
+
+ if not all(isinstance(dim, int) for dim in x_shape):
+ node_info.flops = None
+ self.flops = None
+ continue
+
+ x_shape_ints = cast(List[int], x_shape)
+ w_shape_ints = cast(List[int], w_shape)
+
+ has_bias = False # Note, ConvInteger has no bias
+ if node.op_type == "Conv" and len(node_info.inputs) == 3:
+ has_bias = True
+ elif node.op_type == "QLinearConv" and len(node_info.inputs) == 9:
+ has_bias = True
+
+ num_dims = len(x_shape_ints) - 2
+ strides = node_info.attributes.get(
+ "strides", [1] * num_dims
+ ) # type: List[int]
+ dilation = node_info.attributes.get(
+ "dilations", [1] * num_dims
+ ) # type: List[int]
+ kernel_shape = w_shape_ints[2:]
+ batch_size = x_shape_ints[0]
+ out_channels = w_shape_ints[0]
+ out_dims = [batch_size, out_channels]
+ output_shape = node_info.attributes.get(
+ "output_shape", []
+ ) # type: List[int]
+
+ # If output_shape is given then we do not need to compute it ourselves
+ # The output_shape attribute does not include batch_size or channels and
+ # is only valid for ConvTranspose
+ if output_shape:
+ out_dims.extend(output_shape)
+ else:
+ auto_pad = node_info.attributes.get(
+ "auto_pad", "NOTSET".encode()
+ ).decode()
+ # SAME expects padding so that the output_shape = CEIL(input_shape / stride)
+ if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
+ out_dims.extend(
+ [x * s for x, s in zip(x_shape_ints[2:], strides)]
+ )
+ else:
+ # NOTSET means just use pads attribute
+ if auto_pad == "NOTSET":
+ pads = node_info.attributes.get(
+ "pads", [0] * num_dims * 2
+ )
+ # VALID essentially means no padding
+ elif auto_pad == "VALID":
+ pads = [0] * num_dims * 2
+
+ for i in range(num_dims):
+ dim_in = x_shape_ints[i + 2] # type: int
+
+ if node.op_type == "ConvTranspose":
+ out_dim = (
+ strides[i] * (dim_in - 1)
+ + ((kernel_shape[i] - 1) * dilation[i] + 1)
+ - pads[i]
+ - pads[i + num_dims]
+ )
+ else:
+ out_dim = (
+ dim_in
+ + pads[i]
+ + pads[i + num_dims]
+ - dilation[i] * (kernel_shape[i] - 1)
+ - 1
+ ) // strides[i] + 1
+
+ out_dims.append(out_dim)
+
+ kernel_flops = int(
+ np.prod(np.array(kernel_shape)) * w_shape_ints[1]
+ )
+ output_points = int(np.prod(np.array(out_dims)))
+ bias_ops = output_points if has_bias else int(0)
+ node_info.flops = 2 * kernel_flops * output_points + bias_ops
+
+ elif node.op_type == "LSTM" or node.op_type == "DynamicQuantizeLSTM":
+
+ x_shape = node_info.get_input(
+ 0
+ ).shape # seq_length, batch_size, input_dim
+
+ if not all(isinstance(dim, int) for dim in x_shape):
+ node_info.flops = None
+ self.flops = None
+ continue
+
+ x_shape_ints = cast(List[int], x_shape)
+ hidden_size = node_info.attributes["hidden_size"]
+ direction = (
+ 2
+ if node_info.attributes.get("direction")
+ == "bidirectional".encode()
+ else 1
+ )
+
+ has_bias = True if len(node_info.inputs) >= 4 else False
+ if has_bias:
+ bias_shape = node_info.get_input(3).shape
+ if isinstance(bias_shape[1], int):
+ bias_ops = bias_shape[1]
+ else:
+ bias_ops = 0
+ else:
+ bias_ops = 0
+ # seq_length, batch_size, input_dim = x_shape
+ if not isinstance(bias_ops, int):
+ bias_ops = int(0)
+ num_gates = int(4)
+ gate_input_flops = int(2 * x_shape_ints[2] * hidden_size)
+ gate_hid_flops = int(2 * hidden_size * hidden_size)
+ unit_flops = (
+ num_gates * (gate_input_flops + gate_hid_flops) + bias_ops
+ )
+ node_info.flops = (
+ x_shape_ints[1] * x_shape_ints[0] * direction * unit_flops
+ )
+ # In this case we just hit an op that doesn't have FLOPs
+ else:
+ node_info.flops = None
+
+ except IndexError as err:
+ print(f"Error parsing node {node.name}: {err}")
+ node_info.flops = None
+ self.flops = None
+ continue
+
+ # Update the model level flops count
+ if node_info.flops is not None and self.flops is not None:
+ self.flops += node_info.flops
+
+ # Update the node type flops count
+ self.node_type_flops[node.op_type] = (
+ self.node_type_flops.get(node.op_type, 0) + node_info.flops
+ )
+
+ def save_yaml_report(self, filepath: str) -> None:
+
+ parent_dir = os.path.dirname(os.path.abspath(filepath))
+ if not os.path.exists(parent_dir):
+ raise FileNotFoundError(f"Directory {parent_dir} does not exist.")
+
+ report_date = datetime.now().strftime("%B %d, %Y")
+
+ input_tensors = dict({k: vars(v) for k, v in self.model_inputs.items()})
+ output_tensors = dict({k: vars(v) for k, v in self.model_outputs.items()})
+ digest_version = importlib.metadata.version("digestai")
+
+ yaml_data = {
+ "report_date": report_date,
+ "digest_version": digest_version,
+ "model_type": self.model_type.value,
+ "model_file": self.filepath,
+ "model_name": self.model_name,
+ "model_version": self.model_version,
+ "graph_name": self.graph_name,
+ "producer_name": self.producer_name,
+ "producer_version": self.producer_version,
+ "ir_version": self.ir_version,
+ "opset": self.opset,
+ "import_list": dict(self.imports),
+ "graph_nodes": sum(self.node_type_counts.values()),
+ "parameters": self.parameters,
+ "flops": self.flops,
+ "node_type_counts": dict(self.node_type_counts),
+ "node_type_flops": dict(self.node_type_flops),
+ "node_type_parameters": self.node_type_parameters,
+ "input_tensors": input_tensors,
+ "output_tensors": output_tensors,
+ }
+
+ with open(filepath, "w", encoding="utf-8") as f_p:
+ yaml.dump(yaml_data, f_p, sort_keys=False)
+
+ def save_text_report(self, filepath: str) -> None:
+
+ parent_dir = os.path.dirname(os.path.abspath(filepath))
+ if not os.path.exists(parent_dir):
+ raise FileNotFoundError(f"Directory {parent_dir} does not exist.")
+
+ report_date = datetime.now().strftime("%B %d, %Y")
+
+ digest_version = importlib.metadata.version("digestai")
+
+ with open(filepath, "w", encoding="utf-8") as f_p:
+ f_p.write(f"Report created on {report_date}\n")
+ f_p.write(f"Digest version: {digest_version}\n")
+ f_p.write(f"Model type: {self.model_type.name}\n")
+ if self.filepath:
+ f_p.write(f"ONNX file: {self.filepath}\n")
+ f_p.write(f"Name of the model: {self.model_name}\n")
+ f_p.write(f"Model version: {self.model_version}\n")
+ f_p.write(f"Name of the graph: {self.graph_name}\n")
+ f_p.write(f"Producer: {self.producer_name} {self.producer_version}\n")
+ f_p.write(f"Ir version: {self.ir_version}\n")
+ f_p.write(f"Opset: {self.opset}\n\n")
+ f_p.write("Import list\n")
+ for name, version in self.imports.items():
+ f_p.write(f"\t{name}: {version}\n")
+
+ f_p.write("\n")
+ f_p.write(f"Total graph nodes: {sum(self.node_type_counts.values())}\n")
+ f_p.write(f"Number of parameters: {self.parameters}\n")
+ if self.flops:
+ f_p.write(f"Number of FLOPs: {self.flops}\n")
+ f_p.write("\n")
+
+ table_op_intensity = PrettyTable()
+ table_op_intensity.field_names = ["Operation", "FLOPs", "Intensity (%)"]
+ for op_type, count in self.node_type_flops.items():
+ if count > 0:
+ table_op_intensity.add_row(
+ [
+ op_type,
+ count,
+ 100.0 * float(count) / float(self.flops),
+ ]
+ )
+
+ f_p.write("Op intensity:\n")
+ f_p.write(table_op_intensity.get_string())
+ f_p.write("\n\n")
+
+ node_counts_table = PrettyTable()
+ node_counts_table.field_names = ["Node", "Occurrences"]
+ for op, count in self.node_type_counts.items():
+ node_counts_table.add_row([op, count])
+ f_p.write("Nodes and their occurrences:\n")
+ f_p.write(node_counts_table.get_string())
+ f_p.write("\n\n")
+
+ input_table = PrettyTable()
+ input_table.field_names = [
+ "Input Name",
+ "Shape",
+ "Type",
+ "Tensor Size (KB)",
+ ]
+ for input_name, input_details in self.model_inputs.items():
+ if input_details.size_kbytes:
+ kbytes = f"{input_details.size_kbytes:.2f}"
+ else:
+ kbytes = ""
+
+ input_table.add_row(
+ [
+ input_name,
+ input_details.shape,
+ input_details.dtype,
+ kbytes,
+ ]
+ )
+ f_p.write("Input Tensor(s) Information:\n")
+ f_p.write(input_table.get_string())
+ f_p.write("\n\n")
+
+ output_table = PrettyTable()
+ output_table.field_names = [
+ "Output Name",
+ "Shape",
+ "Type",
+ "Tensor Size (KB)",
+ ]
+ for output_name, output_details in self.model_outputs.items():
+ if output_details.size_kbytes:
+ kbytes = f"{output_details.size_kbytes:.2f}"
+ else:
+ kbytes = ""
+
+ output_table.add_row(
+ [
+ output_name,
+ output_details.shape,
+ output_details.dtype,
+ kbytes,
+ ]
+ )
+ f_p.write("Output Tensor(s) Information:\n")
+ f_p.write(output_table.get_string())
+ f_p.write("\n\n")
diff --git a/src/digest/model_class/digest_report_model.py b/src/digest/model_class/digest_report_model.py
new file mode 100644
index 0000000..f2ccd26
--- /dev/null
+++ b/src/digest/model_class/digest_report_model.py
@@ -0,0 +1,239 @@
+import os
+from collections import OrderedDict
+import csv
+import ast
+import re
+from typing import Tuple, Optional, List, Dict, Any, Union
+import yaml
+from digest.model_class.digest_model import (
+ DigestModel,
+ SupportedModelTypes,
+ NodeData,
+ NodeInfo,
+ TensorData,
+ TensorInfo,
+)
+
+
+def parse_tensor_info(
+ csv_tensor_cell_value,
+) -> Tuple[str, list, str, Union[str, float]]:
+ """This is a helper function that expects the input to come from parsing
+ the nodes csv and extracting either an input or output tensor."""
+
+ # Use regex to split the string into name and details
+ match = re.match(r"(.*?)\s*\((.*)\)$", csv_tensor_cell_value)
+ if not match:
+ raise ValueError(f"Invalid format for tensor info: {csv_tensor_cell_value}")
+
+ name, details = match.groups()
+
+ # Split details, but keep the shape as a single item
+ match = re.match(r"(\[.*?\])\s*,\s*(.*?)\s*,\s*(.*)", details)
+ if not match:
+ raise ValueError(f"Invalid format for tensor details: {details}")
+
+ shape_str, dtype, size = match.groups()
+
+ # Ensure shape is stored as a list
+ shape = ast.literal_eval(shape_str)
+ if not isinstance(shape, list):
+ shape = list(shape)
+
+ if size != "None":
+ size = float(size.split()[0])
+
+ return name.strip(), shape, dtype.strip(), size
+
+
+class DigestReportModel(DigestModel):
+ def __init__(
+ self,
+ report_filepath: str,
+ ) -> None:
+
+ self.model_type = SupportedModelTypes.REPORT
+
+ self.is_valid = validate_yaml(report_filepath)
+
+ if not self.is_valid:
+ print(f"The yaml file {report_filepath} is not a valid digest report.")
+ return
+
+ self.model_data = OrderedDict()
+ with open(report_filepath, "r", encoding="utf-8") as yaml_f:
+ self.model_data = yaml.safe_load(yaml_f)
+
+ model_name = self.model_data["model_name"]
+ super().__init__(report_filepath, model_name, SupportedModelTypes.REPORT)
+
+ self.similarity_heatmap_path: Optional[str] = None
+ self.node_data = NodeData()
+
+ # Given the path to the digest report, let's check if its a complete cache
+ # and we can grab the nodes csv data and the similarity heatmap
+ cache_dir = os.path.dirname(os.path.abspath(report_filepath))
+ expected_heatmap_file = os.path.join(cache_dir, f"{model_name}_heatmap.png")
+ if os.path.exists(expected_heatmap_file):
+ self.similarity_heatmap_path = expected_heatmap_file
+
+ expected_nodes_file = os.path.join(cache_dir, f"{model_name}_nodes.csv")
+ if os.path.exists(expected_nodes_file):
+ with open(expected_nodes_file, "r", encoding="utf-8") as csvfile:
+ reader = csv.DictReader(csvfile)
+ for row in reader:
+ node_name = row["Node Name"]
+ node_info = NodeInfo()
+ node_info.node_type = row["Node Type"]
+ if row["Parameters"]:
+ node_info.parameters = int(row["Parameters"])
+ if ast.literal_eval(row["FLOPs"]):
+ node_info.flops = int(row["FLOPs"])
+ node_info.attributes = (
+ OrderedDict(ast.literal_eval(row["Attributes"]))
+ if row["Attributes"]
+ else OrderedDict()
+ )
+
+ node_info.inputs = TensorData()
+ node_info.outputs = TensorData()
+
+ # Process inputs and outputs
+ for key, value in row.items():
+ if key.startswith("Input") and value:
+ input_name, shape, dtype, size = parse_tensor_info(value)
+ node_info.inputs[input_name] = TensorInfo()
+ node_info.inputs[input_name].shape = shape
+ node_info.inputs[input_name].dtype = dtype
+ node_info.inputs[input_name].size_kbytes = size
+
+ elif key.startswith("Output") and value:
+ output_name, shape, dtype, size = parse_tensor_info(value)
+ node_info.outputs[output_name] = TensorInfo()
+ node_info.outputs[output_name].shape = shape
+ node_info.outputs[output_name].dtype = dtype
+ node_info.outputs[output_name].size_kbytes = size
+
+ self.node_data[node_name] = node_info
+
+ # Unpack the model type agnostic values
+ self.flops = self.model_data["flops"]
+ self.parameters = self.model_data["parameters"]
+ self.node_type_flops = self.model_data["node_type_flops"]
+ self.node_type_parameters = self.model_data["node_type_parameters"]
+ self.node_type_counts = self.model_data["node_type_counts"]
+
+ self.model_inputs = TensorData(
+ {
+ key: TensorInfo(**val)
+ for key, val in self.model_data["input_tensors"].items()
+ }
+ )
+ self.model_outputs = TensorData(
+ {
+ key: TensorInfo(**val)
+ for key, val in self.model_data["output_tensors"].items()
+ }
+ )
+
+ def parse_model_nodes(self) -> None:
+ """There are no model nodes to parse"""
+ return
+
+ def save_yaml_report(self, filepath: str) -> None:
+ """Report models are not intended to be saved"""
+ return
+
+ def save_text_report(self, filepath: str) -> None:
+ """Report models are not intended to be saved"""
+ return
+
+
+def validate_yaml(report_file_path: str) -> bool:
+ """Check that the provided yaml file is indeed a Digest Report file."""
+ expected_keys = [
+ "report_date",
+ "model_file",
+ "model_type",
+ "model_name",
+ "flops",
+ "node_type_flops",
+ "node_type_parameters",
+ "node_type_counts",
+ "input_tensors",
+ "output_tensors",
+ ]
+ try:
+ with open(report_file_path, "r", encoding="utf-8") as file:
+ yaml_content = yaml.safe_load(file)
+
+ if not isinstance(yaml_content, dict):
+ print("Error: YAML content is not a dictionary")
+ return False
+
+ for key in expected_keys:
+ if key not in yaml_content:
+ # print(f"Error: Missing required key '{key}'")
+ return False
+
+ return True
+ except yaml.YAMLError as _:
+ # print(f"Error parsing YAML file: {e}")
+ return False
+ except IOError as _:
+ # print(f"Error reading file: {e}")
+ return False
+
+
+def compare_yaml_files(
+ file1: str, file2: str, skip_keys: Optional[List[str]] = None
+) -> bool:
+ """
+ Compare two YAML files, ignoring specified keys.
+
+ :param file1: Path to the first YAML file
+ :param file2: Path to the second YAML file
+ :param skip_keys: List of keys to ignore in the comparison
+ :return: True if the files are equal (ignoring specified keys), False otherwise
+ """
+
+ def load_yaml(file_path: str) -> Dict[str, Any]:
+ with open(file_path, "r", encoding="utf-8") as file:
+ return yaml.safe_load(file)
+
+ def compare_dicts(
+ dict1: Dict[str, Any], dict2: Dict[str, Any], path: str = ""
+ ) -> List[str]:
+ differences = []
+ all_keys = set(dict1.keys()) | set(dict2.keys())
+
+ for key in all_keys:
+ if skip_keys and key in skip_keys:
+ continue
+
+ current_path = f"{path}.{key}" if path else key
+
+ if key not in dict1:
+ differences.append(f"Key '{current_path}' is missing in the first file")
+ elif key not in dict2:
+ differences.append(
+ f"Key '{current_path}' is missing in the second file"
+ )
+ elif isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
+ differences.extend(compare_dicts(dict1[key], dict2[key], current_path))
+ elif dict1[key] != dict2[key]:
+ differences.append(
+ f"Value mismatch for key '{current_path}': {dict1[key]} != {dict2[key]}"
+ )
+
+ return differences
+
+ yaml1 = load_yaml(file1)
+ yaml2 = load_yaml(file2)
+
+ differences = compare_dicts(yaml1, yaml2)
+
+ if differences:
+ return False
+ else:
+ return True
diff --git a/src/digest/modelsummary.py b/src/digest/modelsummary.py
index 1e3872e..a92b756 100644
--- a/src/digest/modelsummary.py
+++ b/src/digest/modelsummary.py
@@ -3,10 +3,12 @@
import os
# pylint: disable=invalid-name
-from typing import Optional
+from typing import Optional, Union
# pylint: disable=no-name-in-module
from PySide6.QtWidgets import QWidget
+from PySide6.QtGui import QMovie
+from PySide6.QtCore import QSize
from onnx import ModelProto
@@ -14,37 +16,54 @@
from digest.freeze_inputs import FreezeInputs
from digest.popup_window import PopupWindow
from digest.qt_utils import apply_dark_style_sheet
-from utils import onnx_utils
+from digest.model_class.digest_onnx_model import DigestOnnxModel
+from digest.model_class.digest_report_model import DigestReportModel
+
ROOT_FOLDER = os.path.dirname(os.path.abspath(__file__))
class modelSummary(QWidget):
- def __init__(self, digest_model: onnx_utils.DigestOnnxModel, parent=None):
+ def __init__(
+ self, digest_model: Union[DigestOnnxModel, DigestReportModel], parent=None
+ ):
super().__init__(parent)
self.ui = Ui_modelSummary()
self.ui.setupUi(self)
apply_dark_style_sheet(self)
self.file: Optional[str] = None
- self.ui.freezeButton.setVisible(False)
- self.ui.freezeButton.clicked.connect(self.open_freeze_inputs)
self.ui.warningLabel.hide()
self.digest_model = digest_model
- self.model_proto: ModelProto = (
- digest_model.model_proto if digest_model.model_proto else ModelProto()
- )
+ self.model_proto: Optional[ModelProto] = None
model_name: str = digest_model.model_name if digest_model.model_name else ""
- self.freeze_inputs = FreezeInputs(self.model_proto, model_name)
- self.freeze_inputs.complete_signal.connect(self.close_freeze_window)
+
+ self.load_gif = QMovie(":/assets/gifs/load.gif")
+ # We set the size of the GIF to half the original
+ self.load_gif.setScaledSize(QSize(214, 120))
+ self.ui.similarityImg.setMovie(self.load_gif)
+ self.load_gif.start()
+
+ # There is no freezing if the model is not ONNX
+ self.ui.freezeButton.setVisible(False)
+ self.freeze_inputs: Optional[FreezeInputs] = None
self.freeze_window: Optional[QWidget] = None
+ if isinstance(digest_model, DigestOnnxModel):
+ self.model_proto = (
+ digest_model.model_proto if digest_model.model_proto else ModelProto()
+ )
+ self.freeze_inputs = FreezeInputs(self.model_proto, model_name)
+ self.ui.freezeButton.clicked.connect(self.open_freeze_inputs)
+ self.freeze_inputs.complete_signal.connect(self.close_freeze_window)
+
def open_freeze_inputs(self):
- self.freeze_window = PopupWindow(
- self.freeze_inputs, "Freeze Model Inputs", self
- )
- self.freeze_window.open()
+ if self.freeze_inputs:
+ self.freeze_window = PopupWindow(
+ self.freeze_inputs, "Freeze Model Inputs", self
+ )
+ self.freeze_window.open()
def close_freeze_window(self):
if self.freeze_window:
diff --git a/src/digest/multi_model_analysis.py b/src/digest/multi_model_analysis.py
index d7f6bab..1c09905 100644
--- a/src/digest/multi_model_analysis.py
+++ b/src/digest/multi_model_analysis.py
@@ -1,17 +1,27 @@
# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.
import os
+from datetime import datetime
import csv
from typing import List, Dict, Union
from collections import Counter, defaultdict, OrderedDict
# pylint: disable=no-name-in-module
from PySide6.QtWidgets import QWidget, QTableWidgetItem, QFileDialog
+from PySide6.QtCore import Qt
from digest.dialog import ProgressDialog, StatusDialog
from digest.ui.multimodelanalysis_ui import Ui_multiModelAnalysis
from digest.histogramchartwidget import StackedHistogramWidget
from digest.qt_utils import apply_dark_style_sheet
-from utils import onnx_utils
+from digest.model_class.digest_model import (
+ NodeTypeCounts,
+ NodeShapeCounts,
+ save_node_shape_counts_csv_report,
+ save_node_type_counts_csv_report,
+)
+from digest.model_class.digest_onnx_model import DigestOnnxModel
+from digest.model_class.digest_report_model import DigestReportModel
+import utils.onnx_utils as onnx_utils
ROOT_FOLDER = os.path.dirname(__file__)
@@ -21,7 +31,7 @@ class MultiModelAnalysis(QWidget):
def __init__(
self,
- model_list: List[onnx_utils.DigestOnnxModel],
+ model_list: List[Union[DigestOnnxModel, DigestReportModel]],
parent=None,
):
super().__init__(parent)
@@ -34,6 +44,9 @@ def __init__(
self.ui.individualCheckBox.stateChanged.connect(self.check_box_changed)
self.ui.multiCheckBox.stateChanged.connect(self.check_box_changed)
+ # For some reason setting alignments in designer lead to bugs in *ui.py files
+ self.ui.opHistogramChart.layout().setAlignment(Qt.AlignmentFlag.AlignTop)
+
if not model_list:
return
@@ -41,41 +54,60 @@ def __init__(
self.global_node_type_counter: Counter[str] = Counter()
# Holds the data for node shape counts across all models
- self.global_node_shape_counter: onnx_utils.NodeShapeCounts = defaultdict(
- Counter
- )
+ self.global_node_shape_counter: NodeShapeCounts = defaultdict(Counter)
# Holds the data for all models statistics
- self.global_model_data: Dict[str, Dict[str, Union[int, None]]] = {}
+ self.global_model_data: Dict[str, Dict[str, Union[int, str, None]]] = {}
progress = ProgressDialog("", len(model_list), self)
- header_labels = ["Model", "Opset", "Total Nodes", "Parameters", "FLOPs"]
+ header_labels = [
+ "Model Name",
+ "Model Type",
+ "Opset",
+ "Total Nodes",
+ "Parameters",
+ "FLOPs",
+ ]
self.ui.dataTable.setRowCount(len(model_list))
self.ui.dataTable.setColumnCount(len(header_labels))
self.ui.dataTable.setHorizontalHeaderLabels(header_labels)
self.ui.dataTable.setSortingEnabled(False)
for row, model in enumerate(model_list):
+
item = QTableWidgetItem(str(model.model_name))
self.ui.dataTable.setItem(row, 0, item)
- item = QTableWidgetItem(str(model.opset))
+ item = QTableWidgetItem(str(model.model_type.name))
self.ui.dataTable.setItem(row, 1, item)
- item = QTableWidgetItem(str(len(model.per_node_info)))
+ if isinstance(model, DigestOnnxModel):
+ item = QTableWidgetItem(str(model.opset))
+ elif isinstance(model, DigestReportModel):
+ item = QTableWidgetItem(str(model.model_data.get("opset", "")))
+
self.ui.dataTable.setItem(row, 2, item)
- item = QTableWidgetItem(str(model.model_parameters))
+ item = QTableWidgetItem(str(len(model.node_data)))
self.ui.dataTable.setItem(row, 3, item)
- item = QTableWidgetItem(str(model.model_flops))
+ item = QTableWidgetItem(str(model.parameters))
self.ui.dataTable.setItem(row, 4, item)
+ item = QTableWidgetItem(str(model.flops))
+ self.ui.dataTable.setItem(row, 5, item)
+
self.ui.dataTable.resizeColumnsToContents()
self.ui.dataTable.resizeRowsToContents()
- node_type_counter = {}
+ # Until we use the unique_id to represent the model contents we store
+ # the entire model as the key so that we can store models that happen to have
+ # the same name. There is a guarantee that the models will not be duplicates.
+ node_type_counter: Dict[
+ Union[DigestOnnxModel, DigestReportModel], NodeTypeCounts
+ ] = {}
+
for i, digest_model in enumerate(model_list):
progress.step()
progress.setLabelText(f"Analyzing model {digest_model.model_name}")
@@ -83,39 +115,52 @@ def __init__(
if digest_model.model_name is None:
digest_model.model_name = f"model_{i}"
- if digest_model.model_proto:
- dynamic_input_dims = onnx_utils.get_dynamic_input_dims(
- digest_model.model_proto
- )
- if dynamic_input_dims:
- print(
- "Found the following non-static input dims in your model. "
- "It is recommended to make all dims static before generating reports."
+ if isinstance(digest_model, DigestOnnxModel):
+ opset = digest_model.opset
+ if digest_model.model_proto:
+ dynamic_input_dims = onnx_utils.get_dynamic_input_dims(
+ digest_model.model_proto
)
- for dynamic_shape in dynamic_input_dims:
- print(f"dim: {dynamic_shape}")
+ if dynamic_input_dims:
+ print(
+ "Found the following non-static input dims in your model. "
+ "It is recommended to make all dims static before generating reports."
+ )
+ for dynamic_shape in dynamic_input_dims:
+ print(f"dim: {dynamic_shape}")
+
+ elif isinstance(digest_model, DigestReportModel):
+ opset = digest_model.model_data.get("opset", "")
# Update the global model dictionary
- if digest_model.model_name in self.global_model_data:
+ if digest_model.unique_id in self.global_model_data:
print(
- f"Warning! {digest_model.model_name} has already been processed, "
+ f"Warning! {digest_model.model_name} with id "
+ f"{digest_model.unique_id} has already been processed, "
"skipping the duplicate model."
)
-
- self.global_model_data[digest_model.model_name] = {
- "opset": digest_model.opset,
- "parameters": digest_model.model_parameters,
- "flops": digest_model.model_flops,
+ continue
+
+ self.global_model_data[digest_model.unique_id] = {
+ "model_name": digest_model.model_name,
+ "model_type": digest_model.model_type.name,
+ "opset": opset,
+ "parameters": digest_model.parameters,
+ "flops": digest_model.flops,
}
- node_type_counter[digest_model.model_name] = (
- digest_model.get_node_type_counts()
- )
+ if digest_model in node_type_counter:
+ print(
+ f"Warning! {digest_model.model_name} with model type "
+ f"{digest_model.model_type.value} and id {digest_model.unique_id} "
+ "has already been added to the stacked histogram, skipping."
+ )
+ continue
+
+ node_type_counter[digest_model] = digest_model.node_type_counts
# Update global data structure for node type counter
- self.global_node_type_counter.update(
- node_type_counter[digest_model.model_name]
- )
+ self.global_node_type_counter.update(node_type_counter[digest_model])
node_shape_counts = digest_model.get_node_shape_counts()
@@ -133,31 +178,31 @@ def __init__(
# Create stacked op histograms
max_count = 0
top_ops = [key for key, _ in self.global_node_type_counter.most_common(20)]
- for model_name, _ in node_type_counter.items():
- max_local = Counter(node_type_counter[model_name]).most_common()[0][1]
+ for model, _ in node_type_counter.items():
+ max_local = Counter(node_type_counter[model]).most_common()[0][1]
if max_local > max_count:
max_count = max_local
- for idx, model_name in enumerate(node_type_counter):
+ for idx, model in enumerate(node_type_counter):
stacked_histogram_widget = StackedHistogramWidget()
ordered_dict = OrderedDict()
- model_counter = Counter(node_type_counter[model_name])
+ model_counter = Counter(node_type_counter[model])
for key in top_ops:
ordered_dict[key] = model_counter.get(key, 0)
title = "Stacked Op Histogram" if idx == 0 else ""
stacked_histogram_widget.set_data(
ordered_dict,
- model_name=model_name,
+ model_name=model.model_name,
y_max=max_count,
title=title,
set_ticks=False,
)
frame_layout = self.ui.stackedHistogramFrame.layout()
- frame_layout.addWidget(stacked_histogram_widget)
+ if frame_layout:
+ frame_layout.addWidget(stacked_histogram_widget)
# Add a "ghost" histogram to allow us to set the x axis label vertically
- model_name = list(node_type_counter.keys())[0]
stacked_histogram_widget = StackedHistogramWidget()
- ordered_dict = {key: 1 for key in top_ops}
+ ordered_dict = OrderedDict({key: 1 for key in top_ops})
stacked_histogram_widget.set_data(
ordered_dict,
model_name="_",
@@ -165,18 +210,39 @@ def __init__(
set_ticks=True,
)
frame_layout = self.ui.stackedHistogramFrame.layout()
- frame_layout.addWidget(stacked_histogram_widget)
+ if frame_layout:
+ frame_layout.addWidget(stacked_histogram_widget)
self.model_list = model_list
def save_reports(self):
- # Model summary text report
- save_directory = QFileDialog(self).getExistingDirectory(
+ """This function saves all available reports for the models that are opened
+ in the multi-model analysis page."""
+
+ base_directory = QFileDialog(self).getExistingDirectory(
self, "Select Directory"
)
- if not save_directory:
- return
+ # Check if the directory exists and is writable
+ if not os.path.exists(base_directory) or not os.access(base_directory, os.W_OK):
+ bad_ext_dialog = StatusDialog(
+ f"The directory {base_directory} is not valid or writable.",
+ parent=self,
+ )
+ bad_ext_dialog.show()
+
+ # Append a subdirectory to the save_directory so that all reports are co-located
+ name_id = datetime.now().strftime("%Y%m%d%H%M%S")
+ sub_directory = f"multi_model_reports_{name_id}"
+ save_directory = os.path.join(base_directory, sub_directory)
+ try:
+ os.makedirs(save_directory)
+ except OSError as os_err:
+ bad_ext_dialog = StatusDialog(
+ f"Failed to create {save_directory} with error {os_err}",
+ parent=self,
+ )
+ bad_ext_dialog.show()
save_individual_reports = self.ui.individualCheckBox.isChecked()
save_multi_reports = self.ui.multiCheckBox.isChecked()
@@ -192,29 +258,21 @@ def save_reports(self):
save_directory, f"{digest_model.model_name}_summary.txt"
)
- digest_model.save_txt_report(summary_filepath)
+ digest_model.save_text_report(summary_filepath)
# Save csv of node type counts
node_type_filepath = os.path.join(
save_directory, f"{digest_model.model_name}_node_type_counts.csv"
)
- # Save csv containing node type counter
- node_type_counter = digest_model.get_node_type_counts()
-
- if node_type_counter:
- onnx_utils.save_node_type_counts_csv_report(
- node_type_counter, node_type_filepath
- )
+ if digest_model.node_type_counts:
+ digest_model.save_node_type_counts_csv_report(node_type_filepath)
# Save csv containing node shape counts per op_type
- node_shape_counts = digest_model.get_node_shape_counts()
node_shape_filepath = os.path.join(
save_directory, f"{digest_model.model_name}_node_shape_counts.csv"
)
- onnx_utils.save_node_shape_counts_csv_report(
- node_shape_counts, node_shape_filepath
- )
+ digest_model.save_node_shape_counts_csv_report(node_shape_filepath)
# Save csv containing all node-level information
nodes_filepath = os.path.join(
@@ -231,17 +289,17 @@ def save_reports(self):
global_filepath = os.path.join(
save_directory, "global_node_type_counts.csv"
)
- global_node_type_counter = onnx_utils.NodeTypeCounts(
+ global_node_type_counter = NodeTypeCounts(
self.global_node_type_counter.most_common()
)
- onnx_utils.save_node_type_counts_csv_report(
+ save_node_type_counts_csv_report(
global_node_type_counter, global_filepath
)
global_filepath = os.path.join(
save_directory, "global_node_shape_counts.csv"
)
- onnx_utils.save_node_shape_counts_csv_report(
+ save_node_shape_counts_csv_report(
self.global_node_shape_counter, global_filepath
)
@@ -253,10 +311,18 @@ def save_reports(self):
) as csvfile:
writer = csv.writer(csvfile)
rows = [
- [model, data["opset"], data["parameters"], data["flops"]]
- for model, data in self.global_model_data.items()
+ [
+ data["model_name"],
+ data["model_type"],
+ data["opset"],
+ data["parameters"],
+ data["flops"],
+ ]
+ for _, data in self.global_model_data.items()
]
- writer.writerow(["Model", "Opset", "Parameters", "FLOPs"])
+ writer.writerow(
+ ["Model Name", "Model Type", "Opset", "Parameters", "FLOPs"]
+ )
writer.writerows(rows)
if save_individual_reports or save_multi_reports:
diff --git a/src/digest/multi_model_selection_page.py b/src/digest/multi_model_selection_page.py
index d7b6a39..33fd78c 100644
--- a/src/digest/multi_model_selection_page.py
+++ b/src/digest/multi_model_selection_page.py
@@ -2,7 +2,7 @@
import os
import glob
-from typing import List, Optional, Dict
+from typing import List, Optional, Dict, Union
from collections import defaultdict
from google.protobuf.message import DecodeError
import onnx
@@ -22,6 +22,8 @@
from digest.ui.multimodelselection_page_ui import Ui_MultiModelSelection
from digest.multi_model_analysis import MultiModelAnalysis
from digest.qt_utils import apply_dark_style_sheet, prompt_user_ram_limit
+from digest.model_class.digest_onnx_model import DigestOnnxModel
+from digest.model_class.digest_report_model import DigestReportModel, compare_yaml_files
from utils import onnx_utils
@@ -33,7 +35,9 @@ class AnalysisThread(QThread):
def __init__(self):
super().__init__()
- self.model_dict: Dict[str, Optional[onnx_utils.DigestOnnxModel]] = {}
+ self.model_dict: Dict[
+ str, Optional[Union[DigestOnnxModel, DigestReportModel]]
+ ] = {}
self.user_canceled = False
def run(self):
@@ -47,19 +51,21 @@ def run(self):
self.step_progress.emit()
if model:
continue
- model_name = os.path.splitext(os.path.basename(file))[0]
- model_proto = onnx_utils.load_onnx(file, False)
- self.model_dict[file] = onnx_utils.DigestOnnxModel(
- model_proto, onnx_filepath=file, model_name=model_name, save_proto=False
- )
+ model_name, file_ext = os.path.splitext(os.path.basename(file))
+ if file_ext == ".onnx":
+ model_proto = onnx_utils.load_onnx(file, False)
+ self.model_dict[file] = DigestOnnxModel(
+ model_proto,
+ onnx_filepath=file,
+ model_name=model_name,
+ save_proto=False,
+ )
+ elif file_ext == ".yaml":
+ self.model_dict[file] = DigestReportModel(file)
self.close_progress.emit()
- model_list = [
- model
- for model in self.model_dict.values()
- if isinstance(model, onnx_utils.DigestOnnxModel)
- ]
+ model_list = [model for model in self.model_dict.values()]
self.completed.emit(model_list)
@@ -82,10 +88,19 @@ def __init__(
self.ui.warningLabel.hide()
self.item_model = QStandardItemModel()
self.item_model.itemChanged.connect(self.update_num_selected_label)
- self.ui.selectAllBox.setCheckState(Qt.CheckState.Checked)
- self.ui.selectAllBox.stateChanged.connect(self.update_list_view_items)
+ self.ui.radioAll.setChecked(True)
+ self.ui.radioAll.toggled.connect(self.update_list_view_items)
+ self.ui.radioONNX.toggled.connect(self.update_list_view_items)
+ self.ui.radioReports.toggled.connect(self.update_list_view_items)
self.ui.selectFolderBtn.clicked.connect(self.openFolder)
+
+ # We want to retain the size when the duplicate label
+ # is hidden to keep the two list columns even.
+ policy = self.ui.duplicateLabel.sizePolicy()
+ policy.setRetainSizeWhenHidden(True)
+ self.ui.duplicateLabel.setSizePolicy(policy)
self.ui.duplicateLabel.hide()
+
self.ui.modelListView.setModel(self.item_model)
self.ui.modelListView.setContextMenuPolicy(
Qt.ContextMenuPolicy.CustomContextMenu
@@ -94,7 +109,9 @@ def __init__(
self.ui.openAnalysisBtn.clicked.connect(self.start_analysis)
- self.model_dict: Dict[str, Optional[onnx_utils.DigestOnnxModel]] = {}
+ self.model_dict: Dict[
+ str, Optional[Union[DigestOnnxModel, DigestReportModel]]
+ ] = {}
self.analysis_thread: Optional[AnalysisThread] = None
self.progress: Optional[ProgressDialog] = None
@@ -165,14 +182,24 @@ def update_num_selected_label(self):
self.ui.openAnalysisBtn.setEnabled(False)
def update_list_view_items(self):
- state = self.ui.selectAllBox.checkState()
+ radio_all_state = self.ui.radioAll.isChecked()
+ radio_onnx_state = self.ui.radioONNX.isChecked()
+ radio_reports_state = self.ui.radioReports.isChecked()
for row in range(self.item_model.rowCount()):
item = self.item_model.item(row)
- item.setCheckState(state)
+ value = item.data(Qt.ItemDataRole.DisplayRole)
+ if radio_all_state:
+ item.setCheckState(Qt.CheckState.Checked)
+ elif os.path.splitext(value)[-1] == ".onnx" and radio_onnx_state:
+ item.setCheckState(Qt.CheckState.Checked)
+ elif os.path.splitext(value)[-1] == ".yaml" and radio_reports_state:
+ item.setCheckState(Qt.CheckState.Checked)
+ else:
+ item.setCheckState(Qt.CheckState.Unchecked)
def set_directory(self, directory: str):
"""
- Recursively searches a directory for onnx models.
+ Recursively searches a directory for onnx models and yaml report files.
"""
if not os.path.exists(directory):
@@ -183,36 +210,57 @@ def set_directory(self, directory: str):
else:
return
- progress = ProgressDialog("Searching Directory for ONNX Files", 0, self)
+ progress = ProgressDialog("Searching directory for model files", 0, self)
+
onnx_file_list = list(
glob.glob(os.path.join(directory, "**/*.onnx"), recursive=True)
)
+ onnx_file_list = [os.path.normpath(model_file) for model_file in onnx_file_list]
+
+ yaml_file_list = list(
+ glob.glob(os.path.join(directory, "**/*.yaml"), recursive=True)
+ )
+ yaml_file_list = [os.path.normpath(model_file) for model_file in yaml_file_list]
+
+ # Filter out YAML files that are not valid reports
+ report_file_list = []
+ for yaml_file in yaml_file_list:
+ digest_report = DigestReportModel(yaml_file)
+ if digest_report.is_valid:
+ report_file_list.append(yaml_file)
+
+ total_num_models = len(onnx_file_list) + len(report_file_list)
- onnx_file_list = [os.path.normpath(onnx_file) for onnx_file in onnx_file_list]
serialized_models_paths: defaultdict[bytes, List[str]] = defaultdict(list)
progress.close()
- progress = ProgressDialog("Loading ONNX Models", len(onnx_file_list), self)
+ progress = ProgressDialog("Loading models", total_num_models, self)
memory_limit_percentage = 90
models_loaded = 0
- for filepath in onnx_file_list:
+ for filepath in onnx_file_list + report_file_list:
progress.step()
if progress.user_canceled:
break
try:
models_loaded += 1
- model = onnx.load(filepath, load_external_data=False)
- dialog_msg = f"""Warning: System RAM has exceeded the threshold of {memory_limit_percentage}%.
- No further models will be loaded.
- """
+ extension = os.path.splitext(filepath)[-1]
+ if extension == ".onnx":
+ model = onnx.load(filepath, load_external_data=False)
+ serialized_models_paths[model.SerializeToString()].append(filepath)
+ elif extension == ".yaml":
+ pass
+ dialog_msg = (
+ "Warning: System RAM has exceeded the threshold of "
+ f"{memory_limit_percentage}%. No further models will be loaded. "
+ )
if prompt_user_ram_limit(
sys_ram_percent_limit=memory_limit_percentage,
message=dialog_msg,
parent=self,
):
self.update_warning_label(
- f"Loaded only {models_loaded - 1} out of {len(onnx_file_list)} models "
+ f"Loaded only {models_loaded - 1} out of {total_num_models} models "
f"as memory consumption has reached {memory_limit_percentage}% of "
"system memory. Preventing further loading of models."
)
@@ -223,15 +271,13 @@ def set_directory(self, directory: str):
break
else:
self.ui.warningLabel.hide()
- serialized_models_paths[model.SerializeToString()].append(filepath)
+
except DecodeError as error:
print(f"Error decoding model {filepath}: {error}")
progress.close()
- progress = ProgressDialog(
- "Processing ONNX Models", len(serialized_models_paths), self
- )
+ progress = ProgressDialog("Processing Models", total_num_models, self)
num_duplicates = 0
self.item_model.clear()
@@ -245,23 +291,47 @@ def set_directory(self, directory: str):
self.ui.duplicateListWidget.addItem(paths[0])
for dupe in paths[1:]:
self.ui.duplicateListWidget.addItem(f"- Duplicate: {dupe}")
- item = QStandardItem(paths[0])
- item.setCheckable(True)
- item.setCheckState(Qt.CheckState.Checked)
- self.item_model.appendRow(item)
- else:
- item = QStandardItem(paths[0])
- item.setCheckable(True)
- item.setCheckState(Qt.CheckState.Checked)
- self.item_model.appendRow(item)
+ item = QStandardItem(paths[0])
+ item.setCheckable(True)
+ item.setCheckState(Qt.CheckState.Checked)
+ self.item_model.appendRow(item)
+
+ # Use a standard nested loop to detect duplicate reports
+ duplicate_reports: Dict[str, List[str]] = {}
+ processed_files = set()
+ for i in range(len(report_file_list)):
+ progress.step()
+ if progress.user_canceled:
+ break
+ path1 = report_file_list[i]
+ if path1 in processed_files:
+ continue # Skip already processed files
+
+ # We will use path1 as the unique model and save a list of duplicates
+ duplicate_reports[path1] = []
+ for j in range(i + 1, len(report_file_list)):
+ path2 = report_file_list[j]
+ if compare_yaml_files(
+ path1, path2, ["report_date", "model_files", "digest_version"]
+ ):
+ num_duplicates += 1
+ duplicate_reports[path1].append(path2)
+ processed_files.add(path2)
+
+ for path, dupes in duplicate_reports.items():
+ if dupes:
+ self.ui.duplicateListWidget.addItem(path)
+ for dupe in dupes:
+ self.ui.duplicateListWidget.addItem(f"- Duplicate: {dupe}")
+ item = QStandardItem(path)
+ item.setCheckable(True)
+ item.setCheckState(Qt.CheckState.Checked)
+ self.item_model.appendRow(item)
progress.close()
if num_duplicates:
- label_text = (
- f"The following {num_duplicates} models were found to be "
- "duplicates and have been deselected from the list on the left."
- )
+ label_text = f"Ignoring {num_duplicates} duplicate model(s)."
self.ui.duplicateLabel.setText(label_text)
self.ui.duplicateLabel.show()
else:
@@ -270,7 +340,7 @@ def set_directory(self, directory: str):
self.update_num_selected_label()
self.update_message_label(
- f"Found a total of {len(onnx_file_list)} ONNX files. "
+ f"Found a total of {total_num_models} model files. "
"Right click a model below "
"to open it up in the model summary view."
)
@@ -289,7 +359,9 @@ def start_analysis(self):
self.analysis_thread.model_dict = self.model_dict
self.analysis_thread.start()
- def open_analysis(self, model_list: List[onnx_utils.DigestOnnxModel]):
+ def open_analysis(
+ self, model_list: List[Union[DigestOnnxModel, DigestReportModel]]
+ ):
multi_model_analysis = MultiModelAnalysis(model_list)
self.analysis_window.setCentralWidget(multi_model_analysis)
self.analysis_window.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg"))
diff --git a/src/digest/node_summary.py b/src/digest/node_summary.py
index 99eb35f..01aaf09 100644
--- a/src/digest/node_summary.py
+++ b/src/digest/node_summary.py
@@ -6,6 +6,10 @@
from PySide6.QtWidgets import QWidget, QTableWidgetItem, QFileDialog
from digest.ui.nodessummary_ui import Ui_nodesSummary
from digest.qt_utils import apply_dark_style_sheet
+from digest.model_class.digest_model import (
+ save_node_shape_counts_csv_report,
+ save_nodes_csv_report,
+)
from utils import onnx_utils
ROOT_FOLDER = os.path.dirname(__file__)
@@ -111,8 +115,6 @@ def save_csv_file(self):
self, "Save CSV", os.getcwd(), "CSV(*.csv)"
)
if filepath and self.ui.allNodesBtn.isChecked():
- onnx_utils.save_nodes_csv_report(self.node_data, filepath)
+ save_nodes_csv_report(self.node_data, filepath)
elif filepath and self.ui.shapeCountsBtn.isChecked():
- onnx_utils.save_node_shape_counts_csv_report(
- self.node_shape_counts, filepath
- )
+ save_node_shape_counts_csv_report(self.node_shape_counts, filepath)
diff --git a/src/digest/resource_rc.py b/src/digest/resource_rc.py
index cf29584..59afc50 100644
--- a/src/digest/resource_rc.py
+++ b/src/digest/resource_rc.py
@@ -1,6 +1,6 @@
# Resource object code (Python 3)
# Created by: object code
-# Created by: The Resource Compiler for Qt version 6.8.0
+# Created by: The Resource Compiler for Qt version 6.8.1
# WARNING! All changes made in this file will be lost!
from PySide6 import QtCore
@@ -19676,39 +19676,39 @@
\x00\x00\x000\x00\x02\x00\x00\x00\x03\x00\x00\x00\x05\
\x00\x00\x00\x00\x00\x00\x00\x00\
\x00\x00\x00B\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\
-\x00\x00\x01\x93Ju\xc2>\
+\x00\x00\x01\x93K\x85\xbd\xbb\
\x00\x00\x00\xb0\x00\x00\x00\x00\x00\x01\x00\x01\x86\x02\
-\x00\x00\x01\x93Ju\xc2A\
+\x00\x00\x01\x93K\x85\xbd\xbb\
\x00\x00\x00\x84\x00\x00\x00\x00\x00\x01\x00\x01
bool:
+
+ loop = QEventLoop()
+ timer = QTimer()
+ timer.setSingleShot(True)
+ timer.timeout.connect(loop.quit)
+
+ def check_threads():
+ if all(thread.isFinished() for thread in threads):
+ loop.quit()
+
+ check_timer = QTimer()
+ check_timer.timeout.connect(check_threads)
+ check_timer.start(100) # Check every 100ms
+
+ timer.start(timeout)
+ loop.exec()
+
+ check_timer.stop()
+ timer.stop()
+
+ # Return True if all threads finished, False if timed out
+ return all(thread.isFinished() for thread in threads)
+
+
class StatsThread(QThread):
- completed = Signal(onnx_utils.DigestOnnxModel, str)
+ completed = Signal(DigestOnnxModel, str)
def __init__(
self,
@@ -31,14 +59,17 @@ def run(self):
if not self.unique_id:
raise ValueError("You must specify a unique id.")
- digest_model = onnx_utils.DigestOnnxModel(self.model, save_proto=False)
+ digest_model = DigestOnnxModel(self.model, save_proto=False)
self.completed.emit(digest_model, self.unique_id)
+ def wait(self, timeout=10000):
+ wait_threads([self], timeout)
+
class SimilarityThread(QThread):
- completed_successfully = Signal(bool, str, str, str)
+ completed_successfully = Signal(bool, str, str, str, pd.DataFrame)
def __init__(
self,
@@ -60,21 +91,72 @@ def run(self):
raise ValueError("You must set the model id")
try:
- most_similar, _ = find_match(
+ most_similar, _, df_sorted = find_match(
self.model_filepath,
- self.png_filepath,
dequantize=False,
replace=True,
- dark_mode=True,
)
most_similar = [os.path.basename(path) for path in most_similar]
- most_similar = ",".join(most_similar[1:4])
+ # We convert List[str] to str to send through the signal
+ most_similar = ",".join(most_similar)
self.completed_successfully.emit(
- True, self.model_id, most_similar, self.png_filepath
+ True, self.model_id, most_similar, self.png_filepath, df_sorted
)
except Exception as e: # pylint: disable=broad-exception-caught
most_similar = ""
self.completed_successfully.emit(
- False, self.model_id, most_similar, self.png_filepath
+ False, self.model_id, most_similar, self.png_filepath, df_sorted
)
print(f"Issue creating similarity analysis: {e}")
+
+ def wait(self, timeout=10000):
+ wait_threads([self], timeout)
+
+
+def post_process(
+ model_name: str,
+ name_list: List[str],
+ df_sorted: pd.DataFrame,
+ png_file_path: str,
+ dark_mode: bool = True,
+):
+ """Matplotlib is not thread safe so we must do post_processing on the main thread"""
+ if dark_mode:
+ plt.style.use("dark_background")
+ fig, ax = plt.subplots(figsize=(12, 10))
+ im = ax.imshow(df_sorted, cmap="viridis")
+
+ # Show all ticks and label them with the respective list entries
+ ax.set_xticks(np.arange(len(df_sorted.columns)))
+ ax.set_yticks(np.arange(len(name_list)))
+ ax.set_xticklabels([a[:5] for a in df_sorted.columns])
+ ax.set_yticklabels(name_list)
+
+ # Rotate the tick labels and set their alignment
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
+
+ ax.set_title(f"Model Similarity Heatmap - {model_name}")
+
+ cb = plt.colorbar(
+ im,
+ ax=ax,
+ shrink=0.5,
+ format="%.2f",
+ label="Correlation Ratio",
+ orientation="vertical",
+ # pad=0.02,
+ )
+ cb.set_ticks([0, 0.5, 1]) # Set colorbar ticks at 0, 0.5, and 1
+ cb.set_ticklabels(
+ ["0.0 (Low)", "0.5 (Medium)", "1.0 (High)"]
+ ) # Set corresponding labels
+ cb.set_label("Correlation Ratio", labelpad=-100)
+
+ fig.tight_layout()
+
+ if png_file_path is None:
+ png_file_path = "heatmap.png"
+
+ fig.savefig(png_file_path)
+
+ plt.close(fig)
diff --git a/src/digest/ui/freezeinputs_ui.py b/src/digest/ui/freezeinputs_ui.py
index 3838e57..85e8211 100644
--- a/src/digest/ui/freezeinputs_ui.py
+++ b/src/digest/ui/freezeinputs_ui.py
@@ -3,7 +3,7 @@
################################################################################
## Form generated from reading UI file 'freezeinputs.ui'
##
-## Created by: Qt User Interface Compiler version 6.8.0
+## Created by: Qt User Interface Compiler version 6.8.1
##
## WARNING! All changes made in this file will be lost when recompiling UI file!
################################################################################
diff --git a/src/digest/ui/huggingface_page_ui.py b/src/digest/ui/huggingface_page_ui.py
index a06a573..5cfcbfe 100644
--- a/src/digest/ui/huggingface_page_ui.py
+++ b/src/digest/ui/huggingface_page_ui.py
@@ -3,7 +3,7 @@
################################################################################
## Form generated from reading UI file 'huggingface_page.ui'
##
-## Created by: Qt User Interface Compiler version 6.8.0
+## Created by: Qt User Interface Compiler version 6.8.1
##
## WARNING! All changes made in this file will be lost when recompiling UI file!
################################################################################
diff --git a/src/digest/ui/mainwindow.ui b/src/digest/ui/mainwindow.ui
index 8643efa..e7e28f3 100644
--- a/src/digest/ui/mainwindow.ui
+++ b/src/digest/ui/mainwindow.ui
@@ -179,7 +179,7 @@
Qt::FocusPolicy::NoFocus
- <html><head/><body><p>Open a local model file (Ctrl-O)</p></body></html>
+ <html><head/><body><p>Open (Ctrl-O)</p></body></html>
QPushButton {
diff --git a/src/digest/ui/mainwindow_ui.py b/src/digest/ui/mainwindow_ui.py
index 9904c77..61119a1 100644
--- a/src/digest/ui/mainwindow_ui.py
+++ b/src/digest/ui/mainwindow_ui.py
@@ -3,7 +3,7 @@
################################################################################
## Form generated from reading UI file 'mainwindow.ui'
##
-## Created by: Qt User Interface Compiler version 6.8.0
+## Created by: Qt User Interface Compiler version 6.8.1
##
## WARNING! All changes made in this file will be lost when recompiling UI file!
################################################################################
@@ -520,7 +520,7 @@ def setupUi(self, MainWindow):
def retranslateUi(self, MainWindow):
MainWindow.setWindowTitle(QCoreApplication.translate("MainWindow", u"DigestAI", None))
#if QT_CONFIG(tooltip)
- self.openFileBtn.setToolTip(QCoreApplication.translate("MainWindow", u"Open a local model file (Ctrl-O)
", None))
+ self.openFileBtn.setToolTip(QCoreApplication.translate("MainWindow", u"Open (Ctrl-O)
", None))
#endif // QT_CONFIG(tooltip)
self.openFileBtn.setText("")
#if QT_CONFIG(shortcut)
diff --git a/src/digest/ui/modelsummary.ui b/src/digest/ui/modelsummary.ui
index 180fed4..737cf33 100644
--- a/src/digest/ui/modelsummary.ui
+++ b/src/digest/ui/modelsummary.ui
@@ -6,8 +6,8 @@
0
0
- 980
- 687
+ 1138
+ 837
@@ -153,11 +153,17 @@ border-top-right-radius: 10px;
0
- 0
+ -776
991
- 1453
+ 1443
+
+
+ 0
+ 0
+
+
background-color: black;
@@ -244,7 +250,7 @@ QFrame:hover {
6
- -
+
-
@@ -271,7 +277,7 @@ QFrame:hover {
- -
+
-
@@ -667,20 +673,32 @@ QFrame:hover {
-
-
+
0
0
+
+
+ 300
+ 500
+
+
-
-
+
-
-
+
0
0
+
+
+ 0
+ 0
+
+
16777215
@@ -690,6 +708,9 @@ QFrame:hover {
Loading...
+
+ false
+
Qt::AlignmentFlag::AlignCenter
@@ -834,7 +855,7 @@ QFrame:hover {
-
-
+
0
0
@@ -853,6 +874,9 @@ QFrame:hover {
-
+
+ 6
+
20
@@ -861,8 +885,14 @@ QFrame:hover {
-
-
-
+
-
+
+
+ 0
+ 0
+
+
QLabel {
font-size: 18px;
@@ -875,11 +905,11 @@ QFrame:hover {
- -
+
-
-
- 0
+
+ 1
0
@@ -975,7 +1005,7 @@ QScrollBar::handle:vertical {
- -
+
-
@@ -983,6 +1013,18 @@ QScrollBar::handle:vertical {
0
+
+
+ 0
+ 0
+
+
+
+
+ 16777215
+ 16777215
+
+
PointingHandCursor
@@ -1067,7 +1109,7 @@ QPushButton:pressed {
-
-
+
0
0
@@ -1218,7 +1260,7 @@ QScrollBar::handle:vertical {
- -
+
-
diff --git a/src/digest/ui/modelsummary_ui.py b/src/digest/ui/modelsummary_ui.py
index 1102e3a..e217372 100644
--- a/src/digest/ui/modelsummary_ui.py
+++ b/src/digest/ui/modelsummary_ui.py
@@ -3,7 +3,7 @@
################################################################################
## Form generated from reading UI file 'modelsummary.ui'
##
-## Created by: Qt User Interface Compiler version 6.8.0
+## Created by: Qt User Interface Compiler version 6.8.1
##
## WARNING! All changes made in this file will be lost when recompiling UI file!
################################################################################
@@ -29,7 +29,7 @@ class Ui_modelSummary(object):
def setupUi(self, modelSummary):
if not modelSummary.objectName():
modelSummary.setObjectName(u"modelSummary")
- modelSummary.resize(980, 687)
+ modelSummary.resize(1138, 837)
sizePolicy = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.MinimumExpanding)
sizePolicy.setHorizontalStretch(0)
sizePolicy.setVerticalStretch(0)
@@ -115,17 +115,22 @@ def setupUi(self, modelSummary):
self.scrollArea.setWidgetResizable(True)
self.scrollAreaWidgetContents = QWidget()
self.scrollAreaWidgetContents.setObjectName(u"scrollAreaWidgetContents")
- self.scrollAreaWidgetContents.setGeometry(QRect(0, 0, 991, 1453))
+ self.scrollAreaWidgetContents.setGeometry(QRect(0, -776, 991, 1443))
+ sizePolicy2 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.MinimumExpanding)
+ sizePolicy2.setHorizontalStretch(0)
+ sizePolicy2.setVerticalStretch(0)
+ sizePolicy2.setHeightForWidth(self.scrollAreaWidgetContents.sizePolicy().hasHeightForWidth())
+ self.scrollAreaWidgetContents.setSizePolicy(sizePolicy2)
self.scrollAreaWidgetContents.setStyleSheet(u"background-color: black;")
self.verticalLayout_20 = QVBoxLayout(self.scrollAreaWidgetContents)
self.verticalLayout_20.setObjectName(u"verticalLayout_20")
self.cardFrame = QFrame(self.scrollAreaWidgetContents)
self.cardFrame.setObjectName(u"cardFrame")
- sizePolicy2 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred)
- sizePolicy2.setHorizontalStretch(0)
- sizePolicy2.setVerticalStretch(0)
- sizePolicy2.setHeightForWidth(self.cardFrame.sizePolicy().hasHeightForWidth())
- self.cardFrame.setSizePolicy(sizePolicy2)
+ sizePolicy3 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred)
+ sizePolicy3.setHorizontalStretch(0)
+ sizePolicy3.setVerticalStretch(0)
+ sizePolicy3.setHeightForWidth(self.cardFrame.sizePolicy().hasHeightForWidth())
+ self.cardFrame.setSizePolicy(sizePolicy3)
self.cardFrame.setStyleSheet(u"background: transparent; /*rgb(40,40,40)*/")
self.cardFrame.setFrameShape(QFrame.Shape.StyledPanel)
self.cardFrame.setFrameShadow(QFrame.Shadow.Raised)
@@ -134,19 +139,19 @@ def setupUi(self, modelSummary):
self.horizontalLayout.setContentsMargins(-1, -1, -1, 1)
self.cardWidget = QWidget(self.cardFrame)
self.cardWidget.setObjectName(u"cardWidget")
- sizePolicy2.setHeightForWidth(self.cardWidget.sizePolicy().hasHeightForWidth())
- self.cardWidget.setSizePolicy(sizePolicy2)
+ sizePolicy3.setHeightForWidth(self.cardWidget.sizePolicy().hasHeightForWidth())
+ self.cardWidget.setSizePolicy(sizePolicy3)
self.horizontalLayout_2 = QHBoxLayout(self.cardWidget)
self.horizontalLayout_2.setSpacing(13)
self.horizontalLayout_2.setObjectName(u"horizontalLayout_2")
self.horizontalLayout_2.setContentsMargins(-1, 6, 25, 35)
self.opsetFrame = QFrame(self.cardWidget)
self.opsetFrame.setObjectName(u"opsetFrame")
- sizePolicy3 = QSizePolicy(QSizePolicy.Policy.Maximum, QSizePolicy.Policy.Fixed)
- sizePolicy3.setHorizontalStretch(0)
- sizePolicy3.setVerticalStretch(0)
- sizePolicy3.setHeightForWidth(self.opsetFrame.sizePolicy().hasHeightForWidth())
- self.opsetFrame.setSizePolicy(sizePolicy3)
+ sizePolicy4 = QSizePolicy(QSizePolicy.Policy.Maximum, QSizePolicy.Policy.Fixed)
+ sizePolicy4.setHorizontalStretch(0)
+ sizePolicy4.setVerticalStretch(0)
+ sizePolicy4.setHeightForWidth(self.opsetFrame.sizePolicy().hasHeightForWidth())
+ self.opsetFrame.setSizePolicy(sizePolicy4)
self.opsetFrame.setMinimumSize(QSize(220, 70))
self.opsetFrame.setMaximumSize(QSize(16777215, 80))
self.opsetFrame.setStyleSheet(u"QFrame {\n"
@@ -164,11 +169,11 @@ def setupUi(self, modelSummary):
self.verticalLayout_5.setContentsMargins(-1, -1, 6, -1)
self.opsetLabel = QLabel(self.opsetFrame)
self.opsetLabel.setObjectName(u"opsetLabel")
- sizePolicy4 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Fixed)
- sizePolicy4.setHorizontalStretch(0)
- sizePolicy4.setVerticalStretch(0)
- sizePolicy4.setHeightForWidth(self.opsetLabel.sizePolicy().hasHeightForWidth())
- self.opsetLabel.setSizePolicy(sizePolicy4)
+ sizePolicy5 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Fixed)
+ sizePolicy5.setHorizontalStretch(0)
+ sizePolicy5.setVerticalStretch(0)
+ sizePolicy5.setHeightForWidth(self.opsetLabel.sizePolicy().hasHeightForWidth())
+ self.opsetLabel.setSizePolicy(sizePolicy5)
self.opsetLabel.setStyleSheet(u"QLabel {\n"
" font-size: 18px;\n"
" font-weight: bold;\n"
@@ -178,12 +183,12 @@ def setupUi(self, modelSummary):
self.opsetLabel.setAlignment(Qt.AlignmentFlag.AlignCenter)
self.opsetLabel.setTextInteractionFlags(Qt.TextInteractionFlag.TextSelectableByMouse)
- self.verticalLayout_5.addWidget(self.opsetLabel, 0, Qt.AlignmentFlag.AlignHCenter)
+ self.verticalLayout_5.addWidget(self.opsetLabel)
self.opsetVersion = QLabel(self.opsetFrame)
self.opsetVersion.setObjectName(u"opsetVersion")
- sizePolicy4.setHeightForWidth(self.opsetVersion.sizePolicy().hasHeightForWidth())
- self.opsetVersion.setSizePolicy(sizePolicy4)
+ sizePolicy5.setHeightForWidth(self.opsetVersion.sizePolicy().hasHeightForWidth())
+ self.opsetVersion.setSizePolicy(sizePolicy5)
self.opsetVersion.setStyleSheet(u"QLabel {\n"
" font-size: 18px;\n"
" font-weight: bold;\n"
@@ -192,15 +197,15 @@ def setupUi(self, modelSummary):
self.opsetVersion.setAlignment(Qt.AlignmentFlag.AlignCenter)
self.opsetVersion.setTextInteractionFlags(Qt.TextInteractionFlag.LinksAccessibleByMouse|Qt.TextInteractionFlag.TextSelectableByKeyboard|Qt.TextInteractionFlag.TextSelectableByMouse)
- self.verticalLayout_5.addWidget(self.opsetVersion, 0, Qt.AlignmentFlag.AlignHCenter)
+ self.verticalLayout_5.addWidget(self.opsetVersion)
self.horizontalLayout_2.addWidget(self.opsetFrame)
self.nodesFrame = QFrame(self.cardWidget)
self.nodesFrame.setObjectName(u"nodesFrame")
- sizePolicy3.setHeightForWidth(self.nodesFrame.sizePolicy().hasHeightForWidth())
- self.nodesFrame.setSizePolicy(sizePolicy3)
+ sizePolicy4.setHeightForWidth(self.nodesFrame.sizePolicy().hasHeightForWidth())
+ self.nodesFrame.setSizePolicy(sizePolicy4)
self.nodesFrame.setMinimumSize(QSize(220, 70))
self.nodesFrame.setMaximumSize(QSize(16777215, 80))
self.nodesFrame.setStyleSheet(u"QFrame {\n"
@@ -218,8 +223,8 @@ def setupUi(self, modelSummary):
self.verticalLayout_12.setContentsMargins(-1, 9, -1, -1)
self.nodesLabel = QLabel(self.nodesFrame)
self.nodesLabel.setObjectName(u"nodesLabel")
- sizePolicy4.setHeightForWidth(self.nodesLabel.sizePolicy().hasHeightForWidth())
- self.nodesLabel.setSizePolicy(sizePolicy4)
+ sizePolicy5.setHeightForWidth(self.nodesLabel.sizePolicy().hasHeightForWidth())
+ self.nodesLabel.setSizePolicy(sizePolicy5)
self.nodesLabel.setStyleSheet(u"QLabel {\n"
" font-size: 18px;\n"
" font-weight: bold;\n"
@@ -233,8 +238,8 @@ def setupUi(self, modelSummary):
self.nodes = QLabel(self.nodesFrame)
self.nodes.setObjectName(u"nodes")
- sizePolicy4.setHeightForWidth(self.nodes.sizePolicy().hasHeightForWidth())
- self.nodes.setSizePolicy(sizePolicy4)
+ sizePolicy5.setHeightForWidth(self.nodes.sizePolicy().hasHeightForWidth())
+ self.nodes.setSizePolicy(sizePolicy5)
self.nodes.setMinimumSize(QSize(150, 32))
self.nodes.setStyleSheet(u"QLabel {\n"
" font-size: 18px;\n"
@@ -254,8 +259,8 @@ def setupUi(self, modelSummary):
self.paramFrame = QFrame(self.cardWidget)
self.paramFrame.setObjectName(u"paramFrame")
- sizePolicy3.setHeightForWidth(self.paramFrame.sizePolicy().hasHeightForWidth())
- self.paramFrame.setSizePolicy(sizePolicy3)
+ sizePolicy4.setHeightForWidth(self.paramFrame.sizePolicy().hasHeightForWidth())
+ self.paramFrame.setSizePolicy(sizePolicy4)
self.paramFrame.setMinimumSize(QSize(220, 70))
self.paramFrame.setMaximumSize(QSize(16777215, 80))
self.paramFrame.setStyleSheet(u"QFrame {\n"
@@ -272,8 +277,8 @@ def setupUi(self, modelSummary):
self.verticalLayout_9.setObjectName(u"verticalLayout_9")
self.parametersLabel = QLabel(self.paramFrame)
self.parametersLabel.setObjectName(u"parametersLabel")
- sizePolicy4.setHeightForWidth(self.parametersLabel.sizePolicy().hasHeightForWidth())
- self.parametersLabel.setSizePolicy(sizePolicy4)
+ sizePolicy5.setHeightForWidth(self.parametersLabel.sizePolicy().hasHeightForWidth())
+ self.parametersLabel.setSizePolicy(sizePolicy5)
self.parametersLabel.setStyleSheet(u"QLabel {\n"
" font-size: 18px;\n"
" font-weight: bold;\n"
@@ -287,8 +292,8 @@ def setupUi(self, modelSummary):
self.parameters = QLabel(self.paramFrame)
self.parameters.setObjectName(u"parameters")
- sizePolicy4.setHeightForWidth(self.parameters.sizePolicy().hasHeightForWidth())
- self.parameters.setSizePolicy(sizePolicy4)
+ sizePolicy5.setHeightForWidth(self.parameters.sizePolicy().hasHeightForWidth())
+ self.parameters.setSizePolicy(sizePolicy5)
self.parameters.setStyleSheet(u"QLabel {\n"
" font-size: 18px;\n"
" font-weight: bold;\n"
@@ -304,8 +309,8 @@ def setupUi(self, modelSummary):
self.flopsFrame = QFrame(self.cardWidget)
self.flopsFrame.setObjectName(u"flopsFrame")
- sizePolicy3.setHeightForWidth(self.flopsFrame.sizePolicy().hasHeightForWidth())
- self.flopsFrame.setSizePolicy(sizePolicy3)
+ sizePolicy4.setHeightForWidth(self.flopsFrame.sizePolicy().hasHeightForWidth())
+ self.flopsFrame.setSizePolicy(sizePolicy4)
self.flopsFrame.setMinimumSize(QSize(220, 70))
self.flopsFrame.setMaximumSize(QSize(16777215, 80))
self.flopsFrame.setCursor(QCursor(Qt.CursorShape.ArrowCursor))
@@ -323,8 +328,8 @@ def setupUi(self, modelSummary):
self.verticalLayout_11.setObjectName(u"verticalLayout_11")
self.flopsLabel = QLabel(self.flopsFrame)
self.flopsLabel.setObjectName(u"flopsLabel")
- sizePolicy4.setHeightForWidth(self.flopsLabel.sizePolicy().hasHeightForWidth())
- self.flopsLabel.setSizePolicy(sizePolicy4)
+ sizePolicy5.setHeightForWidth(self.flopsLabel.sizePolicy().hasHeightForWidth())
+ self.flopsLabel.setSizePolicy(sizePolicy5)
self.flopsLabel.setStyleSheet(u"QLabel {\n"
" font-size: 18px;\n"
" font-weight: bold;\n"
@@ -338,11 +343,11 @@ def setupUi(self, modelSummary):
self.flops = QLabel(self.flopsFrame)
self.flops.setObjectName(u"flops")
- sizePolicy5 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.Fixed)
- sizePolicy5.setHorizontalStretch(0)
- sizePolicy5.setVerticalStretch(0)
- sizePolicy5.setHeightForWidth(self.flops.sizePolicy().hasHeightForWidth())
- self.flops.setSizePolicy(sizePolicy5)
+ sizePolicy6 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.Fixed)
+ sizePolicy6.setHorizontalStretch(0)
+ sizePolicy6.setVerticalStretch(0)
+ sizePolicy6.setHeightForWidth(self.flops.sizePolicy().hasHeightForWidth())
+ self.flops.setSizePolicy(sizePolicy6)
self.flops.setMinimumSize(QSize(200, 32))
self.flops.setStyleSheet(u"QLabel {\n"
" font-size: 18px;\n"
@@ -380,11 +385,11 @@ def setupUi(self, modelSummary):
self.parametersPieChart = PieChartWidget(self.scrollAreaWidgetContents)
self.parametersPieChart.setObjectName(u"parametersPieChart")
- sizePolicy6 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred)
- sizePolicy6.setHorizontalStretch(0)
- sizePolicy6.setVerticalStretch(0)
- sizePolicy6.setHeightForWidth(self.parametersPieChart.sizePolicy().hasHeightForWidth())
- self.parametersPieChart.setSizePolicy(sizePolicy6)
+ sizePolicy7 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred)
+ sizePolicy7.setHorizontalStretch(0)
+ sizePolicy7.setVerticalStretch(0)
+ sizePolicy7.setHeightForWidth(self.parametersPieChart.sizePolicy().hasHeightForWidth())
+ self.parametersPieChart.setSizePolicy(sizePolicy7)
self.parametersPieChart.setMinimumSize(QSize(300, 500))
self.firstRowChartsLayout.addWidget(self.parametersPieChart)
@@ -397,23 +402,29 @@ def setupUi(self, modelSummary):
self.secondRowChartsLayout.setContentsMargins(-1, 20, -1, -1)
self.similarityWidget = QWidget(self.scrollAreaWidgetContents)
self.similarityWidget.setObjectName(u"similarityWidget")
- sizePolicy.setHeightForWidth(self.similarityWidget.sizePolicy().hasHeightForWidth())
- self.similarityWidget.setSizePolicy(sizePolicy)
+ sizePolicy7.setHeightForWidth(self.similarityWidget.sizePolicy().hasHeightForWidth())
+ self.similarityWidget.setSizePolicy(sizePolicy7)
+ self.similarityWidget.setMinimumSize(QSize(300, 500))
self.placeholderWidget = QVBoxLayout(self.similarityWidget)
self.placeholderWidget.setObjectName(u"placeholderWidget")
self.similarityImg = ClickableLabel(self.similarityWidget)
self.similarityImg.setObjectName(u"similarityImg")
- sizePolicy.setHeightForWidth(self.similarityImg.sizePolicy().hasHeightForWidth())
- self.similarityImg.setSizePolicy(sizePolicy)
+ sizePolicy8 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
+ sizePolicy8.setHorizontalStretch(0)
+ sizePolicy8.setVerticalStretch(0)
+ sizePolicy8.setHeightForWidth(self.similarityImg.sizePolicy().hasHeightForWidth())
+ self.similarityImg.setSizePolicy(sizePolicy8)
+ self.similarityImg.setMinimumSize(QSize(0, 0))
self.similarityImg.setMaximumSize(QSize(16777215, 16777215))
+ self.similarityImg.setScaledContents(False)
self.similarityImg.setAlignment(Qt.AlignmentFlag.AlignCenter)
- self.placeholderWidget.addWidget(self.similarityImg, 0, Qt.AlignmentFlag.AlignHCenter)
+ self.placeholderWidget.addWidget(self.similarityImg)
self.similarityCorrelationStatic = QLabel(self.similarityWidget)
self.similarityCorrelationStatic.setObjectName(u"similarityCorrelationStatic")
- sizePolicy2.setHeightForWidth(self.similarityCorrelationStatic.sizePolicy().hasHeightForWidth())
- self.similarityCorrelationStatic.setSizePolicy(sizePolicy2)
+ sizePolicy3.setHeightForWidth(self.similarityCorrelationStatic.sizePolicy().hasHeightForWidth())
+ self.similarityCorrelationStatic.setSizePolicy(sizePolicy3)
self.similarityCorrelationStatic.setFont(font)
self.similarityCorrelationStatic.setAlignment(Qt.AlignmentFlag.AlignCenter)
@@ -421,8 +432,8 @@ def setupUi(self, modelSummary):
self.similarityCorrelation = QLabel(self.similarityWidget)
self.similarityCorrelation.setObjectName(u"similarityCorrelation")
- sizePolicy2.setHeightForWidth(self.similarityCorrelation.sizePolicy().hasHeightForWidth())
- self.similarityCorrelation.setSizePolicy(sizePolicy2)
+ sizePolicy3.setHeightForWidth(self.similarityCorrelation.sizePolicy().hasHeightForWidth())
+ self.similarityCorrelation.setSizePolicy(sizePolicy3)
palette = QPalette()
brush = QBrush(QColor(0, 0, 0, 255))
brush.setStyle(Qt.SolidPattern)
@@ -446,9 +457,6 @@ def setupUi(self, modelSummary):
self.flopsPieChart = PieChartWidget(self.scrollAreaWidgetContents)
self.flopsPieChart.setObjectName(u"flopsPieChart")
- sizePolicy7 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.Preferred)
- sizePolicy7.setHorizontalStretch(0)
- sizePolicy7.setVerticalStretch(0)
sizePolicy7.setHeightForWidth(self.flopsPieChart.sizePolicy().hasHeightForWidth())
self.flopsPieChart.setSizePolicy(sizePolicy7)
self.flopsPieChart.setMinimumSize(QSize(300, 500))
@@ -462,19 +470,25 @@ def setupUi(self, modelSummary):
self.verticalLayout_20.addLayout(self.chartsLayout)
self.thirdRowInputsLayout = QHBoxLayout()
+ self.thirdRowInputsLayout.setSpacing(6)
self.thirdRowInputsLayout.setObjectName(u"thirdRowInputsLayout")
self.thirdRowInputsLayout.setContentsMargins(20, 30, -1, -1)
self.inputsLayout = QVBoxLayout()
self.inputsLayout.setObjectName(u"inputsLayout")
self.inputsLabel = QLabel(self.scrollAreaWidgetContents)
self.inputsLabel.setObjectName(u"inputsLabel")
+ sizePolicy9 = QSizePolicy(QSizePolicy.Policy.Maximum, QSizePolicy.Policy.Maximum)
+ sizePolicy9.setHorizontalStretch(0)
+ sizePolicy9.setVerticalStretch(0)
+ sizePolicy9.setHeightForWidth(self.inputsLabel.sizePolicy().hasHeightForWidth())
+ self.inputsLabel.setSizePolicy(sizePolicy9)
self.inputsLabel.setStyleSheet(u"QLabel {\n"
" font-size: 18px;\n"
" font-weight: bold;\n"
" background: transparent;\n"
"}")
- self.inputsLayout.addWidget(self.inputsLabel, 0, Qt.AlignmentFlag.AlignVCenter)
+ self.inputsLayout.addWidget(self.inputsLabel)
self.inputsTable = QTableWidget(self.scrollAreaWidgetContents)
if (self.inputsTable.columnCount() < 4):
@@ -488,11 +502,11 @@ def setupUi(self, modelSummary):
__qtablewidgetitem3 = QTableWidgetItem()
self.inputsTable.setHorizontalHeaderItem(3, __qtablewidgetitem3)
self.inputsTable.setObjectName(u"inputsTable")
- sizePolicy8 = QSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Preferred)
- sizePolicy8.setHorizontalStretch(0)
- sizePolicy8.setVerticalStretch(0)
- sizePolicy8.setHeightForWidth(self.inputsTable.sizePolicy().hasHeightForWidth())
- self.inputsTable.setSizePolicy(sizePolicy8)
+ sizePolicy10 = QSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Expanding)
+ sizePolicy10.setHorizontalStretch(1)
+ sizePolicy10.setVerticalStretch(0)
+ sizePolicy10.setHeightForWidth(self.inputsTable.sizePolicy().hasHeightForWidth())
+ self.inputsTable.setSizePolicy(sizePolicy10)
self.inputsTable.setStyleSheet(u"QTableWidget {\n"
" gridline-color: #353535; /* Grid lines */\n"
" selection-background-color: #3949AB; /* Blue selection */\n"
@@ -543,18 +557,20 @@ def setupUi(self, modelSummary):
self.inputsTable.verticalHeader().setVisible(False)
self.inputsTable.verticalHeader().setHighlightSections(True)
- self.inputsLayout.addWidget(self.inputsTable, 0, Qt.AlignmentFlag.AlignVCenter)
+ self.inputsLayout.addWidget(self.inputsTable)
self.thirdRowInputsLayout.addLayout(self.inputsLayout)
self.freezeButton = QPushButton(self.scrollAreaWidgetContents)
self.freezeButton.setObjectName(u"freezeButton")
- sizePolicy9 = QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed)
- sizePolicy9.setHorizontalStretch(0)
- sizePolicy9.setVerticalStretch(0)
- sizePolicy9.setHeightForWidth(self.freezeButton.sizePolicy().hasHeightForWidth())
- self.freezeButton.setSizePolicy(sizePolicy9)
+ sizePolicy11 = QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed)
+ sizePolicy11.setHorizontalStretch(0)
+ sizePolicy11.setVerticalStretch(0)
+ sizePolicy11.setHeightForWidth(self.freezeButton.sizePolicy().hasHeightForWidth())
+ self.freezeButton.setSizePolicy(sizePolicy11)
+ self.freezeButton.setMinimumSize(QSize(0, 0))
+ self.freezeButton.setMaximumSize(QSize(16777215, 16777215))
self.freezeButton.setCursor(QCursor(Qt.CursorShape.PointingHandCursor))
self.freezeButton.setStyleSheet(u"QPushButton {\n"
" color: white;\n"
@@ -580,7 +596,7 @@ def setupUi(self, modelSummary):
self.freezeButton.setIcon(icon)
self.freezeButton.setIconSize(QSize(32, 32))
- self.thirdRowInputsLayout.addWidget(self.freezeButton, 0, Qt.AlignmentFlag.AlignTop)
+ self.thirdRowInputsLayout.addWidget(self.freezeButton)
self.horizontalSpacer = QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum)
@@ -616,8 +632,11 @@ def setupUi(self, modelSummary):
__qtablewidgetitem7 = QTableWidgetItem()
self.outputsTable.setHorizontalHeaderItem(3, __qtablewidgetitem7)
self.outputsTable.setObjectName(u"outputsTable")
- sizePolicy8.setHeightForWidth(self.outputsTable.sizePolicy().hasHeightForWidth())
- self.outputsTable.setSizePolicy(sizePolicy8)
+ sizePolicy12 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.Expanding)
+ sizePolicy12.setHorizontalStretch(0)
+ sizePolicy12.setVerticalStretch(0)
+ sizePolicy12.setHeightForWidth(self.outputsTable.sizePolicy().hasHeightForWidth())
+ self.outputsTable.setSizePolicy(sizePolicy12)
self.outputsTable.setStyleSheet(u"QTableWidget {\n"
" gridline-color: #353535; /* Grid lines */\n"
" selection-background-color: #3949AB; /* Blue selection */\n"
@@ -684,8 +703,8 @@ def setupUi(self, modelSummary):
self.sidePaneFrame = QFrame(modelSummary)
self.sidePaneFrame.setObjectName(u"sidePaneFrame")
- sizePolicy2.setHeightForWidth(self.sidePaneFrame.sizePolicy().hasHeightForWidth())
- self.sidePaneFrame.setSizePolicy(sizePolicy2)
+ sizePolicy3.setHeightForWidth(self.sidePaneFrame.sizePolicy().hasHeightForWidth())
+ self.sidePaneFrame.setSizePolicy(sizePolicy3)
self.sidePaneFrame.setMinimumSize(QSize(0, 0))
self.sidePaneFrame.setStyleSheet(u"QFrame {\n"
" /*background: rgb(30,30,30);*/\n"
@@ -741,8 +760,8 @@ def setupUi(self, modelSummary):
__qtablewidgetitem21 = QTableWidgetItem()
self.modelProtoTable.setItem(3, 1, __qtablewidgetitem21)
self.modelProtoTable.setObjectName(u"modelProtoTable")
- sizePolicy2.setHeightForWidth(self.modelProtoTable.sizePolicy().hasHeightForWidth())
- self.modelProtoTable.setSizePolicy(sizePolicy2)
+ sizePolicy3.setHeightForWidth(self.modelProtoTable.sizePolicy().hasHeightForWidth())
+ self.modelProtoTable.setSizePolicy(sizePolicy3)
self.modelProtoTable.setMinimumSize(QSize(0, 0))
self.modelProtoTable.setMaximumSize(QSize(16777215, 100))
self.modelProtoTable.setStyleSheet(u"QTableWidget::item {\n"
@@ -770,7 +789,7 @@ def setupUi(self, modelSummary):
self.modelProtoTable.verticalHeader().setMinimumSectionSize(20)
self.modelProtoTable.verticalHeader().setDefaultSectionSize(20)
- self.verticalLayout_3.addWidget(self.modelProtoTable, 0, Qt.AlignmentFlag.AlignRight)
+ self.verticalLayout_3.addWidget(self.modelProtoTable)
self.importsLabel = QLabel(self.sidePaneFrame)
self.importsLabel.setObjectName(u"importsLabel")
@@ -791,11 +810,8 @@ def setupUi(self, modelSummary):
__qtablewidgetitem23 = QTableWidgetItem()
self.importsTable.setHorizontalHeaderItem(1, __qtablewidgetitem23)
self.importsTable.setObjectName(u"importsTable")
- sizePolicy10 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.MinimumExpanding)
- sizePolicy10.setHorizontalStretch(0)
- sizePolicy10.setVerticalStretch(0)
- sizePolicy10.setHeightForWidth(self.importsTable.sizePolicy().hasHeightForWidth())
- self.importsTable.setSizePolicy(sizePolicy10)
+ sizePolicy2.setHeightForWidth(self.importsTable.sizePolicy().hasHeightForWidth())
+ self.importsTable.setSizePolicy(sizePolicy2)
self.importsTable.setStyleSheet(u"QTableWidget::item {\n"
" color: white;\n"
" padding: 5px;\n"
diff --git a/src/digest/ui/multimodelanalysis.ui b/src/digest/ui/multimodelanalysis.ui
index cf044e3..16109d0 100644
--- a/src/digest/ui/multimodelanalysis.ui
+++ b/src/digest/ui/multimodelanalysis.ui
@@ -6,8 +6,8 @@
0
0
- 908
- 647
+ 1085
+ 866
@@ -51,7 +51,7 @@
QFrame::Shadow::Raised
-
-
+
-
@@ -176,7 +176,7 @@
-
-
+
0
0
@@ -198,8 +198,8 @@
0
0
- 888
- 464
+ 1065
+ 688
@@ -242,6 +242,12 @@
-
+
+
+ 0
+ 0
+
+
QFrame::Shape::StyledPanel
@@ -258,7 +264,7 @@
QFrame::Shadow::Raised
- -
+
-
@@ -279,17 +285,19 @@
-
+
+
+ 0
+ 0
+
+
QFrame::Shape::StyledPanel
QFrame::Shadow::Raised
-
-
-
-
-
-
+
diff --git a/src/digest/ui/multimodelanalysis_ui.py b/src/digest/ui/multimodelanalysis_ui.py
index 54aa6d6..9f4b359 100644
--- a/src/digest/ui/multimodelanalysis_ui.py
+++ b/src/digest/ui/multimodelanalysis_ui.py
@@ -3,7 +3,7 @@
################################################################################
## Form generated from reading UI file 'multimodelanalysis.ui'
##
-## Created by: Qt User Interface Compiler version 6.8.0
+## Created by: Qt User Interface Compiler version 6.8.1
##
## WARNING! All changes made in this file will be lost when recompiling UI file!
################################################################################
@@ -26,7 +26,7 @@ class Ui_multiModelAnalysis(object):
def setupUi(self, multiModelAnalysis):
if not multiModelAnalysis.objectName():
multiModelAnalysis.setObjectName(u"multiModelAnalysis")
- multiModelAnalysis.resize(908, 647)
+ multiModelAnalysis.resize(1085, 866)
sizePolicy = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
sizePolicy.setHorizontalStretch(0)
sizePolicy.setVerticalStretch(0)
@@ -71,7 +71,7 @@ def setupUi(self, multiModelAnalysis):
self.modelName.setIndent(5)
self.modelName.setTextInteractionFlags(Qt.TextInteractionFlag.LinksAccessibleByMouse|Qt.TextInteractionFlag.TextSelectableByKeyboard|Qt.TextInteractionFlag.TextSelectableByMouse)
- self.verticalLayout_17.addWidget(self.modelName, 0, Qt.AlignmentFlag.AlignTop)
+ self.verticalLayout_17.addWidget(self.modelName)
self.summaryTopBannerLayout.addWidget(self.modelNameFrame)
@@ -127,7 +127,7 @@ def setupUi(self, multiModelAnalysis):
self.scrollArea = QScrollArea(multiModelAnalysis)
self.scrollArea.setObjectName(u"scrollArea")
- sizePolicy4 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.MinimumExpanding)
+ sizePolicy4 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.MinimumExpanding)
sizePolicy4.setHorizontalStretch(0)
sizePolicy4.setVerticalStretch(0)
sizePolicy4.setHeightForWidth(self.scrollArea.sizePolicy().hasHeightForWidth())
@@ -138,7 +138,7 @@ def setupUi(self, multiModelAnalysis):
self.scrollArea.setWidgetResizable(True)
self.scrollAreaWidgetContents = QWidget()
self.scrollAreaWidgetContents.setObjectName(u"scrollAreaWidgetContents")
- self.scrollAreaWidgetContents.setGeometry(QRect(0, 0, 888, 464))
+ self.scrollAreaWidgetContents.setGeometry(QRect(0, 0, 1065, 688))
sizePolicy5 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.MinimumExpanding)
sizePolicy5.setHorizontalStretch(0)
sizePolicy5.setVerticalStretch(100)
@@ -165,6 +165,11 @@ def setupUi(self, multiModelAnalysis):
self.frame_2 = QFrame(self.scrollAreaWidgetContents)
self.frame_2.setObjectName(u"frame_2")
+ sizePolicy7 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.Preferred)
+ sizePolicy7.setHorizontalStretch(0)
+ sizePolicy7.setVerticalStretch(0)
+ sizePolicy7.setHeightForWidth(self.frame_2.sizePolicy().hasHeightForWidth())
+ self.frame_2.setSizePolicy(sizePolicy7)
self.frame_2.setFrameShape(QFrame.Shape.StyledPanel)
self.frame_2.setFrameShadow(QFrame.Shadow.Raised)
self.horizontalLayout_2 = QHBoxLayout(self.frame_2)
@@ -177,29 +182,29 @@ def setupUi(self, multiModelAnalysis):
self.verticalLayout_3.setObjectName(u"verticalLayout_3")
self.opHistogramChart = HistogramChartWidget(self.combinedHistogramFrame)
self.opHistogramChart.setObjectName(u"opHistogramChart")
- sizePolicy7 = QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Minimum)
- sizePolicy7.setHorizontalStretch(0)
- sizePolicy7.setVerticalStretch(0)
- sizePolicy7.setHeightForWidth(self.opHistogramChart.sizePolicy().hasHeightForWidth())
- self.opHistogramChart.setSizePolicy(sizePolicy7)
+ sizePolicy8 = QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Minimum)
+ sizePolicy8.setHorizontalStretch(0)
+ sizePolicy8.setVerticalStretch(0)
+ sizePolicy8.setHeightForWidth(self.opHistogramChart.sizePolicy().hasHeightForWidth())
+ self.opHistogramChart.setSizePolicy(sizePolicy8)
self.opHistogramChart.setMinimumSize(QSize(500, 300))
- self.verticalLayout_3.addWidget(self.opHistogramChart, 0, Qt.AlignmentFlag.AlignTop)
+ self.verticalLayout_3.addWidget(self.opHistogramChart)
self.horizontalLayout_2.addWidget(self.combinedHistogramFrame)
self.stackedHistogramFrame = QFrame(self.frame_2)
self.stackedHistogramFrame.setObjectName(u"stackedHistogramFrame")
+ sizePolicy9 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred)
+ sizePolicy9.setHorizontalStretch(0)
+ sizePolicy9.setVerticalStretch(0)
+ sizePolicy9.setHeightForWidth(self.stackedHistogramFrame.sizePolicy().hasHeightForWidth())
+ self.stackedHistogramFrame.setSizePolicy(sizePolicy9)
self.stackedHistogramFrame.setFrameShape(QFrame.Shape.StyledPanel)
self.stackedHistogramFrame.setFrameShadow(QFrame.Shadow.Raised)
self.verticalLayout_5 = QVBoxLayout(self.stackedHistogramFrame)
self.verticalLayout_5.setObjectName(u"verticalLayout_5")
- self.verticalLayout_4 = QVBoxLayout()
- self.verticalLayout_4.setObjectName(u"verticalLayout_4")
-
- self.verticalLayout_5.addLayout(self.verticalLayout_4)
-
self.horizontalLayout_2.addWidget(self.stackedHistogramFrame)
diff --git a/src/digest/ui/multimodelselection_page.ui b/src/digest/ui/multimodelselection_page.ui
index c5d12f8..b40d460 100644
--- a/src/digest/ui/multimodelselection_page.ui
+++ b/src/digest/ui/multimodelselection_page.ui
@@ -52,7 +52,7 @@
-
-
-
+
-
@@ -68,7 +68,7 @@
- -
+
-
false
@@ -128,7 +128,7 @@
- Warning: The chosen folder contains more than MAX_ONNX_MODELS
+ Warning
2
@@ -141,7 +141,7 @@
-
-
-
+
0
@@ -161,44 +161,115 @@
- Select All
+ All
+
+
+ true
-
-
+
+
+
+ 0
+ 0
+
+
+
+
+ 0
+ 33
+
+
+
+ false
+
- 0 selected models
-
-
- true
+ ONNX
-
-
+
+
+
+ 0
+ 0
+
+
+
+
+ 0
+ 33
+
+
+
+ false
+
- The following models were found to be duplicates and have been deselected from the list on the left.
-
-
- true
+ Reports
+ -
+
+
+ Qt::Orientation::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
-
-
+
-
+
+
+
+ 0
+ 0
+
+
+
+
+ 0
+ 10
+
+
+
+
+
+
+ 0 selected models
+
+
+ true
+
+
+
-
+
+
+ 0
+ 0
+
+
@@ -217,10 +288,38 @@
-
+
-
+
+
+ true
+
+
+
+ 0
+ 0
+
+
+
+
+ 0
+ 10
+
+
+
+
+
+
+ Ignoring 0 duplicate model(s).
+
+
+ true
+
+
+
-
-
+
0
0
diff --git a/src/digest/ui/multimodelselection_page_ui.py b/src/digest/ui/multimodelselection_page_ui.py
index e6acb66..02a3bfc 100644
--- a/src/digest/ui/multimodelselection_page_ui.py
+++ b/src/digest/ui/multimodelselection_page_ui.py
@@ -3,7 +3,7 @@
################################################################################
## Form generated from reading UI file 'multimodelselection_page.ui'
##
-## Created by: Qt User Interface Compiler version 6.8.0
+## Created by: Qt User Interface Compiler version 6.8.1
##
## WARNING! All changes made in this file will be lost when recompiling UI file!
################################################################################
@@ -15,9 +15,9 @@
QFont, QFontDatabase, QGradient, QIcon,
QImage, QKeySequence, QLinearGradient, QPainter,
QPalette, QPixmap, QRadialGradient, QTransform)
-from PySide6.QtWidgets import (QAbstractItemView, QApplication, QCheckBox, QHBoxLayout,
- QLabel, QListView, QListWidget, QListWidgetItem,
- QPushButton, QSizePolicy, QSpacerItem, QVBoxLayout,
+from PySide6.QtWidgets import (QAbstractItemView, QApplication, QHBoxLayout, QLabel,
+ QListView, QListWidget, QListWidgetItem, QPushButton,
+ QRadioButton, QSizePolicy, QSpacerItem, QVBoxLayout,
QWidget)
class Ui_MultiModelSelection(object):
@@ -59,7 +59,7 @@ def setupUi(self, MultiModelSelection):
self.selectFolderBtn.setSizePolicy(sizePolicy)
self.selectFolderBtn.setStyleSheet(u"")
- self.horizontalLayout_2.addWidget(self.selectFolderBtn, 0, Qt.AlignmentFlag.AlignLeft|Qt.AlignmentFlag.AlignVCenter)
+ self.horizontalLayout_2.addWidget(self.selectFolderBtn)
self.openAnalysisBtn = QPushButton(MultiModelSelection)
self.openAnalysisBtn.setObjectName(u"openAnalysisBtn")
@@ -68,7 +68,7 @@ def setupUi(self, MultiModelSelection):
self.openAnalysisBtn.setSizePolicy(sizePolicy)
self.openAnalysisBtn.setStyleSheet(u"")
- self.horizontalLayout_2.addWidget(self.openAnalysisBtn, 0, Qt.AlignmentFlag.AlignLeft|Qt.AlignmentFlag.AlignVCenter)
+ self.horizontalLayout_2.addWidget(self.openAnalysisBtn)
self.horizontalSpacer = QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum)
@@ -104,29 +104,40 @@ def setupUi(self, MultiModelSelection):
self.horizontalLayout_3 = QHBoxLayout()
self.horizontalLayout_3.setObjectName(u"horizontalLayout_3")
- self.selectAllBox = QCheckBox(MultiModelSelection)
- self.selectAllBox.setObjectName(u"selectAllBox")
- sizePolicy.setHeightForWidth(self.selectAllBox.sizePolicy().hasHeightForWidth())
- self.selectAllBox.setSizePolicy(sizePolicy)
- self.selectAllBox.setMinimumSize(QSize(0, 33))
- self.selectAllBox.setAutoFillBackground(False)
- self.selectAllBox.setStyleSheet(u"")
+ self.radioAll = QRadioButton(MultiModelSelection)
+ self.radioAll.setObjectName(u"radioAll")
+ sizePolicy.setHeightForWidth(self.radioAll.sizePolicy().hasHeightForWidth())
+ self.radioAll.setSizePolicy(sizePolicy)
+ self.radioAll.setMinimumSize(QSize(0, 33))
+ self.radioAll.setAutoFillBackground(False)
+ self.radioAll.setStyleSheet(u"")
+ self.radioAll.setChecked(True)
- self.horizontalLayout_3.addWidget(self.selectAllBox)
+ self.horizontalLayout_3.addWidget(self.radioAll)
- self.numSelectedLabel = QLabel(MultiModelSelection)
- self.numSelectedLabel.setObjectName(u"numSelectedLabel")
- self.numSelectedLabel.setStyleSheet(u"")
- self.numSelectedLabel.setWordWrap(True)
+ self.radioONNX = QRadioButton(MultiModelSelection)
+ self.radioONNX.setObjectName(u"radioONNX")
+ sizePolicy.setHeightForWidth(self.radioONNX.sizePolicy().hasHeightForWidth())
+ self.radioONNX.setSizePolicy(sizePolicy)
+ self.radioONNX.setMinimumSize(QSize(0, 33))
+ self.radioONNX.setAutoFillBackground(False)
+ self.radioONNX.setStyleSheet(u"")
- self.horizontalLayout_3.addWidget(self.numSelectedLabel)
+ self.horizontalLayout_3.addWidget(self.radioONNX)
- self.duplicateLabel = QLabel(MultiModelSelection)
- self.duplicateLabel.setObjectName(u"duplicateLabel")
- self.duplicateLabel.setStyleSheet(u"")
- self.duplicateLabel.setWordWrap(True)
+ self.radioReports = QRadioButton(MultiModelSelection)
+ self.radioReports.setObjectName(u"radioReports")
+ sizePolicy.setHeightForWidth(self.radioReports.sizePolicy().hasHeightForWidth())
+ self.radioReports.setSizePolicy(sizePolicy)
+ self.radioReports.setMinimumSize(QSize(0, 33))
+ self.radioReports.setAutoFillBackground(False)
+ self.radioReports.setStyleSheet(u"")
+
+ self.horizontalLayout_3.addWidget(self.radioReports)
+
+ self.horizontalSpacer_2 = QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum)
- self.horizontalLayout_3.addWidget(self.duplicateLabel)
+ self.horizontalLayout_3.addItem(self.horizontalSpacer_2)
self.verticalLayout.addLayout(self.horizontalLayout_3)
@@ -135,8 +146,26 @@ def setupUi(self, MultiModelSelection):
self.columnsLayout.setObjectName(u"columnsLayout")
self.leftColumnLayout = QVBoxLayout()
self.leftColumnLayout.setObjectName(u"leftColumnLayout")
+ self.numSelectedLabel = QLabel(MultiModelSelection)
+ self.numSelectedLabel.setObjectName(u"numSelectedLabel")
+ sizePolicy2 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)
+ sizePolicy2.setHorizontalStretch(0)
+ sizePolicy2.setVerticalStretch(0)
+ sizePolicy2.setHeightForWidth(self.numSelectedLabel.sizePolicy().hasHeightForWidth())
+ self.numSelectedLabel.setSizePolicy(sizePolicy2)
+ self.numSelectedLabel.setMinimumSize(QSize(0, 10))
+ self.numSelectedLabel.setStyleSheet(u"")
+ self.numSelectedLabel.setWordWrap(True)
+
+ self.leftColumnLayout.addWidget(self.numSelectedLabel)
+
self.modelListView = QListView(MultiModelSelection)
self.modelListView.setObjectName(u"modelListView")
+ sizePolicy3 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred)
+ sizePolicy3.setHorizontalStretch(0)
+ sizePolicy3.setVerticalStretch(0)
+ sizePolicy3.setHeightForWidth(self.modelListView.sizePolicy().hasHeightForWidth())
+ self.modelListView.setSizePolicy(sizePolicy3)
self.modelListView.setStyleSheet(u"")
self.modelListView.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers)
self.modelListView.setSelectionMode(QAbstractItemView.SelectionMode.MultiSelection)
@@ -149,13 +178,24 @@ def setupUi(self, MultiModelSelection):
self.rightColumnLayout = QVBoxLayout()
self.rightColumnLayout.setObjectName(u"rightColumnLayout")
+ self.duplicateLabel = QLabel(MultiModelSelection)
+ self.duplicateLabel.setObjectName(u"duplicateLabel")
+ self.duplicateLabel.setEnabled(True)
+ sizePolicy2.setHeightForWidth(self.duplicateLabel.sizePolicy().hasHeightForWidth())
+ self.duplicateLabel.setSizePolicy(sizePolicy2)
+ self.duplicateLabel.setMinimumSize(QSize(0, 10))
+ self.duplicateLabel.setStyleSheet(u"")
+ self.duplicateLabel.setWordWrap(True)
+
+ self.rightColumnLayout.addWidget(self.duplicateLabel)
+
self.duplicateListWidget = QListWidget(MultiModelSelection)
self.duplicateListWidget.setObjectName(u"duplicateListWidget")
- sizePolicy2 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
- sizePolicy2.setHorizontalStretch(0)
- sizePolicy2.setVerticalStretch(0)
- sizePolicy2.setHeightForWidth(self.duplicateListWidget.sizePolicy().hasHeightForWidth())
- self.duplicateListWidget.setSizePolicy(sizePolicy2)
+ sizePolicy4 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.Expanding)
+ sizePolicy4.setHorizontalStretch(0)
+ sizePolicy4.setVerticalStretch(0)
+ sizePolicy4.setHeightForWidth(self.duplicateListWidget.sizePolicy().hasHeightForWidth())
+ self.duplicateListWidget.setSizePolicy(sizePolicy4)
self.duplicateListWidget.setStyleSheet(u"")
self.duplicateListWidget.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers)
self.duplicateListWidget.setSelectionMode(QAbstractItemView.SelectionMode.MultiSelection)
@@ -184,9 +224,11 @@ def retranslateUi(self, MultiModelSelection):
self.selectFolderBtn.setText(QCoreApplication.translate("MultiModelSelection", u"Select Folder", None))
self.openAnalysisBtn.setText(QCoreApplication.translate("MultiModelSelection", u"Open Analysis", None))
self.infoLabel.setText("")
- self.warningLabel.setText(QCoreApplication.translate("MultiModelSelection", u"Warning: The chosen folder contains more than MAX_ONNX_MODELS", None))
- self.selectAllBox.setText(QCoreApplication.translate("MultiModelSelection", u"Select All", None))
+ self.warningLabel.setText(QCoreApplication.translate("MultiModelSelection", u"Warning", None))
+ self.radioAll.setText(QCoreApplication.translate("MultiModelSelection", u"All", None))
+ self.radioONNX.setText(QCoreApplication.translate("MultiModelSelection", u"ONNX", None))
+ self.radioReports.setText(QCoreApplication.translate("MultiModelSelection", u"Reports", None))
self.numSelectedLabel.setText(QCoreApplication.translate("MultiModelSelection", u"0 selected models", None))
- self.duplicateLabel.setText(QCoreApplication.translate("MultiModelSelection", u"The following models were found to be duplicates and have been deselected from the list on the left.", None))
+ self.duplicateLabel.setText(QCoreApplication.translate("MultiModelSelection", u"Ignoring 0 duplicate model(s).", None))
# retranslateUi
diff --git a/src/digest/ui/nodessummary_ui.py b/src/digest/ui/nodessummary_ui.py
index 7efc69d..e0e400c 100644
--- a/src/digest/ui/nodessummary_ui.py
+++ b/src/digest/ui/nodessummary_ui.py
@@ -3,7 +3,7 @@
################################################################################
## Form generated from reading UI file 'nodessummary.ui'
##
-## Created by: Qt User Interface Compiler version 6.8.0
+## Created by: Qt User Interface Compiler version 6.8.1
##
## WARNING! All changes made in this file will be lost when recompiling UI file!
################################################################################
diff --git a/src/utils/onnx_utils.py b/src/utils/onnx_utils.py
index d8a6894..4d4b293 100644
--- a/src/utils/onnx_utils.py
+++ b/src/utils/onnx_utils.py
@@ -1,95 +1,19 @@
# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.
import os
-import csv
import tempfile
-from uuid import uuid4
-from collections import Counter, OrderedDict, defaultdict
-from typing import List, Dict, Optional, Any, Tuple, Union, cast
-from datetime import datetime
+from collections import Counter
+from typing import List, Optional, Tuple, Union
import numpy as np
import onnx
import onnxruntime as ort
-from prettytable import PrettyTable
-
-
-class NodeParsingException(Exception):
- pass
-
-
-# The classes are for type aliasing. Once python 3.10 is the minimum we can switch to TypeAlias
-class NodeShapeCounts(defaultdict[str, Counter]):
- def __init__(self):
- super().__init__(Counter) # Initialize with the Counter factory
-
-
-class NodeTypeCounts(Dict[str, int]):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
-
-class TensorInfo:
- "Used to store node input and output tensor information"
-
- def __init__(self) -> None:
- self.dtype: Optional[str] = None
- self.dtype_bytes: Optional[int] = None
- self.size_kbytes: Optional[float] = None
- self.shape: List[Union[int, str]] = []
-
-
-class TensorData(OrderedDict[str, TensorInfo]):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
-
-class NodeInfo:
- def __init__(self) -> None:
- self.flops: Optional[int] = None
- self.parameters: int = 0
- self.node_type: Optional[str] = None
- self.attributes: OrderedDict[str, Any] = OrderedDict()
- # We use an ordered dictionary because the order in which
- # the inputs and outputs are listed in the node matter.
- self.inputs = TensorData()
- self.outputs = TensorData()
-
- def get_input(self, index: int) -> TensorInfo:
- return list(self.inputs.values())[index]
-
- def get_output(self, index: int) -> TensorInfo:
- return list(self.outputs.values())[index]
-
- def __str__(self):
- """Provides a human-readable string representation of NodeInfo."""
- output = [
- f"Node Type: {self.node_type}",
- f"FLOPs: {self.flops if self.flops is not None else 'N/A'}",
- f"Parameters: {self.parameters}",
- ]
-
- if self.attributes:
- output.append("Attributes:")
- for key, value in self.attributes.items():
- output.append(f" - {key}: {value}")
-
- if self.inputs:
- output.append("Inputs:")
- for name, tensor in self.inputs.items():
- output.append(f" - {name}: {tensor}")
-
- if self.outputs:
- output.append("Outputs:")
- for name, tensor in self.outputs.items():
- output.append(f" - {name}: {tensor}")
-
- return "\n".join(output)
-
-
-# The classes are for type aliasing. Once python 3.10 is the minimum we can switch to TypeAlias
-class NodeData(OrderedDict[str, NodeInfo]):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
+from digest.model_class.digest_model import (
+ NodeTypeCounts,
+ NodeData,
+ NodeShapeCounts,
+ TensorData,
+ TensorInfo,
+)
# Convert tensor type to human-readable string and size in bytes
@@ -117,706 +41,6 @@ def tensor_type_to_str_and_size(elem_type) -> Tuple[str, int]:
return type_mapping.get(elem_type, ("unknown", 0))
-class DigestOnnxModel:
- def __init__(
- self,
- onnx_model: onnx.ModelProto,
- onnx_filepath: Optional[str] = None,
- model_name: Optional[str] = None,
- save_proto: bool = True,
- ) -> None:
- # Public members exposed to the API
- self.unique_id: str = str(uuid4())
- self.filepath: Optional[str] = onnx_filepath
- self.model_proto: Optional[onnx.ModelProto] = onnx_model if save_proto else None
- self.model_name: Optional[str] = model_name
- self.model_version: Optional[int] = None
- self.graph_name: Optional[str] = None
- self.producer_name: Optional[str] = None
- self.producer_version: Optional[str] = None
- self.ir_version: Optional[int] = None
- self.opset: Optional[int] = None
- self.imports: Dict[str, int] = {}
- self.node_type_counts: NodeTypeCounts = NodeTypeCounts()
- self.model_flops: Optional[int] = None
- self.model_parameters: int = 0
- self.node_type_flops: Dict[str, int] = {}
- self.node_type_parameters: Dict[str, int] = {}
- self.per_node_info = NodeData()
- self.model_inputs = TensorData()
- self.model_outputs = TensorData()
-
- # Private members not intended to be exposed
- self.input_tensors_: Dict[str, onnx.ValueInfoProto] = {}
- self.output_tensors_: Dict[str, onnx.ValueInfoProto] = {}
- self.value_tensors_: Dict[str, onnx.ValueInfoProto] = {}
- self.init_tensors_: Dict[str, onnx.TensorProto] = {}
-
- self.update_state(onnx_model)
-
- def update_state(self, model_proto: onnx.ModelProto) -> None:
- self.model_version = model_proto.model_version
- self.graph_name = model_proto.graph.name
- self.producer_name = model_proto.producer_name
- self.producer_version = model_proto.producer_version
- self.ir_version = model_proto.ir_version
- self.opset = get_opset(model_proto)
- self.imports = {
- import_.domain: import_.version for import_ in model_proto.opset_import
- }
-
- self.model_inputs = get_model_input_shapes_types(model_proto)
- self.model_outputs = get_model_output_shapes_types(model_proto)
-
- self.node_type_counts = get_node_type_counts(model_proto)
- self.parse_model_nodes(model_proto)
-
- def get_node_tensor_info_(
- self, onnx_node: onnx.NodeProto
- ) -> Tuple[TensorData, TensorData]:
- """
- This function is set to private because it is not intended to be used
- outside of the DigestOnnxModel class.
- """
-
- input_tensor_info = TensorData()
- for node_input in onnx_node.input:
- input_tensor_info[node_input] = TensorInfo()
- if (
- node_input in self.input_tensors_
- or node_input in self.value_tensors_
- or node_input in self.output_tensors_
- ):
- tensor = (
- self.input_tensors_.get(node_input)
- or self.value_tensors_.get(node_input)
- or self.output_tensors_.get(node_input)
- )
- if tensor:
- for dim in tensor.type.tensor_type.shape.dim:
- if dim.HasField("dim_value"):
- input_tensor_info[node_input].shape.append(dim.dim_value)
- elif dim.HasField("dim_param"):
- input_tensor_info[node_input].shape.append(dim.dim_param)
-
- dtype_str, dtype_bytes = tensor_type_to_str_and_size(
- tensor.type.tensor_type.elem_type
- )
- elif node_input in self.init_tensors_:
- input_tensor_info[node_input].shape.extend(
- [dim for dim in self.init_tensors_[node_input].dims]
- )
- dtype_str, dtype_bytes = tensor_type_to_str_and_size(
- self.init_tensors_[node_input].data_type
- )
- else:
- dtype_str = None
- dtype_bytes = None
-
- input_tensor_info[node_input].dtype = dtype_str
- input_tensor_info[node_input].dtype_bytes = dtype_bytes
-
- if (
- all(isinstance(s, int) for s in input_tensor_info[node_input].shape)
- and dtype_bytes
- ):
- tensor_size = float(
- np.prod(np.array(input_tensor_info[node_input].shape))
- )
- input_tensor_info[node_input].size_kbytes = (
- tensor_size * float(dtype_bytes) / 1024.0
- )
-
- output_tensor_info = TensorData()
- for node_output in onnx_node.output:
- output_tensor_info[node_output] = TensorInfo()
- if (
- node_output in self.input_tensors_
- or node_output in self.value_tensors_
- or node_output in self.output_tensors_
- ):
- tensor = (
- self.input_tensors_.get(node_output)
- or self.value_tensors_.get(node_output)
- or self.output_tensors_.get(node_output)
- )
- if tensor:
- output_tensor_info[node_output].shape.extend(
- [
- int(dim.dim_value)
- for dim in tensor.type.tensor_type.shape.dim
- ]
- )
- dtype_str, dtype_bytes = tensor_type_to_str_and_size(
- tensor.type.tensor_type.elem_type
- )
- elif node_output in self.init_tensors_:
- output_tensor_info[node_output].shape.extend(
- [dim for dim in self.init_tensors_[node_output].dims]
- )
- dtype_str, dtype_bytes = tensor_type_to_str_and_size(
- self.init_tensors_[node_output].data_type
- )
-
- else:
- dtype_str = None
- dtype_bytes = None
-
- output_tensor_info[node_output].dtype = dtype_str
- output_tensor_info[node_output].dtype_bytes = dtype_bytes
-
- if (
- all(isinstance(s, int) for s in output_tensor_info[node_output].shape)
- and dtype_bytes
- ):
- tensor_size = float(
- np.prod(np.array(output_tensor_info[node_output].shape))
- )
- output_tensor_info[node_output].size_kbytes = (
- tensor_size * float(dtype_bytes) / 1024.0
- )
-
- return input_tensor_info, output_tensor_info
-
- def parse_model_nodes(self, onnx_model: onnx.ModelProto) -> None:
- """
- Calculate total number of FLOPs found in the onnx model.
- FLOP is defined as one floating-point operation. This distinguishes
- from multiply-accumulates (MACs) where FLOPs == 2 * MACs.
- """
-
- # Initialze to zero so we can accumulate. Set to None during the
- # model FLOPs calculation if it errors out.
- self.model_flops = 0
-
- # Check to see if the model inputs have any dynamic shapes
- if get_dynamic_input_dims(onnx_model):
- self.model_flops = None
-
- try:
- onnx_model, _ = optimize_onnx_model(onnx_model)
-
- onnx_model = onnx.shape_inference.infer_shapes(
- onnx_model, strict_mode=True, data_prop=True
- )
- except Exception as e: # pylint: disable=broad-except
- print(f"ONNX utils: {str(e)}")
- self.model_flops = None
-
- # If the ONNX model contains one of the following unsupported ops, then this
- # function will return None since the FLOP total is expected to be incorrect
- unsupported_ops = [
- "Einsum",
- "RNN",
- "GRU",
- "DeformConv",
- ]
-
- if not self.input_tensors_:
- self.input_tensors_ = {
- tensor.name: tensor for tensor in onnx_model.graph.input
- }
-
- if not self.output_tensors_:
- self.output_tensors_ = {
- tensor.name: tensor for tensor in onnx_model.graph.output
- }
-
- if not self.value_tensors_:
- self.value_tensors_ = {
- tensor.name: tensor for tensor in onnx_model.graph.value_info
- }
-
- if not self.init_tensors_:
- self.init_tensors_ = {
- tensor.name: tensor for tensor in onnx_model.graph.initializer
- }
-
- for node in onnx_model.graph.node: # pylint: disable=E1101
-
- node_info = NodeInfo()
-
- # TODO: I have encountered models containing nodes with no name. It would be a good idea
- # to have this type of model info fed back to the user through a warnings section.
- if not node.name:
- node.name = f"{node.op_type}_{len(self.per_node_info)}"
-
- node_info.node_type = node.op_type
- input_tensor_info, output_tensor_info = self.get_node_tensor_info_(node)
- node_info.inputs = input_tensor_info
- node_info.outputs = output_tensor_info
-
- # Check if this node has parameters through the init tensors
- for input_name, input_tensor in node_info.inputs.items():
- if input_name in self.init_tensors_:
- if all(isinstance(dim, int) for dim in input_tensor.shape):
- input_parameters = int(np.prod(np.array(input_tensor.shape)))
- node_info.parameters += input_parameters
- self.model_parameters += input_parameters
- self.node_type_parameters[node.op_type] = (
- self.node_type_parameters.get(node.op_type, 0)
- + input_parameters
- )
- else:
- print(f"Tensor with params has unknown shape: {input_name}")
-
- for attribute in node.attribute:
- node_info.attributes.update(attribute_to_dict(attribute))
-
- # if node.name in self.per_node_info:
- # print(f"Node name {node.name} is a duplicate.")
-
- self.per_node_info[node.name] = node_info
-
- if node.op_type in unsupported_ops:
- self.model_flops = None
- node_info.flops = None
-
- try:
-
- if (
- node.op_type == "MatMul"
- or node.op_type == "MatMulInteger"
- or node.op_type == "QLinearMatMul"
- ):
-
- input_a = node_info.get_input(0).shape
- if node.op_type == "QLinearMatMul":
- input_b = node_info.get_input(3).shape
- else:
- input_b = node_info.get_input(1).shape
-
- if not all(
- isinstance(dim, int) for dim in input_a
- ) or not isinstance(input_b[-1], int):
- node_info.flops = None
- self.model_flops = None
- continue
-
- node_info.flops = int(
- 2 * np.prod(np.array(input_a), dtype=np.int64) * input_b[-1]
- )
-
- elif (
- node.op_type == "Mul"
- or node.op_type == "Div"
- or node.op_type == "Add"
- ):
- input_a = node_info.get_input(0).shape
- input_b = node_info.get_input(1).shape
-
- if not all(isinstance(dim, int) for dim in input_a) or not all(
- isinstance(dim, int) for dim in input_b
- ):
- node_info.flops = None
- self.model_flops = None
- continue
-
- node_info.flops = int(
- np.prod(np.array(input_a), dtype=np.int64)
- ) + int(np.prod(np.array(input_b), dtype=np.int64))
-
- elif node.op_type == "Gemm" or node.op_type == "QGemm":
- x_shape = node_info.get_input(0).shape
- if node.op_type == "Gemm":
- w_shape = node_info.get_input(1).shape
- else:
- w_shape = node_info.get_input(3).shape
-
- if not all(isinstance(dim, int) for dim in x_shape) or not all(
- isinstance(dim, int) for dim in w_shape
- ):
- node_info.flops = None
- self.model_flops = None
- continue
-
- mm_dims = [
- (
- x_shape[0]
- if not node_info.attributes.get("transA", 0)
- else x_shape[1]
- ),
- (
- x_shape[1]
- if not node_info.attributes.get("transA", 0)
- else x_shape[0]
- ),
- (
- w_shape[1]
- if not node_info.attributes.get("transB", 0)
- else w_shape[0]
- ),
- ]
-
- node_info.flops = int(
- 2 * np.prod(np.array(mm_dims), dtype=np.int64)
- )
-
- if len(mm_dims) == 3: # if there is a bias input
- bias_shape = node_info.get_input(2).shape
- node_info.flops += int(np.prod(np.array(bias_shape)))
-
- elif (
- node.op_type == "Conv"
- or node.op_type == "ConvInteger"
- or node.op_type == "QLinearConv"
- or node.op_type == "ConvTranspose"
- ):
- # N, C, d1, ..., dn
- x_shape = node_info.get_input(0).shape
-
- # M, C/group, k1, ..., kn. Note C and M are swapped for ConvTranspose
- if node.op_type == "QLinearConv":
- w_shape = node_info.get_input(3).shape
- else:
- w_shape = node_info.get_input(1).shape
-
- if not all(isinstance(dim, int) for dim in x_shape):
- node_info.flops = None
- self.model_flops = None
- continue
-
- x_shape_ints = cast(List[int], x_shape)
- w_shape_ints = cast(List[int], w_shape)
-
- has_bias = False # Note, ConvInteger has no bias
- if node.op_type == "Conv" and len(node_info.inputs) == 3:
- has_bias = True
- elif node.op_type == "QLinearConv" and len(node_info.inputs) == 9:
- has_bias = True
-
- num_dims = len(x_shape_ints) - 2
- strides = node_info.attributes.get(
- "strides", [1] * num_dims
- ) # type: List[int]
- dilation = node_info.attributes.get(
- "dilations", [1] * num_dims
- ) # type: List[int]
- kernel_shape = w_shape_ints[2:]
- batch_size = x_shape_ints[0]
- out_channels = w_shape_ints[0]
- out_dims = [batch_size, out_channels]
- output_shape = node_info.attributes.get(
- "output_shape", []
- ) # type: List[int]
-
- # If output_shape is given then we do not need to compute it ourselves
- # The output_shape attribute does not include batch_size or channels and
- # is only valid for ConvTranspose
- if output_shape:
- out_dims.extend(output_shape)
- else:
- auto_pad = node_info.attributes.get(
- "auto_pad", "NOTSET".encode()
- ).decode()
- # SAME expects padding so that the output_shape = CEIL(input_shape / stride)
- if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
- out_dims.extend(
- [x * s for x, s in zip(x_shape_ints[2:], strides)]
- )
- else:
- # NOTSET means just use pads attribute
- if auto_pad == "NOTSET":
- pads = node_info.attributes.get(
- "pads", [0] * num_dims * 2
- )
- # VALID essentially means no padding
- elif auto_pad == "VALID":
- pads = [0] * num_dims * 2
-
- for i in range(num_dims):
- dim_in = x_shape_ints[i + 2] # type: int
-
- if node.op_type == "ConvTranspose":
- out_dim = (
- strides[i] * (dim_in - 1)
- + ((kernel_shape[i] - 1) * dilation[i] + 1)
- - pads[i]
- - pads[i + num_dims]
- )
- else:
- out_dim = (
- dim_in
- + pads[i]
- + pads[i + num_dims]
- - dilation[i] * (kernel_shape[i] - 1)
- - 1
- ) // strides[i] + 1
-
- out_dims.append(out_dim)
-
- kernel_flops = int(
- np.prod(np.array(kernel_shape)) * w_shape_ints[1]
- )
- output_points = int(np.prod(np.array(out_dims)))
- bias_ops = output_points if has_bias else int(0)
- node_info.flops = 2 * kernel_flops * output_points + bias_ops
-
- elif node.op_type == "LSTM" or node.op_type == "DynamicQuantizeLSTM":
-
- x_shape = node_info.get_input(
- 0
- ).shape # seq_length, batch_size, input_dim
-
- if not all(isinstance(dim, int) for dim in x_shape):
- node_info.flops = None
- self.model_flops = None
- continue
-
- x_shape_ints = cast(List[int], x_shape)
- hidden_size = node_info.attributes["hidden_size"]
- direction = (
- 2
- if node_info.attributes.get("direction")
- == "bidirectional".encode()
- else 1
- )
-
- has_bias = True if len(node_info.inputs) >= 4 else False
- if has_bias:
- bias_shape = node_info.get_input(3).shape
- if isinstance(bias_shape[1], int):
- bias_ops = bias_shape[1]
- else:
- bias_ops = 0
- else:
- bias_ops = 0
- # seq_length, batch_size, input_dim = x_shape
- if not isinstance(bias_ops, int):
- bias_ops = int(0)
- num_gates = int(4)
- gate_input_flops = int(2 * x_shape_ints[2] * hidden_size)
- gate_hid_flops = int(2 * hidden_size * hidden_size)
- unit_flops = (
- num_gates * (gate_input_flops + gate_hid_flops) + bias_ops
- )
- node_info.flops = (
- x_shape_ints[1] * x_shape_ints[0] * direction * unit_flops
- )
- # In this case we just hit an op that doesn't have FLOPs
- else:
- node_info.flops = None
-
- except IndexError as err:
- print(f"Error parsing node {node.name}: {err}")
- node_info.flops = None
- self.model_flops = None
- continue
-
- # Update the model level flops count
- if node_info.flops is not None and self.model_flops is not None:
- self.model_flops += node_info.flops
-
- # Update the node type flops count
- self.node_type_flops[node.op_type] = (
- self.node_type_flops.get(node.op_type, 0) + node_info.flops
- )
-
- def save_txt_report(self, filepath: str) -> None:
-
- parent_dir = os.path.dirname(os.path.abspath(filepath))
- if not os.path.exists(parent_dir):
- raise FileNotFoundError(f"Directory {parent_dir} does not exist.")
-
- report_date = datetime.now().strftime("%B %d, %Y")
-
- with open(filepath, "w", encoding="utf-8") as f_p:
- f_p.write(f"Report created on {report_date}\n")
- if self.filepath:
- f_p.write(f"ONNX file: {self.filepath}\n")
- f_p.write(f"Name of the model: {self.model_name}\n")
- f_p.write(f"Model version: {self.model_version}\n")
- f_p.write(f"Name of the graph: {self.graph_name}\n")
- f_p.write(f"Producer: {self.producer_name} {self.producer_version}\n")
- f_p.write(f"Ir version: {self.ir_version}\n")
- f_p.write(f"Opset: {self.opset}\n\n")
- f_p.write("Import list\n")
- for name, version in self.imports.items():
- f_p.write(f"\t{name}: {version}\n")
-
- f_p.write("\n")
- f_p.write(f"Total graph nodes: {sum(self.node_type_counts.values())}\n")
- f_p.write(f"Number of parameters: {self.model_parameters}\n")
- if self.model_flops:
- f_p.write(f"Number of FLOPs: {self.model_flops}\n")
- f_p.write("\n")
-
- table_op_intensity = PrettyTable()
- table_op_intensity.field_names = ["Operation", "FLOPs", "Intensity (%)"]
- for op_type, count in self.node_type_flops.items():
- if count > 0:
- table_op_intensity.add_row(
- [
- op_type,
- count,
- 100.0 * float(count) / float(self.model_flops),
- ]
- )
-
- f_p.write("Op intensity:\n")
- f_p.write(table_op_intensity.get_string())
- f_p.write("\n\n")
-
- node_counts_table = PrettyTable()
- node_counts_table.field_names = ["Node", "Occurrences"]
- for op, count in self.node_type_counts.items():
- node_counts_table.add_row([op, count])
- f_p.write("Nodes and their occurrences:\n")
- f_p.write(node_counts_table.get_string())
- f_p.write("\n\n")
-
- input_table = PrettyTable()
- input_table.field_names = [
- "Input Name",
- "Shape",
- "Type",
- "Tensor Size (KB)",
- ]
- for input_name, input_details in self.model_inputs.items():
- if input_details.size_kbytes:
- kbytes = f"{input_details.size_kbytes:.2f}"
- else:
- kbytes = ""
-
- input_table.add_row(
- [
- input_name,
- input_details.shape,
- input_details.dtype,
- kbytes,
- ]
- )
- f_p.write("Input Tensor(s) Information:\n")
- f_p.write(input_table.get_string())
- f_p.write("\n\n")
-
- output_table = PrettyTable()
- output_table.field_names = [
- "Output Name",
- "Shape",
- "Type",
- "Tensor Size (KB)",
- ]
- for output_name, output_details in self.model_outputs.items():
- if output_details.size_kbytes:
- kbytes = f"{output_details.size_kbytes:.2f}"
- else:
- kbytes = ""
-
- output_table.add_row(
- [
- output_name,
- output_details.shape,
- output_details.dtype,
- kbytes,
- ]
- )
- f_p.write("Output Tensor(s) Information:\n")
- f_p.write(output_table.get_string())
- f_p.write("\n\n")
-
- def save_nodes_csv_report(self, filepath: str) -> None:
- save_nodes_csv_report(self.per_node_info, filepath)
-
- def get_node_type_counts(self) -> Union[NodeTypeCounts, None]:
- if not self.node_type_counts and self.model_proto:
- self.node_type_counts = get_node_type_counts(self.model_proto)
- return self.node_type_counts if self.node_type_counts else None
-
- def get_node_shape_counts(self) -> NodeShapeCounts:
- tensor_shape_counter = NodeShapeCounts()
- for _, info in self.per_node_info.items():
- shape_hash = tuple([tuple(v.shape) for _, v in info.inputs.items()])
- if info.node_type:
- tensor_shape_counter[info.node_type][shape_hash] += 1
- return tensor_shape_counter
-
-
-def save_nodes_csv_report(node_data: NodeData, filepath: str) -> None:
-
- parent_dir = os.path.dirname(os.path.abspath(filepath))
- if not os.path.exists(parent_dir):
- raise FileNotFoundError(f"Directory {parent_dir} does not exist.")
-
- flattened_data = []
- fieldnames = ["Node Name", "Node Type", "Parameters", "FLOPs", "Attributes"]
- input_fieldnames = []
- output_fieldnames = []
- for name, node_info in node_data.items():
- row = OrderedDict()
- row["Node Name"] = name
- row["Node Type"] = str(node_info.node_type)
- row["Parameters"] = str(node_info.parameters)
- row["FLOPs"] = str(node_info.flops)
- if node_info.attributes:
- row["Attributes"] = str({k: v for k, v in node_info.attributes.items()})
- else:
- row["Attributes"] = ""
-
- for i, (input_name, input_info) in enumerate(node_info.inputs.items()):
- column_name = f"Input{i+1} (Shape, Dtype, Size (kB))"
- row[column_name] = (
- f"{input_name} ({input_info.shape}, {input_info.dtype}, {input_info.size_kbytes})"
- )
-
- # Dynamically add input column names to fieldnames if not already present
- if column_name not in input_fieldnames:
- input_fieldnames.append(column_name)
-
- for i, (output_name, output_info) in enumerate(node_info.outputs.items()):
- column_name = f"Output{i+1} (Shape, Dtype, Size (kB))"
- row[column_name] = (
- f"{output_name} ({output_info.shape}, "
- f"{output_info.dtype}, {output_info.size_kbytes})"
- )
-
- # Dynamically add input column names to fieldnames if not already present
- if column_name not in output_fieldnames:
- output_fieldnames.append(column_name)
-
- flattened_data.append(row)
-
- fieldnames = fieldnames + input_fieldnames + output_fieldnames
- with open(filepath, "w", encoding="utf-8", newline="") as csvfile:
- writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator="\n")
- writer.writeheader()
- writer.writerows(flattened_data)
-
-
-def save_node_type_counts_csv_report(node_data: NodeTypeCounts, filepath: str) -> None:
-
- parent_dir = os.path.dirname(os.path.abspath(filepath))
- if not os.path.exists(parent_dir):
- raise FileNotFoundError(f"Directory {parent_dir} does not exist.")
-
- header = ["Node Type", "Count"]
-
- with open(filepath, "w", encoding="utf-8", newline="") as csvfile:
- writer = csv.writer(csvfile, lineterminator="\n")
- writer.writerow(header)
- for node_type, node_count in node_data.items():
- writer.writerow([node_type, node_count])
-
-
-def save_node_shape_counts_csv_report(
- node_data: NodeShapeCounts, filepath: str
-) -> None:
-
- parent_dir = os.path.dirname(os.path.abspath(filepath))
- if not os.path.exists(parent_dir):
- raise FileNotFoundError(f"Directory {parent_dir} does not exist.")
-
- header = ["Node Type", "Input Tensors Shapes", "Count"]
-
- with open(filepath, "w", encoding="utf-8", newline="") as csvfile:
- writer = csv.writer(csvfile, dialect="excel", lineterminator="\n")
- writer.writerow(header)
- for node_type, node_info in node_data.items():
- info_iter = iter(node_info.items())
- for shape, count in info_iter:
- writer.writerow([node_type, shape, count])
-
-
def load_onnx(onnx_path: str, load_external_data: bool = True) -> onnx.ModelProto:
if os.path.exists(onnx_path):
return onnx.load(onnx_path, load_external_data=load_external_data)
diff --git a/test/resnet18_reports/resnet18_heatmap.png b/test/resnet18_reports/resnet18_heatmap.png
new file mode 100644
index 0000000..1fb614e
Binary files /dev/null and b/test/resnet18_reports/resnet18_heatmap.png differ
diff --git a/test/resnet18_reports/resnet18_histogram.png b/test/resnet18_reports/resnet18_histogram.png
new file mode 100644
index 0000000..eb13e01
Binary files /dev/null and b/test/resnet18_reports/resnet18_histogram.png differ
diff --git a/test/resnet18_reports/resnet18_node_type_counts.csv b/test/resnet18_reports/resnet18_node_type_counts.csv
new file mode 100644
index 0000000..29504ba
--- /dev/null
+++ b/test/resnet18_reports/resnet18_node_type_counts.csv
@@ -0,0 +1,8 @@
+Node Type,Count
+Conv,20
+Relu,17
+Add,8
+MaxPool,1
+GlobalAveragePool,1
+Flatten,1
+Gemm,1
diff --git a/test/resnet18_test_nodes.csv b/test/resnet18_reports/resnet18_nodes.csv
similarity index 100%
rename from test/resnet18_test_nodes.csv
rename to test/resnet18_reports/resnet18_nodes.csv
diff --git a/test/resnet18_test_summary.txt b/test/resnet18_reports/resnet18_report.txt
similarity index 88%
rename from test/resnet18_test_summary.txt
rename to test/resnet18_reports/resnet18_report.txt
index a5b4cfb..d68027e 100644
--- a/test/resnet18_test_summary.txt
+++ b/test/resnet18_reports/resnet18_report.txt
@@ -1,5 +1,5 @@
-Report created on June 02, 2024
-ONNX file: resnet18.onnx
+Report created on December 06, 2024
+ONNX file: test\resnet18.onnx
Name of the model: resnet18
Model version: 0
Name of the graph: main_graph
@@ -9,6 +9,13 @@ Opset: 17
Import list
: 17
+ ai.onnx.ml: 5
+ ai.onnx.preview.training: 1
+ ai.onnx.training: 1
+ com.microsoft: 1
+ com.microsoft.experimental: 1
+ com.microsoft.nchwc: 1
+ org.pytorch.aten: 1
Total graph nodes: 49
Number of parameters: 11684712
diff --git a/test/resnet18_reports/resnet18_report.yaml b/test/resnet18_reports/resnet18_report.yaml
new file mode 100644
index 0000000..9839302
--- /dev/null
+++ b/test/resnet18_reports/resnet18_report.yaml
@@ -0,0 +1,56 @@
+report_date: December 06, 2024
+model_file: test\resnet18.onnx
+model_type: onnx
+model_name: resnet18
+model_version: 0
+graph_name: main_graph
+producer_name: pytorch
+producer_version: 2.1.0
+ir_version: 8
+opset: 17
+import_list:
+ ? ''
+ : 17
+ ai.onnx.ml: 5
+ ai.onnx.preview.training: 1
+ ai.onnx.training: 1
+ com.microsoft: 1
+ com.microsoft.experimental: 1
+ com.microsoft.nchwc: 1
+ org.pytorch.aten: 1
+graph_nodes: 49
+parameters: 11684712
+flops: 3632136680
+node_type_counts:
+ Conv: 20
+ Relu: 17
+ Add: 8
+ MaxPool: 1
+ GlobalAveragePool: 1
+ Flatten: 1
+ Gemm: 1
+node_type_flops:
+ Conv: 3629606400
+ Add: 1505280
+ Gemm: 1025000
+node_type_parameters:
+ Conv: 11171712
+ Gemm: 513000
+input_tensors:
+ input.1:
+ dtype: float32
+ dtype_bytes: 4
+ size_kbytes: 588.0
+ shape:
+ - 1
+ - 3
+ - 224
+ - 224
+output_tensors:
+ '191':
+ dtype: float32
+ dtype_bytes: 4
+ size_kbytes: 3.90625
+ shape:
+ - 1
+ - 1000
diff --git a/test/test_gui.py b/test/test_gui.py
index 0e1d351..59fbb8f 100644
--- a/test/test_gui.py
+++ b/test/test_gui.py
@@ -8,60 +8,97 @@
# pylint: disable=no-name-in-module
from PySide6.QtTest import QTest
-from PySide6.QtCore import Qt, QDeadlineTimer
+from PySide6.QtCore import Qt
from PySide6.QtWidgets import QApplication
import digest.main
from digest.node_summary import NodeSummary
-ONNX_BASENAME = "resnet18"
-TEST_DIR = os.path.abspath(os.path.dirname(__file__))
-ONNX_FILEPATH = os.path.normpath(os.path.join(TEST_DIR, f"{ONNX_BASENAME}.onnx"))
-
class DigestGuiTest(unittest.TestCase):
+ MODEL_BASENAME = "resnet18"
+ TEST_DIR = os.path.abspath(os.path.dirname(__file__))
+ ONNX_FILEPATH = os.path.normpath(os.path.join(TEST_DIR, f"{MODEL_BASENAME}.onnx"))
+ YAML_FILEPATH = os.path.normpath(
+ os.path.join(
+ TEST_DIR, f"{MODEL_BASENAME}_reports", f"{MODEL_BASENAME}_report.yaml"
+ )
+ )
@classmethod
def setUpClass(cls):
cls.app = QApplication(sys.argv)
+ return super().setUpClass()
+
+ @classmethod
+ def tearDownClass(cls):
+ if isinstance(cls.app, QApplication):
+ cls.app.closeAllWindows()
+ cls.app = None
def setUp(self):
self.digest_app = digest.main.DigestApp()
self.digest_app.show()
def tearDown(self):
- self.wait_all_threads()
self.digest_app.close()
- def wait_all_threads(self):
+ def wait_all_threads(self, timeout=10000) -> bool:
+ all_threads = list(self.digest_app.model_nodes_stats_thread.values()) + list(
+ self.digest_app.model_similarity_thread.values()
+ )
- for thread in self.digest_app.model_nodes_stats_thread.values():
- thread.wait(deadline=QDeadlineTimer.Forever)
+ for thread in all_threads:
+ thread.wait(timeout)
- for thread in self.digest_app.model_similarity_thread.values():
- thread.wait(deadline=QDeadlineTimer.Forever)
+ # Return True if all threads finished, False if timed out
+ return all(thread.isFinished() for thread in all_threads)
def test_open_valid_onnx(self):
with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog:
mock_dialog.return_value = (
- ONNX_FILEPATH,
+ self.ONNX_FILEPATH,
+ "",
+ )
+
+ num_tabs_prior = self.digest_app.ui.tabWidget.count()
+
+ QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton)
+
+ self.assertTrue(self.wait_all_threads())
+
+ self.assertTrue(
+ self.digest_app.ui.tabWidget.count() == num_tabs_prior + 1
+ ) # Check if a tab was added
+
+ self.digest_app.closeTab(num_tabs_prior)
+
+ def test_open_valid_yaml(self):
+ with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog:
+ mock_dialog.return_value = (
+ self.YAML_FILEPATH,
"",
)
+ num_tabs_prior = self.digest_app.ui.tabWidget.count()
+
QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton)
- self.wait_all_threads()
+ self.assertTrue(self.wait_all_threads())
self.assertTrue(
- self.digest_app.ui.tabWidget.count() > 0
+ self.digest_app.ui.tabWidget.count() == num_tabs_prior + 1
) # Check if a tab was added
+ self.digest_app.closeTab(num_tabs_prior)
+
def test_open_invalid_file(self):
with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog:
mock_dialog.return_value = ("invalid_file.txt", "")
+ num_tabs_prior = self.digest_app.ui.tabWidget.count()
QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton)
- self.wait_all_threads()
- self.assertEqual(self.digest_app.ui.tabWidget.count(), 0)
+ self.assertTrue(self.wait_all_threads())
+ self.assertEqual(self.digest_app.ui.tabWidget.count(), num_tabs_prior)
def test_save_reports(self):
with patch(
@@ -70,7 +107,7 @@ def test_save_reports(self):
"PySide6.QtWidgets.QFileDialog.getExistingDirectory"
) as mock_save_dialog:
- mock_open_dialog.return_value = (ONNX_FILEPATH, "")
+ mock_open_dialog.return_value = (self.ONNX_FILEPATH, "")
with tempfile.TemporaryDirectory() as tmpdirname:
mock_save_dialog.return_value = tmpdirname
@@ -79,45 +116,57 @@ def test_save_reports(self):
Qt.MouseButton.LeftButton,
)
- self.wait_all_threads()
+ self.assertTrue(self.wait_all_threads())
- # This is a slight hack but the issue is that model similarity takes
- # a bit longer to complete and we must have it done before the save
- # button is enabled guaranteeing all the artifacts are saved.
- # wait_all_threads() above doesn't seem to work. The only thing that
- # does is just waiting 5 seconds.
- QTest.qWait(5000)
+ self.assertTrue(
+ self.digest_app.ui.saveBtn.isEnabled(), "Save button is disabled!"
+ )
QTest.mouseClick(self.digest_app.ui.saveBtn, Qt.MouseButton.LeftButton)
mock_save_dialog.assert_called_once()
- result_basepath = os.path.join(tmpdirname, f"{ONNX_BASENAME}_reports")
+ result_basepath = os.path.join(
+ tmpdirname, f"{self.MODEL_BASENAME}_reports"
+ )
# Text report test
- txt_report_filepath = os.path.join(
- result_basepath, f"{ONNX_BASENAME}_report.txt"
+ text_report_filepath = os.path.join(
+ result_basepath, f"{self.MODEL_BASENAME}_report.txt"
)
- self.assertTrue(os.path.isfile(txt_report_filepath))
+ self.assertTrue(
+ os.path.isfile(text_report_filepath),
+ f"{text_report_filepath} not found!",
+ )
+
+ # YAML report test
+ yaml_report_filepath = os.path.join(
+ result_basepath, f"{self.MODEL_BASENAME}_report.yaml"
+ )
+ self.assertTrue(os.path.isfile(yaml_report_filepath))
# Nodes test
nodes_csv_report_filepath = os.path.join(
- result_basepath, f"{ONNX_BASENAME}_nodes.csv"
+ result_basepath, f"{self.MODEL_BASENAME}_nodes.csv"
)
self.assertTrue(os.path.isfile(nodes_csv_report_filepath))
# Histogram test
histogram_filepath = os.path.join(
- result_basepath, f"{ONNX_BASENAME}_histogram.png"
+ result_basepath, f"{self.MODEL_BASENAME}_histogram.png"
)
self.assertTrue(os.path.isfile(histogram_filepath))
# Heatmap test
heatmap_filepath = os.path.join(
- result_basepath, f"{ONNX_BASENAME}_heatmap.png"
+ result_basepath, f"{self.MODEL_BASENAME}_heatmap.png"
)
self.assertTrue(os.path.isfile(heatmap_filepath))
+ num_tabs = self.digest_app.ui.tabWidget.count()
+ self.assertTrue(num_tabs == 1)
+ self.digest_app.closeTab(0)
+
def test_save_tables(self):
with patch(
"PySide6.QtWidgets.QFileDialog.getOpenFileName"
@@ -125,10 +174,10 @@ def test_save_tables(self):
"PySide6.QtWidgets.QFileDialog.getSaveFileName"
) as mock_save_dialog:
- mock_open_dialog.return_value = (ONNX_FILEPATH, "")
+ mock_open_dialog.return_value = (self.ONNX_FILEPATH, "")
with tempfile.TemporaryDirectory() as tmpdirname:
mock_save_dialog.return_value = (
- os.path.join(tmpdirname, f"{ONNX_BASENAME}_nodes.csv"),
+ os.path.join(tmpdirname, f"{self.MODEL_BASENAME}_nodes.csv"),
"",
)
@@ -136,17 +185,19 @@ def test_save_tables(self):
self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton
)
- self.wait_all_threads()
+ self.assertTrue(self.wait_all_threads())
QTest.mouseClick(
self.digest_app.ui.nodesListBtn, Qt.MouseButton.LeftButton
)
- # We assume there is only model loaded
+ # We assume there is only one model loaded
_, node_window = self.digest_app.nodes_window.popitem()
node_summary = node_window.main_window.centralWidget()
self.assertIsInstance(node_summary, NodeSummary)
+
+ # This line of code seems redundant but we do this to clean pylance
if isinstance(node_summary, NodeSummary):
QTest.mouseClick(
node_summary.ui.saveCsvBtn, Qt.MouseButton.LeftButton
@@ -156,11 +207,15 @@ def test_save_tables(self):
self.assertTrue(
os.path.exists(
- os.path.join(tmpdirname, f"{ONNX_BASENAME}_nodes.csv")
+ os.path.join(tmpdirname, f"{self.MODEL_BASENAME}_nodes.csv")
),
"Nodes csv file not found.",
)
+ num_tabs = self.digest_app.ui.tabWidget.count()
+ self.assertTrue(num_tabs == 1)
+ self.digest_app.closeTab(0)
+
if __name__ == "__main__":
unittest.main()
diff --git a/test/test_reports.py b/test/test_reports.py
index a16c4d8..ae99ab9 100644
--- a/test/test_reports.py
+++ b/test/test_reports.py
@@ -1,17 +1,22 @@
# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.
-"""Unit tests for Vitis ONNX Model Analyzer """
-
import os
import unittest
import tempfile
import csv
-from utils.onnx_utils import DigestOnnxModel, load_onnx
+import utils.onnx_utils as onnx_utils
+from digest.model_class.digest_onnx_model import DigestOnnxModel
+from digest.model_class.digest_report_model import compare_yaml_files
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
TEST_ONNX = os.path.join(TEST_DIR, "resnet18.onnx")
-TEST_SUMMARY_TXT_REPORT = os.path.join(TEST_DIR, "resnet18_test_summary.txt")
-TEST_NODES_CSV_REPORT = os.path.join(TEST_DIR, "resnet18_test_nodes.csv")
+TEST_SUMMARY_TEXT_REPORT = os.path.join(
+ TEST_DIR, "resnet18_reports/resnet18_report.txt"
+)
+TEST_SUMMARY_YAML_REPORT = os.path.join(
+ TEST_DIR, "resnet18_reports/resnet18_report.yaml"
+)
+TEST_NODES_CSV_REPORT = os.path.join(TEST_DIR, "resnet18_reports/resnet18_nodes.csv")
class TestDigestReports(unittest.TestCase):
@@ -46,27 +51,35 @@ def compare_csv_files(self, file1, file2, skip_lines=0):
self.assertEqual(row1, row2, msg=f"Difference in row: {row1} vs {row2}")
def test_against_example_reports(self):
- model_proto = load_onnx(TEST_ONNX)
+ model_proto = onnx_utils.load_onnx(TEST_ONNX, load_external_data=False)
model_name = os.path.splitext(os.path.basename(TEST_ONNX))[0]
+ opt_model, _ = onnx_utils.optimize_onnx_model(model_proto)
digest_model = DigestOnnxModel(
- model_proto, onnx_filepath=TEST_ONNX, model_name=model_name, save_proto=False,
+ opt_model,
+ onnx_filepath=TEST_ONNX,
+ model_name=model_name,
+ save_proto=False,
)
with tempfile.TemporaryDirectory() as tmpdir:
- # Model summary text report
- summary_filepath = os.path.join(tmpdir, f"{model_name}_summary.txt")
- digest_model.save_txt_report(summary_filepath)
-
- with self.subTest("Testing summary text file"):
- self.compare_files_line_by_line(
- TEST_SUMMARY_TXT_REPORT,
- summary_filepath,
- skip_lines=2,
+ # Model yaml report
+ yaml_report_filepath = os.path.join(tmpdir, f"{model_name}_report.yaml")
+ digest_model.save_yaml_report(yaml_report_filepath)
+ with self.subTest("Testing report yaml file"):
+ self.assertTrue(
+ compare_yaml_files(
+ TEST_SUMMARY_YAML_REPORT,
+ yaml_report_filepath,
+ skip_keys=["report_date", "model_file", "digest_version"],
+ )
)
# Save CSV containing node-level information
nodes_filepath = os.path.join(tmpdir, f"{model_name}_nodes.csv")
digest_model.save_nodes_csv_report(nodes_filepath)
-
with self.subTest("Testing nodes csv file"):
self.compare_csv_files(TEST_NODES_CSV_REPORT, nodes_filepath)
+
+
+if __name__ == "__main__":
+ unittest.main()