From f4337d8249224f7ad10d75d56ffc764a6aa81c46 Mon Sep 17 00:00:00 2001 From: Philip Date: Thu, 30 Jan 2025 08:12:28 -0800 Subject: [PATCH] Resolves: Multi-Model Analysis Fix and Improvement (#17) * Multi-Model Analysis Fix and Improvements --------- Co-authored-by: Philip Colangelo --- pyproject.toml | 3 ++ setup.py | 2 +- src/digest/dialog.py | 5 -- src/digest/freeze_inputs.py | 5 +- src/digest/main.py | 3 ++ src/digest/model_class/digest_report_model.py | 8 +++- src/digest/multi_model_analysis.py | 46 ++++++++++++++----- src/digest/multi_model_selection_page.py | 23 +++++----- src/digest/qt_utils.py | 14 ++++++ 9 files changed, 74 insertions(+), 35 deletions(-) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..737e278 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=64", "wheel"] +build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/setup.py b/setup.py index b2ad16d..0e48553 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ setup( name="digestai", - version="1.1.0", + version="1.1.1", description="Model analysis toolkit", author="Philip Colangelo, Daniel Holanda", packages=find_packages(where="src"), diff --git a/src/digest/dialog.py b/src/digest/dialog.py index ae9986d..cd848a6 100644 --- a/src/digest/dialog.py +++ b/src/digest/dialog.py @@ -34,8 +34,6 @@ def __init__(self, label: str, num_steps: int, parent=None): self.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg")) self.setValue(1) - self.user_canceled = False - self.canceled.connect(self.cancel) self.step_size = 1 self.current_step = 0 self.num_steps = num_steps @@ -46,9 +44,6 @@ def step(self): self.current_step = self.num_steps self.setValue(self.current_step) - def cancel(self): - self.user_canceled = True - class InfoDialog(QDialog): """This is a specific dialog class used to display the package information diff --git a/src/digest/freeze_inputs.py b/src/digest/freeze_inputs.py index 947cf11..1abcd77 100644 --- a/src/digest/freeze_inputs.py +++ b/src/digest/freeze_inputs.py @@ -154,7 +154,7 @@ def apply_static_shapes(self) -> None: # to arrive in this function with shapes for each dynamic dim. dims: Dict[str, int] = {} for i in range(self.ui.formLayout.rowCount()): - if status.user_canceled: + if status.wasCanceled(): break status.step() label_item = self.ui.formLayout.itemAt(i, QFormLayout.ItemRole.LabelRole) @@ -168,7 +168,7 @@ def apply_static_shapes(self) -> None: dims[label_text] = line_edit_value for tensor in self.model_proto.graph.input: - if status.user_canceled: + if status.wasCanceled(): break status.step() tensor_shape = [] @@ -200,7 +200,6 @@ def apply_static_shapes(self) -> None: except checker.ValidationError as e: self.show_warning_and_disable_page() print(f"Model did not pass checker: {e}") - finally: status.close() def show_warning_and_disable_page(self): diff --git a/src/digest/main.py b/src/digest/main.py index db4c685..79e55c3 100644 --- a/src/digest/main.py +++ b/src/digest/main.py @@ -973,6 +973,9 @@ def save_reports(self): self, "Select Directory" ) + if not save_directory: + return + # Check if the directory exists and is writable if not os.path.exists(save_directory) or not os.access(save_directory, os.W_OK): self.show_warning_dialog( diff --git a/src/digest/model_class/digest_report_model.py b/src/digest/model_class/digest_report_model.py index f2ccd26..c84ef20 100644 --- a/src/digest/model_class/digest_report_model.py +++ b/src/digest/model_class/digest_report_model.py @@ -1,4 +1,5 @@ import os +import shutil from collections import OrderedDict import csv import ast @@ -141,8 +142,11 @@ def parse_model_nodes(self) -> None: return def save_yaml_report(self, filepath: str) -> None: - """Report models are not intended to be saved""" - return + """self.filepath if exists is the path to the yaml file that was loaded + Despite saving a single model yaml report is not currently support, we can + offer this feature for multi model analysis.""" + if self.filepath and os.path.exists(self.filepath): + shutil.copy(self.filepath, filepath) def save_text_report(self, filepath: str) -> None: """Report models are not intended to be saved""" diff --git a/src/digest/multi_model_analysis.py b/src/digest/multi_model_analysis.py index 1c09905..a6f70b7 100644 --- a/src/digest/multi_model_analysis.py +++ b/src/digest/multi_model_analysis.py @@ -3,7 +3,7 @@ import os from datetime import datetime import csv -from typing import List, Dict, Union +from typing import List, Dict, Union, Optional from collections import Counter, defaultdict, OrderedDict # pylint: disable=no-name-in-module @@ -12,7 +12,7 @@ from digest.dialog import ProgressDialog, StatusDialog from digest.ui.multimodelanalysis_ui import Ui_multiModelAnalysis from digest.histogramchartwidget import StackedHistogramWidget -from digest.qt_utils import apply_dark_style_sheet +from digest.qt_utils import apply_dark_style_sheet, find_available_save_path from digest.model_class.digest_model import ( NodeTypeCounts, NodeShapeCounts, @@ -215,6 +215,8 @@ def __init__( self.model_list = model_list + self.status_dialog: Optional[StatusDialog] = None + def save_reports(self): """This function saves all available reports for the models that are opened in the multi-model analysis page.""" @@ -223,6 +225,9 @@ def save_reports(self): self, "Select Directory" ) + if not base_directory: + return + # Check if the directory exists and is writable if not os.path.exists(base_directory) or not os.access(base_directory, os.W_OK): bad_ext_dialog = StatusDialog( @@ -234,7 +239,7 @@ def save_reports(self): # Append a subdirectory to the save_directory so that all reports are co-located name_id = datetime.now().strftime("%Y%m%d%H%M%S") sub_directory = f"multi_model_reports_{name_id}" - save_directory = os.path.join(base_directory, sub_directory) + save_directory = os.path.normpath(os.path.join(base_directory, sub_directory)) try: os.makedirs(save_directory) except OSError as os_err: @@ -253,16 +258,24 @@ def save_reports(self): for digest_model in self.model_list: progress.step() - # Save the text report for the model + if progress.wasCanceled(): + break + + model_save_dir = find_available_save_path( + os.path.join(save_directory, digest_model.model_name) + ) + os.makedirs(model_save_dir, exist_ok=True) + + # Save the yaml report for the model summary_filepath = os.path.join( - save_directory, f"{digest_model.model_name}_summary.txt" + model_save_dir, f"{digest_model.model_name}_summary.yaml" ) - digest_model.save_text_report(summary_filepath) + digest_model.save_yaml_report(summary_filepath) # Save csv of node type counts node_type_filepath = os.path.join( - save_directory, f"{digest_model.model_name}_node_type_counts.csv" + model_save_dir, f"{digest_model.model_name}_node_type_counts.csv" ) if digest_model.node_type_counts: @@ -270,19 +283,19 @@ def save_reports(self): # Save csv containing node shape counts per op_type node_shape_filepath = os.path.join( - save_directory, f"{digest_model.model_name}_node_shape_counts.csv" + model_save_dir, f"{digest_model.model_name}_node_shape_counts.csv" ) digest_model.save_node_shape_counts_csv_report(node_shape_filepath) # Save csv containing all node-level information nodes_filepath = os.path.join( - save_directory, f"{digest_model.model_name}_nodes.csv" + model_save_dir, f"{digest_model.model_name}_nodes.csv" ) digest_model.save_nodes_csv_report(nodes_filepath) - progress.close() + # progress.close() - if save_multi_reports: + if save_multi_reports and not progress.wasCanceled(): # Save all the global model analysis reports if len(self.model_list) > 1: @@ -326,7 +339,16 @@ def save_reports(self): writer.writerows(rows) if save_individual_reports or save_multi_reports: - StatusDialog(f"Saved reports to {save_directory}") + if not progress.wasCanceled(): + self.status_dialog = StatusDialog( + f"Saved reports to {save_directory}", parent=self + ) + else: + self.status_dialog = StatusDialog( + f"User canceled saving reports, but some have been saved to {save_directory}", + parent=self, + ) + self.status_dialog.show() def check_box_changed(self): if self.ui.individualCheckBox.isChecked() or self.ui.multiCheckBox.isChecked(): diff --git a/src/digest/multi_model_selection_page.py b/src/digest/multi_model_selection_page.py index 33fd78c..a637f84 100644 --- a/src/digest/multi_model_selection_page.py +++ b/src/digest/multi_model_selection_page.py @@ -47,7 +47,9 @@ def run(self): for file, model in self.model_dict.items(): if self.user_canceled: - break + self.close_progress.emit() + self.completed.emit([]) + return self.step_progress.emit() if model: continue @@ -65,7 +67,7 @@ def run(self): self.close_progress.emit() - model_list = [model for model in self.model_dict.values()] + model_list = list(self.model_dict.values()) self.completed.emit(model_list) @@ -210,13 +212,12 @@ def set_directory(self, directory: str): else: return - progress = ProgressDialog("Searching directory for model files", 0, self) - onnx_file_list = list( glob.glob(os.path.join(directory, "**/*.onnx"), recursive=True) ) onnx_file_list = [os.path.normpath(model_file) for model_file in onnx_file_list] + # TODO Move to another thread and see if we can capture tqdm yaml_file_list = list( glob.glob(os.path.join(directory, "**/*.yaml"), recursive=True) ) @@ -233,14 +234,13 @@ def set_directory(self, directory: str): serialized_models_paths: defaultdict[bytes, List[str]] = defaultdict(list) - progress.close() progress = ProgressDialog("Loading models", total_num_models, self) memory_limit_percentage = 90 models_loaded = 0 for filepath in onnx_file_list + report_file_list: progress.step() - if progress.user_canceled: + if progress.wasCanceled(): break try: models_loaded += 1 @@ -275,8 +275,6 @@ def set_directory(self, directory: str): except DecodeError as error: print(f"Error decoding model {filepath}: {error}") - progress.close() - progress = ProgressDialog("Processing Models", total_num_models, self) num_duplicates = 0 @@ -284,7 +282,7 @@ def set_directory(self, directory: str): self.ui.duplicateListWidget.clear() for paths in serialized_models_paths.values(): progress.step() - if progress.user_canceled: + if progress.wasCanceled(): break if len(paths) > 1: num_duplicates += 1 @@ -301,7 +299,7 @@ def set_directory(self, directory: str): processed_files = set() for i in range(len(report_file_list)): progress.step() - if progress.user_canceled: + if progress.wasCanceled(): break path1 = report_file_list[i] if path1 in processed_files: @@ -328,8 +326,6 @@ def set_directory(self, directory: str): item.setCheckState(Qt.CheckState.Checked) self.item_model.appendRow(item) - progress.close() - if num_duplicates: label_text = f"Ignoring {num_duplicates} duplicate model(s)." self.ui.duplicateLabel.setText(label_text) @@ -362,6 +358,9 @@ def start_analysis(self): def open_analysis( self, model_list: List[Union[DigestOnnxModel, DigestReportModel]] ): + if not model_list: + return + multi_model_analysis = MultiModelAnalysis(model_list) self.analysis_window.setCentralWidget(multi_model_analysis) self.analysis_window.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg")) diff --git a/src/digest/qt_utils.py b/src/digest/qt_utils.py index 4e1149e..1015844 100644 --- a/src/digest/qt_utils.py +++ b/src/digest/qt_utils.py @@ -65,3 +65,17 @@ def apply_multiple_style_sheets( style_stream += QTextStream(style_qfile).readAll() widget.setStyleSheet(style_stream) + + +def find_available_save_path(save_path: str) -> str: + """Increments a counter until it finds a suitable save location + For example, if my/dir already exists this function will return the first + available location out of my/dir(1) or my/dir(2) etc...""" + counter = 1 + new_path = save_path + while os.path.exists(new_path): + base_dir, base_name = os.path.split(save_path) + name, ext = os.path.splitext(base_name) + new_path = os.path.join(base_dir, f"{name}({counter}){ext}") + counter += 1 + return new_path