Skip to content

Commit

Permalink
fixed issue with prompts checking if user canceled
Browse files Browse the repository at this point in the history
  • Loading branch information
Philip Colangelo committed Jan 30, 2025
1 parent 821d840 commit 261e4ed
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 24 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools>=64", "wheel"]
build-backend = "setuptools.build_meta"
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="digestai",
version="1.1.0",
version="1.1.1",
description="Model analysis toolkit",
author="Philip Colangelo, Daniel Holanda",
packages=find_packages(where="src"),
Expand Down
5 changes: 0 additions & 5 deletions src/digest/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def __init__(self, label: str, num_steps: int, parent=None):
self.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg"))
self.setValue(1)

self.user_canceled = False
self.canceled.connect(self.cancel)
self.step_size = 1
self.current_step = 0
self.num_steps = num_steps
Expand All @@ -46,9 +44,6 @@ def step(self):
self.current_step = self.num_steps
self.setValue(self.current_step)

def cancel(self):
self.user_canceled = True


class InfoDialog(QDialog):
"""This is a specific dialog class used to display the package information
Expand Down
5 changes: 2 additions & 3 deletions src/digest/freeze_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def apply_static_shapes(self) -> None:
# to arrive in this function with shapes for each dynamic dim.
dims: Dict[str, int] = {}
for i in range(self.ui.formLayout.rowCount()):
if status.user_canceled:
if status.wasCanceled():
break
status.step()
label_item = self.ui.formLayout.itemAt(i, QFormLayout.ItemRole.LabelRole)
Expand All @@ -168,7 +168,7 @@ def apply_static_shapes(self) -> None:
dims[label_text] = line_edit_value

for tensor in self.model_proto.graph.input:
if status.user_canceled:
if status.wasCanceled():
break
status.step()
tensor_shape = []
Expand Down Expand Up @@ -200,7 +200,6 @@ def apply_static_shapes(self) -> None:
except checker.ValidationError as e:
self.show_warning_and_disable_page()
print(f"Model did not pass checker: {e}")
finally:
status.close()

def show_warning_and_disable_page(self):
Expand Down
19 changes: 14 additions & 5 deletions src/digest/multi_model_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ def save_reports(self):
for digest_model in self.model_list:
progress.step()

if progress.wasCanceled():
break

model_save_dir = find_available_save_path(
os.path.join(save_directory, digest_model.model_name)
)
Expand Down Expand Up @@ -290,9 +293,9 @@ def save_reports(self):
)
digest_model.save_nodes_csv_report(nodes_filepath)

progress.close()
# progress.close()

if save_multi_reports:
if save_multi_reports and not progress.wasCanceled():

# Save all the global model analysis reports
if len(self.model_list) > 1:
Expand Down Expand Up @@ -336,9 +339,15 @@ def save_reports(self):
writer.writerows(rows)

if save_individual_reports or save_multi_reports:
self.status_dialog = StatusDialog(
f"Saved reports to {save_directory}", parent=self
)
if not progress.wasCanceled():
self.status_dialog = StatusDialog(
f"Saved reports to {save_directory}", parent=self
)
else:
self.status_dialog = StatusDialog(
f"User canceled saving reports, but some have been saved to {save_directory}",
parent=self,
)
self.status_dialog.show()

def check_box_changed(self):
Expand Down
14 changes: 4 additions & 10 deletions src/digest/multi_model_selection_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,12 @@ def set_directory(self, directory: str):
else:
return

progress = ProgressDialog("Searching directory for model files", 0, self)

onnx_file_list = list(
glob.glob(os.path.join(directory, "**/*.onnx"), recursive=True)
)
onnx_file_list = [os.path.normpath(model_file) for model_file in onnx_file_list]

# TODO Move to another thread and see if we can capture tqdm
yaml_file_list = list(
glob.glob(os.path.join(directory, "**/*.yaml"), recursive=True)
)
Expand All @@ -235,14 +234,13 @@ def set_directory(self, directory: str):

serialized_models_paths: defaultdict[bytes, List[str]] = defaultdict(list)

progress.close()
progress = ProgressDialog("Loading models", total_num_models, self)

memory_limit_percentage = 90
models_loaded = 0
for filepath in onnx_file_list + report_file_list:
progress.step()
if progress.user_canceled:
if progress.wasCanceled():
break
try:
models_loaded += 1
Expand Down Expand Up @@ -277,16 +275,14 @@ def set_directory(self, directory: str):
except DecodeError as error:
print(f"Error decoding model {filepath}: {error}")

progress.close()

progress = ProgressDialog("Processing Models", total_num_models, self)

num_duplicates = 0
self.item_model.clear()
self.ui.duplicateListWidget.clear()
for paths in serialized_models_paths.values():
progress.step()
if progress.user_canceled:
if progress.wasCanceled():
break
if len(paths) > 1:
num_duplicates += 1
Expand All @@ -303,7 +299,7 @@ def set_directory(self, directory: str):
processed_files = set()
for i in range(len(report_file_list)):
progress.step()
if progress.user_canceled:
if progress.wasCanceled():
break
path1 = report_file_list[i]
if path1 in processed_files:
Expand All @@ -330,8 +326,6 @@ def set_directory(self, directory: str):
item.setCheckState(Qt.CheckState.Checked)
self.item_model.appendRow(item)

progress.close()

if num_duplicates:
label_text = f"Ignoring {num_duplicates} duplicate model(s)."
self.ui.duplicateLabel.setText(label_text)
Expand Down

0 comments on commit 261e4ed

Please sign in to comment.