diff --git a/setup.py b/setup.py index b2ad16d..d6a16f7 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ setup( name="digestai", - version="1.1.0", + version="1.2.0", description="Model analysis toolkit", author="Philip Colangelo, Daniel Holanda", packages=find_packages(where="src"), @@ -25,6 +25,8 @@ "platformdirs>=4.2.2", "pyyaml>=6.0.1", "psutil>=6.0.0", + "torch", + "transformers", ], classifiers=[], entry_points={"console_scripts": ["digest = digest.main:main"]}, diff --git a/src/digest/main.py b/src/digest/main.py index 5a894b5..956e07b 100644 --- a/src/digest/main.py +++ b/src/digest/main.py @@ -39,8 +39,9 @@ from digest.dialog import StatusDialog, InfoDialog, WarnDialog, ProgressDialog from digest.thread import StatsThread, SimilarityThread, post_process -from digest.popup_window import PopupWindow +from digest.popup_window import PopupWindow, PopupDialog from digest.huggingface_page import HuggingfacePage +from digest.pytorch_ingest import PyTorchIngest from digest.multi_model_selection_page import MultiModelSelectionPage from digest.ui.mainwindow_ui import Ui_MainWindow from digest.modelsummary import modelSummary @@ -49,6 +50,7 @@ from digest.model_class.digest_model import DigestModel from digest.model_class.digest_onnx_model import DigestOnnxModel from digest.model_class.digest_report_model import DigestReportModel +from digest.model_class.digest_pytorch_model import DigestPyTorchModel from utils import onnx_utils GUI_CONFIG = os.path.join(os.path.dirname(__file__), "gui_config.yaml") @@ -166,7 +168,11 @@ def __init__(self, model_file: Optional[str] = None): self.status_dialog = None self.err_open_dialog = None self.temp_dir = tempfile.TemporaryDirectory() - self.digest_models: Dict[str, Union[DigestOnnxModel, DigestReportModel]] = {} + self.digest_models: Dict[ + str, Union[DigestOnnxModel, DigestReportModel, DigestPyTorchModel] + ] = {} + + self.pytorch_ingest_window: Optional[PopupDialog] = None # QThread containers self.model_nodes_stats_thread: Dict[str, StatsThread] = {} @@ -225,6 +231,9 @@ def __init__(self, model_file: Optional[str] = None): ) self.multimodelselection_page.model_signal.connect(self.load_model) + # Set up the pyptorch ingest page + self.pytorch_ingest: Optional[PyTorchIngest] = None + # Load model file if given as input to the executable if model_file: exists = os.path.exists(model_file) @@ -287,7 +296,10 @@ def closeTab(self, index): def openFile(self): file_name, _ = QFileDialog.getOpenFileName( - self, "Open File", "", "ONNX and Report Files (*.onnx *.yaml)" + self, + "Open File", + "", + "ONNX, PyTorch, and Report Files (*.onnx *.pt *.yaml)", ) if not file_name: @@ -364,7 +376,7 @@ def update_similarity_widget( completed_successfully: bool, model_id: str, most_similar: str, - png_filepath: Optional[str] = None, + png_file_path: Optional[str] = None, df_sorted: Optional[pd.DataFrame] = None, ): widget = None @@ -388,12 +400,12 @@ def update_similarity_widget( completed_successfully and isinstance(widget, modelSummary) and digest_model - and png_filepath + and png_file_path ): if df_sorted is not None: post_process( - digest_model.model_name, most_similar_list, df_sorted, png_filepath + digest_model.model_name, most_similar_list, df_sorted, png_file_path ) widget.load_gif.stop() @@ -401,7 +413,7 @@ def update_similarity_widget( # We give the image a 10% haircut to fit it more aesthetically widget_width = widget.ui.similarityImg.width() - pixmap = QPixmap(png_filepath) + pixmap = QPixmap(png_file_path) aspect_ratio = pixmap.width() / pixmap.height() target_height = int(widget_width / aspect_ratio) pixmap_scaled = pixmap.scaled( @@ -436,12 +448,12 @@ def update_similarity_widget( # Create option to click to enlarge image widget.ui.similarityImg.mousePressEvent = ( lambda event: self.open_similarity_report( - model_id, png_filepath, most_similar_list + model_id, png_file_path, most_similar_list ) ) # Create option to click to enlarge image self.model_similarity_report[model_id] = SimilarityAnalysisReport( - png_filepath, most_similar_list + png_file_path, most_similar_list ) widget.ui.similarityCorrelation.setText(text) @@ -463,12 +475,12 @@ def update_similarity_widget( ): self.ui.saveBtn.setEnabled(True) - def load_onnx(self, filepath: str): + def load_onnx(self, file_path: str): - # Ensure the filepath follows a standard formatting: - filepath = os.path.normpath(filepath) + # Ensure the file_path follows a standard formatting: + file_path = os.path.normpath(file_path) - if not os.path.exists(filepath): + if not os.path.exists(file_path): return # Every time an onnx is loaded we should emulate a model summary button click @@ -477,7 +489,7 @@ def load_onnx(self, filepath: str): # Before opening the file, check to see if it is already opened. for index in range(self.ui.tabWidget.count()): widget = self.ui.tabWidget.widget(index) - if isinstance(widget, modelSummary) and filepath == widget.file: + if isinstance(widget, modelSummary) and file_path == widget.file: self.ui.tabWidget.setCurrentIndex(index) return @@ -486,11 +498,11 @@ def load_onnx(self, filepath: str): progress = ProgressDialog("Loading & Optimizing ONNX Model...", 8, self) QApplication.processEvents() # Process pending events - model = onnx_utils.load_onnx(filepath, load_external_data=False) + model = onnx_utils.load_onnx(file_path, load_external_data=False) opt_model, opt_passed = onnx_utils.optimize_onnx_model(model) progress.step() - basename = os.path.splitext(os.path.basename(filepath)) + basename = os.path.splitext(os.path.basename(file_path)) model_name = basename[0] # Save the model proto so we can use the Freeze Inputs feature @@ -534,14 +546,14 @@ def load_onnx(self, filepath: str): model_summary.ui.similarityCorrelation.hide() model_summary.ui.similarityCorrelationStatic.hide() - model_summary.file = filepath + model_summary.file = file_path model_summary.setObjectName(model_name) model_summary.ui.modelName.setText(model_name) - model_summary.ui.modelFilename.setText(filepath) + model_summary.ui.modelFilename.setText(file_path) model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y")) digest_model.model_name = model_name - digest_model.filepath = filepath + digest_model.file_path = file_path digest_model.model_inputs = onnx_utils.get_model_input_shapes_types( opt_model ) @@ -694,8 +706,8 @@ def load_onnx(self, filepath: str): self.model_similarity_thread[model_id].completed_successfully.connect( self.update_similarity_widget ) - self.model_similarity_thread[model_id].model_filepath = filepath - self.model_similarity_thread[model_id].png_filepath = os.path.join( + self.model_similarity_thread[model_id].model_file_path = file_path + self.model_similarity_thread[model_id].png_file_path = os.path.join( png_tmp_path, f"heatmap_{model_name}.png" ) self.model_similarity_thread[model_id].model_id = model_id @@ -706,12 +718,12 @@ def load_onnx(self, filepath: str): except FileNotFoundError as e: print(f"File not found: {e.filename}") - def load_report(self, filepath: str): + def load_report(self, file_path: str): - # Ensure the filepath follows a standard formatting: - filepath = os.path.normpath(filepath) + # Ensure the file_path follows a standard formatting: + file_path = os.path.normpath(file_path) - if not os.path.exists(filepath): + if not os.path.exists(file_path): return # Every time a report is loaded we should emulate a model summary button click @@ -720,7 +732,7 @@ def load_report(self, filepath: str): # Before opening the file, check to see if it is already opened. for index in range(self.ui.tabWidget.count()): widget = self.ui.tabWidget.widget(index) - if isinstance(widget, modelSummary) and filepath == widget.file: + if isinstance(widget, modelSummary) and file_path == widget.file: self.ui.tabWidget.setCurrentIndex(index) return @@ -729,13 +741,13 @@ def load_report(self, filepath: str): progress = ProgressDialog("Loading Digest Report File...", 2, self) QApplication.processEvents() # Process pending events - digest_model = DigestReportModel(filepath) + digest_model = DigestReportModel(file_path) if not digest_model.is_valid: progress.close() invalid_yaml_dialog = StatusDialog( title="Warning", - status_message=f"YAML file {filepath} is not a valid digest report", + status_message=f"YAML file {file_path} is not a valid digest report", ) invalid_yaml_dialog.show() @@ -758,10 +770,10 @@ def load_report(self, filepath: str): model_summary.ui.similarityCorrelation.hide() model_summary.ui.similarityCorrelationStatic.hide() - model_summary.file = filepath + model_summary.file = file_path model_summary.setObjectName(digest_model.model_name) model_summary.ui.modelName.setText(digest_model.model_name) - model_summary.ui.modelFilename.setText(filepath) + model_summary.ui.modelFilename.setText(file_path) model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y")) model_summary.ui.parameters.setText(format(digest_model.parameters, ",")) @@ -888,7 +900,7 @@ def load_report(self, filepath: str): completed_successfully=bool(digest_model.similarity_heatmap_path), model_id=digest_model.unique_id, most_similar="", - png_filepath=digest_model.similarity_heatmap_path, + png_file_path=digest_model.similarity_heatmap_path, ) progress.close() @@ -896,9 +908,30 @@ def load_report(self, filepath: str): except FileNotFoundError as e: print(f"File not found: {e.filename}") + def load_pytorch(self, file_path: str): + # Ensure the file_path follows a standard formatting: + file_path = os.path.normpath(file_path) + + if not os.path.exists(file_path): + return + + basename = os.path.splitext(os.path.basename(file_path)) + model_name = basename[0] + + self.pytorch_ingest = PyTorchIngest(file_path, model_name) + self.pytorch_ingest_window = PopupDialog( + self.pytorch_ingest, "PyTorch Ingest", self + ) + self.pytorch_ingest_window.open() + + # The above code will block until the user has completed the pytorch ingest form + # The form will exit upon a successful export at which point the path will be set + if self.pytorch_ingest.digest_pytorch_model.onnx_file_path: + self.load_onnx(self.pytorch_ingest.digest_pytorch_model.onnx_file_path) + def load_model(self, file_path: str): - # Ensure the filepath follows a standard formatting: + # Ensure the file_path follows a standard formatting: file_path = os.path.normpath(file_path) if not os.path.exists(file_path): @@ -910,6 +943,8 @@ def load_model(self, file_path: str): self.load_onnx(file_path) elif file_ext == ".yaml": self.load_report(file_path) + elif file_ext == ".pt" or file_ext == ".pth": + self.load_pytorch(file_path) else: bad_ext_dialog = StatusDialog( f"Digest does not support files with the extension {file_ext}", @@ -992,30 +1027,32 @@ def save_reports(self): ) # Save csv of node type counts - node_type_filepath = os.path.join( + node_type_file_path = os.path.join( save_directory, f"{model_name}_node_type_counts.csv" ) - digest_model.save_node_type_counts_csv_report(node_type_filepath) + digest_model.save_node_type_counts_csv_report(node_type_file_path) # Save (copy) the similarity image png_file_path = self.model_similarity_thread[ digest_model.unique_id - ].png_filepath + ].png_file_path png_save_path = os.path.join(save_directory, f"{model_name}_heatmap.png") if png_file_path and os.path.exists(png_file_path): shutil.copy(png_file_path, png_save_path) # Save the text report - txt_report_filepath = os.path.join(save_directory, f"{model_name}_report.txt") - digest_model.save_text_report(txt_report_filepath) + txt_report_file_path = os.path.join(save_directory, f"{model_name}_report.txt") + digest_model.save_text_report(txt_report_file_path) # Save the yaml report - yaml_report_filepath = os.path.join(save_directory, f"{model_name}_report.yaml") - digest_model.save_yaml_report(yaml_report_filepath) + yaml_report_file_path = os.path.join( + save_directory, f"{model_name}_report.yaml" + ) + digest_model.save_yaml_report(yaml_report_file_path) # Save the node list - nodes_report_filepath = os.path.join(save_directory, f"{model_name}_nodes.csv") - self.save_nodes_csv(nodes_report_filepath, False) + nodes_report_file_path = os.path.join(save_directory, f"{model_name}_nodes.csv") + self.save_nodes_csv(nodes_report_file_path, False) self.status_dialog = StatusDialog( f"Saved reports to: \n{os.path.abspath(save_directory)}", @@ -1051,20 +1088,20 @@ def save_file_dialog( ) return path, filter_type - def save_parameters_csv(self, filepath: str, open_dialog: bool = True): - self.save_nodes_csv(filepath, open_dialog) + def save_parameters_csv(self, file_path: str, open_dialog: bool = True): + self.save_nodes_csv(file_path, open_dialog) - def save_flops_csv(self, filepath: str, open_dialog: bool = True): - self.save_nodes_csv(filepath, open_dialog) + def save_flops_csv(self, file_path: str, open_dialog: bool = True): + self.save_nodes_csv(file_path, open_dialog) - def save_nodes_csv(self, csv_filepath: Optional[str], open_dialog: bool = True): + def save_nodes_csv(self, csv_file_path: Optional[str], open_dialog: bool = True): if open_dialog: - csv_filepath, _ = self.save_file_dialog() - if not csv_filepath: - raise ValueError("A filepath must be given.") + csv_file_path, _ = self.save_file_dialog() + if not csv_file_path: + raise ValueError("A file_path must be given.") current_tab = self.ui.tabWidget.currentWidget() if isinstance(current_tab, modelSummary): - current_tab.digest_model.save_nodes_csv_report(csv_filepath) + current_tab.digest_model.save_nodes_csv_report(csv_file_path) def save_chart(self, chart_view): path, _ = self.save_file_dialog("Save PNG", "PNG(*.png)") diff --git a/src/digest/model_class/digest_model.py b/src/digest/model_class/digest_model.py index 9064184..bd38d3a 100644 --- a/src/digest/model_class/digest_model.py +++ b/src/digest/model_class/digest_model.py @@ -13,6 +13,7 @@ class SupportedModelTypes(Enum): ONNX = "onnx" REPORT = "report" + PYTORCH = "pytorch" class NodeParsingException(Exception): @@ -94,10 +95,12 @@ def __init__(self, *args, **kwargs): class DigestModel(ABC): - def __init__(self, filepath: str, model_name: str, model_type: SupportedModelTypes): + def __init__( + self, file_path: str, model_name: str, model_type: SupportedModelTypes + ): # Public members exposed to the API self.unique_id: str = str(uuid4()) - self.filepath: Optional[str] = filepath + self.file_path: Optional[str] = os.path.abspath(file_path) self.model_name: str = model_name self.model_type: SupportedModelTypes = model_type self.node_type_counts: NodeTypeCounts = NodeTypeCounts() @@ -122,27 +125,27 @@ def parse_model_nodes(self, *args, **kwargs) -> None: pass @abstractmethod - def save_yaml_report(self, filepath: str) -> None: + def save_yaml_report(self, file_path: str) -> None: pass @abstractmethod - def save_text_report(self, filepath: str) -> None: + def save_text_report(self, file_path: str) -> None: pass - def save_nodes_csv_report(self, filepath: str) -> None: - save_nodes_csv_report(self.node_data, filepath) + def save_nodes_csv_report(self, file_path: str) -> None: + save_nodes_csv_report(self.node_data, file_path) - def save_node_type_counts_csv_report(self, filepath: str) -> None: + def save_node_type_counts_csv_report(self, file_path: str) -> None: if self.node_type_counts: - save_node_type_counts_csv_report(self.node_type_counts, filepath) + save_node_type_counts_csv_report(self.node_type_counts, file_path) - def save_node_shape_counts_csv_report(self, filepath: str) -> None: - save_node_shape_counts_csv_report(self.get_node_shape_counts(), filepath) + def save_node_shape_counts_csv_report(self, file_path: str) -> None: + save_node_shape_counts_csv_report(self.get_node_shape_counts(), file_path) -def save_nodes_csv_report(node_data: NodeData, filepath: str) -> None: +def save_nodes_csv_report(node_data: NodeData, file_path: str) -> None: - parent_dir = os.path.dirname(os.path.abspath(filepath)) + parent_dir = os.path.dirname(os.path.abspath(file_path)) if not os.path.exists(parent_dir): raise FileNotFoundError(f"Directory {parent_dir} does not exist.") @@ -185,23 +188,23 @@ def save_nodes_csv_report(node_data: NodeData, filepath: str) -> None: flattened_data.append(row) fieldnames = fieldnames + input_fieldnames + output_fieldnames - with open(filepath, "w", encoding="utf-8", newline="") as csvfile: + with open(file_path, "w", encoding="utf-8", newline="") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator="\n") writer.writeheader() writer.writerows(flattened_data) def save_node_type_counts_csv_report( - node_type_counts: NodeTypeCounts, filepath: str + node_type_counts: NodeTypeCounts, file_path: str ) -> None: - parent_dir = os.path.dirname(os.path.abspath(filepath)) + parent_dir = os.path.dirname(os.path.abspath(file_path)) if not os.path.exists(parent_dir): raise FileNotFoundError(f"Directory {parent_dir} does not exist.") header = ["Node Type", "Count"] - with open(filepath, "w", encoding="utf-8", newline="") as csvfile: + with open(file_path, "w", encoding="utf-8", newline="") as csvfile: writer = csv.writer(csvfile, lineterminator="\n") writer.writerow(header) for node_type, node_count in node_type_counts.items(): @@ -209,16 +212,16 @@ def save_node_type_counts_csv_report( def save_node_shape_counts_csv_report( - node_shape_counts: NodeShapeCounts, filepath: str + node_shape_counts: NodeShapeCounts, file_path: str ) -> None: - parent_dir = os.path.dirname(os.path.abspath(filepath)) + parent_dir = os.path.dirname(os.path.abspath(file_path)) if not os.path.exists(parent_dir): raise FileNotFoundError(f"Directory {parent_dir} does not exist.") header = ["Node Type", "Input Tensors Shapes", "Count"] - with open(filepath, "w", encoding="utf-8", newline="") as csvfile: + with open(file_path, "w", encoding="utf-8", newline="") as csvfile: writer = csv.writer(csvfile, dialect="excel", lineterminator="\n") writer.writerow(header) for node_type, node_info in node_shape_counts.items(): diff --git a/src/digest/model_class/digest_onnx_model.py b/src/digest/model_class/digest_onnx_model.py index 35aad1d..5f55636 100644 --- a/src/digest/model_class/digest_onnx_model.py +++ b/src/digest/model_class/digest_onnx_model.py @@ -22,13 +22,11 @@ class DigestOnnxModel(DigestModel): def __init__( self, onnx_model: onnx.ModelProto, - onnx_filepath: str = "", + onnx_file_path: str = "", model_name: str = "", save_proto: bool = True, ) -> None: - super().__init__(onnx_filepath, model_name, SupportedModelTypes.ONNX) - - self.model_type = SupportedModelTypes.ONNX + super().__init__(onnx_file_path, model_name, SupportedModelTypes.ONNX) # Public members exposed to the API self.model_proto: Optional[onnx.ModelProto] = onnx_model if save_proto else None @@ -509,9 +507,9 @@ def parse_model_nodes(self, onnx_model: onnx.ModelProto) -> None: self.node_type_flops.get(node.op_type, 0) + node_info.flops ) - def save_yaml_report(self, filepath: str) -> None: + def save_yaml_report(self, file_path: str) -> None: - parent_dir = os.path.dirname(os.path.abspath(filepath)) + parent_dir = os.path.dirname(os.path.abspath(file_path)) if not os.path.exists(parent_dir): raise FileNotFoundError(f"Directory {parent_dir} does not exist.") @@ -523,7 +521,7 @@ def save_yaml_report(self, filepath: str) -> None: yaml_data = { "report_date": report_date, "model_type": self.model_type.value, - "model_file": self.filepath, + "model_file": self.file_path, "model_name": self.model_name, "model_version": self.model_version, "graph_name": self.graph_name, @@ -542,22 +540,22 @@ def save_yaml_report(self, filepath: str) -> None: "output_tensors": output_tensors, } - with open(filepath, "w", encoding="utf-8") as f_p: + with open(file_path, "w", encoding="utf-8") as f_p: yaml.dump(yaml_data, f_p, sort_keys=False) - def save_text_report(self, filepath: str) -> None: + def save_text_report(self, file_path: str) -> None: - parent_dir = os.path.dirname(os.path.abspath(filepath)) + parent_dir = os.path.dirname(os.path.abspath(file_path)) if not os.path.exists(parent_dir): raise FileNotFoundError(f"Directory {parent_dir} does not exist.") report_date = datetime.now().strftime("%B %d, %Y") - with open(filepath, "w", encoding="utf-8") as f_p: + with open(file_path, "w", encoding="utf-8") as f_p: f_p.write(f"Report created on {report_date}\n") f_p.write(f"Model type: {self.model_type.name}\n") - if self.filepath: - f_p.write(f"ONNX file: {self.filepath}\n") + if self.file_path: + f_p.write(f"ONNX file: {self.file_path}\n") f_p.write(f"Name of the model: {self.model_name}\n") f_p.write(f"Model version: {self.model_version}\n") f_p.write(f"Name of the graph: {self.graph_name}\n") diff --git a/src/digest/model_class/digest_pytorch_model.py b/src/digest/model_class/digest_pytorch_model.py new file mode 100644 index 0000000..68b1a76 --- /dev/null +++ b/src/digest/model_class/digest_pytorch_model.py @@ -0,0 +1,102 @@ +# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. + +import os +from collections import OrderedDict +from typing import List, Tuple, Optional, Any, Union +import inspect +import onnx +import torch +from digest.model_class.digest_onnx_model import DigestOnnxModel +from digest.model_class.digest_model import ( + DigestModel, + SupportedModelTypes, +) + + +class DigestPyTorchModel(DigestModel): + """The idea of this class is to first support PyTorch models by converting them to ONNX + Eventually, we will want to support a PyTorch specific interface that has a custom GUI. + To facilitate this process, it makes the most sense to use this class as helper class + to convert the PyTorch model to ONNX and store the ONNX info in a member DigestOnnxModel + object. We can also store various PyTorch specific details in this class as well. + """ + + def __init__( + self, + pytorch_file_path: str = "", + model_name: str = "", + ) -> None: + super().__init__(pytorch_file_path, model_name, SupportedModelTypes.PYTORCH) + + assert os.path.exists( + pytorch_file_path + ), f"PyTorch file {pytorch_file_path} does not exist." + + # Default opset value + self.opset = 17 + + # Input dictionary to contain the names and shapes + # required for exporting the ONNX model + self.input_tensor_info: OrderedDict[str, List[Any]] = OrderedDict() + + self.pytorch_model = torch.load(pytorch_file_path) + + # Data needed for exporting to ONNX + self.do_constant_folding = True + self.export_params = True + + self.onnx_file_path: Optional[str] = None + + self.digest_onnx_model: Optional[DigestOnnxModel] = None + + def parse_model_nodes(self) -> None: + """This will be done in the DigestOnnxModel""" + + def save_yaml_report(self, file_path: str) -> None: + """This will be done in the DigestOnnxModel""" + + def save_text_report(self, file_path: str) -> None: + """This will be done in the DigestOnnxModel""" + + def generate_random_tensor(self, shape: List[Union[str, int]]): + static_shape = [dim if isinstance(dim, int) else 1 for dim in shape] + return torch.rand(static_shape) + + def export_to_onnx(self, output_onnx_path: str) -> Union[onnx.ModelProto, None]: + + dummy_input_names: List[str] = list(self.input_tensor_info.keys()) + dummy_inputs: List[torch.Tensor] = [] + + for shape in self.input_tensor_info.values(): + dummy_inputs.append(self.generate_random_tensor(shape)) + + dynamic_axes = { + name: {i: dim for i, dim in enumerate(shape) if isinstance(dim, str)} + for name, shape in self.input_tensor_info.items() + } + + try: + torch.onnx.export( + self.pytorch_model, + tuple(dummy_inputs), + output_onnx_path, + input_names=dummy_input_names, + do_constant_folding=self.do_constant_folding, + export_params=self.export_params, + opset_version=self.opset, + dynamic_axes=dynamic_axes, + verbose=False, + ) + + self.onnx_file_path = output_onnx_path + + return onnx.load(output_onnx_path) + + except (TypeError, RuntimeError) as err: + print(f"Failed to export ONNX: {err}") + raise + + +def get_model_fwd_parameters(torch_file_path): + torch_model = torch.load(torch_file_path) + return inspect.signature(torch_model.forward).parameters diff --git a/src/digest/model_class/digest_report_model.py b/src/digest/model_class/digest_report_model.py index 4478285..d99a1ed 100644 --- a/src/digest/model_class/digest_report_model.py +++ b/src/digest/model_class/digest_report_model.py @@ -44,30 +44,30 @@ def parse_tensor_info(csv_tensor_cell_value) -> Tuple[str, list, str, float]: class DigestReportModel(DigestModel): def __init__( self, - report_filepath: str, + report_file_path: str, ) -> None: self.model_type = SupportedModelTypes.REPORT - self.is_valid = self.validate_yaml(report_filepath) + self.is_valid = self.validate_yaml(report_file_path) if not self.is_valid: - print(f"The yaml file {report_filepath} is not a valid digest report.") + print(f"The yaml file {report_file_path} is not a valid digest report.") return self.model_data = OrderedDict() - with open(report_filepath, "r", encoding="utf-8") as yaml_f: + with open(report_file_path, "r", encoding="utf-8") as yaml_f: self.model_data = yaml.safe_load(yaml_f) model_name = self.model_data["model_name"] - super().__init__(report_filepath, model_name, SupportedModelTypes.REPORT) + super().__init__(report_file_path, model_name, SupportedModelTypes.REPORT) self.similarity_heatmap_path: Optional[str] = None self.node_data = NodeData() # Given the path to the digest report, let's check if its a complete cache # and we can grab the nodes csv data and the similarity heatmap - cache_dir = os.path.dirname(os.path.abspath(report_filepath)) + cache_dir = os.path.dirname(os.path.abspath(report_file_path)) expected_heatmap_file = os.path.join(cache_dir, f"{model_name}_heatmap.png") if os.path.exists(expected_heatmap_file): self.similarity_heatmap_path = expected_heatmap_file @@ -169,8 +169,8 @@ def validate_yaml(self, report_file_path: str) -> bool: def parse_model_nodes(self) -> None: """There are no model nodes to parse""" - def save_yaml_report(self, filepath: str) -> None: + def save_yaml_report(self, file_path: str) -> None: """Report models are not intended to be saved""" - def save_text_report(self, filepath: str) -> None: + def save_text_report(self, file_path: str) -> None: """Report models are not intended to be saved""" diff --git a/src/digest/popup_window.py b/src/digest/popup_window.py index 09d1971..6e4e5ea 100644 --- a/src/digest/popup_window.py +++ b/src/digest/popup_window.py @@ -1,11 +1,14 @@ # Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. # pylint: disable=no-name-in-module -from PySide6.QtWidgets import QApplication, QMainWindow, QWidget +from PySide6.QtCore import Qt +from PySide6.QtWidgets import QApplication, QMainWindow, QWidget, QDialog, QVBoxLayout from PySide6.QtGui import QIcon class PopupWindow(QWidget): + """Opens new window that runs separate from the main digest window""" + def __init__(self, widget: QWidget, window_title: str = "", parent=None): super().__init__(parent) @@ -24,3 +27,36 @@ def open(self): def close(self): self.main_window.close() + + +class PopupDialog(QDialog): + """Opens a new window that takes focus and must be closed before returning + to the main digest window""" + + def __init__(self, widget: QWidget, window_title: str = "", parent=None): + super().__init__(parent) + + if hasattr(widget, "close_signal"): + widget.close_signal.connect(self.on_widget_closed) # type: ignore + + self.setWindowModality(Qt.WindowModality.WindowModal) + self.setWindowFlags(Qt.WindowType.Window) + + layout = QVBoxLayout() + layout.addWidget(widget) + self.setLayout(layout) + + self.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg")) + self.setWindowTitle(window_title) + screen = QApplication.primaryScreen() + screen_geometry = screen.geometry() + self.resize( + int(screen_geometry.width() / 1.5), int(screen_geometry.height() * 0.80) + ) + + def open(self): + self.show() + self.exec() + + def on_widget_closed(self): + self.close() diff --git a/src/digest/pytorch_ingest.py b/src/digest/pytorch_ingest.py new file mode 100644 index 0000000..0ddd802 --- /dev/null +++ b/src/digest/pytorch_ingest.py @@ -0,0 +1,280 @@ +# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. + +import os +from collections import OrderedDict +from typing import Optional, Callable, Union +from platformdirs import user_cache_dir + +# pylint: disable=no-name-in-module +from PySide6.QtWidgets import ( + QWidget, + QLabel, + QLineEdit, + QSizePolicy, + QFormLayout, + QFileDialog, + QHBoxLayout, +) +from PySide6.QtGui import QFont +from PySide6.QtCore import Qt, Signal +from utils import onnx_utils +from digest.ui.pytorchingest_ui import Ui_pytorchIngest +from digest.qt_utils import apply_dark_style_sheet +from digest.model_class.digest_pytorch_model import ( + get_model_fwd_parameters, + DigestPyTorchModel, +) + + +class UserInputFormWithInfo: + def __init__(self, form_layout: QFormLayout): + self.form_layout = form_layout + self.num_rows = 0 + + def add_row( + self, + label_text: str, + edit_text: str, + text_width: int, + info_text: str, + edit_finished_fnc: Optional[Callable] = None, + ) -> int: + + font = QFont("Inter", 10) + label = QLabel(f"{label_text}:") + label.setContentsMargins(0, 0, 0, 0) + label.setFont(font) + + line_edit = QLineEdit() + line_edit.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Fixed) + line_edit.setMinimumWidth(text_width) + line_edit.setMinimumHeight(20) + line_edit.setText(edit_text) + if edit_finished_fnc: + line_edit.editingFinished.connect(edit_finished_fnc) + + info_label = QLabel() + info_label.setText(info_text) + font = QFont("Arial", 10, italic=True) + info_label.setFont(font) + info_label.setContentsMargins(10, 0, 0, 0) + + row_layout = QHBoxLayout() + row_layout.setAlignment(Qt.AlignmentFlag.AlignLeft) + row_layout.setSpacing(5) + row_layout.setObjectName(f"row{self.num_rows}_layout") + row_layout.addWidget(label, alignment=Qt.AlignmentFlag.AlignHCenter) + row_layout.addWidget(line_edit, alignment=Qt.AlignmentFlag.AlignHCenter) + row_layout.addWidget(info_label, alignment=Qt.AlignmentFlag.AlignHCenter) + + self.num_rows += 1 + self.form_layout.addRow(row_layout) + + return self.num_rows + + def get_row_label(self, row_idx: int) -> str: + form_item = self.form_layout.itemAt(row_idx, QFormLayout.ItemRole.FieldRole) + if form_item: + row_layout = form_item.layout() + if isinstance(row_layout, QHBoxLayout): + line_edit_item = row_layout.itemAt(0) + if line_edit_item: + line_edit_widget = line_edit_item.widget() + if isinstance(line_edit_widget, QLabel): + return line_edit_widget.text() + return "" + + def get_row_line_edit(self, row_idx: int) -> str: + form_item = self.form_layout.itemAt(row_idx, QFormLayout.ItemRole.FieldRole) + if form_item: + row_layout = form_item.layout() + if isinstance(row_layout, QHBoxLayout): + line_edit_item = row_layout.itemAt(1) + if line_edit_item: + line_edit_widget = line_edit_item.widget() + if isinstance(line_edit_widget, QLineEdit): + return line_edit_widget.text() + return "" + + def get_row_line_edit_widget(self, row_idx: int) -> Union[QLineEdit, None]: + form_item = self.form_layout.itemAt(row_idx, QFormLayout.ItemRole.FieldRole) + if form_item: + row_layout = form_item.layout() + if isinstance(row_layout, QHBoxLayout): + line_edit_item = row_layout.itemAt(1) + if line_edit_item: + line_edit_widget = line_edit_item.widget() + if isinstance(line_edit_widget, QLineEdit): + return line_edit_widget + return None + + +class PyTorchIngest(QWidget): + """PyTorchIngest is the pop up window that enables users to set static shapes and export + PyTorch models to ONNX models.""" + + # This enables the widget to close the parent window + close_signal = Signal() + + def __init__( + self, + model_file: str, + model_name: str, + parent=None, + ): + super().__init__(parent) + self.ui = Ui_pytorchIngest() + self.ui.setupUi(self) + apply_dark_style_sheet(self) + + self.ui.exportWarningLabel.hide() + + # We use a cache dir to save the exported ONNX model + # Users have the option to choose a different location + # if they wish to keep the exported model. + user_cache_directory = user_cache_dir("digest") + os.makedirs(user_cache_directory, exist_ok=True) + self.save_directory: str = user_cache_directory + + self.ui.selectDirBtn.clicked.connect(self.select_directory) + self.ui.exportOnnxBtn.clicked.connect(self.export_onnx) + + self.ui.modelName.setText(str(model_name)) + + self.ui.modelFilename.setText(str(model_file)) + + self.ui.foldingCheckBox.stateChanged.connect(self.on_checkbox_folding_changed) + self.ui.exportParamsCheckBox.stateChanged.connect( + self.on_checkbox_export_params_changed + ) + + self.digest_pytorch_model = DigestPyTorchModel(model_file, model_name) + self.digest_pytorch_model.do_constant_folding = ( + self.ui.foldingCheckBox.isChecked() + ) + self.digest_pytorch_model.export_params = ( + self.ui.exportParamsCheckBox.isChecked() + ) + + self.user_input_form = UserInputFormWithInfo(self.ui.inputsFormLayout) + + # Set up the opset form + self.lowest_supported_opset = 7 # this requirement came from pytorch + self.supported_opset_version = onnx_utils.get_supported_opset() + self.ui.opsetLineEdit.setText(str(self.digest_pytorch_model.opset)) + self.ui.opsetInfoLabel.setStyleSheet("color: grey;") + self.ui.opsetInfoLabel.setText( + f"(accepted range is {self.lowest_supported_opset} - {self.supported_opset_version}):" + ) + self.ui.opsetLineEdit.editingFinished.connect(self.update_opset_version) + + # Present each input in the forward function + self.fwd_parameters = OrderedDict(get_model_fwd_parameters(model_file)) + for val in self.fwd_parameters.values(): + self.user_input_form.add_row( + str(val), + "", + 250, + "", + self.update_input_shape, + ) + + def set_widget_invalid(self, widget: QWidget): + widget.setStyleSheet("border: 1px solid red;") + + def set_widget_valid(self, widget: QWidget): + widget.setStyleSheet("") + + def on_checkbox_folding_changed(self): + self.digest_pytorch_model.do_constant_folding = ( + self.ui.foldingCheckBox.isChecked() + ) + + def on_checkbox_export_params_changed(self): + self.digest_pytorch_model.export_params = ( + self.ui.exportParamsCheckBox.isChecked() + ) + + def select_directory(self): + dir = QFileDialog(self).getExistingDirectory(self, "Select Directory") + if os.path.exists(dir): + self.save_directory = dir + info_message = f"The ONNX model will be exported to {self.save_directory}" + self.update_message_label(info_message=info_message) + + def update_message_label( + self, info_message: Optional[str] = None, warn_message: Optional[str] = None + ) -> None: + if info_message: + message = f"ℹ️ {info_message}" + elif warn_message: + message = f"⚠️ {warn_message}" + + self.ui.selectDirLabel.setText(message) + + def update_opset_version(self): + opset_text_item = self.ui.opsetLineEdit.text() + if all(char.isdigit() for char in opset_text_item): + opset_text_item = int(opset_text_item) + if ( + opset_text_item + and opset_text_item < self.lowest_supported_opset + or opset_text_item > self.supported_opset_version + ): + self.set_widget_invalid(self.ui.opsetLineEdit) + else: + self.digest_pytorch_model.opset = opset_text_item + self.set_widget_valid(self.ui.opsetLineEdit) + + def update_input_shape(self): + """Because this is an external function to the UserInputFormWithInfo class + we go through each input everytime there is an update.""" + for row_idx in range(self.user_input_form.form_layout.rowCount()): + label_text = self.user_input_form.get_row_label(row_idx) + line_edit_text = self.user_input_form.get_row_line_edit(row_idx) + if label_text and line_edit_text: + tensor_name = label_text.split(":")[0] + if tensor_name in self.digest_pytorch_model.input_tensor_info: + self.digest_pytorch_model.input_tensor_info[tensor_name].clear() + else: + self.digest_pytorch_model.input_tensor_info[tensor_name] = [] + shape_list = line_edit_text.split(",") + try: + for dim in shape_list: + dim = dim.strip() + # Integer based shape + if all(char.isdigit() for char in dim): + self.digest_pytorch_model.input_tensor_info[ + tensor_name + ].append(int(dim)) + # Symbolic shape + else: + self.digest_pytorch_model.input_tensor_info[ + tensor_name + ].append(dim) + except ValueError as err: + print(f"Malformed shape: {err}") + widget = self.user_input_form.get_row_line_edit_widget(row_idx) + if widget: + self.set_widget_invalid(widget) + else: + widget = self.user_input_form.get_row_line_edit_widget(row_idx) + if widget: + self.set_widget_valid(widget) + + def export_onnx(self): + onnx_file_path = os.path.join( + self.save_directory, f"{self.digest_pytorch_model.model_name}.onnx" + ) + try: + self.digest_pytorch_model.export_to_onnx(onnx_file_path) + except (TypeError, RuntimeError) as err: + self.ui.exportWarningLabel.setText(f"Failed to export ONNX: {err}") + self.ui.exportWarningLabel.show() + else: + self.ui.exportWarningLabel.hide() + self.close_widget() + + def close_widget(self): + self.close_signal.emit() + self.close() diff --git a/src/digest/resource.qrc b/src/digest/resource.qrc index 5a70586..6d2e347 100644 --- a/src/digest/resource.qrc +++ b/src/digest/resource.qrc @@ -1,21 +1,22 @@ - - assets/icons/close-window-64.ico - assets/icons/info.png - assets/icons/open.png - assets/icons/digest_logo.ico - assets/images/digest_logo_500.jpg - assets/images/remove_background_500_zoom.png - assets/images/remove_background_200_zoom.png - assets/icons/huggingface.png - assets/icons/huggingface_64px.png - assets/gifs/load.gif - assets/icons/save.png - assets/icons/node_list.png - assets/icons/search.png - assets/icons/models.png - assets/icons/file.png - assets/icons/freeze.png - assets/icons/summary.png - + + assets/icons/64px-PyTorch_logo_icon.svg.png + assets/icons/close-window-64.ico + assets/icons/info.png + assets/icons/open.png + assets/icons/digest_logo.ico + assets/images/digest_logo_500.jpg + assets/images/remove_background_500_zoom.png + assets/images/remove_background_200_zoom.png + assets/icons/huggingface.png + assets/icons/huggingface_64px.png + assets/gifs/load.gif + assets/icons/save.png + assets/icons/node_list.png + assets/icons/search.png + assets/icons/models.png + assets/icons/file.png + assets/icons/freeze.png + assets/icons/summary.png + diff --git a/src/digest/resource_rc.py b/src/digest/resource_rc.py index 59afc50..79c2adf 100644 --- a/src/digest/resource_rc.py +++ b/src/digest/resource_rc.py @@ -19134,6 +19134,125 @@ \x00\x9dOif\xf4\x11\xbb\xa4\xfbG\xfe\xfb\x7f\x8c \ \xf7\xde\xa1\x08\xbb~\x00\x00\x00\x00IEND\xaeB\ `\x82\ +\x00\x00\x07A\ +\x89\ +PNG\x0d\x0a\x1a\x0a\x00\x00\x00\x0dIHDR\x00\ +\x00\x00@\x00\x00\x00N\x08\x03\x00\x00\x00\xa7\xbd\xe0\x9c\ +\x00\x00\x00\x04gAMA\x00\x00\xb1\x8f\x0b\xfca\x05\ +\x00\x00\x00 cHRM\x00\x00z&\x00\x00\x80\x84\ +\x00\x00\xfa\x00\x00\x00\x80\xe8\x00\x00u0\x00\x00\xea`\ +\x00\x00:\x98\x00\x00\x17p\x9c\xbaQ<\x00\x00\x02(\ +PLTE\x00\x00\x00\xff\x00\x00\xeeK,\xe3U9\ +\xeeL-\xeeK-\xeeK,\xff@@\xf1L)\xec\ +L+\xefP0\xffUU\xefN*\xeeL,\xeeM\ +,\xeeL,\xdf@ \xf0M.\xeeL,\xefM,\ +\xeeM+\xf0K-\xeeK,\xefL,\xeeL,\xee\ +L+\xefM,\xeeL,\xecM-\xf3I1\xeeL\ +,\xedL,\xedO,\xedL,\xefL,\xeeL,\ +\xff\x80\x00\xeaJ+\xeeL+\xecK-\xedM+\xed\ +L-\xedL,\xeeM+\xeeM-\xedK,\xeeL\ +,\xedK,\xeeK+\xf1G+\xeeL,\xf0L.\ +\xefL-\xeeL-\xeeM,\xedL-\xffI$\xec\ +L/\xeeL,\xeeL,\xff33\xeeM,\xeeD\ +3\xf0M.\xeeL,\xeeK,\xeeL,\xeeL,\ +\xf0J,\xeeL,\xedM,\xedL-\xeeM,\xee\ +K+\xedL+\xe6M3\xeeL,\xeeL,\xebG\ +)\xedK,\xeeL,\xefP0\xedM+\xeeL,\ +\xffU+\xefL+\xeeM-\xefL,\xefL,\xee\ +K,\xeeL,\xebJ/\xedM.\xeeL,\xeeL\ +,\xedL,\xefL+\xeeL-\xeeL,\xeeK.\ +\xeeL,\xefL,\xedM,\xedL,\xeeL,\xee\ +L,\xefJ+\xeeL,\xeeL,\xf0I,\xeeM\ +,\xeeM+\xeeL,\xeeL,\xeeL,\xedM*\ +\xeeL,\xeeL,\xeeM-\xeeK,\xeeM,\xee\ +L+\xefM,\xeeM,\xefK,\xedK+\xefL\ +,\xedK-\xe9N,\xedJ+\xefK,\xeeL,\ +\xefK-\xf0K-\xeeL,\xeeL-\xefL+\xef\ +L,\xeeM-\xedL+\xefM-\xeeL+\xefM\ +-\xedL,\xecJ*\xf2Q(\xeeM,\xeeL,\ +\xedM-\xeaU+\xeeL,\xeeM+\xefK+\xed\ +L,\xefL,\xeeL,\xeeL,\xeeM,\xeeL\ +-\xeeL,\xefJ-\xeeL,\xefL-\xf1N+\ +\xeeL,\xeeM,\xf0M+\xedM+\xeeL,\xee\ +L,\xeeL,\xefL+\xeeM+\xeeM,\xeeM\ +,\xeeJ-\xeeL,\xeeL,\xeeM,\xedM,\ +\xf0K-\xebN'\xeeL,\xff\xff\xff\x8c_ $\ +\x00\x00\x00\xb6tRNS\x00\x01\x95\x09\x9a\xa6\xa9\x04\ +%6\x10\x031\xcf\xfa\x97\x08!\xef\xaa\x993\xd5o\ +\xd1k\x8c\xfd(\x15\xe2\x91\x1d\xae\xe8u\x02\x18\xa5D\ +\x82r\x80jgs\xfb:\xdf\x12\xa7C\xab\xd4\xd2\xbd\ +\x07\x1b\xe0\xd6\x05\xdc\x0f2\xf8i\xdb\xf94\xddV\x83\ +\xc1X\xbc\x0a\xfc\xc3\x19b\xfe q\xe4\x06/x\xbb\ +\x7f\xc4\xf4&8\xd0h\x90\x8d\x94\xb3,\xbf\xcaF\xda\ +\xd7\xf30\xe7\xe3#\xe6\x1e\xf7\xd8\xf1+\xeb\xcc\xb7\x5c\ +\x92Mn\x85\xba\x81@U\x17Hm\xa1}\x11\xf2J\ +|\xc9\xc8e`\x93~\xe97\x13\xbe\xde\x8f\x0c\xb5Y\ +p\x9e\x9d\xed\xb2L\x89\xa2>\xc6a$\xa8{BS\ +\xf6\xea\xee^\x88\xf0\x96-\x86\xec\xa3\xad\x22\x0d@!\ +\x93S\x00\x00\x00\x01bKGD\xb7\xdd\x00;g\x00\ +\x00\x00\x07tIME\x07\xe8\x07\x08\x10*\x1dQ\x1c\ +\xc41\x00\x00\x03TIDATX\xc3\xa5V\xf9C\ +LQ\x14>\xd3\x944K\x85\xb4\xa8\x946\xb4\x8e\xb2\ +f\x10)\xb4\x0e\x8d(\xfb\x12\x225-H\x8aP\x96\ +\xac\xa5\x12\xca.\xd9I\xee\xdfg\xe6\xbds\xdf{\xb3\ +\xf4\xde\xbbw\xeeO\xe7|\xf7|\xdf\xbd\xf7\xdd\xfb\xce\ +9\x00\xaa\xc3\x10b\x84`\x86!\x84\x18\x83\xe4\x07%\ +\xe0\xe1\x07# \xf0\x83\x10\x10\xf9\xfc\x02\xc8\xe7\x16\xa0\ +|^\x01\x89\xcf) \xf3\xf9\x04\x14|.\x01C\xa8\ +\xcc\x0f3\x04\xb7\xfe\x82p\x19_\x18ab_\xdf,\ +\xc1\x16+!\x91Q\xd1\xdc\xeb/Z, K\xb8\xf7\ +\x1f\x83\xd8RN>\xc4\x22\x18\xc7\xc9\x87xD\x138\ +\xf9\xb0LD\x13\x938\xf9\x90\xbc\x5c\x80Sx\xf9\x00\ +\xa9+\x22IZz\x067\xdf=\x8c\x99\x19\xaa\x1f \ +J\xe6g\x99\xb5n{\xa5?\xb4Jc}\xaf\xb1\x9a\ +d\xfbB9,\xeb\xa7\xbb\xa3r\xbd\xa1\x88H\x86\xf5\ +\xf3\xae\x98&\ +\xef\xde\xfb\xf4[\x1f\x06Z\x94\xf3\xd6\x00\xed\x98\xf3\xa3\ +2\x82\xb8z?M\xa7\x0a\xdbt&}\x1e\xb5z\xcd\ +\x91\xea\xc1@G\xccH!\xbec\xe6K\x99\xedk\x95\ +\x1f<9_\x1d\xfc\xd6Lt\x0c\xc7\xf7\xf9_|\x92\ +U\x9b\xdf:\xad\xfa\xd0\xe2~\xa8\xd3\xf3c\x865\xde\ +j\xffO\x97\xca\xee'\x7f\xe9\xf8\xdd\xfa\x7f\xff\x09L\ +/\x99\xd5\xdb\x85\xd8\xa7\xb2\xfcNRZQkb\xc9\ +\xdbf\xcb\xc4\xec\x10\xfeb3\xf1\x7f\xfb\xa6+\x81g\ +D\x0f\xcf\xcde\xfeS\x0d\xf9\x0f|\xbd\x92\x22\xc7k\ +h\x91\x00\x00\x00%tEXtdate:c\ +reate\x002024-07-08\ +T16:42:28+00:00\x91\ +\x1c5\xf8\x00\x00\x00%tEXtdate:\ +modify\x002024-07-0\ +8T16:42:28+00:00\ +\xe0A\x8dD\x00\x00\x00\x00IEND\xaeB`\x82\ +\ \x00\x00\x171\ \x89\ PNG\x0d\x0a\x1a\x0a\x00\x00\x00\x0dIHDR\x00\ @@ -19653,6 +19772,11 @@ \x04\xd2YG\ \x00i\ \x00n\x00f\x00o\x00.\x00p\x00n\x00g\ +\x00\x1e\ +\x01_\x1b\xa7\ +\x006\ +\x004\x00p\x00x\x00-\x00P\x00y\x00T\x00o\x00r\x00c\x00h\x00_\x00l\x00o\x00g\x00o\ +\x00_\x00i\x00c\x00o\x00n\x00.\x00s\x00v\x00g\x00.\x00p\x00n\x00g\ \x00\x0f\ \x0cYr'\ \x00h\ @@ -19669,9 +19793,9 @@ \x00\x00\x00\x00\x00\x00\x00\x00\ \x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x02\ \x00\x00\x00\x00\x00\x00\x00\x00\ -\x00\x00\x00\x22\x00\x02\x00\x00\x00\x01\x00\x00\x00\x15\ +\x00\x00\x00\x22\x00\x02\x00\x00\x00\x01\x00\x00\x00\x16\ \x00\x00\x00\x00\x00\x00\x00\x00\ -\x00\x00\x00\x12\x00\x02\x00\x00\x00\x0d\x00\x00\x00\x08\ +\x00\x00\x00\x12\x00\x02\x00\x00\x00\x0e\x00\x00\x00\x08\ \x00\x00\x00\x00\x00\x00\x00\x00\ \x00\x00\x000\x00\x02\x00\x00\x00\x03\x00\x00\x00\x05\ \x00\x00\x00\x00\x00\x00\x00\x00\ @@ -19683,7 +19807,9 @@ \x00\x00\x01\x93K\x85\xbd\xbb\ \x00\x00\x02\x10\x00\x00\x00\x00\x00\x01\x00\x04\x5c\xf1\ \x00\x00\x01\x93K\x85\xbd\xa9\ -\x00\x00\x02`\x00\x01\x00\x00\x00\x01\x00\x04\xc0\x80\ +\x00\x00\x02<\x00\x00\x00\x00\x00\x01\x00\x04\xa9K\ +\x00\x00\x01\x93\xc1\xeeE~\ +\x00\x00\x02\xa2\x00\x01\x00\x00\x00\x01\x00\x04\xc7\xc5\ \x00\x00\x01\x93K\x85\xbd\xa8\ \x00\x00\x02&\x00\x00\x00\x00\x00\x01\x00\x04c$\ \x00\x00\x01\x93K\x85\xbd\xa9\ @@ -19699,7 +19825,7 @@ \x00\x00\x01\x93K\x85\xbd\xa9\ \x00\x00\x01l\x00\x00\x00\x00\x00\x01\x00\x03E\x14\ \x00\x00\x01\x93K\x85\xbd\xa9\ -\x00\x00\x02<\x00\x00\x00\x00\x00\x01\x00\x04\xa9K\ +\x00\x00\x02~\x00\x00\x00\x00\x00\x01\x00\x04\xb0\x90\ \x00\x00\x01\x93K\x85\xbd\xa9\ \x00\x00\x01\x08\x00\x00\x00\x00\x00\x01\x00\x03\x18\xb3\ \x00\x00\x01\x93K\x85\xbd\xa9\ diff --git a/src/digest/styles/darkstyle.qss b/src/digest/styles/darkstyle.qss index 29c9bbd..9d8b92e 100644 --- a/src/digest/styles/darkstyle.qss +++ b/src/digest/styles/darkstyle.qss @@ -7,6 +7,39 @@ QFrame { background-color: transparent; } +QTextEdit { + background-color: #1e1e1e; + color: #cecece; + border: 1px solid #333333; + border-radius: 3px; + padding: 2px; +} + +QTextEdit::disabled { + background-color: #404040; + color: #656565; + border-color: #333333; +} + +QTextEdit::selection { + background-color: #264f78; + color: #ffffff; +} + +QGroupBox{ + background-color: transparent; + border: 1px solid #282c34; + border-radius: 5px; + margin-top: 2ex; +} + +QGroupBox::title { + color: lightgrey; + subcontrol-origin: margin; + left: 7px; + padding: 0px 5px 0px 5px; +} + QMenu { background-color: #333333; color: #DDDDDD; diff --git a/src/digest/thread.py b/src/digest/thread.py index bf9c546..0d9d6ea 100644 --- a/src/digest/thread.py +++ b/src/digest/thread.py @@ -73,26 +73,26 @@ class SimilarityThread(QThread): def __init__( self, - model_filepath: Optional[str] = None, - png_filepath: Optional[str] = None, + model_file_path: Optional[str] = None, + png_file_path: Optional[str] = None, model_id: Optional[str] = None, ): super().__init__() - self.model_filepath = model_filepath - self.png_filepath = png_filepath + self.model_file_path = model_file_path + self.png_file_path = png_file_path self.model_id = model_id def run(self): - if not self.model_filepath: - raise ValueError("You must set the model filepath") - if not self.png_filepath: - raise ValueError("You must set the png filepath") + if not self.model_file_path: + raise ValueError("You must set the model file_path") + if not self.png_file_path: + raise ValueError("You must set the png file_path") if not self.model_id: raise ValueError("You must set the model id") try: most_similar, _, df_sorted = find_match( - self.model_filepath, + self.model_file_path, dequantize=False, replace=True, ) @@ -100,12 +100,12 @@ def run(self): # We convert List[str] to str to send through the signal most_similar = ",".join(most_similar) self.completed_successfully.emit( - True, self.model_id, most_similar, self.png_filepath, df_sorted + True, self.model_id, most_similar, self.png_file_path, df_sorted ) except Exception as e: # pylint: disable=broad-exception-caught most_similar = "" self.completed_successfully.emit( - False, self.model_id, most_similar, self.png_filepath, df_sorted + False, self.model_id, most_similar, self.png_file_path, df_sorted ) print(f"Issue creating similarity analysis: {e}") diff --git a/src/digest/ui/pytorchingest.ui b/src/digest/ui/pytorchingest.ui new file mode 100644 index 0000000..abc6a5e --- /dev/null +++ b/src/digest/ui/pytorchingest.ui @@ -0,0 +1,560 @@ + + + pytorchIngest + + + + 0 + 0 + 1060 + 748 + + + + + 0 + 0 + + + + Form + + + + :/assets/images/digest_logo_500.jpg:/assets/images/digest_logo_500.jpg + + + + + + + + + + 0 + 0 + + + + + + + + 0 + 0 + + + + + 16777215 + 16777215 + + + + + + + :/assets/icons/64px-PyTorch_logo_icon.svg.png + + + true + + + 5 + + + + + + + + 0 + 0 + + + + false + + + + + + QFrame::Shape::NoFrame + + + QFrame::Shadow::Raised + + + + + + + 0 + 0 + + + + + true + + + + + + + PyTorch Ingest + + + true + + + 1 + + + 5 + + + Qt::TextInteractionFlag::LinksAccessibleByMouse|Qt::TextInteractionFlag::TextSelectableByKeyboard|Qt::TextInteractionFlag::TextSelectableByMouse + + + + + + + Qt::Orientation::Horizontal + + + + 40 + 20 + + + + + + + + + + + + + + + 0 + 0 + + + + Qt::ScrollBarPolicy::ScrollBarAsNeeded + + + Qt::ScrollBarPolicy::ScrollBarAsNeeded + + + QAbstractScrollArea::SizeAdjustPolicy::AdjustToContents + + + true + + + + + 0 + 0 + 1040 + 616 + + + + + 0 + 100 + + + + + + + + 10 + + + + + + 0 + 0 + + + + QLabel { + font-size: 28px; + font-weight: bold; + margin-bottom: -5px; +} + + + model name + + + + + + + path to the model file + + + 5 + + + + + + + 20 + + + 10 + + + + + + 0 + 0 + + + + PointingHandCursor + + + + + + Select Directory + + + false + + + + + + + + + + Select a directory if you would like to save the ONNX model file + + + + + + + Qt::Orientation::Horizontal + + + + 40 + 20 + + + + + + + + + + + 0 + 0 + + + + + 13 + + + + + + + + Export Options + + + + 15 + + + 35 + + + 9 + + + + + 0 + + + 0 + + + + + + 0 + 0 + + + + + 10 + + + + Do constant folding + + + true + + + + + + + + + 10 + + + + + + 0 + 0 + + + + + 10 + + + + Export params + + + true + + + + + + + + + + + + 0 + 0 + + + + + 12 + false + + + + Opset + + + + + + + + 0 + 0 + + + + + 10 + false + + + + (accepted range is 7 - 21): + + + 0 + + + + + + + + 0 + 0 + + + + + 35 + 16777215 + + + + + 10 + + + + 17 + + + + + + + Qt::Orientation::Horizontal + + + + 40 + 20 + + + + + + + + + + + + + + 14 + + + + Inputs + + + + 15 + + + 25 + + + + + + 12 + + + + color: lightgrey; + + + The following inputs were taken from the PyTorch model's forward function. Please set the dimensions for each input needed. Dimensions can be set by specifying a combination of symbolic and integer values separated by a comma, for example: batch_size, 3, 224, 244. + + + true + + + 5 + + + + + + + 20 + + + + + + + + + + + 0 + 0 + + + + QLabel { + font-size: 10px; + background-color: #FFCC00; + border: 1px solid #996600; + color: #333333; + font-weight: bold; + border-radius: 0px; +} + + + <html><head/><body><p>This is a warning message that we can use for now to prompt the user.</p></body></html> + + + 5 + + + + + + + + 0 + 0 + + + + PointingHandCursor + + + + + + Export ONNX + + + false + + + + + + + Qt::Orientation::Vertical + + + + 20 + 40 + + + + + + + + + + + + + + + diff --git a/src/digest/ui/pytorchingest_ui.py b/src/digest/ui/pytorchingest_ui.py new file mode 100644 index 0000000..c9a761e --- /dev/null +++ b/src/digest/ui/pytorchingest_ui.py @@ -0,0 +1,358 @@ +# -*- coding: utf-8 -*- + +################################################################################ +## Form generated from reading UI file 'pytorchingest.ui' +## +## Created by: Qt User Interface Compiler version 6.8.1 +## +## WARNING! All changes made in this file will be lost when recompiling UI file! +################################################################################ + +from PySide6.QtCore import (QCoreApplication, QDate, QDateTime, QLocale, + QMetaObject, QObject, QPoint, QRect, + QSize, QTime, QUrl, Qt) +from PySide6.QtGui import (QBrush, QColor, QConicalGradient, QCursor, + QFont, QFontDatabase, QGradient, QIcon, + QImage, QKeySequence, QLinearGradient, QPainter, + QPalette, QPixmap, QRadialGradient, QTransform) +from PySide6.QtWidgets import (QAbstractScrollArea, QApplication, QCheckBox, QFormLayout, + QFrame, QGroupBox, QHBoxLayout, QLabel, + QLineEdit, QPushButton, QScrollArea, QSizePolicy, + QSpacerItem, QVBoxLayout, QWidget) +import resource_rc + +class Ui_pytorchIngest(object): + def setupUi(self, pytorchIngest): + if not pytorchIngest.objectName(): + pytorchIngest.setObjectName(u"pytorchIngest") + pytorchIngest.resize(1060, 748) + sizePolicy = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(pytorchIngest.sizePolicy().hasHeightForWidth()) + pytorchIngest.setSizePolicy(sizePolicy) + icon = QIcon() + icon.addFile(u":/assets/images/digest_logo_500.jpg", QSize(), QIcon.Mode.Normal, QIcon.State.Off) + pytorchIngest.setWindowIcon(icon) + pytorchIngest.setStyleSheet(u"") + self.verticalLayout = QVBoxLayout(pytorchIngest) + self.verticalLayout.setObjectName(u"verticalLayout") + self.summaryTopBanner = QWidget(pytorchIngest) + self.summaryTopBanner.setObjectName(u"summaryTopBanner") + sizePolicy1 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Maximum) + sizePolicy1.setHorizontalStretch(0) + sizePolicy1.setVerticalStretch(0) + sizePolicy1.setHeightForWidth(self.summaryTopBanner.sizePolicy().hasHeightForWidth()) + self.summaryTopBanner.setSizePolicy(sizePolicy1) + self.summaryTopBannerLayout = QHBoxLayout(self.summaryTopBanner) + self.summaryTopBannerLayout.setObjectName(u"summaryTopBannerLayout") + self.pytorchLogo = QLabel(self.summaryTopBanner) + self.pytorchLogo.setObjectName(u"pytorchLogo") + sizePolicy2 = QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed) + sizePolicy2.setHorizontalStretch(0) + sizePolicy2.setVerticalStretch(0) + sizePolicy2.setHeightForWidth(self.pytorchLogo.sizePolicy().hasHeightForWidth()) + self.pytorchLogo.setSizePolicy(sizePolicy2) + self.pytorchLogo.setMaximumSize(QSize(16777215, 16777215)) + self.pytorchLogo.setPixmap(QPixmap(u":/assets/icons/64px-PyTorch_logo_icon.svg.png")) + self.pytorchLogo.setScaledContents(True) + self.pytorchLogo.setMargin(5) + + self.summaryTopBannerLayout.addWidget(self.pytorchLogo) + + self.headerFrame = QFrame(self.summaryTopBanner) + self.headerFrame.setObjectName(u"headerFrame") + sizePolicy3 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Expanding) + sizePolicy3.setHorizontalStretch(0) + sizePolicy3.setVerticalStretch(0) + sizePolicy3.setHeightForWidth(self.headerFrame.sizePolicy().hasHeightForWidth()) + self.headerFrame.setSizePolicy(sizePolicy3) + self.headerFrame.setAutoFillBackground(False) + self.headerFrame.setStyleSheet(u"") + self.headerFrame.setFrameShape(QFrame.Shape.NoFrame) + self.headerFrame.setFrameShadow(QFrame.Shadow.Raised) + self.horizontalLayout = QHBoxLayout(self.headerFrame) + self.horizontalLayout.setObjectName(u"horizontalLayout") + self.titleLabel = QLabel(self.headerFrame) + self.titleLabel.setObjectName(u"titleLabel") + sizePolicy4 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) + sizePolicy4.setHorizontalStretch(0) + sizePolicy4.setVerticalStretch(0) + sizePolicy4.setHeightForWidth(self.titleLabel.sizePolicy().hasHeightForWidth()) + self.titleLabel.setSizePolicy(sizePolicy4) + font = QFont() + font.setBold(True) + self.titleLabel.setFont(font) + self.titleLabel.setStyleSheet(u"") + self.titleLabel.setWordWrap(True) + self.titleLabel.setMargin(1) + self.titleLabel.setIndent(5) + self.titleLabel.setTextInteractionFlags(Qt.TextInteractionFlag.LinksAccessibleByMouse|Qt.TextInteractionFlag.TextSelectableByKeyboard|Qt.TextInteractionFlag.TextSelectableByMouse) + + self.horizontalLayout.addWidget(self.titleLabel) + + self.horizontalSpacer = QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum) + + self.horizontalLayout.addItem(self.horizontalSpacer) + + + self.summaryTopBannerLayout.addWidget(self.headerFrame) + + + self.verticalLayout.addWidget(self.summaryTopBanner) + + self.scrollArea = QScrollArea(pytorchIngest) + self.scrollArea.setObjectName(u"scrollArea") + sizePolicy5 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.MinimumExpanding) + sizePolicy5.setHorizontalStretch(0) + sizePolicy5.setVerticalStretch(0) + sizePolicy5.setHeightForWidth(self.scrollArea.sizePolicy().hasHeightForWidth()) + self.scrollArea.setSizePolicy(sizePolicy5) + self.scrollArea.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded) + self.scrollArea.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded) + self.scrollArea.setSizeAdjustPolicy(QAbstractScrollArea.SizeAdjustPolicy.AdjustToContents) + self.scrollArea.setWidgetResizable(True) + self.scrollAreaWidgetContents = QWidget() + self.scrollAreaWidgetContents.setObjectName(u"scrollAreaWidgetContents") + self.scrollAreaWidgetContents.setGeometry(QRect(0, 0, 1040, 616)) + sizePolicy6 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.MinimumExpanding) + sizePolicy6.setHorizontalStretch(0) + sizePolicy6.setVerticalStretch(100) + sizePolicy6.setHeightForWidth(self.scrollAreaWidgetContents.sizePolicy().hasHeightForWidth()) + self.scrollAreaWidgetContents.setSizePolicy(sizePolicy6) + self.scrollAreaWidgetContents.setStyleSheet(u"") + self.verticalLayout_20 = QVBoxLayout(self.scrollAreaWidgetContents) + self.verticalLayout_20.setSpacing(10) + self.verticalLayout_20.setObjectName(u"verticalLayout_20") + self.modelName = QLabel(self.scrollAreaWidgetContents) + self.modelName.setObjectName(u"modelName") + sizePolicy7 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Minimum) + sizePolicy7.setHorizontalStretch(0) + sizePolicy7.setVerticalStretch(0) + sizePolicy7.setHeightForWidth(self.modelName.sizePolicy().hasHeightForWidth()) + self.modelName.setSizePolicy(sizePolicy7) + self.modelName.setStyleSheet(u"QLabel {\n" +" font-size: 28px;\n" +" font-weight: bold;\n" +" margin-bottom: -5px;\n" +"}") + + self.verticalLayout_20.addWidget(self.modelName) + + self.modelFilename = QLabel(self.scrollAreaWidgetContents) + self.modelFilename.setObjectName(u"modelFilename") + self.modelFilename.setMargin(5) + + self.verticalLayout_20.addWidget(self.modelFilename) + + self.selectDirLayout = QHBoxLayout() + self.selectDirLayout.setSpacing(20) + self.selectDirLayout.setObjectName(u"selectDirLayout") + self.selectDirLayout.setContentsMargins(-1, -1, -1, 10) + self.selectDirBtn = QPushButton(self.scrollAreaWidgetContents) + self.selectDirBtn.setObjectName(u"selectDirBtn") + sizePolicy8 = QSizePolicy(QSizePolicy.Policy.Maximum, QSizePolicy.Policy.Fixed) + sizePolicy8.setHorizontalStretch(0) + sizePolicy8.setVerticalStretch(0) + sizePolicy8.setHeightForWidth(self.selectDirBtn.sizePolicy().hasHeightForWidth()) + self.selectDirBtn.setSizePolicy(sizePolicy8) + self.selectDirBtn.setCursor(QCursor(Qt.CursorShape.PointingHandCursor)) + self.selectDirBtn.setStyleSheet(u"") + self.selectDirBtn.setAutoExclusive(False) + + self.selectDirLayout.addWidget(self.selectDirBtn) + + self.selectDirLabel = QLabel(self.scrollAreaWidgetContents) + self.selectDirLabel.setObjectName(u"selectDirLabel") + self.selectDirLabel.setStyleSheet(u"") + + self.selectDirLayout.addWidget(self.selectDirLabel) + + self.horizontalSpacer_2 = QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum) + + self.selectDirLayout.addItem(self.horizontalSpacer_2) + + + self.verticalLayout_20.addLayout(self.selectDirLayout) + + self.exportOptionsGroupBox = QGroupBox(self.scrollAreaWidgetContents) + self.exportOptionsGroupBox.setObjectName(u"exportOptionsGroupBox") + sizePolicy4.setHeightForWidth(self.exportOptionsGroupBox.sizePolicy().hasHeightForWidth()) + self.exportOptionsGroupBox.setSizePolicy(sizePolicy4) + font1 = QFont() + font1.setPointSize(13) + self.exportOptionsGroupBox.setFont(font1) + self.exportOptionsGroupBox.setStyleSheet(u"") + self.verticalLayout_2 = QVBoxLayout(self.exportOptionsGroupBox) + self.verticalLayout_2.setSpacing(15) + self.verticalLayout_2.setObjectName(u"verticalLayout_2") + self.verticalLayout_2.setContentsMargins(-1, 35, -1, 9) + self.horizontalLayout_3 = QHBoxLayout() + self.horizontalLayout_3.setSpacing(0) + self.horizontalLayout_3.setObjectName(u"horizontalLayout_3") + self.horizontalLayout_3.setContentsMargins(-1, 0, -1, -1) + self.foldingCheckBox = QCheckBox(self.exportOptionsGroupBox) + self.foldingCheckBox.setObjectName(u"foldingCheckBox") + sizePolicy4.setHeightForWidth(self.foldingCheckBox.sizePolicy().hasHeightForWidth()) + self.foldingCheckBox.setSizePolicy(sizePolicy4) + font2 = QFont() + font2.setPointSize(10) + self.foldingCheckBox.setFont(font2) + self.foldingCheckBox.setChecked(True) + + self.horizontalLayout_3.addWidget(self.foldingCheckBox) + + + self.verticalLayout_2.addLayout(self.horizontalLayout_3) + + self.horizontalLayout_4 = QHBoxLayout() + self.horizontalLayout_4.setSpacing(10) + self.horizontalLayout_4.setObjectName(u"horizontalLayout_4") + self.exportParamsCheckBox = QCheckBox(self.exportOptionsGroupBox) + self.exportParamsCheckBox.setObjectName(u"exportParamsCheckBox") + sizePolicy4.setHeightForWidth(self.exportParamsCheckBox.sizePolicy().hasHeightForWidth()) + self.exportParamsCheckBox.setSizePolicy(sizePolicy4) + self.exportParamsCheckBox.setFont(font2) + self.exportParamsCheckBox.setChecked(True) + + self.horizontalLayout_4.addWidget(self.exportParamsCheckBox) + + + self.verticalLayout_2.addLayout(self.horizontalLayout_4) + + self.opsetLayout = QHBoxLayout() + self.opsetLayout.setObjectName(u"opsetLayout") + self.opsetLabel = QLabel(self.exportOptionsGroupBox) + self.opsetLabel.setObjectName(u"opsetLabel") + sizePolicy2.setHeightForWidth(self.opsetLabel.sizePolicy().hasHeightForWidth()) + self.opsetLabel.setSizePolicy(sizePolicy2) + font3 = QFont() + font3.setPointSize(12) + font3.setBold(False) + self.opsetLabel.setFont(font3) + + self.opsetLayout.addWidget(self.opsetLabel) + + self.opsetInfoLabel = QLabel(self.exportOptionsGroupBox) + self.opsetInfoLabel.setObjectName(u"opsetInfoLabel") + sizePolicy2.setHeightForWidth(self.opsetInfoLabel.sizePolicy().hasHeightForWidth()) + self.opsetInfoLabel.setSizePolicy(sizePolicy2) + font4 = QFont() + font4.setPointSize(10) + font4.setItalic(False) + self.opsetInfoLabel.setFont(font4) + self.opsetInfoLabel.setMargin(0) + + self.opsetLayout.addWidget(self.opsetInfoLabel) + + self.opsetLineEdit = QLineEdit(self.exportOptionsGroupBox) + self.opsetLineEdit.setObjectName(u"opsetLineEdit") + sizePolicy2.setHeightForWidth(self.opsetLineEdit.sizePolicy().hasHeightForWidth()) + self.opsetLineEdit.setSizePolicy(sizePolicy2) + self.opsetLineEdit.setMaximumSize(QSize(35, 16777215)) + self.opsetLineEdit.setFont(font2) + + self.opsetLayout.addWidget(self.opsetLineEdit) + + self.horizontalSpacer_4 = QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum) + + self.opsetLayout.addItem(self.horizontalSpacer_4) + + + self.verticalLayout_2.addLayout(self.opsetLayout) + + + self.verticalLayout_20.addWidget(self.exportOptionsGroupBox) + + self.inputsGroupBox = QGroupBox(self.scrollAreaWidgetContents) + self.inputsGroupBox.setObjectName(u"inputsGroupBox") + font5 = QFont() + font5.setPointSize(14) + self.inputsGroupBox.setFont(font5) + self.verticalLayout_3 = QVBoxLayout(self.inputsGroupBox) + self.verticalLayout_3.setSpacing(15) + self.verticalLayout_3.setObjectName(u"verticalLayout_3") + self.verticalLayout_3.setContentsMargins(-1, 25, -1, -1) + self.label = QLabel(self.inputsGroupBox) + self.label.setObjectName(u"label") + font6 = QFont() + font6.setPointSize(12) + self.label.setFont(font6) + self.label.setStyleSheet(u"color: lightgrey;") + self.label.setWordWrap(True) + self.label.setMargin(5) + + self.verticalLayout_3.addWidget(self.label) + + self.inputsFormLayout = QFormLayout() + self.inputsFormLayout.setObjectName(u"inputsFormLayout") + self.inputsFormLayout.setContentsMargins(20, -1, -1, -1) + + self.verticalLayout_3.addLayout(self.inputsFormLayout) + + + self.verticalLayout_20.addWidget(self.inputsGroupBox) + + self.exportWarningLabel = QLabel(self.scrollAreaWidgetContents) + self.exportWarningLabel.setObjectName(u"exportWarningLabel") + sizePolicy9 = QSizePolicy(QSizePolicy.Policy.Maximum, QSizePolicy.Policy.Preferred) + sizePolicy9.setHorizontalStretch(0) + sizePolicy9.setVerticalStretch(0) + sizePolicy9.setHeightForWidth(self.exportWarningLabel.sizePolicy().hasHeightForWidth()) + self.exportWarningLabel.setSizePolicy(sizePolicy9) + self.exportWarningLabel.setStyleSheet(u"QLabel {\n" +" font-size: 10px;\n" +" background-color: #FFCC00; \n" +" border: 1px solid #996600; \n" +" color: #333333;\n" +" font-weight: bold;\n" +" border-radius: 0px;\n" +"}") + self.exportWarningLabel.setMargin(5) + + self.verticalLayout_20.addWidget(self.exportWarningLabel) + + self.exportOnnxBtn = QPushButton(self.scrollAreaWidgetContents) + self.exportOnnxBtn.setObjectName(u"exportOnnxBtn") + sizePolicy8.setHeightForWidth(self.exportOnnxBtn.sizePolicy().hasHeightForWidth()) + self.exportOnnxBtn.setSizePolicy(sizePolicy8) + self.exportOnnxBtn.setCursor(QCursor(Qt.CursorShape.PointingHandCursor)) + self.exportOnnxBtn.setStyleSheet(u"") + self.exportOnnxBtn.setAutoExclusive(False) + + self.verticalLayout_20.addWidget(self.exportOnnxBtn) + + self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Expanding) + + self.verticalLayout_20.addItem(self.verticalSpacer) + + self.scrollArea.setWidget(self.scrollAreaWidgetContents) + + self.verticalLayout.addWidget(self.scrollArea) + + + self.retranslateUi(pytorchIngest) + + QMetaObject.connectSlotsByName(pytorchIngest) + # setupUi + + def retranslateUi(self, pytorchIngest): + pytorchIngest.setWindowTitle(QCoreApplication.translate("pytorchIngest", u"Form", None)) + self.pytorchLogo.setText("") + self.titleLabel.setText(QCoreApplication.translate("pytorchIngest", u"PyTorch Ingest", None)) + self.modelName.setText(QCoreApplication.translate("pytorchIngest", u"model name", None)) + self.modelFilename.setText(QCoreApplication.translate("pytorchIngest", u"path to the model file", None)) + self.selectDirBtn.setText(QCoreApplication.translate("pytorchIngest", u"Select Directory", None)) + self.selectDirLabel.setText(QCoreApplication.translate("pytorchIngest", u"Select a directory if you would like to save the ONNX model file", None)) + self.exportOptionsGroupBox.setTitle(QCoreApplication.translate("pytorchIngest", u"Export Options", None)) + self.foldingCheckBox.setText(QCoreApplication.translate("pytorchIngest", u"Do constant folding", None)) + self.exportParamsCheckBox.setText(QCoreApplication.translate("pytorchIngest", u"Export params", None)) + self.opsetLabel.setText(QCoreApplication.translate("pytorchIngest", u"Opset", None)) + self.opsetInfoLabel.setText(QCoreApplication.translate("pytorchIngest", u"(accepted range is 7 - 21):", None)) + self.opsetLineEdit.setText(QCoreApplication.translate("pytorchIngest", u"17", None)) + self.inputsGroupBox.setTitle(QCoreApplication.translate("pytorchIngest", u"Inputs", None)) + self.label.setText(QCoreApplication.translate("pytorchIngest", u"The following inputs were taken from the PyTorch model's forward function. Please set the dimensions for each input needed. Dimensions can be set by specifying a combination of symbolic and integer values separated by a comma, for example: batch_size, 3, 224, 244.", None)) + self.exportWarningLabel.setText(QCoreApplication.translate("pytorchIngest", u"

