From 261e4edc72b2fce019777d96e3f8a011435a5b24 Mon Sep 17 00:00:00 2001 From: Philip Colangelo Date: Wed, 29 Jan 2025 20:55:06 -0500 Subject: [PATCH] fixed issue with prompts checking if user canceled --- pyproject.toml | 3 +++ setup.py | 2 +- src/digest/dialog.py | 5 ----- src/digest/freeze_inputs.py | 5 ++--- src/digest/multi_model_analysis.py | 19 ++++++++++++++----- src/digest/multi_model_selection_page.py | 14 ++++---------- 6 files changed, 24 insertions(+), 24 deletions(-) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..737e278 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=64", "wheel"] +build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/setup.py b/setup.py index b2ad16d..0e48553 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ setup( name="digestai", - version="1.1.0", + version="1.1.1", description="Model analysis toolkit", author="Philip Colangelo, Daniel Holanda", packages=find_packages(where="src"), diff --git a/src/digest/dialog.py b/src/digest/dialog.py index ae9986d..cd848a6 100644 --- a/src/digest/dialog.py +++ b/src/digest/dialog.py @@ -34,8 +34,6 @@ def __init__(self, label: str, num_steps: int, parent=None): self.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg")) self.setValue(1) - self.user_canceled = False - self.canceled.connect(self.cancel) self.step_size = 1 self.current_step = 0 self.num_steps = num_steps @@ -46,9 +44,6 @@ def step(self): self.current_step = self.num_steps self.setValue(self.current_step) - def cancel(self): - self.user_canceled = True - class InfoDialog(QDialog): """This is a specific dialog class used to display the package information diff --git a/src/digest/freeze_inputs.py b/src/digest/freeze_inputs.py index 947cf11..1abcd77 100644 --- a/src/digest/freeze_inputs.py +++ b/src/digest/freeze_inputs.py @@ -154,7 +154,7 @@ def apply_static_shapes(self) -> None: # to arrive in this function with shapes for each dynamic dim. dims: Dict[str, int] = {} for i in range(self.ui.formLayout.rowCount()): - if status.user_canceled: + if status.wasCanceled(): break status.step() label_item = self.ui.formLayout.itemAt(i, QFormLayout.ItemRole.LabelRole) @@ -168,7 +168,7 @@ def apply_static_shapes(self) -> None: dims[label_text] = line_edit_value for tensor in self.model_proto.graph.input: - if status.user_canceled: + if status.wasCanceled(): break status.step() tensor_shape = [] @@ -200,7 +200,6 @@ def apply_static_shapes(self) -> None: except checker.ValidationError as e: self.show_warning_and_disable_page() print(f"Model did not pass checker: {e}") - finally: status.close() def show_warning_and_disable_page(self): diff --git a/src/digest/multi_model_analysis.py b/src/digest/multi_model_analysis.py index 607ff5e..a6f70b7 100644 --- a/src/digest/multi_model_analysis.py +++ b/src/digest/multi_model_analysis.py @@ -258,6 +258,9 @@ def save_reports(self): for digest_model in self.model_list: progress.step() + if progress.wasCanceled(): + break + model_save_dir = find_available_save_path( os.path.join(save_directory, digest_model.model_name) ) @@ -290,9 +293,9 @@ def save_reports(self): ) digest_model.save_nodes_csv_report(nodes_filepath) - progress.close() + # progress.close() - if save_multi_reports: + if save_multi_reports and not progress.wasCanceled(): # Save all the global model analysis reports if len(self.model_list) > 1: @@ -336,9 +339,15 @@ def save_reports(self): writer.writerows(rows) if save_individual_reports or save_multi_reports: - self.status_dialog = StatusDialog( - f"Saved reports to {save_directory}", parent=self - ) + if not progress.wasCanceled(): + self.status_dialog = StatusDialog( + f"Saved reports to {save_directory}", parent=self + ) + else: + self.status_dialog = StatusDialog( + f"User canceled saving reports, but some have been saved to {save_directory}", + parent=self, + ) self.status_dialog.show() def check_box_changed(self): diff --git a/src/digest/multi_model_selection_page.py b/src/digest/multi_model_selection_page.py index 5f87edb..a637f84 100644 --- a/src/digest/multi_model_selection_page.py +++ b/src/digest/multi_model_selection_page.py @@ -212,13 +212,12 @@ def set_directory(self, directory: str): else: return - progress = ProgressDialog("Searching directory for model files", 0, self) - onnx_file_list = list( glob.glob(os.path.join(directory, "**/*.onnx"), recursive=True) ) onnx_file_list = [os.path.normpath(model_file) for model_file in onnx_file_list] + # TODO Move to another thread and see if we can capture tqdm yaml_file_list = list( glob.glob(os.path.join(directory, "**/*.yaml"), recursive=True) ) @@ -235,14 +234,13 @@ def set_directory(self, directory: str): serialized_models_paths: defaultdict[bytes, List[str]] = defaultdict(list) - progress.close() progress = ProgressDialog("Loading models", total_num_models, self) memory_limit_percentage = 90 models_loaded = 0 for filepath in onnx_file_list + report_file_list: progress.step() - if progress.user_canceled: + if progress.wasCanceled(): break try: models_loaded += 1 @@ -277,8 +275,6 @@ def set_directory(self, directory: str): except DecodeError as error: print(f"Error decoding model {filepath}: {error}") - progress.close() - progress = ProgressDialog("Processing Models", total_num_models, self) num_duplicates = 0 @@ -286,7 +282,7 @@ def set_directory(self, directory: str): self.ui.duplicateListWidget.clear() for paths in serialized_models_paths.values(): progress.step() - if progress.user_canceled: + if progress.wasCanceled(): break if len(paths) > 1: num_duplicates += 1 @@ -303,7 +299,7 @@ def set_directory(self, directory: str): processed_files = set() for i in range(len(report_file_list)): progress.step() - if progress.user_canceled: + if progress.wasCanceled(): break path1 = report_file_list[i] if path1 in processed_files: @@ -330,8 +326,6 @@ def set_directory(self, directory: str): item.setCheckState(Qt.CheckState.Checked) self.item_model.appendRow(item) - progress.close() - if num_duplicates: label_text = f"Ignoring {num_duplicates} duplicate model(s)." self.ui.duplicateLabel.setText(label_text)