diff --git a/setup.py b/setup.py
index b2ad16d..d6a16f7 100644
--- a/setup.py
+++ b/setup.py
@@ -4,7 +4,7 @@
setup(
name="digestai",
- version="1.1.0",
+ version="1.2.0",
description="Model analysis toolkit",
author="Philip Colangelo, Daniel Holanda",
packages=find_packages(where="src"),
@@ -25,6 +25,8 @@
"platformdirs>=4.2.2",
"pyyaml>=6.0.1",
"psutil>=6.0.0",
+ "torch",
+ "transformers",
],
classifiers=[],
entry_points={"console_scripts": ["digest = digest.main:main"]},
diff --git a/src/digest/main.py b/src/digest/main.py
index 5a894b5..956e07b 100644
--- a/src/digest/main.py
+++ b/src/digest/main.py
@@ -39,8 +39,9 @@
from digest.dialog import StatusDialog, InfoDialog, WarnDialog, ProgressDialog
from digest.thread import StatsThread, SimilarityThread, post_process
-from digest.popup_window import PopupWindow
+from digest.popup_window import PopupWindow, PopupDialog
from digest.huggingface_page import HuggingfacePage
+from digest.pytorch_ingest import PyTorchIngest
from digest.multi_model_selection_page import MultiModelSelectionPage
from digest.ui.mainwindow_ui import Ui_MainWindow
from digest.modelsummary import modelSummary
@@ -49,6 +50,7 @@
from digest.model_class.digest_model import DigestModel
from digest.model_class.digest_onnx_model import DigestOnnxModel
from digest.model_class.digest_report_model import DigestReportModel
+from digest.model_class.digest_pytorch_model import DigestPyTorchModel
from utils import onnx_utils
GUI_CONFIG = os.path.join(os.path.dirname(__file__), "gui_config.yaml")
@@ -166,7 +168,11 @@ def __init__(self, model_file: Optional[str] = None):
self.status_dialog = None
self.err_open_dialog = None
self.temp_dir = tempfile.TemporaryDirectory()
- self.digest_models: Dict[str, Union[DigestOnnxModel, DigestReportModel]] = {}
+ self.digest_models: Dict[
+ str, Union[DigestOnnxModel, DigestReportModel, DigestPyTorchModel]
+ ] = {}
+
+ self.pytorch_ingest_window: Optional[PopupDialog] = None
# QThread containers
self.model_nodes_stats_thread: Dict[str, StatsThread] = {}
@@ -225,6 +231,9 @@ def __init__(self, model_file: Optional[str] = None):
)
self.multimodelselection_page.model_signal.connect(self.load_model)
+ # Set up the pyptorch ingest page
+ self.pytorch_ingest: Optional[PyTorchIngest] = None
+
# Load model file if given as input to the executable
if model_file:
exists = os.path.exists(model_file)
@@ -287,7 +296,10 @@ def closeTab(self, index):
def openFile(self):
file_name, _ = QFileDialog.getOpenFileName(
- self, "Open File", "", "ONNX and Report Files (*.onnx *.yaml)"
+ self,
+ "Open File",
+ "",
+ "ONNX, PyTorch, and Report Files (*.onnx *.pt *.yaml)",
)
if not file_name:
@@ -364,7 +376,7 @@ def update_similarity_widget(
completed_successfully: bool,
model_id: str,
most_similar: str,
- png_filepath: Optional[str] = None,
+ png_file_path: Optional[str] = None,
df_sorted: Optional[pd.DataFrame] = None,
):
widget = None
@@ -388,12 +400,12 @@ def update_similarity_widget(
completed_successfully
and isinstance(widget, modelSummary)
and digest_model
- and png_filepath
+ and png_file_path
):
if df_sorted is not None:
post_process(
- digest_model.model_name, most_similar_list, df_sorted, png_filepath
+ digest_model.model_name, most_similar_list, df_sorted, png_file_path
)
widget.load_gif.stop()
@@ -401,7 +413,7 @@ def update_similarity_widget(
# We give the image a 10% haircut to fit it more aesthetically
widget_width = widget.ui.similarityImg.width()
- pixmap = QPixmap(png_filepath)
+ pixmap = QPixmap(png_file_path)
aspect_ratio = pixmap.width() / pixmap.height()
target_height = int(widget_width / aspect_ratio)
pixmap_scaled = pixmap.scaled(
@@ -436,12 +448,12 @@ def update_similarity_widget(
# Create option to click to enlarge image
widget.ui.similarityImg.mousePressEvent = (
lambda event: self.open_similarity_report(
- model_id, png_filepath, most_similar_list
+ model_id, png_file_path, most_similar_list
)
)
# Create option to click to enlarge image
self.model_similarity_report[model_id] = SimilarityAnalysisReport(
- png_filepath, most_similar_list
+ png_file_path, most_similar_list
)
widget.ui.similarityCorrelation.setText(text)
@@ -463,12 +475,12 @@ def update_similarity_widget(
):
self.ui.saveBtn.setEnabled(True)
- def load_onnx(self, filepath: str):
+ def load_onnx(self, file_path: str):
- # Ensure the filepath follows a standard formatting:
- filepath = os.path.normpath(filepath)
+ # Ensure the file_path follows a standard formatting:
+ file_path = os.path.normpath(file_path)
- if not os.path.exists(filepath):
+ if not os.path.exists(file_path):
return
# Every time an onnx is loaded we should emulate a model summary button click
@@ -477,7 +489,7 @@ def load_onnx(self, filepath: str):
# Before opening the file, check to see if it is already opened.
for index in range(self.ui.tabWidget.count()):
widget = self.ui.tabWidget.widget(index)
- if isinstance(widget, modelSummary) and filepath == widget.file:
+ if isinstance(widget, modelSummary) and file_path == widget.file:
self.ui.tabWidget.setCurrentIndex(index)
return
@@ -486,11 +498,11 @@ def load_onnx(self, filepath: str):
progress = ProgressDialog("Loading & Optimizing ONNX Model...", 8, self)
QApplication.processEvents() # Process pending events
- model = onnx_utils.load_onnx(filepath, load_external_data=False)
+ model = onnx_utils.load_onnx(file_path, load_external_data=False)
opt_model, opt_passed = onnx_utils.optimize_onnx_model(model)
progress.step()
- basename = os.path.splitext(os.path.basename(filepath))
+ basename = os.path.splitext(os.path.basename(file_path))
model_name = basename[0]
# Save the model proto so we can use the Freeze Inputs feature
@@ -534,14 +546,14 @@ def load_onnx(self, filepath: str):
model_summary.ui.similarityCorrelation.hide()
model_summary.ui.similarityCorrelationStatic.hide()
- model_summary.file = filepath
+ model_summary.file = file_path
model_summary.setObjectName(model_name)
model_summary.ui.modelName.setText(model_name)
- model_summary.ui.modelFilename.setText(filepath)
+ model_summary.ui.modelFilename.setText(file_path)
model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y"))
digest_model.model_name = model_name
- digest_model.filepath = filepath
+ digest_model.file_path = file_path
digest_model.model_inputs = onnx_utils.get_model_input_shapes_types(
opt_model
)
@@ -694,8 +706,8 @@ def load_onnx(self, filepath: str):
self.model_similarity_thread[model_id].completed_successfully.connect(
self.update_similarity_widget
)
- self.model_similarity_thread[model_id].model_filepath = filepath
- self.model_similarity_thread[model_id].png_filepath = os.path.join(
+ self.model_similarity_thread[model_id].model_file_path = file_path
+ self.model_similarity_thread[model_id].png_file_path = os.path.join(
png_tmp_path, f"heatmap_{model_name}.png"
)
self.model_similarity_thread[model_id].model_id = model_id
@@ -706,12 +718,12 @@ def load_onnx(self, filepath: str):
except FileNotFoundError as e:
print(f"File not found: {e.filename}")
- def load_report(self, filepath: str):
+ def load_report(self, file_path: str):
- # Ensure the filepath follows a standard formatting:
- filepath = os.path.normpath(filepath)
+ # Ensure the file_path follows a standard formatting:
+ file_path = os.path.normpath(file_path)
- if not os.path.exists(filepath):
+ if not os.path.exists(file_path):
return
# Every time a report is loaded we should emulate a model summary button click
@@ -720,7 +732,7 @@ def load_report(self, filepath: str):
# Before opening the file, check to see if it is already opened.
for index in range(self.ui.tabWidget.count()):
widget = self.ui.tabWidget.widget(index)
- if isinstance(widget, modelSummary) and filepath == widget.file:
+ if isinstance(widget, modelSummary) and file_path == widget.file:
self.ui.tabWidget.setCurrentIndex(index)
return
@@ -729,13 +741,13 @@ def load_report(self, filepath: str):
progress = ProgressDialog("Loading Digest Report File...", 2, self)
QApplication.processEvents() # Process pending events
- digest_model = DigestReportModel(filepath)
+ digest_model = DigestReportModel(file_path)
if not digest_model.is_valid:
progress.close()
invalid_yaml_dialog = StatusDialog(
title="Warning",
- status_message=f"YAML file {filepath} is not a valid digest report",
+ status_message=f"YAML file {file_path} is not a valid digest report",
)
invalid_yaml_dialog.show()
@@ -758,10 +770,10 @@ def load_report(self, filepath: str):
model_summary.ui.similarityCorrelation.hide()
model_summary.ui.similarityCorrelationStatic.hide()
- model_summary.file = filepath
+ model_summary.file = file_path
model_summary.setObjectName(digest_model.model_name)
model_summary.ui.modelName.setText(digest_model.model_name)
- model_summary.ui.modelFilename.setText(filepath)
+ model_summary.ui.modelFilename.setText(file_path)
model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y"))
model_summary.ui.parameters.setText(format(digest_model.parameters, ","))
@@ -888,7 +900,7 @@ def load_report(self, filepath: str):
completed_successfully=bool(digest_model.similarity_heatmap_path),
model_id=digest_model.unique_id,
most_similar="",
- png_filepath=digest_model.similarity_heatmap_path,
+ png_file_path=digest_model.similarity_heatmap_path,
)
progress.close()
@@ -896,9 +908,30 @@ def load_report(self, filepath: str):
except FileNotFoundError as e:
print(f"File not found: {e.filename}")
+ def load_pytorch(self, file_path: str):
+ # Ensure the file_path follows a standard formatting:
+ file_path = os.path.normpath(file_path)
+
+ if not os.path.exists(file_path):
+ return
+
+ basename = os.path.splitext(os.path.basename(file_path))
+ model_name = basename[0]
+
+ self.pytorch_ingest = PyTorchIngest(file_path, model_name)
+ self.pytorch_ingest_window = PopupDialog(
+ self.pytorch_ingest, "PyTorch Ingest", self
+ )
+ self.pytorch_ingest_window.open()
+
+ # The above code will block until the user has completed the pytorch ingest form
+ # The form will exit upon a successful export at which point the path will be set
+ if self.pytorch_ingest.digest_pytorch_model.onnx_file_path:
+ self.load_onnx(self.pytorch_ingest.digest_pytorch_model.onnx_file_path)
+
def load_model(self, file_path: str):
- # Ensure the filepath follows a standard formatting:
+ # Ensure the file_path follows a standard formatting:
file_path = os.path.normpath(file_path)
if not os.path.exists(file_path):
@@ -910,6 +943,8 @@ def load_model(self, file_path: str):
self.load_onnx(file_path)
elif file_ext == ".yaml":
self.load_report(file_path)
+ elif file_ext == ".pt" or file_ext == ".pth":
+ self.load_pytorch(file_path)
else:
bad_ext_dialog = StatusDialog(
f"Digest does not support files with the extension {file_ext}",
@@ -992,30 +1027,32 @@ def save_reports(self):
)
# Save csv of node type counts
- node_type_filepath = os.path.join(
+ node_type_file_path = os.path.join(
save_directory, f"{model_name}_node_type_counts.csv"
)
- digest_model.save_node_type_counts_csv_report(node_type_filepath)
+ digest_model.save_node_type_counts_csv_report(node_type_file_path)
# Save (copy) the similarity image
png_file_path = self.model_similarity_thread[
digest_model.unique_id
- ].png_filepath
+ ].png_file_path
png_save_path = os.path.join(save_directory, f"{model_name}_heatmap.png")
if png_file_path and os.path.exists(png_file_path):
shutil.copy(png_file_path, png_save_path)
# Save the text report
- txt_report_filepath = os.path.join(save_directory, f"{model_name}_report.txt")
- digest_model.save_text_report(txt_report_filepath)
+ txt_report_file_path = os.path.join(save_directory, f"{model_name}_report.txt")
+ digest_model.save_text_report(txt_report_file_path)
# Save the yaml report
- yaml_report_filepath = os.path.join(save_directory, f"{model_name}_report.yaml")
- digest_model.save_yaml_report(yaml_report_filepath)
+ yaml_report_file_path = os.path.join(
+ save_directory, f"{model_name}_report.yaml"
+ )
+ digest_model.save_yaml_report(yaml_report_file_path)
# Save the node list
- nodes_report_filepath = os.path.join(save_directory, f"{model_name}_nodes.csv")
- self.save_nodes_csv(nodes_report_filepath, False)
+ nodes_report_file_path = os.path.join(save_directory, f"{model_name}_nodes.csv")
+ self.save_nodes_csv(nodes_report_file_path, False)
self.status_dialog = StatusDialog(
f"Saved reports to: \n{os.path.abspath(save_directory)}",
@@ -1051,20 +1088,20 @@ def save_file_dialog(
)
return path, filter_type
- def save_parameters_csv(self, filepath: str, open_dialog: bool = True):
- self.save_nodes_csv(filepath, open_dialog)
+ def save_parameters_csv(self, file_path: str, open_dialog: bool = True):
+ self.save_nodes_csv(file_path, open_dialog)
- def save_flops_csv(self, filepath: str, open_dialog: bool = True):
- self.save_nodes_csv(filepath, open_dialog)
+ def save_flops_csv(self, file_path: str, open_dialog: bool = True):
+ self.save_nodes_csv(file_path, open_dialog)
- def save_nodes_csv(self, csv_filepath: Optional[str], open_dialog: bool = True):
+ def save_nodes_csv(self, csv_file_path: Optional[str], open_dialog: bool = True):
if open_dialog:
- csv_filepath, _ = self.save_file_dialog()
- if not csv_filepath:
- raise ValueError("A filepath must be given.")
+ csv_file_path, _ = self.save_file_dialog()
+ if not csv_file_path:
+ raise ValueError("A file_path must be given.")
current_tab = self.ui.tabWidget.currentWidget()
if isinstance(current_tab, modelSummary):
- current_tab.digest_model.save_nodes_csv_report(csv_filepath)
+ current_tab.digest_model.save_nodes_csv_report(csv_file_path)
def save_chart(self, chart_view):
path, _ = self.save_file_dialog("Save PNG", "PNG(*.png)")
diff --git a/src/digest/model_class/digest_model.py b/src/digest/model_class/digest_model.py
index 9064184..bd38d3a 100644
--- a/src/digest/model_class/digest_model.py
+++ b/src/digest/model_class/digest_model.py
@@ -13,6 +13,7 @@
class SupportedModelTypes(Enum):
ONNX = "onnx"
REPORT = "report"
+ PYTORCH = "pytorch"
class NodeParsingException(Exception):
@@ -94,10 +95,12 @@ def __init__(self, *args, **kwargs):
class DigestModel(ABC):
- def __init__(self, filepath: str, model_name: str, model_type: SupportedModelTypes):
+ def __init__(
+ self, file_path: str, model_name: str, model_type: SupportedModelTypes
+ ):
# Public members exposed to the API
self.unique_id: str = str(uuid4())
- self.filepath: Optional[str] = filepath
+ self.file_path: Optional[str] = os.path.abspath(file_path)
self.model_name: str = model_name
self.model_type: SupportedModelTypes = model_type
self.node_type_counts: NodeTypeCounts = NodeTypeCounts()
@@ -122,27 +125,27 @@ def parse_model_nodes(self, *args, **kwargs) -> None:
pass
@abstractmethod
- def save_yaml_report(self, filepath: str) -> None:
+ def save_yaml_report(self, file_path: str) -> None:
pass
@abstractmethod
- def save_text_report(self, filepath: str) -> None:
+ def save_text_report(self, file_path: str) -> None:
pass
- def save_nodes_csv_report(self, filepath: str) -> None:
- save_nodes_csv_report(self.node_data, filepath)
+ def save_nodes_csv_report(self, file_path: str) -> None:
+ save_nodes_csv_report(self.node_data, file_path)
- def save_node_type_counts_csv_report(self, filepath: str) -> None:
+ def save_node_type_counts_csv_report(self, file_path: str) -> None:
if self.node_type_counts:
- save_node_type_counts_csv_report(self.node_type_counts, filepath)
+ save_node_type_counts_csv_report(self.node_type_counts, file_path)
- def save_node_shape_counts_csv_report(self, filepath: str) -> None:
- save_node_shape_counts_csv_report(self.get_node_shape_counts(), filepath)
+ def save_node_shape_counts_csv_report(self, file_path: str) -> None:
+ save_node_shape_counts_csv_report(self.get_node_shape_counts(), file_path)
-def save_nodes_csv_report(node_data: NodeData, filepath: str) -> None:
+def save_nodes_csv_report(node_data: NodeData, file_path: str) -> None:
- parent_dir = os.path.dirname(os.path.abspath(filepath))
+ parent_dir = os.path.dirname(os.path.abspath(file_path))
if not os.path.exists(parent_dir):
raise FileNotFoundError(f"Directory {parent_dir} does not exist.")
@@ -185,23 +188,23 @@ def save_nodes_csv_report(node_data: NodeData, filepath: str) -> None:
flattened_data.append(row)
fieldnames = fieldnames + input_fieldnames + output_fieldnames
- with open(filepath, "w", encoding="utf-8", newline="") as csvfile:
+ with open(file_path, "w", encoding="utf-8", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator="\n")
writer.writeheader()
writer.writerows(flattened_data)
def save_node_type_counts_csv_report(
- node_type_counts: NodeTypeCounts, filepath: str
+ node_type_counts: NodeTypeCounts, file_path: str
) -> None:
- parent_dir = os.path.dirname(os.path.abspath(filepath))
+ parent_dir = os.path.dirname(os.path.abspath(file_path))
if not os.path.exists(parent_dir):
raise FileNotFoundError(f"Directory {parent_dir} does not exist.")
header = ["Node Type", "Count"]
- with open(filepath, "w", encoding="utf-8", newline="") as csvfile:
+ with open(file_path, "w", encoding="utf-8", newline="") as csvfile:
writer = csv.writer(csvfile, lineterminator="\n")
writer.writerow(header)
for node_type, node_count in node_type_counts.items():
@@ -209,16 +212,16 @@ def save_node_type_counts_csv_report(
def save_node_shape_counts_csv_report(
- node_shape_counts: NodeShapeCounts, filepath: str
+ node_shape_counts: NodeShapeCounts, file_path: str
) -> None:
- parent_dir = os.path.dirname(os.path.abspath(filepath))
+ parent_dir = os.path.dirname(os.path.abspath(file_path))
if not os.path.exists(parent_dir):
raise FileNotFoundError(f"Directory {parent_dir} does not exist.")
header = ["Node Type", "Input Tensors Shapes", "Count"]
- with open(filepath, "w", encoding="utf-8", newline="") as csvfile:
+ with open(file_path, "w", encoding="utf-8", newline="") as csvfile:
writer = csv.writer(csvfile, dialect="excel", lineterminator="\n")
writer.writerow(header)
for node_type, node_info in node_shape_counts.items():
diff --git a/src/digest/model_class/digest_onnx_model.py b/src/digest/model_class/digest_onnx_model.py
index 35aad1d..5f55636 100644
--- a/src/digest/model_class/digest_onnx_model.py
+++ b/src/digest/model_class/digest_onnx_model.py
@@ -22,13 +22,11 @@ class DigestOnnxModel(DigestModel):
def __init__(
self,
onnx_model: onnx.ModelProto,
- onnx_filepath: str = "",
+ onnx_file_path: str = "",
model_name: str = "",
save_proto: bool = True,
) -> None:
- super().__init__(onnx_filepath, model_name, SupportedModelTypes.ONNX)
-
- self.model_type = SupportedModelTypes.ONNX
+ super().__init__(onnx_file_path, model_name, SupportedModelTypes.ONNX)
# Public members exposed to the API
self.model_proto: Optional[onnx.ModelProto] = onnx_model if save_proto else None
@@ -509,9 +507,9 @@ def parse_model_nodes(self, onnx_model: onnx.ModelProto) -> None:
self.node_type_flops.get(node.op_type, 0) + node_info.flops
)
- def save_yaml_report(self, filepath: str) -> None:
+ def save_yaml_report(self, file_path: str) -> None:
- parent_dir = os.path.dirname(os.path.abspath(filepath))
+ parent_dir = os.path.dirname(os.path.abspath(file_path))
if not os.path.exists(parent_dir):
raise FileNotFoundError(f"Directory {parent_dir} does not exist.")
@@ -523,7 +521,7 @@ def save_yaml_report(self, filepath: str) -> None:
yaml_data = {
"report_date": report_date,
"model_type": self.model_type.value,
- "model_file": self.filepath,
+ "model_file": self.file_path,
"model_name": self.model_name,
"model_version": self.model_version,
"graph_name": self.graph_name,
@@ -542,22 +540,22 @@ def save_yaml_report(self, filepath: str) -> None:
"output_tensors": output_tensors,
}
- with open(filepath, "w", encoding="utf-8") as f_p:
+ with open(file_path, "w", encoding="utf-8") as f_p:
yaml.dump(yaml_data, f_p, sort_keys=False)
- def save_text_report(self, filepath: str) -> None:
+ def save_text_report(self, file_path: str) -> None:
- parent_dir = os.path.dirname(os.path.abspath(filepath))
+ parent_dir = os.path.dirname(os.path.abspath(file_path))
if not os.path.exists(parent_dir):
raise FileNotFoundError(f"Directory {parent_dir} does not exist.")
report_date = datetime.now().strftime("%B %d, %Y")
- with open(filepath, "w", encoding="utf-8") as f_p:
+ with open(file_path, "w", encoding="utf-8") as f_p:
f_p.write(f"Report created on {report_date}\n")
f_p.write(f"Model type: {self.model_type.name}\n")
- if self.filepath:
- f_p.write(f"ONNX file: {self.filepath}\n")
+ if self.file_path:
+ f_p.write(f"ONNX file: {self.file_path}\n")
f_p.write(f"Name of the model: {self.model_name}\n")
f_p.write(f"Model version: {self.model_version}\n")
f_p.write(f"Name of the graph: {self.graph_name}\n")
diff --git a/src/digest/model_class/digest_pytorch_model.py b/src/digest/model_class/digest_pytorch_model.py
new file mode 100644
index 0000000..68b1a76
--- /dev/null
+++ b/src/digest/model_class/digest_pytorch_model.py
@@ -0,0 +1,102 @@
+# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.
+
+import os
+from collections import OrderedDict
+from typing import List, Tuple, Optional, Any, Union
+import inspect
+import onnx
+import torch
+from digest.model_class.digest_onnx_model import DigestOnnxModel
+from digest.model_class.digest_model import (
+ DigestModel,
+ SupportedModelTypes,
+)
+
+
+class DigestPyTorchModel(DigestModel):
+ """The idea of this class is to first support PyTorch models by converting them to ONNX
+ Eventually, we will want to support a PyTorch specific interface that has a custom GUI.
+ To facilitate this process, it makes the most sense to use this class as helper class
+ to convert the PyTorch model to ONNX and store the ONNX info in a member DigestOnnxModel
+ object. We can also store various PyTorch specific details in this class as well.
+ """
+
+ def __init__(
+ self,
+ pytorch_file_path: str = "",
+ model_name: str = "",
+ ) -> None:
+ super().__init__(pytorch_file_path, model_name, SupportedModelTypes.PYTORCH)
+
+ assert os.path.exists(
+ pytorch_file_path
+ ), f"PyTorch file {pytorch_file_path} does not exist."
+
+ # Default opset value
+ self.opset = 17
+
+ # Input dictionary to contain the names and shapes
+ # required for exporting the ONNX model
+ self.input_tensor_info: OrderedDict[str, List[Any]] = OrderedDict()
+
+ self.pytorch_model = torch.load(pytorch_file_path)
+
+ # Data needed for exporting to ONNX
+ self.do_constant_folding = True
+ self.export_params = True
+
+ self.onnx_file_path: Optional[str] = None
+
+ self.digest_onnx_model: Optional[DigestOnnxModel] = None
+
+ def parse_model_nodes(self) -> None:
+ """This will be done in the DigestOnnxModel"""
+
+ def save_yaml_report(self, file_path: str) -> None:
+ """This will be done in the DigestOnnxModel"""
+
+ def save_text_report(self, file_path: str) -> None:
+ """This will be done in the DigestOnnxModel"""
+
+ def generate_random_tensor(self, shape: List[Union[str, int]]):
+ static_shape = [dim if isinstance(dim, int) else 1 for dim in shape]
+ return torch.rand(static_shape)
+
+ def export_to_onnx(self, output_onnx_path: str) -> Union[onnx.ModelProto, None]:
+
+ dummy_input_names: List[str] = list(self.input_tensor_info.keys())
+ dummy_inputs: List[torch.Tensor] = []
+
+ for shape in self.input_tensor_info.values():
+ dummy_inputs.append(self.generate_random_tensor(shape))
+
+ dynamic_axes = {
+ name: {i: dim for i, dim in enumerate(shape) if isinstance(dim, str)}
+ for name, shape in self.input_tensor_info.items()
+ }
+
+ try:
+ torch.onnx.export(
+ self.pytorch_model,
+ tuple(dummy_inputs),
+ output_onnx_path,
+ input_names=dummy_input_names,
+ do_constant_folding=self.do_constant_folding,
+ export_params=self.export_params,
+ opset_version=self.opset,
+ dynamic_axes=dynamic_axes,
+ verbose=False,
+ )
+
+ self.onnx_file_path = output_onnx_path
+
+ return onnx.load(output_onnx_path)
+
+ except (TypeError, RuntimeError) as err:
+ print(f"Failed to export ONNX: {err}")
+ raise
+
+
+def get_model_fwd_parameters(torch_file_path):
+ torch_model = torch.load(torch_file_path)
+ return inspect.signature(torch_model.forward).parameters
diff --git a/src/digest/model_class/digest_report_model.py b/src/digest/model_class/digest_report_model.py
index 4478285..d99a1ed 100644
--- a/src/digest/model_class/digest_report_model.py
+++ b/src/digest/model_class/digest_report_model.py
@@ -44,30 +44,30 @@ def parse_tensor_info(csv_tensor_cell_value) -> Tuple[str, list, str, float]:
class DigestReportModel(DigestModel):
def __init__(
self,
- report_filepath: str,
+ report_file_path: str,
) -> None:
self.model_type = SupportedModelTypes.REPORT
- self.is_valid = self.validate_yaml(report_filepath)
+ self.is_valid = self.validate_yaml(report_file_path)
if not self.is_valid:
- print(f"The yaml file {report_filepath} is not a valid digest report.")
+ print(f"The yaml file {report_file_path} is not a valid digest report.")
return
self.model_data = OrderedDict()
- with open(report_filepath, "r", encoding="utf-8") as yaml_f:
+ with open(report_file_path, "r", encoding="utf-8") as yaml_f:
self.model_data = yaml.safe_load(yaml_f)
model_name = self.model_data["model_name"]
- super().__init__(report_filepath, model_name, SupportedModelTypes.REPORT)
+ super().__init__(report_file_path, model_name, SupportedModelTypes.REPORT)
self.similarity_heatmap_path: Optional[str] = None
self.node_data = NodeData()
# Given the path to the digest report, let's check if its a complete cache
# and we can grab the nodes csv data and the similarity heatmap
- cache_dir = os.path.dirname(os.path.abspath(report_filepath))
+ cache_dir = os.path.dirname(os.path.abspath(report_file_path))
expected_heatmap_file = os.path.join(cache_dir, f"{model_name}_heatmap.png")
if os.path.exists(expected_heatmap_file):
self.similarity_heatmap_path = expected_heatmap_file
@@ -169,8 +169,8 @@ def validate_yaml(self, report_file_path: str) -> bool:
def parse_model_nodes(self) -> None:
"""There are no model nodes to parse"""
- def save_yaml_report(self, filepath: str) -> None:
+ def save_yaml_report(self, file_path: str) -> None:
"""Report models are not intended to be saved"""
- def save_text_report(self, filepath: str) -> None:
+ def save_text_report(self, file_path: str) -> None:
"""Report models are not intended to be saved"""
diff --git a/src/digest/popup_window.py b/src/digest/popup_window.py
index 09d1971..6e4e5ea 100644
--- a/src/digest/popup_window.py
+++ b/src/digest/popup_window.py
@@ -1,11 +1,14 @@
# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.
# pylint: disable=no-name-in-module
-from PySide6.QtWidgets import QApplication, QMainWindow, QWidget
+from PySide6.QtCore import Qt
+from PySide6.QtWidgets import QApplication, QMainWindow, QWidget, QDialog, QVBoxLayout
from PySide6.QtGui import QIcon
class PopupWindow(QWidget):
+ """Opens new window that runs separate from the main digest window"""
+
def __init__(self, widget: QWidget, window_title: str = "", parent=None):
super().__init__(parent)
@@ -24,3 +27,36 @@ def open(self):
def close(self):
self.main_window.close()
+
+
+class PopupDialog(QDialog):
+ """Opens a new window that takes focus and must be closed before returning
+ to the main digest window"""
+
+ def __init__(self, widget: QWidget, window_title: str = "", parent=None):
+ super().__init__(parent)
+
+ if hasattr(widget, "close_signal"):
+ widget.close_signal.connect(self.on_widget_closed) # type: ignore
+
+ self.setWindowModality(Qt.WindowModality.WindowModal)
+ self.setWindowFlags(Qt.WindowType.Window)
+
+ layout = QVBoxLayout()
+ layout.addWidget(widget)
+ self.setLayout(layout)
+
+ self.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg"))
+ self.setWindowTitle(window_title)
+ screen = QApplication.primaryScreen()
+ screen_geometry = screen.geometry()
+ self.resize(
+ int(screen_geometry.width() / 1.5), int(screen_geometry.height() * 0.80)
+ )
+
+ def open(self):
+ self.show()
+ self.exec()
+
+ def on_widget_closed(self):
+ self.close()
diff --git a/src/digest/pytorch_ingest.py b/src/digest/pytorch_ingest.py
new file mode 100644
index 0000000..0ddd802
--- /dev/null
+++ b/src/digest/pytorch_ingest.py
@@ -0,0 +1,280 @@
+# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.
+
+import os
+from collections import OrderedDict
+from typing import Optional, Callable, Union
+from platformdirs import user_cache_dir
+
+# pylint: disable=no-name-in-module
+from PySide6.QtWidgets import (
+ QWidget,
+ QLabel,
+ QLineEdit,
+ QSizePolicy,
+ QFormLayout,
+ QFileDialog,
+ QHBoxLayout,
+)
+from PySide6.QtGui import QFont
+from PySide6.QtCore import Qt, Signal
+from utils import onnx_utils
+from digest.ui.pytorchingest_ui import Ui_pytorchIngest
+from digest.qt_utils import apply_dark_style_sheet
+from digest.model_class.digest_pytorch_model import (
+ get_model_fwd_parameters,
+ DigestPyTorchModel,
+)
+
+
+class UserInputFormWithInfo:
+ def __init__(self, form_layout: QFormLayout):
+ self.form_layout = form_layout
+ self.num_rows = 0
+
+ def add_row(
+ self,
+ label_text: str,
+ edit_text: str,
+ text_width: int,
+ info_text: str,
+ edit_finished_fnc: Optional[Callable] = None,
+ ) -> int:
+
+ font = QFont("Inter", 10)
+ label = QLabel(f"{label_text}:")
+ label.setContentsMargins(0, 0, 0, 0)
+ label.setFont(font)
+
+ line_edit = QLineEdit()
+ line_edit.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Fixed)
+ line_edit.setMinimumWidth(text_width)
+ line_edit.setMinimumHeight(20)
+ line_edit.setText(edit_text)
+ if edit_finished_fnc:
+ line_edit.editingFinished.connect(edit_finished_fnc)
+
+ info_label = QLabel()
+ info_label.setText(info_text)
+ font = QFont("Arial", 10, italic=True)
+ info_label.setFont(font)
+ info_label.setContentsMargins(10, 0, 0, 0)
+
+ row_layout = QHBoxLayout()
+ row_layout.setAlignment(Qt.AlignmentFlag.AlignLeft)
+ row_layout.setSpacing(5)
+ row_layout.setObjectName(f"row{self.num_rows}_layout")
+ row_layout.addWidget(label, alignment=Qt.AlignmentFlag.AlignHCenter)
+ row_layout.addWidget(line_edit, alignment=Qt.AlignmentFlag.AlignHCenter)
+ row_layout.addWidget(info_label, alignment=Qt.AlignmentFlag.AlignHCenter)
+
+ self.num_rows += 1
+ self.form_layout.addRow(row_layout)
+
+ return self.num_rows
+
+ def get_row_label(self, row_idx: int) -> str:
+ form_item = self.form_layout.itemAt(row_idx, QFormLayout.ItemRole.FieldRole)
+ if form_item:
+ row_layout = form_item.layout()
+ if isinstance(row_layout, QHBoxLayout):
+ line_edit_item = row_layout.itemAt(0)
+ if line_edit_item:
+ line_edit_widget = line_edit_item.widget()
+ if isinstance(line_edit_widget, QLabel):
+ return line_edit_widget.text()
+ return ""
+
+ def get_row_line_edit(self, row_idx: int) -> str:
+ form_item = self.form_layout.itemAt(row_idx, QFormLayout.ItemRole.FieldRole)
+ if form_item:
+ row_layout = form_item.layout()
+ if isinstance(row_layout, QHBoxLayout):
+ line_edit_item = row_layout.itemAt(1)
+ if line_edit_item:
+ line_edit_widget = line_edit_item.widget()
+ if isinstance(line_edit_widget, QLineEdit):
+ return line_edit_widget.text()
+ return ""
+
+ def get_row_line_edit_widget(self, row_idx: int) -> Union[QLineEdit, None]:
+ form_item = self.form_layout.itemAt(row_idx, QFormLayout.ItemRole.FieldRole)
+ if form_item:
+ row_layout = form_item.layout()
+ if isinstance(row_layout, QHBoxLayout):
+ line_edit_item = row_layout.itemAt(1)
+ if line_edit_item:
+ line_edit_widget = line_edit_item.widget()
+ if isinstance(line_edit_widget, QLineEdit):
+ return line_edit_widget
+ return None
+
+
+class PyTorchIngest(QWidget):
+ """PyTorchIngest is the pop up window that enables users to set static shapes and export
+ PyTorch models to ONNX models."""
+
+ # This enables the widget to close the parent window
+ close_signal = Signal()
+
+ def __init__(
+ self,
+ model_file: str,
+ model_name: str,
+ parent=None,
+ ):
+ super().__init__(parent)
+ self.ui = Ui_pytorchIngest()
+ self.ui.setupUi(self)
+ apply_dark_style_sheet(self)
+
+ self.ui.exportWarningLabel.hide()
+
+ # We use a cache dir to save the exported ONNX model
+ # Users have the option to choose a different location
+ # if they wish to keep the exported model.
+ user_cache_directory = user_cache_dir("digest")
+ os.makedirs(user_cache_directory, exist_ok=True)
+ self.save_directory: str = user_cache_directory
+
+ self.ui.selectDirBtn.clicked.connect(self.select_directory)
+ self.ui.exportOnnxBtn.clicked.connect(self.export_onnx)
+
+ self.ui.modelName.setText(str(model_name))
+
+ self.ui.modelFilename.setText(str(model_file))
+
+ self.ui.foldingCheckBox.stateChanged.connect(self.on_checkbox_folding_changed)
+ self.ui.exportParamsCheckBox.stateChanged.connect(
+ self.on_checkbox_export_params_changed
+ )
+
+ self.digest_pytorch_model = DigestPyTorchModel(model_file, model_name)
+ self.digest_pytorch_model.do_constant_folding = (
+ self.ui.foldingCheckBox.isChecked()
+ )
+ self.digest_pytorch_model.export_params = (
+ self.ui.exportParamsCheckBox.isChecked()
+ )
+
+ self.user_input_form = UserInputFormWithInfo(self.ui.inputsFormLayout)
+
+ # Set up the opset form
+ self.lowest_supported_opset = 7 # this requirement came from pytorch
+ self.supported_opset_version = onnx_utils.get_supported_opset()
+ self.ui.opsetLineEdit.setText(str(self.digest_pytorch_model.opset))
+ self.ui.opsetInfoLabel.setStyleSheet("color: grey;")
+ self.ui.opsetInfoLabel.setText(
+ f"(accepted range is {self.lowest_supported_opset} - {self.supported_opset_version}):"
+ )
+ self.ui.opsetLineEdit.editingFinished.connect(self.update_opset_version)
+
+ # Present each input in the forward function
+ self.fwd_parameters = OrderedDict(get_model_fwd_parameters(model_file))
+ for val in self.fwd_parameters.values():
+ self.user_input_form.add_row(
+ str(val),
+ "",
+ 250,
+ "",
+ self.update_input_shape,
+ )
+
+ def set_widget_invalid(self, widget: QWidget):
+ widget.setStyleSheet("border: 1px solid red;")
+
+ def set_widget_valid(self, widget: QWidget):
+ widget.setStyleSheet("")
+
+ def on_checkbox_folding_changed(self):
+ self.digest_pytorch_model.do_constant_folding = (
+ self.ui.foldingCheckBox.isChecked()
+ )
+
+ def on_checkbox_export_params_changed(self):
+ self.digest_pytorch_model.export_params = (
+ self.ui.exportParamsCheckBox.isChecked()
+ )
+
+ def select_directory(self):
+ dir = QFileDialog(self).getExistingDirectory(self, "Select Directory")
+ if os.path.exists(dir):
+ self.save_directory = dir
+ info_message = f"The ONNX model will be exported to {self.save_directory}"
+ self.update_message_label(info_message=info_message)
+
+ def update_message_label(
+ self, info_message: Optional[str] = None, warn_message: Optional[str] = None
+ ) -> None:
+ if info_message:
+ message = f"ℹ️ {info_message}"
+ elif warn_message:
+ message = f"⚠️ {warn_message}"
+
+ self.ui.selectDirLabel.setText(message)
+
+ def update_opset_version(self):
+ opset_text_item = self.ui.opsetLineEdit.text()
+ if all(char.isdigit() for char in opset_text_item):
+ opset_text_item = int(opset_text_item)
+ if (
+ opset_text_item
+ and opset_text_item < self.lowest_supported_opset
+ or opset_text_item > self.supported_opset_version
+ ):
+ self.set_widget_invalid(self.ui.opsetLineEdit)
+ else:
+ self.digest_pytorch_model.opset = opset_text_item
+ self.set_widget_valid(self.ui.opsetLineEdit)
+
+ def update_input_shape(self):
+ """Because this is an external function to the UserInputFormWithInfo class
+ we go through each input everytime there is an update."""
+ for row_idx in range(self.user_input_form.form_layout.rowCount()):
+ label_text = self.user_input_form.get_row_label(row_idx)
+ line_edit_text = self.user_input_form.get_row_line_edit(row_idx)
+ if label_text and line_edit_text:
+ tensor_name = label_text.split(":")[0]
+ if tensor_name in self.digest_pytorch_model.input_tensor_info:
+ self.digest_pytorch_model.input_tensor_info[tensor_name].clear()
+ else:
+ self.digest_pytorch_model.input_tensor_info[tensor_name] = []
+ shape_list = line_edit_text.split(",")
+ try:
+ for dim in shape_list:
+ dim = dim.strip()
+ # Integer based shape
+ if all(char.isdigit() for char in dim):
+ self.digest_pytorch_model.input_tensor_info[
+ tensor_name
+ ].append(int(dim))
+ # Symbolic shape
+ else:
+ self.digest_pytorch_model.input_tensor_info[
+ tensor_name
+ ].append(dim)
+ except ValueError as err:
+ print(f"Malformed shape: {err}")
+ widget = self.user_input_form.get_row_line_edit_widget(row_idx)
+ if widget:
+ self.set_widget_invalid(widget)
+ else:
+ widget = self.user_input_form.get_row_line_edit_widget(row_idx)
+ if widget:
+ self.set_widget_valid(widget)
+
+ def export_onnx(self):
+ onnx_file_path = os.path.join(
+ self.save_directory, f"{self.digest_pytorch_model.model_name}.onnx"
+ )
+ try:
+ self.digest_pytorch_model.export_to_onnx(onnx_file_path)
+ except (TypeError, RuntimeError) as err:
+ self.ui.exportWarningLabel.setText(f"Failed to export ONNX: {err}")
+ self.ui.exportWarningLabel.show()
+ else:
+ self.ui.exportWarningLabel.hide()
+ self.close_widget()
+
+ def close_widget(self):
+ self.close_signal.emit()
+ self.close()
diff --git a/src/digest/resource.qrc b/src/digest/resource.qrc
index 5a70586..6d2e347 100644
--- a/src/digest/resource.qrc
+++ b/src/digest/resource.qrc
@@ -1,21 +1,22 @@
This is a warning message that we can use for now to prompt the user.
", None)) + self.exportOnnxBtn.setText(QCoreApplication.translate("pytorchIngest", u"Export ONNX", None)) + # retranslateUi + diff --git a/src/utils/onnx_utils.py b/src/utils/onnx_utils.py index 4d4b293..9b92be1 100644 --- a/src/utils/onnx_utils.py +++ b/src/utils/onnx_utils.py @@ -211,3 +211,9 @@ def optimize_onnx_model( except onnx.checker.ValidationError: print("Model did not pass checker!") return model_proto, False + + +def get_supported_opset() -> int: + """This function will return the opset version associated + with the currently installed ONNX library""" + return onnx.defs.onnx_opset_version() diff --git a/test/test_gui.py b/test/test_gui.py index 59fbb8f..0308ec7 100644 --- a/test/test_gui.py +++ b/test/test_gui.py @@ -5,23 +5,44 @@ import tempfile import unittest from unittest.mock import patch +import timm +import torch # pylint: disable=no-name-in-module from PySide6.QtTest import QTest -from PySide6.QtCore import Qt +from PySide6.QtCore import Qt, QTimer, QEventLoop from PySide6.QtWidgets import QApplication import digest.main from digest.node_summary import NodeSummary +from digest.model_class.digest_pytorch_model import DigestPyTorchModel +from digest.pytorch_ingest import PyTorchIngest + + +def save_resnet18_pt(directory: str) -> str: + """Simply saves a PyTorch resnet18 model and returns its file path""" + model = timm.models.create_model("resnet18", pretrained=True) # type: ignore + model.eval() + file_path = os.path.join(directory, "resnet18.pt") + # Save the model + try: + torch.save(model, file_path) + return file_path + except Exception as e: # pylint: disable=broad-exception-caught + print(f"Error saving model: {e}") + return "" class DigestGuiTest(unittest.TestCase): - MODEL_BASENAME = "resnet18" + RESNET18_BASENAME = "resnet18" + TEST_DIR = os.path.abspath(os.path.dirname(__file__)) - ONNX_FILEPATH = os.path.normpath(os.path.join(TEST_DIR, f"{MODEL_BASENAME}.onnx")) - YAML_FILEPATH = os.path.normpath( + ONNX_FILE_PATH = os.path.normpath( + os.path.join(TEST_DIR, f"{RESNET18_BASENAME}.onnx") + ) + YAML_FILE_PATH = os.path.normpath( os.path.join( - TEST_DIR, f"{MODEL_BASENAME}_reports", f"{MODEL_BASENAME}_report.yaml" + TEST_DIR, f"{RESNET18_BASENAME}_reports", f"{RESNET18_BASENAME}_report.yaml" ) ) @@ -57,7 +78,7 @@ def wait_all_threads(self, timeout=10000) -> bool: def test_open_valid_onnx(self): with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: mock_dialog.return_value = ( - self.ONNX_FILEPATH, + self.ONNX_FILE_PATH, "", ) @@ -76,7 +97,7 @@ def test_open_valid_onnx(self): def test_open_valid_yaml(self): with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: mock_dialog.return_value = ( - self.YAML_FILEPATH, + self.YAML_FILE_PATH, "", ) @@ -92,6 +113,51 @@ def test_open_valid_yaml(self): self.digest_app.closeTab(num_tabs_prior) + def test_open_valid_pytorch(self): + """We test the PyTorch path slightly different than the others + since Digest opens a modal window that blocks the main thread. This makes it difficult + to interact with the Window in this test.""" + + with tempfile.TemporaryDirectory() as tmpdir: + pt_file_path = save_resnet18_pt(tmpdir) + self.assertTrue(os.path.exists(tmpdir)) + basename = os.path.splitext(os.path.basename(pt_file_path)) + model_name = basename[0] + digest_model = DigestPyTorchModel(pt_file_path, model_name) + self.assertTrue(isinstance(digest_model.file_path, str)) + pytorch_ingest = PyTorchIngest(pt_file_path, digest_model.model_name) + pytorch_ingest.show() + + input_shape_edit = pytorch_ingest.user_input_form.get_row_line_edit_widget( + 0 + ) + + assert input_shape_edit + input_shape_edit.setText("batch_size, 3, 224, 224") + pytorch_ingest.update_input_shape() + + with patch( + "PySide6.QtWidgets.QFileDialog.getExistingDirectory" + ) as mock_save_dialog: + print("TMPDIR", tmpdir) + mock_save_dialog.return_value = tmpdir + pytorch_ingest.select_directory() + + pytorch_ingest.ui.exportOnnxBtn.click() + + timeout_ms = 10000 + interval_ms = 100 + for _ in range(timeout_ms // interval_ms): + QTest.qWait(interval_ms) + onnx_file_path = pytorch_ingest.digest_pytorch_model.onnx_file_path + if onnx_file_path and os.path.exists(onnx_file_path): + break # File found! + + assert isinstance(pytorch_ingest.digest_pytorch_model.onnx_file_path, str) + self.assertTrue( + os.path.exists(pytorch_ingest.digest_pytorch_model.onnx_file_path) + ) + def test_open_invalid_file(self): with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: mock_dialog.return_value = ("invalid_file.txt", "") @@ -107,7 +173,7 @@ def test_save_reports(self): "PySide6.QtWidgets.QFileDialog.getExistingDirectory" ) as mock_save_dialog: - mock_open_dialog.return_value = (self.ONNX_FILEPATH, "") + mock_open_dialog.return_value = (self.ONNX_FILE_PATH, "") with tempfile.TemporaryDirectory() as tmpdirname: mock_save_dialog.return_value = tmpdirname @@ -127,41 +193,41 @@ def test_save_reports(self): mock_save_dialog.assert_called_once() result_basepath = os.path.join( - tmpdirname, f"{self.MODEL_BASENAME}_reports" + tmpdirname, f"{self.RESNET18_BASENAME}_reports" ) # Text report test - text_report_filepath = os.path.join( - result_basepath, f"{self.MODEL_BASENAME}_report.txt" + text_report_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_report.txt" ) self.assertTrue( - os.path.isfile(text_report_filepath), - f"{text_report_filepath} not found!", + os.path.isfile(text_report_FILE_PATH), + f"{text_report_FILE_PATH} not found!", ) # YAML report test - yaml_report_filepath = os.path.join( - result_basepath, f"{self.MODEL_BASENAME}_report.yaml" + yaml_report_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_report.yaml" ) - self.assertTrue(os.path.isfile(yaml_report_filepath)) + self.assertTrue(os.path.isfile(yaml_report_FILE_PATH)) # Nodes test - nodes_csv_report_filepath = os.path.join( - result_basepath, f"{self.MODEL_BASENAME}_nodes.csv" + nodes_csv_report_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_nodes.csv" ) - self.assertTrue(os.path.isfile(nodes_csv_report_filepath)) + self.assertTrue(os.path.isfile(nodes_csv_report_FILE_PATH)) # Histogram test - histogram_filepath = os.path.join( - result_basepath, f"{self.MODEL_BASENAME}_histogram.png" + histogram_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_histogram.png" ) - self.assertTrue(os.path.isfile(histogram_filepath)) + self.assertTrue(os.path.isfile(histogram_FILE_PATH)) # Heatmap test - heatmap_filepath = os.path.join( - result_basepath, f"{self.MODEL_BASENAME}_heatmap.png" + heatmap_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_heatmap.png" ) - self.assertTrue(os.path.isfile(heatmap_filepath)) + self.assertTrue(os.path.isfile(heatmap_FILE_PATH)) num_tabs = self.digest_app.ui.tabWidget.count() self.assertTrue(num_tabs == 1) @@ -174,10 +240,10 @@ def test_save_tables(self): "PySide6.QtWidgets.QFileDialog.getSaveFileName" ) as mock_save_dialog: - mock_open_dialog.return_value = (self.ONNX_FILEPATH, "") + mock_open_dialog.return_value = (self.ONNX_FILE_PATH, "") with tempfile.TemporaryDirectory() as tmpdirname: mock_save_dialog.return_value = ( - os.path.join(tmpdirname, f"{self.MODEL_BASENAME}_nodes.csv"), + os.path.join(tmpdirname, f"{self.RESNET18_BASENAME}_nodes.csv"), "", ) @@ -207,7 +273,7 @@ def test_save_tables(self): self.assertTrue( os.path.exists( - os.path.join(tmpdirname, f"{self.MODEL_BASENAME}_nodes.csv") + os.path.join(tmpdirname, f"{self.RESNET18_BASENAME}_nodes.csv") ), "Nodes csv file not found.", )