This is a warning message that we can use for now to prompt the user.

", None)) + self.exportOnnxBtn.setText(QCoreApplication.translate("pytorchIngest", u"Export ONNX", None)) + # retranslateUi + diff --git a/src/utils/onnx_utils.py b/src/utils/onnx_utils.py index 4d4b293..9b92be1 100644 --- a/src/utils/onnx_utils.py +++ b/src/utils/onnx_utils.py @@ -211,3 +211,9 @@ def optimize_onnx_model( except onnx.checker.ValidationError: print("Model did not pass checker!") return model_proto, False + + +def get_supported_opset() -> int: + """This function will return the opset version associated + with the currently installed ONNX library""" + return onnx.defs.onnx_opset_version() diff --git a/test/test_gui.py b/test/test_gui.py index 59fbb8f..0308ec7 100644 --- a/test/test_gui.py +++ b/test/test_gui.py @@ -5,23 +5,44 @@ import tempfile import unittest from unittest.mock import patch +import timm +import torch # pylint: disable=no-name-in-module from PySide6.QtTest import QTest -from PySide6.QtCore import Qt +from PySide6.QtCore import Qt, QTimer, QEventLoop from PySide6.QtWidgets import QApplication import digest.main from digest.node_summary import NodeSummary +from digest.model_class.digest_pytorch_model import DigestPyTorchModel +from digest.pytorch_ingest import PyTorchIngest + + +def save_resnet18_pt(directory: str) -> str: + """Simply saves a PyTorch resnet18 model and returns its file path""" + model = timm.models.create_model("resnet18", pretrained=True) # type: ignore + model.eval() + file_path = os.path.join(directory, "resnet18.pt") + # Save the model + try: + torch.save(model, file_path) + return file_path + except Exception as e: # pylint: disable=broad-exception-caught + print(f"Error saving model: {e}") + return "" class DigestGuiTest(unittest.TestCase): - MODEL_BASENAME = "resnet18" + RESNET18_BASENAME = "resnet18" + TEST_DIR = os.path.abspath(os.path.dirname(__file__)) - ONNX_FILEPATH = os.path.normpath(os.path.join(TEST_DIR, f"{MODEL_BASENAME}.onnx")) - YAML_FILEPATH = os.path.normpath( + ONNX_FILE_PATH = os.path.normpath( + os.path.join(TEST_DIR, f"{RESNET18_BASENAME}.onnx") + ) + YAML_FILE_PATH = os.path.normpath( os.path.join( - TEST_DIR, f"{MODEL_BASENAME}_reports", f"{MODEL_BASENAME}_report.yaml" + TEST_DIR, f"{RESNET18_BASENAME}_reports", f"{RESNET18_BASENAME}_report.yaml" ) ) @@ -57,7 +78,7 @@ def wait_all_threads(self, timeout=10000) -> bool: def test_open_valid_onnx(self): with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: mock_dialog.return_value = ( - self.ONNX_FILEPATH, + self.ONNX_FILE_PATH, "", ) @@ -76,7 +97,7 @@ def test_open_valid_onnx(self): def test_open_valid_yaml(self): with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: mock_dialog.return_value = ( - self.YAML_FILEPATH, + self.YAML_FILE_PATH, "", ) @@ -92,6 +113,51 @@ def test_open_valid_yaml(self): self.digest_app.closeTab(num_tabs_prior) + def test_open_valid_pytorch(self): + """We test the PyTorch path slightly different than the others + since Digest opens a modal window that blocks the main thread. This makes it difficult + to interact with the Window in this test.""" + + with tempfile.TemporaryDirectory() as tmpdir: + pt_file_path = save_resnet18_pt(tmpdir) + self.assertTrue(os.path.exists(tmpdir)) + basename = os.path.splitext(os.path.basename(pt_file_path)) + model_name = basename[0] + digest_model = DigestPyTorchModel(pt_file_path, model_name) + self.assertTrue(isinstance(digest_model.file_path, str)) + pytorch_ingest = PyTorchIngest(pt_file_path, digest_model.model_name) + pytorch_ingest.show() + + input_shape_edit = pytorch_ingest.user_input_form.get_row_line_edit_widget( + 0 + ) + + assert input_shape_edit + input_shape_edit.setText("batch_size, 3, 224, 224") + pytorch_ingest.update_input_shape() + + with patch( + "PySide6.QtWidgets.QFileDialog.getExistingDirectory" + ) as mock_save_dialog: + print("TMPDIR", tmpdir) + mock_save_dialog.return_value = tmpdir + pytorch_ingest.select_directory() + + pytorch_ingest.ui.exportOnnxBtn.click() + + timeout_ms = 10000 + interval_ms = 100 + for _ in range(timeout_ms // interval_ms): + QTest.qWait(interval_ms) + onnx_file_path = pytorch_ingest.digest_pytorch_model.onnx_file_path + if onnx_file_path and os.path.exists(onnx_file_path): + break # File found! + + assert isinstance(pytorch_ingest.digest_pytorch_model.onnx_file_path, str) + self.assertTrue( + os.path.exists(pytorch_ingest.digest_pytorch_model.onnx_file_path) + ) + def test_open_invalid_file(self): with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: mock_dialog.return_value = ("invalid_file.txt", "") @@ -107,7 +173,7 @@ def test_save_reports(self): "PySide6.QtWidgets.QFileDialog.getExistingDirectory" ) as mock_save_dialog: - mock_open_dialog.return_value = (self.ONNX_FILEPATH, "") + mock_open_dialog.return_value = (self.ONNX_FILE_PATH, "") with tempfile.TemporaryDirectory() as tmpdirname: mock_save_dialog.return_value = tmpdirname @@ -127,41 +193,41 @@ def test_save_reports(self): mock_save_dialog.assert_called_once() result_basepath = os.path.join( - tmpdirname, f"{self.MODEL_BASENAME}_reports" + tmpdirname, f"{self.RESNET18_BASENAME}_reports" ) # Text report test - text_report_filepath = os.path.join( - result_basepath, f"{self.MODEL_BASENAME}_report.txt" + text_report_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_report.txt" ) self.assertTrue( - os.path.isfile(text_report_filepath), - f"{text_report_filepath} not found!", + os.path.isfile(text_report_FILE_PATH), + f"{text_report_FILE_PATH} not found!", ) # YAML report test - yaml_report_filepath = os.path.join( - result_basepath, f"{self.MODEL_BASENAME}_report.yaml" + yaml_report_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_report.yaml" ) - self.assertTrue(os.path.isfile(yaml_report_filepath)) + self.assertTrue(os.path.isfile(yaml_report_FILE_PATH)) # Nodes test - nodes_csv_report_filepath = os.path.join( - result_basepath, f"{self.MODEL_BASENAME}_nodes.csv" + nodes_csv_report_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_nodes.csv" ) - self.assertTrue(os.path.isfile(nodes_csv_report_filepath)) + self.assertTrue(os.path.isfile(nodes_csv_report_FILE_PATH)) # Histogram test - histogram_filepath = os.path.join( - result_basepath, f"{self.MODEL_BASENAME}_histogram.png" + histogram_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_histogram.png" ) - self.assertTrue(os.path.isfile(histogram_filepath)) + self.assertTrue(os.path.isfile(histogram_FILE_PATH)) # Heatmap test - heatmap_filepath = os.path.join( - result_basepath, f"{self.MODEL_BASENAME}_heatmap.png" + heatmap_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_heatmap.png" ) - self.assertTrue(os.path.isfile(heatmap_filepath)) + self.assertTrue(os.path.isfile(heatmap_FILE_PATH)) num_tabs = self.digest_app.ui.tabWidget.count() self.assertTrue(num_tabs == 1) @@ -174,10 +240,10 @@ def test_save_tables(self): "PySide6.QtWidgets.QFileDialog.getSaveFileName" ) as mock_save_dialog: - mock_open_dialog.return_value = (self.ONNX_FILEPATH, "") + mock_open_dialog.return_value = (self.ONNX_FILE_PATH, "") with tempfile.TemporaryDirectory() as tmpdirname: mock_save_dialog.return_value = ( - os.path.join(tmpdirname, f"{self.MODEL_BASENAME}_nodes.csv"), + os.path.join(tmpdirname, f"{self.RESNET18_BASENAME}_nodes.csv"), "", ) @@ -207,7 +273,7 @@ def test_save_tables(self): self.assertTrue( os.path.exists( - os.path.join(tmpdirname, f"{self.MODEL_BASENAME}_nodes.csv") + os.path.join(tmpdirname, f"{self.RESNET18_BASENAME}_nodes.csv") ), "Nodes csv file not found.", )