Skip to content

Commit

Permalink
Resolves: Multi-Model Analysis Fix and Improvement (#17)
Browse files Browse the repository at this point in the history
* Multi-Model Analysis Fix and Improvements

---------

Co-authored-by: Philip Colangelo <[email protected]>
  • Loading branch information
pcolange and Philip Colangelo authored Jan 30, 2025
1 parent cc55d21 commit f4337d8
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 35 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools>=64", "wheel"]
build-backend = "setuptools.build_meta"
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="digestai",
version="1.1.0",
version="1.1.1",
description="Model analysis toolkit",
author="Philip Colangelo, Daniel Holanda",
packages=find_packages(where="src"),
Expand Down
5 changes: 0 additions & 5 deletions src/digest/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def __init__(self, label: str, num_steps: int, parent=None):
self.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg"))
self.setValue(1)

self.user_canceled = False
self.canceled.connect(self.cancel)
self.step_size = 1
self.current_step = 0
self.num_steps = num_steps
Expand All @@ -46,9 +44,6 @@ def step(self):
self.current_step = self.num_steps
self.setValue(self.current_step)

def cancel(self):
self.user_canceled = True


class InfoDialog(QDialog):
"""This is a specific dialog class used to display the package information
Expand Down
5 changes: 2 additions & 3 deletions src/digest/freeze_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def apply_static_shapes(self) -> None:
# to arrive in this function with shapes for each dynamic dim.
dims: Dict[str, int] = {}
for i in range(self.ui.formLayout.rowCount()):
if status.user_canceled:
if status.wasCanceled():
break
status.step()
label_item = self.ui.formLayout.itemAt(i, QFormLayout.ItemRole.LabelRole)
Expand All @@ -168,7 +168,7 @@ def apply_static_shapes(self) -> None:
dims[label_text] = line_edit_value

for tensor in self.model_proto.graph.input:
if status.user_canceled:
if status.wasCanceled():
break
status.step()
tensor_shape = []
Expand Down Expand Up @@ -200,7 +200,6 @@ def apply_static_shapes(self) -> None:
except checker.ValidationError as e:
self.show_warning_and_disable_page()
print(f"Model did not pass checker: {e}")
finally:
status.close()

def show_warning_and_disable_page(self):
Expand Down
3 changes: 3 additions & 0 deletions src/digest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,9 @@ def save_reports(self):
self, "Select Directory"
)

if not save_directory:
return

# Check if the directory exists and is writable
if not os.path.exists(save_directory) or not os.access(save_directory, os.W_OK):
self.show_warning_dialog(
Expand Down
8 changes: 6 additions & 2 deletions src/digest/model_class/digest_report_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import shutil
from collections import OrderedDict
import csv
import ast
Expand Down Expand Up @@ -141,8 +142,11 @@ def parse_model_nodes(self) -> None:
return

def save_yaml_report(self, filepath: str) -> None:
"""Report models are not intended to be saved"""
return
"""self.filepath if exists is the path to the yaml file that was loaded
Despite saving a single model yaml report is not currently support, we can
offer this feature for multi model analysis."""
if self.filepath and os.path.exists(self.filepath):
shutil.copy(self.filepath, filepath)

def save_text_report(self, filepath: str) -> None:
"""Report models are not intended to be saved"""
Expand Down
46 changes: 34 additions & 12 deletions src/digest/multi_model_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from datetime import datetime
import csv
from typing import List, Dict, Union
from typing import List, Dict, Union, Optional
from collections import Counter, defaultdict, OrderedDict

# pylint: disable=no-name-in-module
Expand All @@ -12,7 +12,7 @@
from digest.dialog import ProgressDialog, StatusDialog
from digest.ui.multimodelanalysis_ui import Ui_multiModelAnalysis
from digest.histogramchartwidget import StackedHistogramWidget
from digest.qt_utils import apply_dark_style_sheet
from digest.qt_utils import apply_dark_style_sheet, find_available_save_path
from digest.model_class.digest_model import (
NodeTypeCounts,
NodeShapeCounts,
Expand Down Expand Up @@ -215,6 +215,8 @@ def __init__(

self.model_list = model_list

self.status_dialog: Optional[StatusDialog] = None

def save_reports(self):
"""This function saves all available reports for the models that are opened
in the multi-model analysis page."""
Expand All @@ -223,6 +225,9 @@ def save_reports(self):
self, "Select Directory"
)

if not base_directory:
return

# Check if the directory exists and is writable
if not os.path.exists(base_directory) or not os.access(base_directory, os.W_OK):
bad_ext_dialog = StatusDialog(
Expand All @@ -234,7 +239,7 @@ def save_reports(self):
# Append a subdirectory to the save_directory so that all reports are co-located
name_id = datetime.now().strftime("%Y%m%d%H%M%S")
sub_directory = f"multi_model_reports_{name_id}"
save_directory = os.path.join(base_directory, sub_directory)
save_directory = os.path.normpath(os.path.join(base_directory, sub_directory))
try:
os.makedirs(save_directory)
except OSError as os_err:
Expand All @@ -253,36 +258,44 @@ def save_reports(self):
for digest_model in self.model_list:
progress.step()

# Save the text report for the model
if progress.wasCanceled():
break

model_save_dir = find_available_save_path(
os.path.join(save_directory, digest_model.model_name)
)
os.makedirs(model_save_dir, exist_ok=True)

# Save the yaml report for the model
summary_filepath = os.path.join(
save_directory, f"{digest_model.model_name}_summary.txt"
model_save_dir, f"{digest_model.model_name}_summary.yaml"
)

digest_model.save_text_report(summary_filepath)
digest_model.save_yaml_report(summary_filepath)

# Save csv of node type counts
node_type_filepath = os.path.join(
save_directory, f"{digest_model.model_name}_node_type_counts.csv"
model_save_dir, f"{digest_model.model_name}_node_type_counts.csv"
)

if digest_model.node_type_counts:
digest_model.save_node_type_counts_csv_report(node_type_filepath)

# Save csv containing node shape counts per op_type
node_shape_filepath = os.path.join(
save_directory, f"{digest_model.model_name}_node_shape_counts.csv"
model_save_dir, f"{digest_model.model_name}_node_shape_counts.csv"
)
digest_model.save_node_shape_counts_csv_report(node_shape_filepath)

# Save csv containing all node-level information
nodes_filepath = os.path.join(
save_directory, f"{digest_model.model_name}_nodes.csv"
model_save_dir, f"{digest_model.model_name}_nodes.csv"
)
digest_model.save_nodes_csv_report(nodes_filepath)

progress.close()
# progress.close()

if save_multi_reports:
if save_multi_reports and not progress.wasCanceled():

# Save all the global model analysis reports
if len(self.model_list) > 1:
Expand Down Expand Up @@ -326,7 +339,16 @@ def save_reports(self):
writer.writerows(rows)

if save_individual_reports or save_multi_reports:
StatusDialog(f"Saved reports to {save_directory}")
if not progress.wasCanceled():
self.status_dialog = StatusDialog(
f"Saved reports to {save_directory}", parent=self
)
else:
self.status_dialog = StatusDialog(
f"User canceled saving reports, but some have been saved to {save_directory}",
parent=self,
)
self.status_dialog.show()

def check_box_changed(self):
if self.ui.individualCheckBox.isChecked() or self.ui.multiCheckBox.isChecked():
Expand Down
23 changes: 11 additions & 12 deletions src/digest/multi_model_selection_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def run(self):

for file, model in self.model_dict.items():
if self.user_canceled:
break
self.close_progress.emit()
self.completed.emit([])
return
self.step_progress.emit()
if model:
continue
Expand All @@ -65,7 +67,7 @@ def run(self):

self.close_progress.emit()

model_list = [model for model in self.model_dict.values()]
model_list = list(self.model_dict.values())

self.completed.emit(model_list)

Expand Down Expand Up @@ -210,13 +212,12 @@ def set_directory(self, directory: str):
else:
return

progress = ProgressDialog("Searching directory for model files", 0, self)

onnx_file_list = list(
glob.glob(os.path.join(directory, "**/*.onnx"), recursive=True)
)
onnx_file_list = [os.path.normpath(model_file) for model_file in onnx_file_list]

# TODO Move to another thread and see if we can capture tqdm
yaml_file_list = list(
glob.glob(os.path.join(directory, "**/*.yaml"), recursive=True)
)
Expand All @@ -233,14 +234,13 @@ def set_directory(self, directory: str):

serialized_models_paths: defaultdict[bytes, List[str]] = defaultdict(list)

progress.close()
progress = ProgressDialog("Loading models", total_num_models, self)

memory_limit_percentage = 90
models_loaded = 0
for filepath in onnx_file_list + report_file_list:
progress.step()
if progress.user_canceled:
if progress.wasCanceled():
break
try:
models_loaded += 1
Expand Down Expand Up @@ -275,16 +275,14 @@ def set_directory(self, directory: str):
except DecodeError as error:
print(f"Error decoding model {filepath}: {error}")

progress.close()

progress = ProgressDialog("Processing Models", total_num_models, self)

num_duplicates = 0
self.item_model.clear()
self.ui.duplicateListWidget.clear()
for paths in serialized_models_paths.values():
progress.step()
if progress.user_canceled:
if progress.wasCanceled():
break
if len(paths) > 1:
num_duplicates += 1
Expand All @@ -301,7 +299,7 @@ def set_directory(self, directory: str):
processed_files = set()
for i in range(len(report_file_list)):
progress.step()
if progress.user_canceled:
if progress.wasCanceled():
break
path1 = report_file_list[i]
if path1 in processed_files:
Expand All @@ -328,8 +326,6 @@ def set_directory(self, directory: str):
item.setCheckState(Qt.CheckState.Checked)
self.item_model.appendRow(item)

progress.close()

if num_duplicates:
label_text = f"Ignoring {num_duplicates} duplicate model(s)."
self.ui.duplicateLabel.setText(label_text)
Expand Down Expand Up @@ -362,6 +358,9 @@ def start_analysis(self):
def open_analysis(
self, model_list: List[Union[DigestOnnxModel, DigestReportModel]]
):
if not model_list:
return

multi_model_analysis = MultiModelAnalysis(model_list)
self.analysis_window.setCentralWidget(multi_model_analysis)
self.analysis_window.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg"))
Expand Down
14 changes: 14 additions & 0 deletions src/digest/qt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,17 @@ def apply_multiple_style_sheets(
style_stream += QTextStream(style_qfile).readAll()

widget.setStyleSheet(style_stream)


def find_available_save_path(save_path: str) -> str:
"""Increments a counter until it finds a suitable save location
For example, if my/dir already exists this function will return the first
available location out of my/dir(1) or my/dir(2) etc..."""
counter = 1
new_path = save_path
while os.path.exists(new_path):
base_dir, base_name = os.path.split(save_path)
name, ext = os.path.splitext(base_name)
new_path = os.path.join(base_dir, f"{name}({counter}){ext}")
counter += 1
return new_path

0 comments on commit f4337d8

Please sign in to comment.