diff --git a/.pylintrc b/.pylintrc index b831756..efb14e7 100644 --- a/.pylintrc +++ b/.pylintrc @@ -81,7 +81,6 @@ enable = expression-not-assigned, confusing-with-statement, unnecessary-lambda, - assign-to-new-keyword, redeclared-assigned-name, pointless-statement, pointless-string-statement, @@ -123,7 +122,6 @@ enable = invalid-length-returned, protected-access, attribute-defined-outside-init, - no-init, abstract-method, invalid-overridden-method, arguments-differ, @@ -165,9 +163,7 @@ enable = ### format # Line length, indentation, whitespace: bad-indentation, - mixed-indentation, unnecessary-semicolon, - bad-whitespace, missing-final-newline, line-too-long, mixed-line-endings, @@ -187,7 +183,6 @@ enable = import-self, preferred-module, reimported, - relative-import, deprecated-module, wildcard-import, misplaced-future, @@ -282,12 +277,6 @@ indent-string = ' ' # black doesn't always obey its own limit. See pyproject.toml. max-line-length = 100 -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check = - # Allow the body of a class to be on the same line as the declaration if body # contains single statement. single-line-class-stmt = no diff --git a/examples/analysis.py b/examples/analysis.py index e9b9c63..da89068 100644 --- a/examples/analysis.py +++ b/examples/analysis.py @@ -6,14 +6,16 @@ import csv from collections import Counter, defaultdict from tqdm import tqdm +from digest.model_class.digest_model import ( + NodeShapeCounts, + NodeTypeCounts, + save_node_shape_counts_csv_report, + save_node_type_counts_csv_report, +) +from digest.model_class.digest_onnx_model import DigestOnnxModel from utils.onnx_utils import ( get_dynamic_input_dims, load_onnx, - DigestOnnxModel, - save_node_shape_counts_csv_report, - save_node_type_counts_csv_report, - NodeTypeCounts, - NodeShapeCounts, ) GLOBAL_MODEL_HEADERS = [ @@ -82,46 +84,46 @@ def main(onnx_files: str, output_dir: str): global_model_data[model_name] = { "opset": digest_model.opset, - "parameters": digest_model.model_parameters, - "flops": digest_model.model_flops, + "parameters": digest_model.parameters, + "flops": digest_model.flops, } # Model summary text report summary_filepath = os.path.join(output_dir, f"{model_name}_summary.txt") - digest_model.save_txt_report(summary_filepath) + digest_model.save_text_report(summary_filepath) + + # Model summary yaml report + summary_filepath = os.path.join(output_dir, f"{model_name}_summary.yaml") + digest_model.save_yaml_report(summary_filepath) # Save csv containing node-level information nodes_filepath = os.path.join(output_dir, f"{model_name}_nodes.csv") digest_model.save_nodes_csv_report(nodes_filepath) # Save csv containing node type counter - node_type_counter = digest_model.get_node_type_counts() node_type_filepath = os.path.join( output_dir, f"{model_name}_node_type_counts.csv" ) - if node_type_counter: - save_node_type_counts_csv_report(node_type_counter, node_type_filepath) + + digest_model.save_node_type_counts_csv_report(node_type_filepath) # Update global data structure for node type counter - global_node_type_counter.update(node_type_counter) + global_node_type_counter.update(digest_model.node_type_counts) # Save csv containing node shape counts per op_type - node_shape_counts = digest_model.get_node_shape_counts() node_shape_filepath = os.path.join( output_dir, f"{model_name}_node_shape_counts.csv" ) - save_node_shape_counts_csv_report(node_shape_counts, node_shape_filepath) + digest_model.save_node_shape_counts_csv_report(node_shape_filepath) # Update global data structure for node shape counter - for node_type, shape_counts in node_shape_counts.items(): + for node_type, shape_counts in digest_model.get_node_shape_counts().items(): global_node_shape_counter[node_type].update(shape_counts) if len(onnx_file_list) > 1: global_filepath = os.path.join(output_dir, "global_node_type_counts.csv") - global_node_type_counter = NodeTypeCounts( - global_node_type_counter.most_common() - ) - save_node_type_counts_csv_report(global_node_type_counter, global_filepath) + global_node_type_counts = NodeTypeCounts(global_node_type_counter.most_common()) + save_node_type_counts_csv_report(global_node_type_counts, global_filepath) global_filepath = os.path.join(output_dir, "global_node_shape_counts.csv") save_node_shape_counts_csv_report(global_node_shape_counter, global_filepath) diff --git a/setup.py b/setup.py index ca21f4a..b2ad16d 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ setup( name="digestai", - version="1.0.0", + version="1.1.0", description="Model analysis toolkit", author="Philip Colangelo, Daniel Holanda", packages=find_packages(where="src"), diff --git a/src/digest/dialog.py b/src/digest/dialog.py index d2f834e..ae9986d 100644 --- a/src/digest/dialog.py +++ b/src/digest/dialog.py @@ -125,13 +125,23 @@ class WarnDialog(QDialog): def __init__(self, warning_message: str, parent=None): super().__init__(parent) - self.setWindowTitle("Warning Message") + self.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg")) + + self.setWindowTitle("Warning Message") + self.setWindowFlags(Qt.WindowType.Dialog) self.setMinimumWidth(300) + self.setWindowModality(Qt.WindowModality.WindowModal) + layout = QVBoxLayout() # Application Version - layout.addWidget(QLabel("Something went wrong")) + layout.addWidget(QLabel("Warning")) layout.addWidget(QLabel(warning_message)) + + ok_button = QPushButton("OK") + ok_button.clicked.connect(self.accept) # Close dialog when clicked + layout.addWidget(ok_button) + self.setLayout(layout) diff --git a/src/digest/gui_config.yaml b/src/digest/gui_config.yaml index baffd47..dbd1c08 100644 --- a/src/digest/gui_config.yaml +++ b/src/digest/gui_config.yaml @@ -2,4 +2,4 @@ # For EXE releases we can block certain features e.g. to customers modules: - huggingface: false \ No newline at end of file + huggingface: true \ No newline at end of file diff --git a/src/digest/histogramchartwidget.py b/src/digest/histogramchartwidget.py index 97d5f16..f72befb 100644 --- a/src/digest/histogramchartwidget.py +++ b/src/digest/histogramchartwidget.py @@ -140,7 +140,7 @@ def __init__(self, *args, **kwargs): super(StackedHistogramWidget, self).__init__(*args, **kwargs) self.plot_widget = pg.PlotWidget() - self.plot_widget.setMaximumHeight(150) + self.plot_widget.setMaximumHeight(200) plot_item = self.plot_widget.getPlotItem() if plot_item: plot_item.setContentsMargins(0, 0, 0, 0) @@ -157,7 +157,6 @@ def __init__(self, *args, **kwargs): self.bar_spacing = 25 def set_data(self, data: OrderedDict, model_name, y_max, title="", set_ticks=False): - title_color = "rgb(0,0,0)" if set_ticks else "rgb(200,200,200)" self.plot_widget.setLabel( "left", @@ -173,7 +172,8 @@ def set_data(self, data: OrderedDict, model_name, y_max, title="", set_ticks=Fal x_positions = list(range(len(op_count))) total_count = sum(op_count) width = 0.6 - self.plot_widget.setFixedWidth(len(op_names) * self.bar_spacing) + self.plot_widget.setFixedWidth(500) + for count, x_pos, tick in zip(op_count, x_positions, op_names): x0 = x_pos - width / 2 y0 = 0 diff --git a/src/digest/main.py b/src/digest/main.py index 08c401a..db4c685 100644 --- a/src/digest/main.py +++ b/src/digest/main.py @@ -3,11 +3,13 @@ import os import sys +import shutil import argparse from datetime import datetime -from typing import Dict, Tuple, Optional +from typing import Dict, Tuple, Optional, Union import tempfile from enum import IntEnum +import pandas as pd import yaml # This is a temporary workaround since the Qt designer generated files @@ -33,10 +35,10 @@ QMenu, ) from PySide6.QtGui import QDragEnterEvent, QDropEvent, QPixmap, QMovie, QIcon, QFont -from PySide6.QtCore import Qt, QDir +from PySide6.QtCore import Qt, QSize from digest.dialog import StatusDialog, InfoDialog, WarnDialog, ProgressDialog -from digest.thread import StatsThread, SimilarityThread +from digest.thread import StatsThread, SimilarityThread, post_process from digest.popup_window import PopupWindow from digest.huggingface_page import HuggingfacePage from digest.multi_model_selection_page import MultiModelSelectionPage @@ -44,6 +46,9 @@ from digest.modelsummary import modelSummary from digest.node_summary import NodeSummary from digest.qt_utils import apply_dark_style_sheet +from digest.model_class.digest_model import DigestModel +from digest.model_class.digest_onnx_model import DigestOnnxModel +from digest.model_class.digest_report_model import DigestReportModel from utils import onnx_utils GUI_CONFIG = os.path.join(os.path.dirname(__file__), "gui_config.yaml") @@ -161,11 +166,12 @@ def __init__(self, model_file: Optional[str] = None): self.status_dialog = None self.err_open_dialog = None self.temp_dir = tempfile.TemporaryDirectory() - self.digest_models: Dict[str, onnx_utils.DigestOnnxModel] = {} + self.digest_models: Dict[str, Union[DigestOnnxModel, DigestReportModel]] = {} # QThread containers self.model_nodes_stats_thread: Dict[str, StatsThread] = {} self.model_similarity_thread: Dict[str, SimilarityThread] = {} + self.model_similarity_report: Dict[str, SimilarityAnalysisReport] = {} self.ui.singleModelWidget.hide() @@ -209,7 +215,7 @@ def __init__(self, model_file: Optional[str] = None): # Set up the HUGGINGFACE Page huggingface_page = HuggingfacePage() - huggingface_page.model_signal.connect(self.load_onnx) + huggingface_page.model_signal.connect(self.load_model) self.ui.stackedWidget.insertWidget(self.Page.HUGGINGFACE, huggingface_page) # Set up the multi model page and relevant button @@ -217,15 +223,16 @@ def __init__(self, model_file: Optional[str] = None): self.ui.stackedWidget.insertWidget( self.Page.MULTIMODEL, self.multimodelselection_page ) - self.multimodelselection_page.model_signal.connect(self.load_onnx) + self.multimodelselection_page.model_signal.connect(self.load_model) # Load model file if given as input to the executable if model_file: - if ( - os.path.exists(model_file) - and os.path.splitext(model_file)[-1] == ".onnx" - ): + exists = os.path.exists(model_file) + ext = os.path.splitext(model_file)[-1] + if exists and ext == ".onnx": self.load_onnx(model_file) + elif exists and ext == ".yaml": + self.load_report(model_file) else: self.err_open_dialog = StatusDialog( f"Could not open {model_file}", parent=self @@ -243,10 +250,11 @@ def uncheck_ingest_buttons(self): def tab_focused(self, index): widget = self.ui.tabWidget.widget(index) if isinstance(widget, modelSummary): - model_id = widget.digest_model.unique_id + unique_id = widget.digest_model.unique_id if ( - self.stats_save_button_flag[model_id] - and self.similarity_save_button_flag[model_id] + self.stats_save_button_flag[unique_id] + and self.similarity_save_button_flag[unique_id] + and not isinstance(widget.digest_model, DigestReportModel) ): self.ui.saveBtn.setEnabled(True) else: @@ -257,11 +265,17 @@ def closeTab(self, index): if isinstance(summary_widget, modelSummary): unique_id = summary_widget.digest_model.unique_id summary_widget.deleteLater() - tab_thread = self.model_nodes_stats_thread[unique_id] + + tab_thread = self.model_nodes_stats_thread.get(unique_id) if tab_thread: tab_thread.exit() + tab_thread.wait(5000) + if not tab_thread.isRunning(): del self.model_nodes_stats_thread[unique_id] + else: + print(f"Warning: Thread for {unique_id} did not finish in time") + # delete the digest model to free up used memory if unique_id in self.digest_models: del self.digest_models[unique_id] @@ -272,40 +286,41 @@ def closeTab(self, index): self.ui.singleModelWidget.hide() def openFile(self): - filename, _ = QFileDialog.getOpenFileName( - self, "Open File", "", "ONNX Files (*.onnx)" + file_name, _ = QFileDialog.getOpenFileName( + self, "Open File", "", "ONNX and Report Files (*.onnx *.yaml)" ) - if ( - filename and os.path.splitext(filename)[-1] == ".onnx" - ): # Only if user selects a file and clicks OK - self.load_onnx(filename) + if not file_name: + return - def update_flops_label( + self.load_model(file_name) + + def update_cards( self, - digest_model: onnx_utils.DigestOnnxModel, + digest_model: DigestModel, unique_id: str, ): - self.digest_models[unique_id].model_flops = digest_model.model_flops + self.digest_models[unique_id].flops = digest_model.flops self.digest_models[unique_id].node_type_flops = digest_model.node_type_flops - self.digest_models[unique_id].model_parameters = digest_model.model_parameters + self.digest_models[unique_id].parameters = digest_model.parameters self.digest_models[unique_id].node_type_parameters = ( digest_model.node_type_parameters ) - self.digest_models[unique_id].per_node_info = digest_model.per_node_info + self.digest_models[unique_id].node_data = digest_model.node_data # We must iterate over the tabWidget and match to the tab_name because the user # may have switched the currentTab during the threads execution. + curr_index = -1 for index in range(self.ui.tabWidget.count()): widget = self.ui.tabWidget.widget(index) if ( isinstance(widget, modelSummary) and widget.digest_model.unique_id == unique_id ): - if digest_model.model_flops is None: + if digest_model.flops is None: flops_str = "--" else: - flops_str = format(digest_model.model_flops, ",") + flops_str = format(digest_model.flops, ",") # Set up the pie chart pie_chart_labels, pie_chart_data = zip( @@ -328,11 +343,14 @@ def update_flops_label( pie_chart_labels, pie_chart_data, ) + curr_index = index break self.stats_save_button_flag[unique_id] = True - if self.ui.tabWidget.currentIndex() == index: - if self.similarity_save_button_flag[unique_id]: + if self.ui.tabWidget.currentIndex() == curr_index: + if self.similarity_save_button_flag[unique_id] and not isinstance( + digest_model, DigestReportModel + ): self.ui.saveBtn.setEnabled(True) def open_similarity_report(self, model_id: str, image_path, most_similar_models): @@ -346,10 +364,12 @@ def update_similarity_widget( completed_successfully: bool, model_id: str, most_similar: str, - png_filepath: str, + png_filepath: Optional[str] = None, + df_sorted: Optional[pd.DataFrame] = None, ): - widget = None + digest_model = None + curr_index = -1 for index in range(self.ui.tabWidget.count()): tab_widget = self.ui.tabWidget.widget(index) if ( @@ -357,49 +377,90 @@ def update_similarity_widget( and tab_widget.digest_model.unique_id == model_id ): widget = tab_widget + digest_model = tab_widget.digest_model + curr_index = index break - if completed_successfully and isinstance(widget, modelSummary): + # convert back to a List[str] + most_similar_list = most_similar.split(",") + + if ( + completed_successfully + and isinstance(widget, modelSummary) + and digest_model + and png_filepath + ): + + if df_sorted is not None: + post_process( + digest_model.model_name, most_similar_list, df_sorted, png_filepath + ) + + widget.load_gif.stop() + widget.ui.similarityImg.clear() + # We give the image a 10% haircut to fit it more aesthetically widget_width = widget.ui.similarityImg.width() - widget.ui.similarityImg.setPixmap( - QPixmap(png_filepath).scaledToWidth(widget_width) + + pixmap = QPixmap(png_filepath) + aspect_ratio = pixmap.width() / pixmap.height() + target_height = int(widget_width / aspect_ratio) + pixmap_scaled = pixmap.scaled( + QSize(widget_width, target_height), + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation, ) + + widget.ui.similarityImg.setPixmap(pixmap_scaled) widget.ui.similarityImg.setText("") widget.ui.similarityImg.setCursor(Qt.CursorShape.PointingHandCursor) # Show most correlated models widget.ui.similarityCorrelation.show() widget.ui.similarityCorrelationStatic.show() - most_similar_models = most_similar.split(",") - text = ( - "\n" - f"{most_similar_models[0]}, {most_similar_models[1]}, and {most_similar_models[2]}." - "" - ) + + most_similar_list = most_similar_list[1:4] + if most_similar: + text = ( + "\n" + f"{most_similar_list[0]}, {most_similar_list[1]}, " + f"and {most_similar_list[2]}. " + "" + ) + else: + # currently the similarity widget expects the most_similar_models + # to allows contains 3 models. For now we will just send three empty + # strings but at some point we should handle an arbitrary case. + most_similar_list = ["", "", ""] + text = "NTD" # Create option to click to enlarge image widget.ui.similarityImg.mousePressEvent = ( lambda event: self.open_similarity_report( - model_id, png_filepath, most_similar_models + model_id, png_filepath, most_similar_list ) ) # Create option to click to enlarge image self.model_similarity_report[model_id] = SimilarityAnalysisReport( - png_filepath, most_similar_models + png_filepath, most_similar_list ) widget.ui.similarityCorrelation.setText(text) elif isinstance(widget, modelSummary): # Remove animation and set text to failing message - widget.ui.similarityImg.setMovie(QMovie(None)) + widget.load_gif.stop() + widget.ui.similarityImg.clear() widget.ui.similarityImg.setText("Failed to perform similarity analysis") else: - print("Tab widget is not of type modelSummary which is unexpected.") + print( + f"Tab widget is of type {type(widget)} and not of type modelSummary " + "which is unexpected." + ) - # self.similarity_save_button_flag[model_id] = True - if self.ui.tabWidget.currentIndex() == index: - if self.stats_save_button_flag[model_id]: + if self.ui.tabWidget.currentIndex() == curr_index: + if self.stats_save_button_flag[model_id] and not isinstance( + digest_model, DigestReportModel + ): self.ui.saveBtn.setEnabled(True) def load_onnx(self, filepath: str): @@ -432,8 +493,9 @@ def load_onnx(self, filepath: str): basename = os.path.splitext(os.path.basename(filepath)) model_name = basename[0] - digest_model = onnx_utils.DigestOnnxModel( - onnx_model=model, model_name=model_name, save_proto=False + # Save the model proto so we can use the Freeze Inputs feature + digest_model = DigestOnnxModel( + onnx_model=opt_model, model_name=model_name, save_proto=True ) model_id = digest_model.unique_id @@ -442,11 +504,9 @@ def load_onnx(self, filepath: str): self.digest_models[model_id] = digest_model - # We must set the proto for the model_summary freeze_inputs - self.digest_models[model_id].model_proto = opt_model - - model_summary = modelSummary(self.digest_models[model_id]) - model_summary.freeze_inputs.complete_signal.connect(self.load_onnx) + model_summary = modelSummary(digest_model) + if model_summary.freeze_inputs: + model_summary.freeze_inputs.complete_signal.connect(self.load_onnx) dynamic_input_dims = onnx_utils.get_dynamic_input_dims(opt_model) if dynamic_input_dims: @@ -480,14 +540,13 @@ def load_onnx(self, filepath: str): model_summary.ui.modelFilename.setText(filepath) model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y")) - self.digest_models[model_id].model_name = model_name - self.digest_models[model_id].filepath = filepath - - self.digest_models[model_id].model_inputs = ( - onnx_utils.get_model_input_shapes_types(opt_model) + digest_model.model_name = model_name + digest_model.filepath = filepath + digest_model.model_inputs = onnx_utils.get_model_input_shapes_types( + opt_model ) - self.digest_models[model_id].model_outputs = ( - onnx_utils.get_model_output_shapes_types(opt_model) + digest_model.model_outputs = onnx_utils.get_model_output_shapes_types( + opt_model ) progress.step() @@ -498,9 +557,7 @@ def load_onnx(self, filepath: str): # Kick off model stats thread self.model_nodes_stats_thread[model_id] = StatsThread() - self.model_nodes_stats_thread[model_id].completed.connect( - self.update_flops_label - ) + self.model_nodes_stats_thread[model_id].completed.connect(self.update_cards) self.model_nodes_stats_thread[model_id].model = opt_model self.model_nodes_stats_thread[model_id].tab_name = model_name @@ -518,7 +575,7 @@ def load_onnx(self, filepath: str): model_summary.ui.opHistogramChart.bar_spacing = bar_spacing model_summary.ui.opHistogramChart.set_data(node_type_counts) model_summary.ui.nodes.setText(str(sum(node_type_counts.values()))) - self.digest_models[model_id].node_type_counts = node_type_counts + digest_model.node_type_counts = node_type_counts progress.step() progress.setLabelText("Gathering Model Inputs and Outputs") @@ -577,24 +634,24 @@ def load_onnx(self, filepath: str): model_summary.ui.modelProtoTable.setItem( 0, 1, QTableWidgetItem(str(opt_model.model_version)) ) - self.digest_models[model_id].model_version = opt_model.model_version + digest_model.model_version = opt_model.model_version model_summary.ui.modelProtoTable.setItem( 1, 1, QTableWidgetItem(str(opt_model.graph.name)) ) - self.digest_models[model_id].graph_name = opt_model.graph.name + digest_model.graph_name = opt_model.graph.name producer_txt = f"{opt_model.producer_name} {opt_model.producer_version}" model_summary.ui.modelProtoTable.setItem( 2, 1, QTableWidgetItem(producer_txt) ) - self.digest_models[model_id].producer_name = opt_model.producer_name - self.digest_models[model_id].producer_version = opt_model.producer_version + digest_model.producer_name = opt_model.producer_name + digest_model.producer_version = opt_model.producer_version model_summary.ui.modelProtoTable.setItem( 3, 1, QTableWidgetItem(str(opt_model.ir_version)) ) - self.digest_models[model_id].ir_version = opt_model.ir_version + digest_model.ir_version = opt_model.ir_version for imp in opt_model.opset_import: row_idx = model_summary.ui.importsTable.rowCount() @@ -602,7 +659,7 @@ def load_onnx(self, filepath: str): if imp.domain == "" or imp.domain == "ai.onnx": model_summary.ui.opsetVersion.setText(str(imp.version)) domain = "ai.onnx" - self.digest_models[model_id].opset = imp.version + digest_model.opset = imp.version else: domain = imp.domain model_summary.ui.importsTable.setItem( @@ -613,7 +670,7 @@ def load_onnx(self, filepath: str): ) row_idx += 1 - self.digest_models[model_id].imports[imp.domain] = imp.version + digest_model.imports[imp.domain] = imp.version progress.step() progress.setLabelText("Wrapping Up Model Analysis") @@ -628,14 +685,11 @@ def load_onnx(self, filepath: str): self.ui.singleModelWidget.show() progress.step() - movie = QMovie(":/assets/gifs/load.gif") - model_summary.ui.similarityImg.setMovie(movie) - movie.start() - # Start similarity Analysis # Note: Should only be started after the model tab has been created png_tmp_path = os.path.join(self.temp_dir.name, model_id) os.makedirs(png_tmp_path, exist_ok=True) + assert os.path.exists(png_tmp_path), f"Error with creating {png_tmp_path}" self.model_similarity_thread[model_id] = SimilarityThread() self.model_similarity_thread[model_id].completed_successfully.connect( self.update_similarity_widget @@ -652,6 +706,217 @@ def load_onnx(self, filepath: str): except FileNotFoundError as e: print(f"File not found: {e.filename}") + def load_report(self, filepath: str): + + # Ensure the filepath follows a standard formatting: + filepath = os.path.normpath(filepath) + + if not os.path.exists(filepath): + return + + # Every time a report is loaded we should emulate a model summary button click + self.summary_clicked() + + # Before opening the file, check to see if it is already opened. + for index in range(self.ui.tabWidget.count()): + widget = self.ui.tabWidget.widget(index) + if isinstance(widget, modelSummary) and filepath == widget.file: + self.ui.tabWidget.setCurrentIndex(index) + return + + try: + + progress = ProgressDialog("Loading Digest Report File...", 2, self) + QApplication.processEvents() # Process pending events + + digest_model = DigestReportModel(filepath) + + if not digest_model.is_valid: + progress.close() + invalid_yaml_dialog = StatusDialog( + title="Warning", + status_message=f"YAML file {filepath} is not a valid digest report", + ) + invalid_yaml_dialog.show() + + return + + model_id = digest_model.unique_id + + # There is no sense in offering to save the report + self.stats_save_button_flag[model_id] = False + self.similarity_save_button_flag[model_id] = False + + self.digest_models[model_id] = digest_model + + model_summary = modelSummary(digest_model) + + self.ui.tabWidget.addTab(model_summary, "") + model_summary.ui.flops.setText("Loading...") + + # Hide some of the components + model_summary.ui.similarityCorrelation.hide() + model_summary.ui.similarityCorrelationStatic.hide() + + model_summary.file = filepath + model_summary.setObjectName(digest_model.model_name) + model_summary.ui.modelName.setText(digest_model.model_name) + model_summary.ui.modelFilename.setText(filepath) + model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y")) + + model_summary.ui.parameters.setText(format(digest_model.parameters, ",")) + + node_type_counts = digest_model.node_type_counts + if len(node_type_counts) < 15: + bar_spacing = 40 + else: + bar_spacing = 20 + + model_summary.ui.opHistogramChart.bar_spacing = bar_spacing + model_summary.ui.opHistogramChart.set_data(node_type_counts) + model_summary.ui.nodes.setText(str(sum(node_type_counts.values()))) + + progress.step() + progress.setLabelText("Gathering Model Inputs and Outputs") + + # Inputs Table + model_summary.ui.inputsTable.setRowCount( + len(self.digest_models[model_id].model_inputs) + ) + + for row_idx, (input_name, input_info) in enumerate( + self.digest_models[model_id].model_inputs.items() + ): + model_summary.ui.inputsTable.setItem( + row_idx, 0, QTableWidgetItem(input_name) + ) + model_summary.ui.inputsTable.setItem( + row_idx, 1, QTableWidgetItem(str(input_info.shape)) + ) + model_summary.ui.inputsTable.setItem( + row_idx, 2, QTableWidgetItem(str(input_info.dtype)) + ) + model_summary.ui.inputsTable.setItem( + row_idx, 3, QTableWidgetItem(str(input_info.size_kbytes)) + ) + + model_summary.ui.inputsTable.resizeColumnsToContents() + model_summary.ui.inputsTable.resizeRowsToContents() + + # Outputs Table + model_summary.ui.outputsTable.setRowCount( + len(self.digest_models[model_id].model_outputs) + ) + for row_idx, (output_name, output_info) in enumerate( + self.digest_models[model_id].model_outputs.items() + ): + model_summary.ui.outputsTable.setItem( + row_idx, 0, QTableWidgetItem(output_name) + ) + model_summary.ui.outputsTable.setItem( + row_idx, 1, QTableWidgetItem(str(output_info.shape)) + ) + model_summary.ui.outputsTable.setItem( + row_idx, 2, QTableWidgetItem(str(output_info.dtype)) + ) + model_summary.ui.outputsTable.setItem( + row_idx, 3, QTableWidgetItem(str(output_info.size_kbytes)) + ) + + model_summary.ui.outputsTable.resizeColumnsToContents() + model_summary.ui.outputsTable.resizeRowsToContents() + + progress.step() + progress.setLabelText("Gathering Model Proto Data") + + # ModelProto Info + model_summary.ui.modelProtoTable.setItem( + 0, 1, QTableWidgetItem(str(digest_model.model_data["model_version"])) + ) + + model_summary.ui.modelProtoTable.setItem( + 1, 1, QTableWidgetItem(str(digest_model.model_data["graph_name"])) + ) + + producer_txt = ( + f"{digest_model.model_data['producer_name']} " + f"{digest_model.model_data['producer_version']}" + ) + model_summary.ui.modelProtoTable.setItem( + 2, 1, QTableWidgetItem(producer_txt) + ) + + model_summary.ui.modelProtoTable.setItem( + 3, 1, QTableWidgetItem(str(digest_model.model_data["ir_version"])) + ) + + for domain, version in digest_model.model_data["import_list"].items(): + row_idx = model_summary.ui.importsTable.rowCount() + model_summary.ui.importsTable.insertRow(row_idx) + if domain == "" or domain == "ai.onnx": + model_summary.ui.opsetVersion.setText(str(version)) + domain = "ai.onnx" + + model_summary.ui.importsTable.setItem( + row_idx, 0, QTableWidgetItem(str(domain)) + ) + model_summary.ui.importsTable.setItem( + row_idx, 1, QTableWidgetItem(str(version)) + ) + row_idx += 1 + + progress.step() + progress.setLabelText("Wrapping Up Model Analysis") + + model_summary.ui.importsTable.resizeColumnsToContents() + model_summary.ui.modelProtoTable.resizeColumnsToContents() + model_summary.setObjectName(digest_model.model_name) + new_tab_idx = self.ui.tabWidget.count() - 1 + self.ui.tabWidget.setTabText(new_tab_idx, "".join(digest_model.model_name)) + self.ui.tabWidget.setCurrentIndex(new_tab_idx) + self.ui.stackedWidget.setCurrentIndex(self.Page.SUMMARY) + self.ui.singleModelWidget.show() + progress.step() + + self.update_cards(digest_model, digest_model.unique_id) + + movie = QMovie(":/assets/gifs/load.gif") + model_summary.ui.similarityImg.setMovie(movie) + movie.start() + + self.update_similarity_widget( + completed_successfully=bool(digest_model.similarity_heatmap_path), + model_id=digest_model.unique_id, + most_similar="", + png_filepath=digest_model.similarity_heatmap_path, + ) + + progress.close() + + except FileNotFoundError as e: + print(f"File not found: {e.filename}") + + def load_model(self, file_path: str): + + # Ensure the filepath follows a standard formatting: + file_path = os.path.normpath(file_path) + + if not os.path.exists(file_path): + return + + file_ext = os.path.splitext(file_path)[-1] + + if file_ext == ".onnx": + self.load_onnx(file_path) + elif file_ext == ".yaml": + self.load_report(file_path) + else: + bad_ext_dialog = StatusDialog( + f"Digest does not support files with the extension {file_ext}", + parent=self, + ) + bad_ext_dialog.show() + def dragEnterEvent(self, event: QDragEnterEvent): if event.mimeData().hasUrls(): event.acceptProposedAction() @@ -660,9 +925,7 @@ def dropEvent(self, event: QDropEvent): if event.mimeData().hasUrls(): for url in event.mimeData().urls(): file_path = url.toLocalFile() - if file_path.endswith(".onnx"): - self.load_onnx(file_path) - break + self.load_model(file_path) ## functions for changing menu page def logo_clicked(self): @@ -710,14 +973,8 @@ def save_reports(self): self, "Select Directory" ) - if not save_directory: - return - - # Create a QDir object - directory = QDir(save_directory) - # Check if the directory exists and is writable - if not directory.exists() and directory.isWritable(): # type: ignore + if not os.path.exists(save_directory) or not os.access(save_directory, os.W_OK): self.show_warning_dialog( f"The directory {save_directory} is not valid or writable." ) @@ -726,43 +983,56 @@ def save_reports(self): save_directory, str(digest_model.model_name) + "_reports" ) - os.makedirs(save_directory, exist_ok=True) - - # Save the node histogram image - node_histogram = current_tab.ui.opHistogramChart.grab() - node_histogram.save( - os.path.join(save_directory, f"{model_name}_histogram.png"), "PNG" - ) + try: + os.makedirs(save_directory, exist_ok=True) - # Save csv of node type counts - node_type_filepath = os.path.join( - save_directory, f"{model_name}_node_type_counts.csv" - ) - node_counter = digest_model.get_node_type_counts() - if node_counter: - onnx_utils.save_node_type_counts_csv_report( - node_counter, node_type_filepath + # Save the node histogram image + node_histogram = current_tab.ui.opHistogramChart.grab() + node_histogram.save( + os.path.join(save_directory, f"{model_name}_histogram.png"), "PNG" ) - # Save the similarity image - similarity_png = self.model_similarity_report[digest_model.unique_id].grab() - similarity_png.save( - os.path.join(save_directory, f"{model_name}_heatmap.png"), "PNG" - ) + # Save csv of node type counts + node_type_filepath = os.path.join( + save_directory, f"{model_name}_node_type_counts.csv" + ) + digest_model.save_node_type_counts_csv_report(node_type_filepath) + + # Save (copy) the similarity image + png_file_path = self.model_similarity_thread[ + digest_model.unique_id + ].png_filepath + png_save_path = os.path.join(save_directory, f"{model_name}_heatmap.png") + if png_file_path and os.path.exists(png_file_path): + shutil.copy(png_file_path, png_save_path) + + # Save the text report + txt_report_filepath = os.path.join( + save_directory, f"{model_name}_report.txt" + ) + digest_model.save_text_report(txt_report_filepath) - # Save the text report - txt_report_filepath = os.path.join(save_directory, f"{model_name}_report.txt") - digest_model.save_txt_report(txt_report_filepath) + # Save the yaml report + yaml_report_filepath = os.path.join( + save_directory, f"{model_name}_report.yaml" + ) + digest_model.save_yaml_report(yaml_report_filepath) - # Save the node list - nodes_report_filepath = os.path.join(save_directory, f"{model_name}_nodes.csv") - self.save_nodes_csv(nodes_report_filepath, False) + # Save the node list + nodes_report_filepath = os.path.join( + save_directory, f"{model_name}_nodes.csv" + ) - self.status_dialog = StatusDialog( - f"Saved reports to: \n{os.path.abspath(save_directory)}", - "Successfully saved reports!", - ) - self.status_dialog.show() + self.save_nodes_csv(nodes_report_filepath, False) + except Exception as exception: # pylint: disable=broad-exception-caught + self.status_dialog = StatusDialog(f"{exception}") + self.status_dialog.show() + else: + self.status_dialog = StatusDialog( + f"Saved reports to: \n{os.path.abspath(save_directory)}", + "Successfully saved reports!", + ) + self.status_dialog.show() def on_dialog_closed(self): self.infoDialog = None @@ -829,7 +1099,7 @@ def open_node_summary(self): digest_models = self.digest_models[model_id] node_summary = NodeSummary( - model_name=model_name, node_data=digest_models.per_node_info + model_name=model_name, node_data=digest_models.node_data ) self.nodes_window[model_id] = PopupWindow( diff --git a/src/digest/model_class/digest_model.py b/src/digest/model_class/digest_model.py new file mode 100644 index 0000000..87f399d --- /dev/null +++ b/src/digest/model_class/digest_model.py @@ -0,0 +1,232 @@ +# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. + +import os +import csv +from enum import Enum +from dataclasses import dataclass, field +from uuid import uuid4 +from abc import ABC, abstractmethod +from collections import Counter, OrderedDict, defaultdict +from typing import List, Dict, Optional, Any, Union + + +class SupportedModelTypes(Enum): + ONNX = "onnx" + REPORT = "report" + + +class NodeParsingException(Exception): + pass + + +# The classes are for type aliasing. Once python 3.10 is the minimum we can switch to TypeAlias +class NodeShapeCounts(defaultdict[str, Counter]): + def __init__(self): + super().__init__(Counter) # Initialize with the Counter factory + + +class NodeTypeCounts(Dict[str, int]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +@dataclass +class TensorInfo: + "Used to store node input and output tensor information" + dtype: Optional[str] = None + dtype_bytes: Optional[int] = None + size_kbytes: Optional[float] = None + shape: List[Union[int, str]] = field(default_factory=list) + + +class TensorData(OrderedDict[str, TensorInfo]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class NodeInfo: + def __init__(self) -> None: + self.flops: Optional[int] = None + self.parameters: int = 0 # TODO: should we make this Optional[int] = None? + self.node_type: Optional[str] = None + self.attributes: OrderedDict[str, Any] = OrderedDict() + # We use an ordered dictionary because the order in which + # the inputs and outputs are listed in the node matter. + self.inputs = TensorData() + self.outputs = TensorData() + + def get_input(self, index: int) -> TensorInfo: + return list(self.inputs.values())[index] + + def get_output(self, index: int) -> TensorInfo: + return list(self.outputs.values())[index] + + def __str__(self): + """Provides a human-readable string representation of NodeInfo.""" + output = [ + f"Node Type: {self.node_type}", + f"FLOPs: {self.flops if self.flops is not None else 'N/A'}", + f"Parameters: {self.parameters}", + ] + + if self.attributes: + output.append("Attributes:") + for key, value in self.attributes.items(): + output.append(f" - {key}: {value}") + + if self.inputs: + output.append("Inputs:") + for name, tensor in self.inputs.items(): + output.append(f" - {name}: {tensor}") + + if self.outputs: + output.append("Outputs:") + for name, tensor in self.outputs.items(): + output.append(f" - {name}: {tensor}") + + return "\n".join(output) + + +# The classes are for type aliasing. Once python 3.10 is the minimum we can switch to TypeAlias +class NodeData(OrderedDict[str, NodeInfo]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class DigestModel(ABC): + def __init__(self, filepath: str, model_name: str, model_type: SupportedModelTypes): + # Public members exposed to the API + self.unique_id: str = str(uuid4()) + self.filepath: Optional[str] = filepath + self.model_name: str = model_name + self.model_type: SupportedModelTypes = model_type + self.node_type_counts: NodeTypeCounts = NodeTypeCounts() + self.flops: Optional[int] = None + self.parameters: int = 0 + self.node_type_flops: Dict[str, int] = {} + self.node_type_parameters: Dict[str, int] = {} + self.node_data = NodeData() + self.model_inputs = TensorData() + self.model_outputs = TensorData() + + def get_node_shape_counts(self) -> NodeShapeCounts: + tensor_shape_counter = NodeShapeCounts() + for _, info in self.node_data.items(): + shape_hash = tuple([tuple(v.shape) for _, v in info.inputs.items()]) + if info.node_type: + tensor_shape_counter[info.node_type][shape_hash] += 1 + return tensor_shape_counter + + @abstractmethod + def parse_model_nodes(self, *args, **kwargs) -> None: + pass + + @abstractmethod + def save_yaml_report(self, filepath: str) -> None: + pass + + @abstractmethod + def save_text_report(self, filepath: str) -> None: + pass + + def save_nodes_csv_report(self, filepath: str) -> None: + save_nodes_csv_report(self.node_data, filepath) + + def save_node_type_counts_csv_report(self, filepath: str) -> None: + if self.node_type_counts: + save_node_type_counts_csv_report(self.node_type_counts, filepath) + + def save_node_shape_counts_csv_report(self, filepath: str) -> None: + save_node_shape_counts_csv_report(self.get_node_shape_counts(), filepath) + + +def save_nodes_csv_report(node_data: NodeData, filepath: str) -> None: + + parent_dir = os.path.dirname(os.path.abspath(filepath)) + if not os.path.exists(parent_dir): + raise FileNotFoundError(f"Directory {parent_dir} does not exist.") + + flattened_data = [] + fieldnames = ["Node Name", "Node Type", "Parameters", "FLOPs", "Attributes"] + input_fieldnames = [] + output_fieldnames = [] + for name, node_info in node_data.items(): + row = OrderedDict() + row["Node Name"] = name + row["Node Type"] = str(node_info.node_type) + row["Parameters"] = str(node_info.parameters) + row["FLOPs"] = str(node_info.flops) + if node_info.attributes: + row["Attributes"] = str({k: v for k, v in node_info.attributes.items()}) + else: + row["Attributes"] = "" + + for i, (input_name, input_info) in enumerate(node_info.inputs.items()): + column_name = f"Input{i+1} (Shape, Dtype, Size (kB))" + row[column_name] = ( + f"{input_name} ({input_info.shape}, {input_info.dtype}, {input_info.size_kbytes})" + ) + + # Dynamically add input column names to fieldnames if not already present + if column_name not in input_fieldnames: + input_fieldnames.append(column_name) + + for i, (output_name, output_info) in enumerate(node_info.outputs.items()): + column_name = f"Output{i+1} (Shape, Dtype, Size (kB))" + row[column_name] = ( + f"{output_name} ({output_info.shape}, " + f"{output_info.dtype}, {output_info.size_kbytes})" + ) + + # Dynamically add input column names to fieldnames if not already present + if column_name not in output_fieldnames: + output_fieldnames.append(column_name) + + flattened_data.append(row) + + fieldnames = fieldnames + input_fieldnames + output_fieldnames + try: + with open(filepath, "w", encoding="utf-8", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator="\n") + writer.writeheader() + writer.writerows(flattened_data) + except PermissionError as exception: + raise PermissionError( + f"Saving reports to {filepath} failed with error {exception}" + ) + + +def save_node_type_counts_csv_report( + node_type_counts: NodeTypeCounts, filepath: str +) -> None: + + parent_dir = os.path.dirname(os.path.abspath(filepath)) + if not os.path.exists(parent_dir): + raise FileNotFoundError(f"Directory {parent_dir} does not exist.") + + header = ["Node Type", "Count"] + + with open(filepath, "w", encoding="utf-8", newline="") as csvfile: + writer = csv.writer(csvfile, lineterminator="\n") + writer.writerow(header) + for node_type, node_count in node_type_counts.items(): + writer.writerow([node_type, node_count]) + + +def save_node_shape_counts_csv_report( + node_shape_counts: NodeShapeCounts, filepath: str +) -> None: + + parent_dir = os.path.dirname(os.path.abspath(filepath)) + if not os.path.exists(parent_dir): + raise FileNotFoundError(f"Directory {parent_dir} does not exist.") + + header = ["Node Type", "Input Tensors Shapes", "Count"] + + with open(filepath, "w", encoding="utf-8", newline="") as csvfile: + writer = csv.writer(csvfile, dialect="excel", lineterminator="\n") + writer.writerow(header) + for node_type, node_info in node_shape_counts.items(): + info_iter = iter(node_info.items()) + for shape, count in info_iter: + writer.writerow([node_type, shape, count]) diff --git a/src/digest/model_class/digest_onnx_model.py b/src/digest/model_class/digest_onnx_model.py new file mode 100644 index 0000000..8c8dd7f --- /dev/null +++ b/src/digest/model_class/digest_onnx_model.py @@ -0,0 +1,656 @@ +# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. + +import os +from typing import List, Dict, Optional, Tuple, cast +from datetime import datetime +import importlib.metadata +from collections import OrderedDict +import yaml +import numpy as np +import onnx +from prettytable import PrettyTable +from digest.model_class.digest_model import ( + DigestModel, + SupportedModelTypes, + NodeInfo, + TensorData, + TensorInfo, +) +import utils.onnx_utils as onnx_utils + + +class DigestOnnxModel(DigestModel): + def __init__( + self, + onnx_model: onnx.ModelProto, + onnx_filepath: str = "", + model_name: str = "", + save_proto: bool = True, + ) -> None: + super().__init__(onnx_filepath, model_name, SupportedModelTypes.ONNX) + + self.model_type = SupportedModelTypes.ONNX + + # Public members exposed to the API + self.model_proto: Optional[onnx.ModelProto] = onnx_model if save_proto else None + self.model_version: Optional[int] = None + self.graph_name: Optional[str] = None + self.producer_name: Optional[str] = None + self.producer_version: Optional[str] = None + self.ir_version: Optional[int] = None + self.opset: Optional[int] = None + self.imports: OrderedDict[str, int] = OrderedDict() + + # Private members not intended to be exposed + self.input_tensors_: Dict[str, onnx.ValueInfoProto] = {} + self.output_tensors_: Dict[str, onnx.ValueInfoProto] = {} + self.value_tensors_: Dict[str, onnx.ValueInfoProto] = {} + self.init_tensors_: Dict[str, onnx.TensorProto] = {} + + self.update_state(onnx_model) + + def update_state(self, model_proto: onnx.ModelProto) -> None: + self.model_version = model_proto.model_version + self.graph_name = model_proto.graph.name + self.producer_name = model_proto.producer_name + self.producer_version = model_proto.producer_version + self.ir_version = model_proto.ir_version + self.opset = onnx_utils.get_opset(model_proto) + self.imports = OrderedDict( + sorted( + (import_.domain, import_.version) + for import_ in model_proto.opset_import + ) + ) + + self.model_inputs = onnx_utils.get_model_input_shapes_types(model_proto) + self.model_outputs = onnx_utils.get_model_output_shapes_types(model_proto) + + self.node_type_counts = onnx_utils.get_node_type_counts(model_proto) + self.parse_model_nodes(model_proto) + + def get_node_tensor_info_( + self, onnx_node: onnx.NodeProto + ) -> Tuple[TensorData, TensorData]: + """ + This function is set to private because it is not intended to be used + outside of the DigestOnnxModel class. + """ + + input_tensor_info = TensorData() + for node_input in onnx_node.input: + input_tensor_info[node_input] = TensorInfo() + if ( + node_input in self.input_tensors_ + or node_input in self.value_tensors_ + or node_input in self.output_tensors_ + ): + tensor = ( + self.input_tensors_.get(node_input) + or self.value_tensors_.get(node_input) + or self.output_tensors_.get(node_input) + ) + if tensor: + for dim in tensor.type.tensor_type.shape.dim: + if dim.HasField("dim_value"): + input_tensor_info[node_input].shape.append(dim.dim_value) + elif dim.HasField("dim_param"): + input_tensor_info[node_input].shape.append(dim.dim_param) + + dtype_str, dtype_bytes = onnx_utils.tensor_type_to_str_and_size( + tensor.type.tensor_type.elem_type + ) + elif node_input in self.init_tensors_: + input_tensor_info[node_input].shape.extend( + [dim for dim in self.init_tensors_[node_input].dims] + ) + dtype_str, dtype_bytes = onnx_utils.tensor_type_to_str_and_size( + self.init_tensors_[node_input].data_type + ) + else: + dtype_str = None + dtype_bytes = None + + input_tensor_info[node_input].dtype = dtype_str + input_tensor_info[node_input].dtype_bytes = dtype_bytes + + if ( + all(isinstance(s, int) for s in input_tensor_info[node_input].shape) + and dtype_bytes + ): + tensor_size = float( + np.prod(np.array(input_tensor_info[node_input].shape)) + ) + input_tensor_info[node_input].size_kbytes = ( + tensor_size * float(dtype_bytes) / 1024.0 + ) + + output_tensor_info = TensorData() + for node_output in onnx_node.output: + output_tensor_info[node_output] = TensorInfo() + if ( + node_output in self.input_tensors_ + or node_output in self.value_tensors_ + or node_output in self.output_tensors_ + ): + tensor = ( + self.input_tensors_.get(node_output) + or self.value_tensors_.get(node_output) + or self.output_tensors_.get(node_output) + ) + if tensor: + output_tensor_info[node_output].shape.extend( + [ + int(dim.dim_value) + for dim in tensor.type.tensor_type.shape.dim + ] + ) + dtype_str, dtype_bytes = onnx_utils.tensor_type_to_str_and_size( + tensor.type.tensor_type.elem_type + ) + elif node_output in self.init_tensors_: + output_tensor_info[node_output].shape.extend( + [dim for dim in self.init_tensors_[node_output].dims] + ) + dtype_str, dtype_bytes = onnx_utils.tensor_type_to_str_and_size( + self.init_tensors_[node_output].data_type + ) + + else: + dtype_str = None + dtype_bytes = None + + output_tensor_info[node_output].dtype = dtype_str + output_tensor_info[node_output].dtype_bytes = dtype_bytes + + if ( + all(isinstance(s, int) for s in output_tensor_info[node_output].shape) + and dtype_bytes + ): + tensor_size = float( + np.prod(np.array(output_tensor_info[node_output].shape)) + ) + output_tensor_info[node_output].size_kbytes = ( + tensor_size * float(dtype_bytes) / 1024.0 + ) + + return input_tensor_info, output_tensor_info + + def parse_model_nodes(self, onnx_model: onnx.ModelProto) -> None: + """ + Calculate total number of FLOPs found in the onnx model. + FLOP is defined as one floating-point operation. This distinguishes + from multiply-accumulates (MACs) where FLOPs == 2 * MACs. + """ + + # Initialze to zero so we can accumulate. Set to None during the + # model FLOPs calculation if it errors out. + self.flops = 0 + + # Check to see if the model inputs have any dynamic shapes + if onnx_utils.get_dynamic_input_dims(onnx_model): + self.flops = None + + try: + onnx_model, _ = onnx_utils.optimize_onnx_model(onnx_model) + + onnx_model = onnx.shape_inference.infer_shapes( + onnx_model, strict_mode=True, data_prop=True + ) + except Exception as e: # pylint: disable=broad-except + print(f"ONNX utils: {str(e)}") + self.flops = None + + # If the ONNX model contains one of the following unsupported ops, then this + # function will return None since the FLOP total is expected to be incorrect + unsupported_ops = [ + "Einsum", + "RNN", + "GRU", + "DeformConv", + ] + + if not self.input_tensors_: + self.input_tensors_ = { + tensor.name: tensor for tensor in onnx_model.graph.input + } + + if not self.output_tensors_: + self.output_tensors_ = { + tensor.name: tensor for tensor in onnx_model.graph.output + } + + if not self.value_tensors_: + self.value_tensors_ = { + tensor.name: tensor for tensor in onnx_model.graph.value_info + } + + if not self.init_tensors_: + self.init_tensors_ = { + tensor.name: tensor for tensor in onnx_model.graph.initializer + } + + for node in onnx_model.graph.node: # pylint: disable=E1101 + + node_info = NodeInfo() + + # TODO: I have encountered models containing nodes with no name. It would be a good idea + # to have this type of model info fed back to the user through a warnings section. + if not node.name: + node.name = f"{node.op_type}_{len(self.node_data)}" + + node_info.node_type = node.op_type + input_tensor_info, output_tensor_info = self.get_node_tensor_info_(node) + node_info.inputs = input_tensor_info + node_info.outputs = output_tensor_info + + # Check if this node has parameters through the init tensors + for input_name, input_tensor in node_info.inputs.items(): + if input_name in self.init_tensors_: + if all(isinstance(dim, int) for dim in input_tensor.shape): + input_parameters = int(np.prod(np.array(input_tensor.shape))) + node_info.parameters += input_parameters + self.parameters += input_parameters + self.node_type_parameters[node.op_type] = ( + self.node_type_parameters.get(node.op_type, 0) + + input_parameters + ) + else: + print(f"Tensor with params has unknown shape: {input_name}") + + for attribute in node.attribute: + node_info.attributes.update(onnx_utils.attribute_to_dict(attribute)) + + # if node.name in self.node_data: + # print(f"Node name {node.name} is a duplicate.") + + self.node_data[node.name] = node_info + + if node.op_type in unsupported_ops: + self.flops = None + node_info.flops = None + + try: + + if ( + node.op_type == "MatMul" + or node.op_type == "MatMulInteger" + or node.op_type == "QLinearMatMul" + ): + + input_a = node_info.get_input(0).shape + if node.op_type == "QLinearMatMul": + input_b = node_info.get_input(3).shape + else: + input_b = node_info.get_input(1).shape + + if not all( + isinstance(dim, int) for dim in input_a + ) or not isinstance(input_b[-1], int): + node_info.flops = None + self.flops = None + continue + + node_info.flops = int( + 2 * np.prod(np.array(input_a), dtype=np.int64) * input_b[-1] + ) + + elif ( + node.op_type == "Mul" + or node.op_type == "Div" + or node.op_type == "Add" + ): + input_a = node_info.get_input(0).shape + input_b = node_info.get_input(1).shape + + if not all(isinstance(dim, int) for dim in input_a) or not all( + isinstance(dim, int) for dim in input_b + ): + node_info.flops = None + self.flops = None + continue + + node_info.flops = int( + np.prod(np.array(input_a), dtype=np.int64) + ) + int(np.prod(np.array(input_b), dtype=np.int64)) + + elif node.op_type == "Gemm" or node.op_type == "QGemm": + x_shape = node_info.get_input(0).shape + if node.op_type == "Gemm": + w_shape = node_info.get_input(1).shape + else: + w_shape = node_info.get_input(3).shape + + if not all(isinstance(dim, int) for dim in x_shape) or not all( + isinstance(dim, int) for dim in w_shape + ): + node_info.flops = None + self.flops = None + continue + + mm_dims = [ + ( + x_shape[0] + if not node_info.attributes.get("transA", 0) + else x_shape[1] + ), + ( + x_shape[1] + if not node_info.attributes.get("transA", 0) + else x_shape[0] + ), + ( + w_shape[1] + if not node_info.attributes.get("transB", 0) + else w_shape[0] + ), + ] + + node_info.flops = int( + 2 * np.prod(np.array(mm_dims), dtype=np.int64) + ) + + if len(mm_dims) == 3: # if there is a bias input + bias_shape = node_info.get_input(2).shape + node_info.flops += int(np.prod(np.array(bias_shape))) + + elif ( + node.op_type == "Conv" + or node.op_type == "ConvInteger" + or node.op_type == "QLinearConv" + or node.op_type == "ConvTranspose" + ): + # N, C, d1, ..., dn + x_shape = node_info.get_input(0).shape + + # M, C/group, k1, ..., kn. Note C and M are swapped for ConvTranspose + if node.op_type == "QLinearConv": + w_shape = node_info.get_input(3).shape + else: + w_shape = node_info.get_input(1).shape + + if not all(isinstance(dim, int) for dim in x_shape): + node_info.flops = None + self.flops = None + continue + + x_shape_ints = cast(List[int], x_shape) + w_shape_ints = cast(List[int], w_shape) + + has_bias = False # Note, ConvInteger has no bias + if node.op_type == "Conv" and len(node_info.inputs) == 3: + has_bias = True + elif node.op_type == "QLinearConv" and len(node_info.inputs) == 9: + has_bias = True + + num_dims = len(x_shape_ints) - 2 + strides = node_info.attributes.get( + "strides", [1] * num_dims + ) # type: List[int] + dilation = node_info.attributes.get( + "dilations", [1] * num_dims + ) # type: List[int] + kernel_shape = w_shape_ints[2:] + batch_size = x_shape_ints[0] + out_channels = w_shape_ints[0] + out_dims = [batch_size, out_channels] + output_shape = node_info.attributes.get( + "output_shape", [] + ) # type: List[int] + + # If output_shape is given then we do not need to compute it ourselves + # The output_shape attribute does not include batch_size or channels and + # is only valid for ConvTranspose + if output_shape: + out_dims.extend(output_shape) + else: + auto_pad = node_info.attributes.get( + "auto_pad", "NOTSET".encode() + ).decode() + # SAME expects padding so that the output_shape = CEIL(input_shape / stride) + if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": + out_dims.extend( + [x * s for x, s in zip(x_shape_ints[2:], strides)] + ) + else: + # NOTSET means just use pads attribute + if auto_pad == "NOTSET": + pads = node_info.attributes.get( + "pads", [0] * num_dims * 2 + ) + # VALID essentially means no padding + elif auto_pad == "VALID": + pads = [0] * num_dims * 2 + + for i in range(num_dims): + dim_in = x_shape_ints[i + 2] # type: int + + if node.op_type == "ConvTranspose": + out_dim = ( + strides[i] * (dim_in - 1) + + ((kernel_shape[i] - 1) * dilation[i] + 1) + - pads[i] + - pads[i + num_dims] + ) + else: + out_dim = ( + dim_in + + pads[i] + + pads[i + num_dims] + - dilation[i] * (kernel_shape[i] - 1) + - 1 + ) // strides[i] + 1 + + out_dims.append(out_dim) + + kernel_flops = int( + np.prod(np.array(kernel_shape)) * w_shape_ints[1] + ) + output_points = int(np.prod(np.array(out_dims))) + bias_ops = output_points if has_bias else int(0) + node_info.flops = 2 * kernel_flops * output_points + bias_ops + + elif node.op_type == "LSTM" or node.op_type == "DynamicQuantizeLSTM": + + x_shape = node_info.get_input( + 0 + ).shape # seq_length, batch_size, input_dim + + if not all(isinstance(dim, int) for dim in x_shape): + node_info.flops = None + self.flops = None + continue + + x_shape_ints = cast(List[int], x_shape) + hidden_size = node_info.attributes["hidden_size"] + direction = ( + 2 + if node_info.attributes.get("direction") + == "bidirectional".encode() + else 1 + ) + + has_bias = True if len(node_info.inputs) >= 4 else False + if has_bias: + bias_shape = node_info.get_input(3).shape + if isinstance(bias_shape[1], int): + bias_ops = bias_shape[1] + else: + bias_ops = 0 + else: + bias_ops = 0 + # seq_length, batch_size, input_dim = x_shape + if not isinstance(bias_ops, int): + bias_ops = int(0) + num_gates = int(4) + gate_input_flops = int(2 * x_shape_ints[2] * hidden_size) + gate_hid_flops = int(2 * hidden_size * hidden_size) + unit_flops = ( + num_gates * (gate_input_flops + gate_hid_flops) + bias_ops + ) + node_info.flops = ( + x_shape_ints[1] * x_shape_ints[0] * direction * unit_flops + ) + # In this case we just hit an op that doesn't have FLOPs + else: + node_info.flops = None + + except IndexError as err: + print(f"Error parsing node {node.name}: {err}") + node_info.flops = None + self.flops = None + continue + + # Update the model level flops count + if node_info.flops is not None and self.flops is not None: + self.flops += node_info.flops + + # Update the node type flops count + self.node_type_flops[node.op_type] = ( + self.node_type_flops.get(node.op_type, 0) + node_info.flops + ) + + def save_yaml_report(self, filepath: str) -> None: + + parent_dir = os.path.dirname(os.path.abspath(filepath)) + if not os.path.exists(parent_dir): + raise FileNotFoundError(f"Directory {parent_dir} does not exist.") + + report_date = datetime.now().strftime("%B %d, %Y") + + input_tensors = dict({k: vars(v) for k, v in self.model_inputs.items()}) + output_tensors = dict({k: vars(v) for k, v in self.model_outputs.items()}) + digest_version = importlib.metadata.version("digestai") + + yaml_data = { + "report_date": report_date, + "digest_version": digest_version, + "model_type": self.model_type.value, + "model_file": self.filepath, + "model_name": self.model_name, + "model_version": self.model_version, + "graph_name": self.graph_name, + "producer_name": self.producer_name, + "producer_version": self.producer_version, + "ir_version": self.ir_version, + "opset": self.opset, + "import_list": dict(self.imports), + "graph_nodes": sum(self.node_type_counts.values()), + "parameters": self.parameters, + "flops": self.flops, + "node_type_counts": dict(self.node_type_counts), + "node_type_flops": dict(self.node_type_flops), + "node_type_parameters": self.node_type_parameters, + "input_tensors": input_tensors, + "output_tensors": output_tensors, + } + + with open(filepath, "w", encoding="utf-8") as f_p: + yaml.dump(yaml_data, f_p, sort_keys=False) + + def save_text_report(self, filepath: str) -> None: + + parent_dir = os.path.dirname(os.path.abspath(filepath)) + if not os.path.exists(parent_dir): + raise FileNotFoundError(f"Directory {parent_dir} does not exist.") + + report_date = datetime.now().strftime("%B %d, %Y") + + digest_version = importlib.metadata.version("digestai") + + with open(filepath, "w", encoding="utf-8") as f_p: + f_p.write(f"Report created on {report_date}\n") + f_p.write(f"Digest version: {digest_version}\n") + f_p.write(f"Model type: {self.model_type.name}\n") + if self.filepath: + f_p.write(f"ONNX file: {self.filepath}\n") + f_p.write(f"Name of the model: {self.model_name}\n") + f_p.write(f"Model version: {self.model_version}\n") + f_p.write(f"Name of the graph: {self.graph_name}\n") + f_p.write(f"Producer: {self.producer_name} {self.producer_version}\n") + f_p.write(f"Ir version: {self.ir_version}\n") + f_p.write(f"Opset: {self.opset}\n\n") + f_p.write("Import list\n") + for name, version in self.imports.items(): + f_p.write(f"\t{name}: {version}\n") + + f_p.write("\n") + f_p.write(f"Total graph nodes: {sum(self.node_type_counts.values())}\n") + f_p.write(f"Number of parameters: {self.parameters}\n") + if self.flops: + f_p.write(f"Number of FLOPs: {self.flops}\n") + f_p.write("\n") + + table_op_intensity = PrettyTable() + table_op_intensity.field_names = ["Operation", "FLOPs", "Intensity (%)"] + for op_type, count in self.node_type_flops.items(): + if count > 0: + table_op_intensity.add_row( + [ + op_type, + count, + 100.0 * float(count) / float(self.flops), + ] + ) + + f_p.write("Op intensity:\n") + f_p.write(table_op_intensity.get_string()) + f_p.write("\n\n") + + node_counts_table = PrettyTable() + node_counts_table.field_names = ["Node", "Occurrences"] + for op, count in self.node_type_counts.items(): + node_counts_table.add_row([op, count]) + f_p.write("Nodes and their occurrences:\n") + f_p.write(node_counts_table.get_string()) + f_p.write("\n\n") + + input_table = PrettyTable() + input_table.field_names = [ + "Input Name", + "Shape", + "Type", + "Tensor Size (KB)", + ] + for input_name, input_details in self.model_inputs.items(): + if input_details.size_kbytes: + kbytes = f"{input_details.size_kbytes:.2f}" + else: + kbytes = "" + + input_table.add_row( + [ + input_name, + input_details.shape, + input_details.dtype, + kbytes, + ] + ) + f_p.write("Input Tensor(s) Information:\n") + f_p.write(input_table.get_string()) + f_p.write("\n\n") + + output_table = PrettyTable() + output_table.field_names = [ + "Output Name", + "Shape", + "Type", + "Tensor Size (KB)", + ] + for output_name, output_details in self.model_outputs.items(): + if output_details.size_kbytes: + kbytes = f"{output_details.size_kbytes:.2f}" + else: + kbytes = "" + + output_table.add_row( + [ + output_name, + output_details.shape, + output_details.dtype, + kbytes, + ] + ) + f_p.write("Output Tensor(s) Information:\n") + f_p.write(output_table.get_string()) + f_p.write("\n\n") diff --git a/src/digest/model_class/digest_report_model.py b/src/digest/model_class/digest_report_model.py new file mode 100644 index 0000000..50e76de --- /dev/null +++ b/src/digest/model_class/digest_report_model.py @@ -0,0 +1,240 @@ +import os +from collections import OrderedDict +import csv +import ast +import re +from typing import Tuple, Optional, List, Dict, Any, Union +import yaml +from digest.model_class.digest_model import ( + DigestModel, + SupportedModelTypes, + NodeData, + NodeInfo, + TensorData, + TensorInfo, +) + + +def parse_tensor_info( + csv_tensor_cell_value, +) -> Tuple[str, list, str, Union[str, float]]: + """This is a helper function that expects the input to come from parsing + the nodes csv and extracting either an input or output tensor.""" + + # Use regex to split the string into name and details + match = re.match(r"(.*?)\s*\((.*)\)$", csv_tensor_cell_value) + if not match: + raise ValueError(f"Invalid format for tensor info: {csv_tensor_cell_value}") + + name, details = match.groups() + + # Split details, but keep the shape as a single item + match = re.match(r"(\[.*?\])\s*,\s*(.*?)\s*,\s*(.*)", details) + if not match: + raise ValueError(f"Invalid format for tensor details: {details}") + + shape_str, dtype, size = match.groups() + + # Ensure shape is stored as a list + shape = ast.literal_eval(shape_str) + if not isinstance(shape, list): + shape = list(shape) + + if size != "None": + size = float(size.split()[0]) + + return name.strip(), shape, dtype.strip(), size + + +class DigestReportModel(DigestModel): + def __init__( + self, + report_filepath: str, + ) -> None: + + self.model_type = SupportedModelTypes.REPORT + + self.is_valid = validate_yaml(report_filepath) + + if not self.is_valid: + print(f"The yaml file {report_filepath} is not a valid digest report.") + return + + self.model_data = OrderedDict() + with open(report_filepath, "r", encoding="utf-8") as yaml_f: + self.model_data = yaml.safe_load(yaml_f) + + model_name = self.model_data["model_name"] + super().__init__(report_filepath, model_name, SupportedModelTypes.REPORT) + + self.similarity_heatmap_path: Optional[str] = None + self.node_data = NodeData() + + # Given the path to the digest report, let's check if its a complete cache + # and we can grab the nodes csv data and the similarity heatmap + cache_dir = os.path.dirname(os.path.abspath(report_filepath)) + expected_heatmap_file = os.path.join(cache_dir, f"{model_name}_heatmap.png") + if os.path.exists(expected_heatmap_file): + self.similarity_heatmap_path = expected_heatmap_file + + expected_nodes_file = os.path.join(cache_dir, f"{model_name}_nodes.csv") + if os.path.exists(expected_nodes_file): + with open(expected_nodes_file, "r", encoding="utf-8") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + node_name = row["Node Name"] + node_info = NodeInfo() + node_info.node_type = row["Node Type"] + if row["Parameters"]: + node_info.parameters = int(row["Parameters"]) + if ast.literal_eval(row["FLOPs"]): + node_info.flops = int(row["FLOPs"]) + node_info.attributes = ( + OrderedDict(ast.literal_eval(row["Attributes"])) + if row["Attributes"] + else OrderedDict() + ) + + node_info.inputs = TensorData() + node_info.outputs = TensorData() + + # Process inputs and outputs + for key, value in row.items(): + if key.startswith("Input") and value: + input_name, shape, dtype, size = parse_tensor_info(value) + node_info.inputs[input_name] = TensorInfo() + node_info.inputs[input_name].shape = shape + node_info.inputs[input_name].dtype = dtype + node_info.inputs[input_name].size_kbytes = size + + elif key.startswith("Output") and value: + output_name, shape, dtype, size = parse_tensor_info(value) + node_info.outputs[output_name] = TensorInfo() + node_info.outputs[output_name].shape = shape + node_info.outputs[output_name].dtype = dtype + node_info.outputs[output_name].size_kbytes = size + + self.node_data[node_name] = node_info + + # Unpack the model type agnostic values + self.flops = self.model_data["flops"] + self.parameters = self.model_data["parameters"] + self.node_type_flops = self.model_data["node_type_flops"] + self.node_type_parameters = self.model_data["node_type_parameters"] + self.node_type_counts = self.model_data["node_type_counts"] + + self.model_inputs = TensorData( + { + key: TensorInfo(**val) + for key, val in self.model_data["input_tensors"].items() + } + ) + self.model_outputs = TensorData( + { + key: TensorInfo(**val) + for key, val in self.model_data["output_tensors"].items() + } + ) + + def parse_model_nodes(self) -> None: + """There are no model nodes to parse""" + + def save_yaml_report(self, filepath: str) -> None: + """Report models are not intended to be saved""" + + def save_text_report(self, filepath: str) -> None: + """Report models are not intended to be saved""" + + +def validate_yaml(report_file_path: str) -> bool: + """Check that the provided yaml file is indeed a Digest Report file.""" + expected_keys = [ + "report_date", + "model_file", + "model_type", + "model_name", + "flops", + "node_type_flops", + "node_type_parameters", + "node_type_counts", + "input_tensors", + "output_tensors", + ] + try: + with open(report_file_path, "r", encoding="utf-8") as file: + yaml_content = yaml.safe_load(file) + + if not isinstance(yaml_content, dict): + print("Error: YAML content is not a dictionary") + return False + + for key in expected_keys: + if key not in yaml_content: + # print(f"Error: Missing required key '{key}'") + return False + + return True + except yaml.YAMLError as _: + # print(f"Error parsing YAML file: {e}") + return False + except IOError as _: + # print(f"Error reading file: {e}") + return False + + +def compare_yaml_files( + file1: str, file2: str, skip_keys: Optional[List[str]] = None +) -> bool: + """ + Compare two YAML files, ignoring specified keys. + + :param file1: Path to the first YAML file + :param file2: Path to the second YAML file + :param skip_keys: List of keys to ignore in the comparison + :return: True if the files are equal (ignoring specified keys), False otherwise + """ + + def load_yaml(file_path: str) -> Dict[str, Any]: + with open(file_path, "r", encoding="utf-8") as file: + return yaml.safe_load(file) + + def compare_dicts( + dict1: Dict[str, Any], dict2: Dict[str, Any], path: str = "" + ) -> List[str]: + differences = [] + all_keys = set(dict1.keys()) | set(dict2.keys()) + + for key in all_keys: + if skip_keys and key in skip_keys: + continue + + current_path = f"{path}.{key}" if path else key + + if key not in dict1: + differences.append(f"Key '{current_path}' is missing in the first file") + elif key not in dict2: + differences.append( + f"Key '{current_path}' is missing in the second file" + ) + elif isinstance(dict1[key], dict) and isinstance(dict2[key], dict): + differences.extend(compare_dicts(dict1[key], dict2[key], current_path)) + elif dict1[key] != dict2[key]: + differences.append( + f"Value mismatch for key '{current_path}': {dict1[key]} != {dict2[key]}" + ) + + return differences + + yaml1 = load_yaml(file1) + yaml2 = load_yaml(file2) + + differences = compare_dicts(yaml1, yaml2) + + if differences: + # print("Differences found:") + # for diff in differences: + # print(f"- {diff}") + return False + else: + # print("No differences found.") + return True diff --git a/src/digest/modelsummary.py b/src/digest/modelsummary.py index 1e3872e..a92b756 100644 --- a/src/digest/modelsummary.py +++ b/src/digest/modelsummary.py @@ -3,10 +3,12 @@ import os # pylint: disable=invalid-name -from typing import Optional +from typing import Optional, Union # pylint: disable=no-name-in-module from PySide6.QtWidgets import QWidget +from PySide6.QtGui import QMovie +from PySide6.QtCore import QSize from onnx import ModelProto @@ -14,37 +16,54 @@ from digest.freeze_inputs import FreezeInputs from digest.popup_window import PopupWindow from digest.qt_utils import apply_dark_style_sheet -from utils import onnx_utils +from digest.model_class.digest_onnx_model import DigestOnnxModel +from digest.model_class.digest_report_model import DigestReportModel + ROOT_FOLDER = os.path.dirname(os.path.abspath(__file__)) class modelSummary(QWidget): - def __init__(self, digest_model: onnx_utils.DigestOnnxModel, parent=None): + def __init__( + self, digest_model: Union[DigestOnnxModel, DigestReportModel], parent=None + ): super().__init__(parent) self.ui = Ui_modelSummary() self.ui.setupUi(self) apply_dark_style_sheet(self) self.file: Optional[str] = None - self.ui.freezeButton.setVisible(False) - self.ui.freezeButton.clicked.connect(self.open_freeze_inputs) self.ui.warningLabel.hide() self.digest_model = digest_model - self.model_proto: ModelProto = ( - digest_model.model_proto if digest_model.model_proto else ModelProto() - ) + self.model_proto: Optional[ModelProto] = None model_name: str = digest_model.model_name if digest_model.model_name else "" - self.freeze_inputs = FreezeInputs(self.model_proto, model_name) - self.freeze_inputs.complete_signal.connect(self.close_freeze_window) + + self.load_gif = QMovie(":/assets/gifs/load.gif") + # We set the size of the GIF to half the original + self.load_gif.setScaledSize(QSize(214, 120)) + self.ui.similarityImg.setMovie(self.load_gif) + self.load_gif.start() + + # There is no freezing if the model is not ONNX + self.ui.freezeButton.setVisible(False) + self.freeze_inputs: Optional[FreezeInputs] = None self.freeze_window: Optional[QWidget] = None + if isinstance(digest_model, DigestOnnxModel): + self.model_proto = ( + digest_model.model_proto if digest_model.model_proto else ModelProto() + ) + self.freeze_inputs = FreezeInputs(self.model_proto, model_name) + self.ui.freezeButton.clicked.connect(self.open_freeze_inputs) + self.freeze_inputs.complete_signal.connect(self.close_freeze_window) + def open_freeze_inputs(self): - self.freeze_window = PopupWindow( - self.freeze_inputs, "Freeze Model Inputs", self - ) - self.freeze_window.open() + if self.freeze_inputs: + self.freeze_window = PopupWindow( + self.freeze_inputs, "Freeze Model Inputs", self + ) + self.freeze_window.open() def close_freeze_window(self): if self.freeze_window: diff --git a/src/digest/multi_model_analysis.py b/src/digest/multi_model_analysis.py index d7f6bab..1c09905 100644 --- a/src/digest/multi_model_analysis.py +++ b/src/digest/multi_model_analysis.py @@ -1,17 +1,27 @@ # Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. import os +from datetime import datetime import csv from typing import List, Dict, Union from collections import Counter, defaultdict, OrderedDict # pylint: disable=no-name-in-module from PySide6.QtWidgets import QWidget, QTableWidgetItem, QFileDialog +from PySide6.QtCore import Qt from digest.dialog import ProgressDialog, StatusDialog from digest.ui.multimodelanalysis_ui import Ui_multiModelAnalysis from digest.histogramchartwidget import StackedHistogramWidget from digest.qt_utils import apply_dark_style_sheet -from utils import onnx_utils +from digest.model_class.digest_model import ( + NodeTypeCounts, + NodeShapeCounts, + save_node_shape_counts_csv_report, + save_node_type_counts_csv_report, +) +from digest.model_class.digest_onnx_model import DigestOnnxModel +from digest.model_class.digest_report_model import DigestReportModel +import utils.onnx_utils as onnx_utils ROOT_FOLDER = os.path.dirname(__file__) @@ -21,7 +31,7 @@ class MultiModelAnalysis(QWidget): def __init__( self, - model_list: List[onnx_utils.DigestOnnxModel], + model_list: List[Union[DigestOnnxModel, DigestReportModel]], parent=None, ): super().__init__(parent) @@ -34,6 +44,9 @@ def __init__( self.ui.individualCheckBox.stateChanged.connect(self.check_box_changed) self.ui.multiCheckBox.stateChanged.connect(self.check_box_changed) + # For some reason setting alignments in designer lead to bugs in *ui.py files + self.ui.opHistogramChart.layout().setAlignment(Qt.AlignmentFlag.AlignTop) + if not model_list: return @@ -41,41 +54,60 @@ def __init__( self.global_node_type_counter: Counter[str] = Counter() # Holds the data for node shape counts across all models - self.global_node_shape_counter: onnx_utils.NodeShapeCounts = defaultdict( - Counter - ) + self.global_node_shape_counter: NodeShapeCounts = defaultdict(Counter) # Holds the data for all models statistics - self.global_model_data: Dict[str, Dict[str, Union[int, None]]] = {} + self.global_model_data: Dict[str, Dict[str, Union[int, str, None]]] = {} progress = ProgressDialog("", len(model_list), self) - header_labels = ["Model", "Opset", "Total Nodes", "Parameters", "FLOPs"] + header_labels = [ + "Model Name", + "Model Type", + "Opset", + "Total Nodes", + "Parameters", + "FLOPs", + ] self.ui.dataTable.setRowCount(len(model_list)) self.ui.dataTable.setColumnCount(len(header_labels)) self.ui.dataTable.setHorizontalHeaderLabels(header_labels) self.ui.dataTable.setSortingEnabled(False) for row, model in enumerate(model_list): + item = QTableWidgetItem(str(model.model_name)) self.ui.dataTable.setItem(row, 0, item) - item = QTableWidgetItem(str(model.opset)) + item = QTableWidgetItem(str(model.model_type.name)) self.ui.dataTable.setItem(row, 1, item) - item = QTableWidgetItem(str(len(model.per_node_info))) + if isinstance(model, DigestOnnxModel): + item = QTableWidgetItem(str(model.opset)) + elif isinstance(model, DigestReportModel): + item = QTableWidgetItem(str(model.model_data.get("opset", ""))) + self.ui.dataTable.setItem(row, 2, item) - item = QTableWidgetItem(str(model.model_parameters)) + item = QTableWidgetItem(str(len(model.node_data))) self.ui.dataTable.setItem(row, 3, item) - item = QTableWidgetItem(str(model.model_flops)) + item = QTableWidgetItem(str(model.parameters)) self.ui.dataTable.setItem(row, 4, item) + item = QTableWidgetItem(str(model.flops)) + self.ui.dataTable.setItem(row, 5, item) + self.ui.dataTable.resizeColumnsToContents() self.ui.dataTable.resizeRowsToContents() - node_type_counter = {} + # Until we use the unique_id to represent the model contents we store + # the entire model as the key so that we can store models that happen to have + # the same name. There is a guarantee that the models will not be duplicates. + node_type_counter: Dict[ + Union[DigestOnnxModel, DigestReportModel], NodeTypeCounts + ] = {} + for i, digest_model in enumerate(model_list): progress.step() progress.setLabelText(f"Analyzing model {digest_model.model_name}") @@ -83,39 +115,52 @@ def __init__( if digest_model.model_name is None: digest_model.model_name = f"model_{i}" - if digest_model.model_proto: - dynamic_input_dims = onnx_utils.get_dynamic_input_dims( - digest_model.model_proto - ) - if dynamic_input_dims: - print( - "Found the following non-static input dims in your model. " - "It is recommended to make all dims static before generating reports." + if isinstance(digest_model, DigestOnnxModel): + opset = digest_model.opset + if digest_model.model_proto: + dynamic_input_dims = onnx_utils.get_dynamic_input_dims( + digest_model.model_proto ) - for dynamic_shape in dynamic_input_dims: - print(f"dim: {dynamic_shape}") + if dynamic_input_dims: + print( + "Found the following non-static input dims in your model. " + "It is recommended to make all dims static before generating reports." + ) + for dynamic_shape in dynamic_input_dims: + print(f"dim: {dynamic_shape}") + + elif isinstance(digest_model, DigestReportModel): + opset = digest_model.model_data.get("opset", "") # Update the global model dictionary - if digest_model.model_name in self.global_model_data: + if digest_model.unique_id in self.global_model_data: print( - f"Warning! {digest_model.model_name} has already been processed, " + f"Warning! {digest_model.model_name} with id " + f"{digest_model.unique_id} has already been processed, " "skipping the duplicate model." ) - - self.global_model_data[digest_model.model_name] = { - "opset": digest_model.opset, - "parameters": digest_model.model_parameters, - "flops": digest_model.model_flops, + continue + + self.global_model_data[digest_model.unique_id] = { + "model_name": digest_model.model_name, + "model_type": digest_model.model_type.name, + "opset": opset, + "parameters": digest_model.parameters, + "flops": digest_model.flops, } - node_type_counter[digest_model.model_name] = ( - digest_model.get_node_type_counts() - ) + if digest_model in node_type_counter: + print( + f"Warning! {digest_model.model_name} with model type " + f"{digest_model.model_type.value} and id {digest_model.unique_id} " + "has already been added to the stacked histogram, skipping." + ) + continue + + node_type_counter[digest_model] = digest_model.node_type_counts # Update global data structure for node type counter - self.global_node_type_counter.update( - node_type_counter[digest_model.model_name] - ) + self.global_node_type_counter.update(node_type_counter[digest_model]) node_shape_counts = digest_model.get_node_shape_counts() @@ -133,31 +178,31 @@ def __init__( # Create stacked op histograms max_count = 0 top_ops = [key for key, _ in self.global_node_type_counter.most_common(20)] - for model_name, _ in node_type_counter.items(): - max_local = Counter(node_type_counter[model_name]).most_common()[0][1] + for model, _ in node_type_counter.items(): + max_local = Counter(node_type_counter[model]).most_common()[0][1] if max_local > max_count: max_count = max_local - for idx, model_name in enumerate(node_type_counter): + for idx, model in enumerate(node_type_counter): stacked_histogram_widget = StackedHistogramWidget() ordered_dict = OrderedDict() - model_counter = Counter(node_type_counter[model_name]) + model_counter = Counter(node_type_counter[model]) for key in top_ops: ordered_dict[key] = model_counter.get(key, 0) title = "Stacked Op Histogram" if idx == 0 else "" stacked_histogram_widget.set_data( ordered_dict, - model_name=model_name, + model_name=model.model_name, y_max=max_count, title=title, set_ticks=False, ) frame_layout = self.ui.stackedHistogramFrame.layout() - frame_layout.addWidget(stacked_histogram_widget) + if frame_layout: + frame_layout.addWidget(stacked_histogram_widget) # Add a "ghost" histogram to allow us to set the x axis label vertically - model_name = list(node_type_counter.keys())[0] stacked_histogram_widget = StackedHistogramWidget() - ordered_dict = {key: 1 for key in top_ops} + ordered_dict = OrderedDict({key: 1 for key in top_ops}) stacked_histogram_widget.set_data( ordered_dict, model_name="_", @@ -165,18 +210,39 @@ def __init__( set_ticks=True, ) frame_layout = self.ui.stackedHistogramFrame.layout() - frame_layout.addWidget(stacked_histogram_widget) + if frame_layout: + frame_layout.addWidget(stacked_histogram_widget) self.model_list = model_list def save_reports(self): - # Model summary text report - save_directory = QFileDialog(self).getExistingDirectory( + """This function saves all available reports for the models that are opened + in the multi-model analysis page.""" + + base_directory = QFileDialog(self).getExistingDirectory( self, "Select Directory" ) - if not save_directory: - return + # Check if the directory exists and is writable + if not os.path.exists(base_directory) or not os.access(base_directory, os.W_OK): + bad_ext_dialog = StatusDialog( + f"The directory {base_directory} is not valid or writable.", + parent=self, + ) + bad_ext_dialog.show() + + # Append a subdirectory to the save_directory so that all reports are co-located + name_id = datetime.now().strftime("%Y%m%d%H%M%S") + sub_directory = f"multi_model_reports_{name_id}" + save_directory = os.path.join(base_directory, sub_directory) + try: + os.makedirs(save_directory) + except OSError as os_err: + bad_ext_dialog = StatusDialog( + f"Failed to create {save_directory} with error {os_err}", + parent=self, + ) + bad_ext_dialog.show() save_individual_reports = self.ui.individualCheckBox.isChecked() save_multi_reports = self.ui.multiCheckBox.isChecked() @@ -192,29 +258,21 @@ def save_reports(self): save_directory, f"{digest_model.model_name}_summary.txt" ) - digest_model.save_txt_report(summary_filepath) + digest_model.save_text_report(summary_filepath) # Save csv of node type counts node_type_filepath = os.path.join( save_directory, f"{digest_model.model_name}_node_type_counts.csv" ) - # Save csv containing node type counter - node_type_counter = digest_model.get_node_type_counts() - - if node_type_counter: - onnx_utils.save_node_type_counts_csv_report( - node_type_counter, node_type_filepath - ) + if digest_model.node_type_counts: + digest_model.save_node_type_counts_csv_report(node_type_filepath) # Save csv containing node shape counts per op_type - node_shape_counts = digest_model.get_node_shape_counts() node_shape_filepath = os.path.join( save_directory, f"{digest_model.model_name}_node_shape_counts.csv" ) - onnx_utils.save_node_shape_counts_csv_report( - node_shape_counts, node_shape_filepath - ) + digest_model.save_node_shape_counts_csv_report(node_shape_filepath) # Save csv containing all node-level information nodes_filepath = os.path.join( @@ -231,17 +289,17 @@ def save_reports(self): global_filepath = os.path.join( save_directory, "global_node_type_counts.csv" ) - global_node_type_counter = onnx_utils.NodeTypeCounts( + global_node_type_counter = NodeTypeCounts( self.global_node_type_counter.most_common() ) - onnx_utils.save_node_type_counts_csv_report( + save_node_type_counts_csv_report( global_node_type_counter, global_filepath ) global_filepath = os.path.join( save_directory, "global_node_shape_counts.csv" ) - onnx_utils.save_node_shape_counts_csv_report( + save_node_shape_counts_csv_report( self.global_node_shape_counter, global_filepath ) @@ -253,10 +311,18 @@ def save_reports(self): ) as csvfile: writer = csv.writer(csvfile) rows = [ - [model, data["opset"], data["parameters"], data["flops"]] - for model, data in self.global_model_data.items() + [ + data["model_name"], + data["model_type"], + data["opset"], + data["parameters"], + data["flops"], + ] + for _, data in self.global_model_data.items() ] - writer.writerow(["Model", "Opset", "Parameters", "FLOPs"]) + writer.writerow( + ["Model Name", "Model Type", "Opset", "Parameters", "FLOPs"] + ) writer.writerows(rows) if save_individual_reports or save_multi_reports: diff --git a/src/digest/multi_model_selection_page.py b/src/digest/multi_model_selection_page.py index d7b6a39..e9d5c2b 100644 --- a/src/digest/multi_model_selection_page.py +++ b/src/digest/multi_model_selection_page.py @@ -2,7 +2,7 @@ import os import glob -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Union from collections import defaultdict from google.protobuf.message import DecodeError import onnx @@ -22,6 +22,8 @@ from digest.ui.multimodelselection_page_ui import Ui_MultiModelSelection from digest.multi_model_analysis import MultiModelAnalysis from digest.qt_utils import apply_dark_style_sheet, prompt_user_ram_limit +from digest.model_class.digest_onnx_model import DigestOnnxModel +from digest.model_class.digest_report_model import DigestReportModel, compare_yaml_files from utils import onnx_utils @@ -33,7 +35,9 @@ class AnalysisThread(QThread): def __init__(self): super().__init__() - self.model_dict: Dict[str, Optional[onnx_utils.DigestOnnxModel]] = {} + self.model_dict: Dict[ + str, Optional[Union[DigestOnnxModel, DigestReportModel]] + ] = {} self.user_canceled = False def run(self): @@ -47,19 +51,21 @@ def run(self): self.step_progress.emit() if model: continue - model_name = os.path.splitext(os.path.basename(file))[0] - model_proto = onnx_utils.load_onnx(file, False) - self.model_dict[file] = onnx_utils.DigestOnnxModel( - model_proto, onnx_filepath=file, model_name=model_name, save_proto=False - ) + model_name, file_ext = os.path.splitext(os.path.basename(file)) + if file_ext == ".onnx": + model_proto = onnx_utils.load_onnx(file, False) + self.model_dict[file] = DigestOnnxModel( + model_proto, + onnx_filepath=file, + model_name=model_name, + save_proto=False, + ) + elif file_ext == ".yaml": + self.model_dict[file] = DigestReportModel(file) self.close_progress.emit() - model_list = [ - model - for model in self.model_dict.values() - if isinstance(model, onnx_utils.DigestOnnxModel) - ] + model_list = [model for model in self.model_dict.values()] self.completed.emit(model_list) @@ -82,8 +88,10 @@ def __init__( self.ui.warningLabel.hide() self.item_model = QStandardItemModel() self.item_model.itemChanged.connect(self.update_num_selected_label) - self.ui.selectAllBox.setCheckState(Qt.CheckState.Checked) - self.ui.selectAllBox.stateChanged.connect(self.update_list_view_items) + self.ui.radioAll.setChecked(True) + self.ui.radioAll.toggled.connect(self.update_list_view_items) + self.ui.radioONNX.toggled.connect(self.update_list_view_items) + self.ui.radioReports.toggled.connect(self.update_list_view_items) self.ui.selectFolderBtn.clicked.connect(self.openFolder) self.ui.duplicateLabel.hide() self.ui.modelListView.setModel(self.item_model) @@ -94,7 +102,9 @@ def __init__( self.ui.openAnalysisBtn.clicked.connect(self.start_analysis) - self.model_dict: Dict[str, Optional[onnx_utils.DigestOnnxModel]] = {} + self.model_dict: Dict[ + str, Optional[Union[DigestOnnxModel, DigestReportModel]] + ] = {} self.analysis_thread: Optional[AnalysisThread] = None self.progress: Optional[ProgressDialog] = None @@ -165,14 +175,24 @@ def update_num_selected_label(self): self.ui.openAnalysisBtn.setEnabled(False) def update_list_view_items(self): - state = self.ui.selectAllBox.checkState() + radio_all_state = self.ui.radioAll.isChecked() + radio_onnx_state = self.ui.radioONNX.isChecked() + radio_reports_state = self.ui.radioReports.isChecked() for row in range(self.item_model.rowCount()): item = self.item_model.item(row) - item.setCheckState(state) + value = item.data(Qt.ItemDataRole.DisplayRole) + if radio_all_state: + item.setCheckState(Qt.CheckState.Checked) + elif os.path.splitext(value)[-1] == ".onnx" and radio_onnx_state: + item.setCheckState(Qt.CheckState.Checked) + elif os.path.splitext(value)[-1] == ".yaml" and radio_reports_state: + item.setCheckState(Qt.CheckState.Checked) + else: + item.setCheckState(Qt.CheckState.Unchecked) def set_directory(self, directory: str): """ - Recursively searches a directory for onnx models. + Recursively searches a directory for onnx models and yaml report files. """ if not os.path.exists(directory): @@ -183,36 +203,57 @@ def set_directory(self, directory: str): else: return - progress = ProgressDialog("Searching Directory for ONNX Files", 0, self) + progress = ProgressDialog("Searching directory for model files", 0, self) + onnx_file_list = list( glob.glob(os.path.join(directory, "**/*.onnx"), recursive=True) ) + onnx_file_list = [os.path.normpath(model_file) for model_file in onnx_file_list] + + yaml_file_list = list( + glob.glob(os.path.join(directory, "**/*.yaml"), recursive=True) + ) + yaml_file_list = [os.path.normpath(model_file) for model_file in yaml_file_list] + + # Filter out YAML files that are not valid reports + report_file_list = [] + for yaml_file in yaml_file_list: + digest_report = DigestReportModel(yaml_file) + if digest_report.is_valid: + report_file_list.append(yaml_file) + + total_num_models = len(onnx_file_list) + len(report_file_list) - onnx_file_list = [os.path.normpath(onnx_file) for onnx_file in onnx_file_list] serialized_models_paths: defaultdict[bytes, List[str]] = defaultdict(list) progress.close() - progress = ProgressDialog("Loading ONNX Models", len(onnx_file_list), self) + progress = ProgressDialog("Loading models", total_num_models, self) memory_limit_percentage = 90 models_loaded = 0 - for filepath in onnx_file_list: + for filepath in onnx_file_list + report_file_list: progress.step() if progress.user_canceled: break try: models_loaded += 1 - model = onnx.load(filepath, load_external_data=False) - dialog_msg = f"""Warning: System RAM has exceeded the threshold of {memory_limit_percentage}%. - No further models will be loaded. - """ + extension = os.path.splitext(filepath)[-1] + if extension == ".onnx": + model = onnx.load(filepath, load_external_data=False) + serialized_models_paths[model.SerializeToString()].append(filepath) + elif extension == ".yaml": + pass + dialog_msg = ( + "Warning: System RAM has exceeded the threshold of " + f"{memory_limit_percentage}%. No further models will be loaded. " + ) if prompt_user_ram_limit( sys_ram_percent_limit=memory_limit_percentage, message=dialog_msg, parent=self, ): self.update_warning_label( - f"Loaded only {models_loaded - 1} out of {len(onnx_file_list)} models " + f"Loaded only {models_loaded - 1} out of {total_num_models} models " f"as memory consumption has reached {memory_limit_percentage}% of " "system memory. Preventing further loading of models." ) @@ -223,15 +264,13 @@ def set_directory(self, directory: str): break else: self.ui.warningLabel.hide() - serialized_models_paths[model.SerializeToString()].append(filepath) + except DecodeError as error: print(f"Error decoding model {filepath}: {error}") progress.close() - progress = ProgressDialog( - "Processing ONNX Models", len(serialized_models_paths), self - ) + progress = ProgressDialog("Processing Models", total_num_models, self) num_duplicates = 0 self.item_model.clear() @@ -245,15 +284,42 @@ def set_directory(self, directory: str): self.ui.duplicateListWidget.addItem(paths[0]) for dupe in paths[1:]: self.ui.duplicateListWidget.addItem(f"- Duplicate: {dupe}") - item = QStandardItem(paths[0]) - item.setCheckable(True) - item.setCheckState(Qt.CheckState.Checked) - self.item_model.appendRow(item) - else: - item = QStandardItem(paths[0]) - item.setCheckable(True) - item.setCheckState(Qt.CheckState.Checked) - self.item_model.appendRow(item) + item = QStandardItem(paths[0]) + item.setCheckable(True) + item.setCheckState(Qt.CheckState.Checked) + self.item_model.appendRow(item) + + # Use a standard nested loop to detect duplicate reports + duplicate_reports: Dict[str, List[str]] = {} + processed_files = set() + for i in range(len(report_file_list)): + progress.step() + if progress.user_canceled: + break + path1 = report_file_list[i] + if path1 in processed_files: + continue # Skip already processed files + + # We will use path1 as the unique model and save a list of duplicates + duplicate_reports[path1] = [] + for j in range(i + 1, len(report_file_list)): + path2 = report_file_list[j] + if compare_yaml_files( + path1, path2, ["report_date", "model_files", "digest_version"] + ): + num_duplicates += 1 + duplicate_reports[path1].append(path2) + processed_files.add(path2) + + for path, dupes in duplicate_reports.items(): + if dupes: + self.ui.duplicateListWidget.addItem(path) + for dupe in dupes: + self.ui.duplicateListWidget.addItem(f"- Duplicate: {dupe}") + item = QStandardItem(path) + item.setCheckable(True) + item.setCheckState(Qt.CheckState.Checked) + self.item_model.appendRow(item) progress.close() @@ -270,7 +336,7 @@ def set_directory(self, directory: str): self.update_num_selected_label() self.update_message_label( - f"Found a total of {len(onnx_file_list)} ONNX files. " + f"Found a total of {total_num_models} model files. " "Right click a model below " "to open it up in the model summary view." ) @@ -289,7 +355,9 @@ def start_analysis(self): self.analysis_thread.model_dict = self.model_dict self.analysis_thread.start() - def open_analysis(self, model_list: List[onnx_utils.DigestOnnxModel]): + def open_analysis( + self, model_list: List[Union[DigestOnnxModel, DigestReportModel]] + ): multi_model_analysis = MultiModelAnalysis(model_list) self.analysis_window.setCentralWidget(multi_model_analysis) self.analysis_window.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg")) diff --git a/src/digest/node_summary.py b/src/digest/node_summary.py index 99eb35f..01aaf09 100644 --- a/src/digest/node_summary.py +++ b/src/digest/node_summary.py @@ -6,6 +6,10 @@ from PySide6.QtWidgets import QWidget, QTableWidgetItem, QFileDialog from digest.ui.nodessummary_ui import Ui_nodesSummary from digest.qt_utils import apply_dark_style_sheet +from digest.model_class.digest_model import ( + save_node_shape_counts_csv_report, + save_nodes_csv_report, +) from utils import onnx_utils ROOT_FOLDER = os.path.dirname(__file__) @@ -111,8 +115,6 @@ def save_csv_file(self): self, "Save CSV", os.getcwd(), "CSV(*.csv)" ) if filepath and self.ui.allNodesBtn.isChecked(): - onnx_utils.save_nodes_csv_report(self.node_data, filepath) + save_nodes_csv_report(self.node_data, filepath) elif filepath and self.ui.shapeCountsBtn.isChecked(): - onnx_utils.save_node_shape_counts_csv_report( - self.node_shape_counts, filepath - ) + save_node_shape_counts_csv_report(self.node_shape_counts, filepath) diff --git a/src/digest/resource_rc.py b/src/digest/resource_rc.py index cf29584..59afc50 100644 --- a/src/digest/resource_rc.py +++ b/src/digest/resource_rc.py @@ -1,6 +1,6 @@ # Resource object code (Python 3) # Created by: object code -# Created by: The Resource Compiler for Qt version 6.8.0 +# Created by: The Resource Compiler for Qt version 6.8.1 # WARNING! All changes made in this file will be lost! from PySide6 import QtCore @@ -19676,39 +19676,39 @@ \x00\x00\x000\x00\x02\x00\x00\x00\x03\x00\x00\x00\x05\ \x00\x00\x00\x00\x00\x00\x00\x00\ \x00\x00\x00B\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\ -\x00\x00\x01\x93Ju\xc2>\ +\x00\x00\x01\x93K\x85\xbd\xbb\ \x00\x00\x00\xb0\x00\x00\x00\x00\x00\x01\x00\x01\x86\x02\ -\x00\x00\x01\x93Ju\xc2A\ +\x00\x00\x01\x93K\x85\xbd\xbb\ \x00\x00\x00\x84\x00\x00\x00\x00\x00\x01\x00\x01 bool: + + loop = QEventLoop() + timer = QTimer() + timer.setSingleShot(True) + timer.timeout.connect(loop.quit) + + def check_threads(): + if all(thread.isFinished() for thread in threads): + loop.quit() + + check_timer = QTimer() + check_timer.timeout.connect(check_threads) + check_timer.start(100) # Check every 100ms + + timer.start(timeout) + loop.exec() + + check_timer.stop() + timer.stop() + + # Return True if all threads finished, False if timed out + return all(thread.isFinished() for thread in threads) + + class StatsThread(QThread): - completed = Signal(onnx_utils.DigestOnnxModel, str) + completed = Signal(DigestOnnxModel, str) def __init__( self, @@ -31,14 +59,17 @@ def run(self): if not self.unique_id: raise ValueError("You must specify a unique id.") - digest_model = onnx_utils.DigestOnnxModel(self.model, save_proto=False) + digest_model = DigestOnnxModel(self.model, save_proto=False) self.completed.emit(digest_model, self.unique_id) + def wait(self, timeout=10000): + wait_threads([self], timeout) + class SimilarityThread(QThread): - completed_successfully = Signal(bool, str, str, str) + completed_successfully = Signal(bool, str, str, str, pd.DataFrame) def __init__( self, @@ -60,21 +91,72 @@ def run(self): raise ValueError("You must set the model id") try: - most_similar, _ = find_match( + most_similar, _, df_sorted = find_match( self.model_filepath, - self.png_filepath, dequantize=False, replace=True, - dark_mode=True, ) most_similar = [os.path.basename(path) for path in most_similar] - most_similar = ",".join(most_similar[1:4]) + # We convert List[str] to str to send through the signal + most_similar = ",".join(most_similar) self.completed_successfully.emit( - True, self.model_id, most_similar, self.png_filepath + True, self.model_id, most_similar, self.png_filepath, df_sorted ) except Exception as e: # pylint: disable=broad-exception-caught most_similar = "" self.completed_successfully.emit( - False, self.model_id, most_similar, self.png_filepath + False, self.model_id, most_similar, self.png_filepath, df_sorted ) print(f"Issue creating similarity analysis: {e}") + + def wait(self, timeout=10000): + wait_threads([self], timeout) + + +def post_process( + model_name: str, + name_list: List[str], + df_sorted: pd.DataFrame, + png_file_path: str, + dark_mode: bool = True, +): + """Matplotlib is not thread safe so we must do post_processing on the main thread""" + if dark_mode: + plt.style.use("dark_background") + fig, ax = plt.subplots(figsize=(12, 10)) + im = ax.imshow(df_sorted, cmap="viridis") + + # Show all ticks and label them with the respective list entries + ax.set_xticks(np.arange(len(df_sorted.columns))) + ax.set_yticks(np.arange(len(name_list))) + ax.set_xticklabels([a[:5] for a in df_sorted.columns]) + ax.set_yticklabels(name_list) + + # Rotate the tick labels and set their alignment + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") + + ax.set_title(f"Model Similarity Heatmap - {model_name}") + + cb = plt.colorbar( + im, + ax=ax, + shrink=0.5, + format="%.2f", + label="Correlation Ratio", + orientation="vertical", + # pad=0.02, + ) + cb.set_ticks([0, 0.5, 1]) # Set colorbar ticks at 0, 0.5, and 1 + cb.set_ticklabels( + ["0.0 (Low)", "0.5 (Medium)", "1.0 (High)"] + ) # Set corresponding labels + cb.set_label("Correlation Ratio", labelpad=-100) + + fig.tight_layout() + + if png_file_path is None: + png_file_path = "heatmap.png" + + fig.savefig(png_file_path) + + plt.close(fig) diff --git a/src/digest/ui/freezeinputs_ui.py b/src/digest/ui/freezeinputs_ui.py index 3838e57..85e8211 100644 --- a/src/digest/ui/freezeinputs_ui.py +++ b/src/digest/ui/freezeinputs_ui.py @@ -3,7 +3,7 @@ ################################################################################ ## Form generated from reading UI file 'freezeinputs.ui' ## -## Created by: Qt User Interface Compiler version 6.8.0 +## Created by: Qt User Interface Compiler version 6.8.1 ## ## WARNING! All changes made in this file will be lost when recompiling UI file! ################################################################################ diff --git a/src/digest/ui/huggingface_page_ui.py b/src/digest/ui/huggingface_page_ui.py index a06a573..5cfcbfe 100644 --- a/src/digest/ui/huggingface_page_ui.py +++ b/src/digest/ui/huggingface_page_ui.py @@ -3,7 +3,7 @@ ################################################################################ ## Form generated from reading UI file 'huggingface_page.ui' ## -## Created by: Qt User Interface Compiler version 6.8.0 +## Created by: Qt User Interface Compiler version 6.8.1 ## ## WARNING! All changes made in this file will be lost when recompiling UI file! ################################################################################ diff --git a/src/digest/ui/mainwindow.ui b/src/digest/ui/mainwindow.ui index 8643efa..e7e28f3 100644 --- a/src/digest/ui/mainwindow.ui +++ b/src/digest/ui/mainwindow.ui @@ -179,7 +179,7 @@ Qt::FocusPolicy::NoFocus - <html><head/><body><p>Open a local model file (Ctrl-O)</p></body></html> + <html><head/><body><p>Open (Ctrl-O)</p></body></html> QPushButton { diff --git a/src/digest/ui/mainwindow_ui.py b/src/digest/ui/mainwindow_ui.py index 9904c77..61119a1 100644 --- a/src/digest/ui/mainwindow_ui.py +++ b/src/digest/ui/mainwindow_ui.py @@ -3,7 +3,7 @@ ################################################################################ ## Form generated from reading UI file 'mainwindow.ui' ## -## Created by: Qt User Interface Compiler version 6.8.0 +## Created by: Qt User Interface Compiler version 6.8.1 ## ## WARNING! All changes made in this file will be lost when recompiling UI file! ################################################################################ @@ -520,7 +520,7 @@ def setupUi(self, MainWindow): def retranslateUi(self, MainWindow): MainWindow.setWindowTitle(QCoreApplication.translate("MainWindow", u"DigestAI", None)) #if QT_CONFIG(tooltip) - self.openFileBtn.setToolTip(QCoreApplication.translate("MainWindow", u"

Open a local model file (Ctrl-O)

", None)) + self.openFileBtn.setToolTip(QCoreApplication.translate("MainWindow", u"

Open (Ctrl-O)

", None)) #endif // QT_CONFIG(tooltip) self.openFileBtn.setText("") #if QT_CONFIG(shortcut) diff --git a/src/digest/ui/modelsummary.ui b/src/digest/ui/modelsummary.ui index 180fed4..737cf33 100644 --- a/src/digest/ui/modelsummary.ui +++ b/src/digest/ui/modelsummary.ui @@ -6,8 +6,8 @@ 0 0 - 980 - 687 + 1138 + 837
@@ -153,11 +153,17 @@ border-top-right-radius: 10px; 0 - 0 + -776 991 - 1453 + 1443 + + + 0 + 0 + + background-color: black; @@ -244,7 +250,7 @@ QFrame:hover { 6 - + @@ -271,7 +277,7 @@ QFrame:hover { - + @@ -667,20 +673,32 @@ QFrame:hover { - + 0 0 + + + 300 + 500 + + - + - + 0 0 + + + 0 + 0 + + 16777215 @@ -690,6 +708,9 @@ QFrame:hover { Loading... + + false + Qt::AlignmentFlag::AlignCenter @@ -834,7 +855,7 @@ QFrame:hover { - + 0 0 @@ -853,6 +874,9 @@ QFrame:hover { + + 6 + 20 @@ -861,8 +885,14 @@ QFrame:hover { - + + + + 0 + 0 + + QLabel { font-size: 18px; @@ -875,11 +905,11 @@ QFrame:hover { - + - - 0 + + 1 0 @@ -975,7 +1005,7 @@ QScrollBar::handle:vertical { - + @@ -983,6 +1013,18 @@ QScrollBar::handle:vertical { 0 + + + 0 + 0 + + + + + 16777215 + 16777215 + + PointingHandCursor @@ -1067,7 +1109,7 @@ QPushButton:pressed { - + 0 0 @@ -1218,7 +1260,7 @@ QScrollBar::handle:vertical { - + diff --git a/src/digest/ui/modelsummary_ui.py b/src/digest/ui/modelsummary_ui.py index 1102e3a..e217372 100644 --- a/src/digest/ui/modelsummary_ui.py +++ b/src/digest/ui/modelsummary_ui.py @@ -3,7 +3,7 @@ ################################################################################ ## Form generated from reading UI file 'modelsummary.ui' ## -## Created by: Qt User Interface Compiler version 6.8.0 +## Created by: Qt User Interface Compiler version 6.8.1 ## ## WARNING! All changes made in this file will be lost when recompiling UI file! ################################################################################ @@ -29,7 +29,7 @@ class Ui_modelSummary(object): def setupUi(self, modelSummary): if not modelSummary.objectName(): modelSummary.setObjectName(u"modelSummary") - modelSummary.resize(980, 687) + modelSummary.resize(1138, 837) sizePolicy = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.MinimumExpanding) sizePolicy.setHorizontalStretch(0) sizePolicy.setVerticalStretch(0) @@ -115,17 +115,22 @@ def setupUi(self, modelSummary): self.scrollArea.setWidgetResizable(True) self.scrollAreaWidgetContents = QWidget() self.scrollAreaWidgetContents.setObjectName(u"scrollAreaWidgetContents") - self.scrollAreaWidgetContents.setGeometry(QRect(0, 0, 991, 1453)) + self.scrollAreaWidgetContents.setGeometry(QRect(0, -776, 991, 1443)) + sizePolicy2 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.MinimumExpanding) + sizePolicy2.setHorizontalStretch(0) + sizePolicy2.setVerticalStretch(0) + sizePolicy2.setHeightForWidth(self.scrollAreaWidgetContents.sizePolicy().hasHeightForWidth()) + self.scrollAreaWidgetContents.setSizePolicy(sizePolicy2) self.scrollAreaWidgetContents.setStyleSheet(u"background-color: black;") self.verticalLayout_20 = QVBoxLayout(self.scrollAreaWidgetContents) self.verticalLayout_20.setObjectName(u"verticalLayout_20") self.cardFrame = QFrame(self.scrollAreaWidgetContents) self.cardFrame.setObjectName(u"cardFrame") - sizePolicy2 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) - sizePolicy2.setHorizontalStretch(0) - sizePolicy2.setVerticalStretch(0) - sizePolicy2.setHeightForWidth(self.cardFrame.sizePolicy().hasHeightForWidth()) - self.cardFrame.setSizePolicy(sizePolicy2) + sizePolicy3 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) + sizePolicy3.setHorizontalStretch(0) + sizePolicy3.setVerticalStretch(0) + sizePolicy3.setHeightForWidth(self.cardFrame.sizePolicy().hasHeightForWidth()) + self.cardFrame.setSizePolicy(sizePolicy3) self.cardFrame.setStyleSheet(u"background: transparent; /*rgb(40,40,40)*/") self.cardFrame.setFrameShape(QFrame.Shape.StyledPanel) self.cardFrame.setFrameShadow(QFrame.Shadow.Raised) @@ -134,19 +139,19 @@ def setupUi(self, modelSummary): self.horizontalLayout.setContentsMargins(-1, -1, -1, 1) self.cardWidget = QWidget(self.cardFrame) self.cardWidget.setObjectName(u"cardWidget") - sizePolicy2.setHeightForWidth(self.cardWidget.sizePolicy().hasHeightForWidth()) - self.cardWidget.setSizePolicy(sizePolicy2) + sizePolicy3.setHeightForWidth(self.cardWidget.sizePolicy().hasHeightForWidth()) + self.cardWidget.setSizePolicy(sizePolicy3) self.horizontalLayout_2 = QHBoxLayout(self.cardWidget) self.horizontalLayout_2.setSpacing(13) self.horizontalLayout_2.setObjectName(u"horizontalLayout_2") self.horizontalLayout_2.setContentsMargins(-1, 6, 25, 35) self.opsetFrame = QFrame(self.cardWidget) self.opsetFrame.setObjectName(u"opsetFrame") - sizePolicy3 = QSizePolicy(QSizePolicy.Policy.Maximum, QSizePolicy.Policy.Fixed) - sizePolicy3.setHorizontalStretch(0) - sizePolicy3.setVerticalStretch(0) - sizePolicy3.setHeightForWidth(self.opsetFrame.sizePolicy().hasHeightForWidth()) - self.opsetFrame.setSizePolicy(sizePolicy3) + sizePolicy4 = QSizePolicy(QSizePolicy.Policy.Maximum, QSizePolicy.Policy.Fixed) + sizePolicy4.setHorizontalStretch(0) + sizePolicy4.setVerticalStretch(0) + sizePolicy4.setHeightForWidth(self.opsetFrame.sizePolicy().hasHeightForWidth()) + self.opsetFrame.setSizePolicy(sizePolicy4) self.opsetFrame.setMinimumSize(QSize(220, 70)) self.opsetFrame.setMaximumSize(QSize(16777215, 80)) self.opsetFrame.setStyleSheet(u"QFrame {\n" @@ -164,11 +169,11 @@ def setupUi(self, modelSummary): self.verticalLayout_5.setContentsMargins(-1, -1, 6, -1) self.opsetLabel = QLabel(self.opsetFrame) self.opsetLabel.setObjectName(u"opsetLabel") - sizePolicy4 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Fixed) - sizePolicy4.setHorizontalStretch(0) - sizePolicy4.setVerticalStretch(0) - sizePolicy4.setHeightForWidth(self.opsetLabel.sizePolicy().hasHeightForWidth()) - self.opsetLabel.setSizePolicy(sizePolicy4) + sizePolicy5 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Fixed) + sizePolicy5.setHorizontalStretch(0) + sizePolicy5.setVerticalStretch(0) + sizePolicy5.setHeightForWidth(self.opsetLabel.sizePolicy().hasHeightForWidth()) + self.opsetLabel.setSizePolicy(sizePolicy5) self.opsetLabel.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" " font-weight: bold;\n" @@ -178,12 +183,12 @@ def setupUi(self, modelSummary): self.opsetLabel.setAlignment(Qt.AlignmentFlag.AlignCenter) self.opsetLabel.setTextInteractionFlags(Qt.TextInteractionFlag.TextSelectableByMouse) - self.verticalLayout_5.addWidget(self.opsetLabel, 0, Qt.AlignmentFlag.AlignHCenter) + self.verticalLayout_5.addWidget(self.opsetLabel) self.opsetVersion = QLabel(self.opsetFrame) self.opsetVersion.setObjectName(u"opsetVersion") - sizePolicy4.setHeightForWidth(self.opsetVersion.sizePolicy().hasHeightForWidth()) - self.opsetVersion.setSizePolicy(sizePolicy4) + sizePolicy5.setHeightForWidth(self.opsetVersion.sizePolicy().hasHeightForWidth()) + self.opsetVersion.setSizePolicy(sizePolicy5) self.opsetVersion.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" " font-weight: bold;\n" @@ -192,15 +197,15 @@ def setupUi(self, modelSummary): self.opsetVersion.setAlignment(Qt.AlignmentFlag.AlignCenter) self.opsetVersion.setTextInteractionFlags(Qt.TextInteractionFlag.LinksAccessibleByMouse|Qt.TextInteractionFlag.TextSelectableByKeyboard|Qt.TextInteractionFlag.TextSelectableByMouse) - self.verticalLayout_5.addWidget(self.opsetVersion, 0, Qt.AlignmentFlag.AlignHCenter) + self.verticalLayout_5.addWidget(self.opsetVersion) self.horizontalLayout_2.addWidget(self.opsetFrame) self.nodesFrame = QFrame(self.cardWidget) self.nodesFrame.setObjectName(u"nodesFrame") - sizePolicy3.setHeightForWidth(self.nodesFrame.sizePolicy().hasHeightForWidth()) - self.nodesFrame.setSizePolicy(sizePolicy3) + sizePolicy4.setHeightForWidth(self.nodesFrame.sizePolicy().hasHeightForWidth()) + self.nodesFrame.setSizePolicy(sizePolicy4) self.nodesFrame.setMinimumSize(QSize(220, 70)) self.nodesFrame.setMaximumSize(QSize(16777215, 80)) self.nodesFrame.setStyleSheet(u"QFrame {\n" @@ -218,8 +223,8 @@ def setupUi(self, modelSummary): self.verticalLayout_12.setContentsMargins(-1, 9, -1, -1) self.nodesLabel = QLabel(self.nodesFrame) self.nodesLabel.setObjectName(u"nodesLabel") - sizePolicy4.setHeightForWidth(self.nodesLabel.sizePolicy().hasHeightForWidth()) - self.nodesLabel.setSizePolicy(sizePolicy4) + sizePolicy5.setHeightForWidth(self.nodesLabel.sizePolicy().hasHeightForWidth()) + self.nodesLabel.setSizePolicy(sizePolicy5) self.nodesLabel.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" " font-weight: bold;\n" @@ -233,8 +238,8 @@ def setupUi(self, modelSummary): self.nodes = QLabel(self.nodesFrame) self.nodes.setObjectName(u"nodes") - sizePolicy4.setHeightForWidth(self.nodes.sizePolicy().hasHeightForWidth()) - self.nodes.setSizePolicy(sizePolicy4) + sizePolicy5.setHeightForWidth(self.nodes.sizePolicy().hasHeightForWidth()) + self.nodes.setSizePolicy(sizePolicy5) self.nodes.setMinimumSize(QSize(150, 32)) self.nodes.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" @@ -254,8 +259,8 @@ def setupUi(self, modelSummary): self.paramFrame = QFrame(self.cardWidget) self.paramFrame.setObjectName(u"paramFrame") - sizePolicy3.setHeightForWidth(self.paramFrame.sizePolicy().hasHeightForWidth()) - self.paramFrame.setSizePolicy(sizePolicy3) + sizePolicy4.setHeightForWidth(self.paramFrame.sizePolicy().hasHeightForWidth()) + self.paramFrame.setSizePolicy(sizePolicy4) self.paramFrame.setMinimumSize(QSize(220, 70)) self.paramFrame.setMaximumSize(QSize(16777215, 80)) self.paramFrame.setStyleSheet(u"QFrame {\n" @@ -272,8 +277,8 @@ def setupUi(self, modelSummary): self.verticalLayout_9.setObjectName(u"verticalLayout_9") self.parametersLabel = QLabel(self.paramFrame) self.parametersLabel.setObjectName(u"parametersLabel") - sizePolicy4.setHeightForWidth(self.parametersLabel.sizePolicy().hasHeightForWidth()) - self.parametersLabel.setSizePolicy(sizePolicy4) + sizePolicy5.setHeightForWidth(self.parametersLabel.sizePolicy().hasHeightForWidth()) + self.parametersLabel.setSizePolicy(sizePolicy5) self.parametersLabel.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" " font-weight: bold;\n" @@ -287,8 +292,8 @@ def setupUi(self, modelSummary): self.parameters = QLabel(self.paramFrame) self.parameters.setObjectName(u"parameters") - sizePolicy4.setHeightForWidth(self.parameters.sizePolicy().hasHeightForWidth()) - self.parameters.setSizePolicy(sizePolicy4) + sizePolicy5.setHeightForWidth(self.parameters.sizePolicy().hasHeightForWidth()) + self.parameters.setSizePolicy(sizePolicy5) self.parameters.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" " font-weight: bold;\n" @@ -304,8 +309,8 @@ def setupUi(self, modelSummary): self.flopsFrame = QFrame(self.cardWidget) self.flopsFrame.setObjectName(u"flopsFrame") - sizePolicy3.setHeightForWidth(self.flopsFrame.sizePolicy().hasHeightForWidth()) - self.flopsFrame.setSizePolicy(sizePolicy3) + sizePolicy4.setHeightForWidth(self.flopsFrame.sizePolicy().hasHeightForWidth()) + self.flopsFrame.setSizePolicy(sizePolicy4) self.flopsFrame.setMinimumSize(QSize(220, 70)) self.flopsFrame.setMaximumSize(QSize(16777215, 80)) self.flopsFrame.setCursor(QCursor(Qt.CursorShape.ArrowCursor)) @@ -323,8 +328,8 @@ def setupUi(self, modelSummary): self.verticalLayout_11.setObjectName(u"verticalLayout_11") self.flopsLabel = QLabel(self.flopsFrame) self.flopsLabel.setObjectName(u"flopsLabel") - sizePolicy4.setHeightForWidth(self.flopsLabel.sizePolicy().hasHeightForWidth()) - self.flopsLabel.setSizePolicy(sizePolicy4) + sizePolicy5.setHeightForWidth(self.flopsLabel.sizePolicy().hasHeightForWidth()) + self.flopsLabel.setSizePolicy(sizePolicy5) self.flopsLabel.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" " font-weight: bold;\n" @@ -338,11 +343,11 @@ def setupUi(self, modelSummary): self.flops = QLabel(self.flopsFrame) self.flops.setObjectName(u"flops") - sizePolicy5 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.Fixed) - sizePolicy5.setHorizontalStretch(0) - sizePolicy5.setVerticalStretch(0) - sizePolicy5.setHeightForWidth(self.flops.sizePolicy().hasHeightForWidth()) - self.flops.setSizePolicy(sizePolicy5) + sizePolicy6 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.Fixed) + sizePolicy6.setHorizontalStretch(0) + sizePolicy6.setVerticalStretch(0) + sizePolicy6.setHeightForWidth(self.flops.sizePolicy().hasHeightForWidth()) + self.flops.setSizePolicy(sizePolicy6) self.flops.setMinimumSize(QSize(200, 32)) self.flops.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" @@ -380,11 +385,11 @@ def setupUi(self, modelSummary): self.parametersPieChart = PieChartWidget(self.scrollAreaWidgetContents) self.parametersPieChart.setObjectName(u"parametersPieChart") - sizePolicy6 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred) - sizePolicy6.setHorizontalStretch(0) - sizePolicy6.setVerticalStretch(0) - sizePolicy6.setHeightForWidth(self.parametersPieChart.sizePolicy().hasHeightForWidth()) - self.parametersPieChart.setSizePolicy(sizePolicy6) + sizePolicy7 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred) + sizePolicy7.setHorizontalStretch(0) + sizePolicy7.setVerticalStretch(0) + sizePolicy7.setHeightForWidth(self.parametersPieChart.sizePolicy().hasHeightForWidth()) + self.parametersPieChart.setSizePolicy(sizePolicy7) self.parametersPieChart.setMinimumSize(QSize(300, 500)) self.firstRowChartsLayout.addWidget(self.parametersPieChart) @@ -397,23 +402,29 @@ def setupUi(self, modelSummary): self.secondRowChartsLayout.setContentsMargins(-1, 20, -1, -1) self.similarityWidget = QWidget(self.scrollAreaWidgetContents) self.similarityWidget.setObjectName(u"similarityWidget") - sizePolicy.setHeightForWidth(self.similarityWidget.sizePolicy().hasHeightForWidth()) - self.similarityWidget.setSizePolicy(sizePolicy) + sizePolicy7.setHeightForWidth(self.similarityWidget.sizePolicy().hasHeightForWidth()) + self.similarityWidget.setSizePolicy(sizePolicy7) + self.similarityWidget.setMinimumSize(QSize(300, 500)) self.placeholderWidget = QVBoxLayout(self.similarityWidget) self.placeholderWidget.setObjectName(u"placeholderWidget") self.similarityImg = ClickableLabel(self.similarityWidget) self.similarityImg.setObjectName(u"similarityImg") - sizePolicy.setHeightForWidth(self.similarityImg.sizePolicy().hasHeightForWidth()) - self.similarityImg.setSizePolicy(sizePolicy) + sizePolicy8 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + sizePolicy8.setHorizontalStretch(0) + sizePolicy8.setVerticalStretch(0) + sizePolicy8.setHeightForWidth(self.similarityImg.sizePolicy().hasHeightForWidth()) + self.similarityImg.setSizePolicy(sizePolicy8) + self.similarityImg.setMinimumSize(QSize(0, 0)) self.similarityImg.setMaximumSize(QSize(16777215, 16777215)) + self.similarityImg.setScaledContents(False) self.similarityImg.setAlignment(Qt.AlignmentFlag.AlignCenter) - self.placeholderWidget.addWidget(self.similarityImg, 0, Qt.AlignmentFlag.AlignHCenter) + self.placeholderWidget.addWidget(self.similarityImg) self.similarityCorrelationStatic = QLabel(self.similarityWidget) self.similarityCorrelationStatic.setObjectName(u"similarityCorrelationStatic") - sizePolicy2.setHeightForWidth(self.similarityCorrelationStatic.sizePolicy().hasHeightForWidth()) - self.similarityCorrelationStatic.setSizePolicy(sizePolicy2) + sizePolicy3.setHeightForWidth(self.similarityCorrelationStatic.sizePolicy().hasHeightForWidth()) + self.similarityCorrelationStatic.setSizePolicy(sizePolicy3) self.similarityCorrelationStatic.setFont(font) self.similarityCorrelationStatic.setAlignment(Qt.AlignmentFlag.AlignCenter) @@ -421,8 +432,8 @@ def setupUi(self, modelSummary): self.similarityCorrelation = QLabel(self.similarityWidget) self.similarityCorrelation.setObjectName(u"similarityCorrelation") - sizePolicy2.setHeightForWidth(self.similarityCorrelation.sizePolicy().hasHeightForWidth()) - self.similarityCorrelation.setSizePolicy(sizePolicy2) + sizePolicy3.setHeightForWidth(self.similarityCorrelation.sizePolicy().hasHeightForWidth()) + self.similarityCorrelation.setSizePolicy(sizePolicy3) palette = QPalette() brush = QBrush(QColor(0, 0, 0, 255)) brush.setStyle(Qt.SolidPattern) @@ -446,9 +457,6 @@ def setupUi(self, modelSummary): self.flopsPieChart = PieChartWidget(self.scrollAreaWidgetContents) self.flopsPieChart.setObjectName(u"flopsPieChart") - sizePolicy7 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.Preferred) - sizePolicy7.setHorizontalStretch(0) - sizePolicy7.setVerticalStretch(0) sizePolicy7.setHeightForWidth(self.flopsPieChart.sizePolicy().hasHeightForWidth()) self.flopsPieChart.setSizePolicy(sizePolicy7) self.flopsPieChart.setMinimumSize(QSize(300, 500)) @@ -462,19 +470,25 @@ def setupUi(self, modelSummary): self.verticalLayout_20.addLayout(self.chartsLayout) self.thirdRowInputsLayout = QHBoxLayout() + self.thirdRowInputsLayout.setSpacing(6) self.thirdRowInputsLayout.setObjectName(u"thirdRowInputsLayout") self.thirdRowInputsLayout.setContentsMargins(20, 30, -1, -1) self.inputsLayout = QVBoxLayout() self.inputsLayout.setObjectName(u"inputsLayout") self.inputsLabel = QLabel(self.scrollAreaWidgetContents) self.inputsLabel.setObjectName(u"inputsLabel") + sizePolicy9 = QSizePolicy(QSizePolicy.Policy.Maximum, QSizePolicy.Policy.Maximum) + sizePolicy9.setHorizontalStretch(0) + sizePolicy9.setVerticalStretch(0) + sizePolicy9.setHeightForWidth(self.inputsLabel.sizePolicy().hasHeightForWidth()) + self.inputsLabel.setSizePolicy(sizePolicy9) self.inputsLabel.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" " font-weight: bold;\n" " background: transparent;\n" "}") - self.inputsLayout.addWidget(self.inputsLabel, 0, Qt.AlignmentFlag.AlignVCenter) + self.inputsLayout.addWidget(self.inputsLabel) self.inputsTable = QTableWidget(self.scrollAreaWidgetContents) if (self.inputsTable.columnCount() < 4): @@ -488,11 +502,11 @@ def setupUi(self, modelSummary): __qtablewidgetitem3 = QTableWidgetItem() self.inputsTable.setHorizontalHeaderItem(3, __qtablewidgetitem3) self.inputsTable.setObjectName(u"inputsTable") - sizePolicy8 = QSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Preferred) - sizePolicy8.setHorizontalStretch(0) - sizePolicy8.setVerticalStretch(0) - sizePolicy8.setHeightForWidth(self.inputsTable.sizePolicy().hasHeightForWidth()) - self.inputsTable.setSizePolicy(sizePolicy8) + sizePolicy10 = QSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Expanding) + sizePolicy10.setHorizontalStretch(1) + sizePolicy10.setVerticalStretch(0) + sizePolicy10.setHeightForWidth(self.inputsTable.sizePolicy().hasHeightForWidth()) + self.inputsTable.setSizePolicy(sizePolicy10) self.inputsTable.setStyleSheet(u"QTableWidget {\n" " gridline-color: #353535; /* Grid lines */\n" " selection-background-color: #3949AB; /* Blue selection */\n" @@ -543,18 +557,20 @@ def setupUi(self, modelSummary): self.inputsTable.verticalHeader().setVisible(False) self.inputsTable.verticalHeader().setHighlightSections(True) - self.inputsLayout.addWidget(self.inputsTable, 0, Qt.AlignmentFlag.AlignVCenter) + self.inputsLayout.addWidget(self.inputsTable) self.thirdRowInputsLayout.addLayout(self.inputsLayout) self.freezeButton = QPushButton(self.scrollAreaWidgetContents) self.freezeButton.setObjectName(u"freezeButton") - sizePolicy9 = QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed) - sizePolicy9.setHorizontalStretch(0) - sizePolicy9.setVerticalStretch(0) - sizePolicy9.setHeightForWidth(self.freezeButton.sizePolicy().hasHeightForWidth()) - self.freezeButton.setSizePolicy(sizePolicy9) + sizePolicy11 = QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed) + sizePolicy11.setHorizontalStretch(0) + sizePolicy11.setVerticalStretch(0) + sizePolicy11.setHeightForWidth(self.freezeButton.sizePolicy().hasHeightForWidth()) + self.freezeButton.setSizePolicy(sizePolicy11) + self.freezeButton.setMinimumSize(QSize(0, 0)) + self.freezeButton.setMaximumSize(QSize(16777215, 16777215)) self.freezeButton.setCursor(QCursor(Qt.CursorShape.PointingHandCursor)) self.freezeButton.setStyleSheet(u"QPushButton {\n" " color: white;\n" @@ -580,7 +596,7 @@ def setupUi(self, modelSummary): self.freezeButton.setIcon(icon) self.freezeButton.setIconSize(QSize(32, 32)) - self.thirdRowInputsLayout.addWidget(self.freezeButton, 0, Qt.AlignmentFlag.AlignTop) + self.thirdRowInputsLayout.addWidget(self.freezeButton) self.horizontalSpacer = QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum) @@ -616,8 +632,11 @@ def setupUi(self, modelSummary): __qtablewidgetitem7 = QTableWidgetItem() self.outputsTable.setHorizontalHeaderItem(3, __qtablewidgetitem7) self.outputsTable.setObjectName(u"outputsTable") - sizePolicy8.setHeightForWidth(self.outputsTable.sizePolicy().hasHeightForWidth()) - self.outputsTable.setSizePolicy(sizePolicy8) + sizePolicy12 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.Expanding) + sizePolicy12.setHorizontalStretch(0) + sizePolicy12.setVerticalStretch(0) + sizePolicy12.setHeightForWidth(self.outputsTable.sizePolicy().hasHeightForWidth()) + self.outputsTable.setSizePolicy(sizePolicy12) self.outputsTable.setStyleSheet(u"QTableWidget {\n" " gridline-color: #353535; /* Grid lines */\n" " selection-background-color: #3949AB; /* Blue selection */\n" @@ -684,8 +703,8 @@ def setupUi(self, modelSummary): self.sidePaneFrame = QFrame(modelSummary) self.sidePaneFrame.setObjectName(u"sidePaneFrame") - sizePolicy2.setHeightForWidth(self.sidePaneFrame.sizePolicy().hasHeightForWidth()) - self.sidePaneFrame.setSizePolicy(sizePolicy2) + sizePolicy3.setHeightForWidth(self.sidePaneFrame.sizePolicy().hasHeightForWidth()) + self.sidePaneFrame.setSizePolicy(sizePolicy3) self.sidePaneFrame.setMinimumSize(QSize(0, 0)) self.sidePaneFrame.setStyleSheet(u"QFrame {\n" " /*background: rgb(30,30,30);*/\n" @@ -741,8 +760,8 @@ def setupUi(self, modelSummary): __qtablewidgetitem21 = QTableWidgetItem() self.modelProtoTable.setItem(3, 1, __qtablewidgetitem21) self.modelProtoTable.setObjectName(u"modelProtoTable") - sizePolicy2.setHeightForWidth(self.modelProtoTable.sizePolicy().hasHeightForWidth()) - self.modelProtoTable.setSizePolicy(sizePolicy2) + sizePolicy3.setHeightForWidth(self.modelProtoTable.sizePolicy().hasHeightForWidth()) + self.modelProtoTable.setSizePolicy(sizePolicy3) self.modelProtoTable.setMinimumSize(QSize(0, 0)) self.modelProtoTable.setMaximumSize(QSize(16777215, 100)) self.modelProtoTable.setStyleSheet(u"QTableWidget::item {\n" @@ -770,7 +789,7 @@ def setupUi(self, modelSummary): self.modelProtoTable.verticalHeader().setMinimumSectionSize(20) self.modelProtoTable.verticalHeader().setDefaultSectionSize(20) - self.verticalLayout_3.addWidget(self.modelProtoTable, 0, Qt.AlignmentFlag.AlignRight) + self.verticalLayout_3.addWidget(self.modelProtoTable) self.importsLabel = QLabel(self.sidePaneFrame) self.importsLabel.setObjectName(u"importsLabel") @@ -791,11 +810,8 @@ def setupUi(self, modelSummary): __qtablewidgetitem23 = QTableWidgetItem() self.importsTable.setHorizontalHeaderItem(1, __qtablewidgetitem23) self.importsTable.setObjectName(u"importsTable") - sizePolicy10 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.MinimumExpanding) - sizePolicy10.setHorizontalStretch(0) - sizePolicy10.setVerticalStretch(0) - sizePolicy10.setHeightForWidth(self.importsTable.sizePolicy().hasHeightForWidth()) - self.importsTable.setSizePolicy(sizePolicy10) + sizePolicy2.setHeightForWidth(self.importsTable.sizePolicy().hasHeightForWidth()) + self.importsTable.setSizePolicy(sizePolicy2) self.importsTable.setStyleSheet(u"QTableWidget::item {\n" " color: white;\n" " padding: 5px;\n" diff --git a/src/digest/ui/multimodelanalysis.ui b/src/digest/ui/multimodelanalysis.ui index cf044e3..16109d0 100644 --- a/src/digest/ui/multimodelanalysis.ui +++ b/src/digest/ui/multimodelanalysis.ui @@ -6,8 +6,8 @@ 0 0 - 908 - 647 + 1085 + 866 @@ -51,7 +51,7 @@ QFrame::Shadow::Raised - + @@ -176,7 +176,7 @@ - + 0 0 @@ -198,8 +198,8 @@ 0 0 - 888 - 464 + 1065 + 688 @@ -242,6 +242,12 @@ + + + 0 + 0 + + QFrame::Shape::StyledPanel @@ -258,7 +264,7 @@ QFrame::Shadow::Raised - + @@ -279,17 +285,19 @@ + + + 0 + 0 + + QFrame::Shape::StyledPanel QFrame::Shadow::Raised - - - - - + diff --git a/src/digest/ui/multimodelanalysis_ui.py b/src/digest/ui/multimodelanalysis_ui.py index 54aa6d6..9f4b359 100644 --- a/src/digest/ui/multimodelanalysis_ui.py +++ b/src/digest/ui/multimodelanalysis_ui.py @@ -3,7 +3,7 @@ ################################################################################ ## Form generated from reading UI file 'multimodelanalysis.ui' ## -## Created by: Qt User Interface Compiler version 6.8.0 +## Created by: Qt User Interface Compiler version 6.8.1 ## ## WARNING! All changes made in this file will be lost when recompiling UI file! ################################################################################ @@ -26,7 +26,7 @@ class Ui_multiModelAnalysis(object): def setupUi(self, multiModelAnalysis): if not multiModelAnalysis.objectName(): multiModelAnalysis.setObjectName(u"multiModelAnalysis") - multiModelAnalysis.resize(908, 647) + multiModelAnalysis.resize(1085, 866) sizePolicy = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) sizePolicy.setHorizontalStretch(0) sizePolicy.setVerticalStretch(0) @@ -71,7 +71,7 @@ def setupUi(self, multiModelAnalysis): self.modelName.setIndent(5) self.modelName.setTextInteractionFlags(Qt.TextInteractionFlag.LinksAccessibleByMouse|Qt.TextInteractionFlag.TextSelectableByKeyboard|Qt.TextInteractionFlag.TextSelectableByMouse) - self.verticalLayout_17.addWidget(self.modelName, 0, Qt.AlignmentFlag.AlignTop) + self.verticalLayout_17.addWidget(self.modelName) self.summaryTopBannerLayout.addWidget(self.modelNameFrame) @@ -127,7 +127,7 @@ def setupUi(self, multiModelAnalysis): self.scrollArea = QScrollArea(multiModelAnalysis) self.scrollArea.setObjectName(u"scrollArea") - sizePolicy4 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.MinimumExpanding) + sizePolicy4 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.MinimumExpanding) sizePolicy4.setHorizontalStretch(0) sizePolicy4.setVerticalStretch(0) sizePolicy4.setHeightForWidth(self.scrollArea.sizePolicy().hasHeightForWidth()) @@ -138,7 +138,7 @@ def setupUi(self, multiModelAnalysis): self.scrollArea.setWidgetResizable(True) self.scrollAreaWidgetContents = QWidget() self.scrollAreaWidgetContents.setObjectName(u"scrollAreaWidgetContents") - self.scrollAreaWidgetContents.setGeometry(QRect(0, 0, 888, 464)) + self.scrollAreaWidgetContents.setGeometry(QRect(0, 0, 1065, 688)) sizePolicy5 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.MinimumExpanding) sizePolicy5.setHorizontalStretch(0) sizePolicy5.setVerticalStretch(100) @@ -165,6 +165,11 @@ def setupUi(self, multiModelAnalysis): self.frame_2 = QFrame(self.scrollAreaWidgetContents) self.frame_2.setObjectName(u"frame_2") + sizePolicy7 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.Preferred) + sizePolicy7.setHorizontalStretch(0) + sizePolicy7.setVerticalStretch(0) + sizePolicy7.setHeightForWidth(self.frame_2.sizePolicy().hasHeightForWidth()) + self.frame_2.setSizePolicy(sizePolicy7) self.frame_2.setFrameShape(QFrame.Shape.StyledPanel) self.frame_2.setFrameShadow(QFrame.Shadow.Raised) self.horizontalLayout_2 = QHBoxLayout(self.frame_2) @@ -177,29 +182,29 @@ def setupUi(self, multiModelAnalysis): self.verticalLayout_3.setObjectName(u"verticalLayout_3") self.opHistogramChart = HistogramChartWidget(self.combinedHistogramFrame) self.opHistogramChart.setObjectName(u"opHistogramChart") - sizePolicy7 = QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Minimum) - sizePolicy7.setHorizontalStretch(0) - sizePolicy7.setVerticalStretch(0) - sizePolicy7.setHeightForWidth(self.opHistogramChart.sizePolicy().hasHeightForWidth()) - self.opHistogramChart.setSizePolicy(sizePolicy7) + sizePolicy8 = QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Minimum) + sizePolicy8.setHorizontalStretch(0) + sizePolicy8.setVerticalStretch(0) + sizePolicy8.setHeightForWidth(self.opHistogramChart.sizePolicy().hasHeightForWidth()) + self.opHistogramChart.setSizePolicy(sizePolicy8) self.opHistogramChart.setMinimumSize(QSize(500, 300)) - self.verticalLayout_3.addWidget(self.opHistogramChart, 0, Qt.AlignmentFlag.AlignTop) + self.verticalLayout_3.addWidget(self.opHistogramChart) self.horizontalLayout_2.addWidget(self.combinedHistogramFrame) self.stackedHistogramFrame = QFrame(self.frame_2) self.stackedHistogramFrame.setObjectName(u"stackedHistogramFrame") + sizePolicy9 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred) + sizePolicy9.setHorizontalStretch(0) + sizePolicy9.setVerticalStretch(0) + sizePolicy9.setHeightForWidth(self.stackedHistogramFrame.sizePolicy().hasHeightForWidth()) + self.stackedHistogramFrame.setSizePolicy(sizePolicy9) self.stackedHistogramFrame.setFrameShape(QFrame.Shape.StyledPanel) self.stackedHistogramFrame.setFrameShadow(QFrame.Shadow.Raised) self.verticalLayout_5 = QVBoxLayout(self.stackedHistogramFrame) self.verticalLayout_5.setObjectName(u"verticalLayout_5") - self.verticalLayout_4 = QVBoxLayout() - self.verticalLayout_4.setObjectName(u"verticalLayout_4") - - self.verticalLayout_5.addLayout(self.verticalLayout_4) - self.horizontalLayout_2.addWidget(self.stackedHistogramFrame) diff --git a/src/digest/ui/multimodelselection_page.ui b/src/digest/ui/multimodelselection_page.ui index c5d12f8..034ed88 100644 --- a/src/digest/ui/multimodelselection_page.ui +++ b/src/digest/ui/multimodelselection_page.ui @@ -52,7 +52,7 @@ - + @@ -68,7 +68,7 @@ - + false @@ -128,7 +128,7 @@ - Warning: The chosen folder contains more than MAX_ONNX_MODELS + Warning 2 @@ -141,7 +141,7 @@ - + 0 @@ -161,25 +161,77 @@ - Select All + All + + + true - + + + + 0 + 0 + + + + + 0 + 33 + + + + false + - 0 selected models - - - true + ONNX + + + + 0 + 0 + + + + + 0 + 33 + + + + false + + + + + + Reports + + + + + + + 0 + 0 + + + + + 550 + 0 + + @@ -191,8 +243,34 @@ + + + + Qt::Orientation::Horizontal + + + + 40 + 20 + + + + + + + + + + + 0 selected models + + + true + + + diff --git a/src/digest/ui/multimodelselection_page_ui.py b/src/digest/ui/multimodelselection_page_ui.py index e6acb66..0e25178 100644 --- a/src/digest/ui/multimodelselection_page_ui.py +++ b/src/digest/ui/multimodelselection_page_ui.py @@ -3,7 +3,7 @@ ################################################################################ ## Form generated from reading UI file 'multimodelselection_page.ui' ## -## Created by: Qt User Interface Compiler version 6.8.0 +## Created by: Qt User Interface Compiler version 6.8.1 ## ## WARNING! All changes made in this file will be lost when recompiling UI file! ################################################################################ @@ -15,9 +15,9 @@ QFont, QFontDatabase, QGradient, QIcon, QImage, QKeySequence, QLinearGradient, QPainter, QPalette, QPixmap, QRadialGradient, QTransform) -from PySide6.QtWidgets import (QAbstractItemView, QApplication, QCheckBox, QHBoxLayout, - QLabel, QListView, QListWidget, QListWidgetItem, - QPushButton, QSizePolicy, QSpacerItem, QVBoxLayout, +from PySide6.QtWidgets import (QAbstractItemView, QApplication, QHBoxLayout, QLabel, + QListView, QListWidget, QListWidgetItem, QPushButton, + QRadioButton, QSizePolicy, QSpacerItem, QVBoxLayout, QWidget) class Ui_MultiModelSelection(object): @@ -59,7 +59,7 @@ def setupUi(self, MultiModelSelection): self.selectFolderBtn.setSizePolicy(sizePolicy) self.selectFolderBtn.setStyleSheet(u"") - self.horizontalLayout_2.addWidget(self.selectFolderBtn, 0, Qt.AlignmentFlag.AlignLeft|Qt.AlignmentFlag.AlignVCenter) + self.horizontalLayout_2.addWidget(self.selectFolderBtn) self.openAnalysisBtn = QPushButton(MultiModelSelection) self.openAnalysisBtn.setObjectName(u"openAnalysisBtn") @@ -68,7 +68,7 @@ def setupUi(self, MultiModelSelection): self.openAnalysisBtn.setSizePolicy(sizePolicy) self.openAnalysisBtn.setStyleSheet(u"") - self.horizontalLayout_2.addWidget(self.openAnalysisBtn, 0, Qt.AlignmentFlag.AlignLeft|Qt.AlignmentFlag.AlignVCenter) + self.horizontalLayout_2.addWidget(self.openAnalysisBtn) self.horizontalSpacer = QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum) @@ -104,33 +104,64 @@ def setupUi(self, MultiModelSelection): self.horizontalLayout_3 = QHBoxLayout() self.horizontalLayout_3.setObjectName(u"horizontalLayout_3") - self.selectAllBox = QCheckBox(MultiModelSelection) - self.selectAllBox.setObjectName(u"selectAllBox") - sizePolicy.setHeightForWidth(self.selectAllBox.sizePolicy().hasHeightForWidth()) - self.selectAllBox.setSizePolicy(sizePolicy) - self.selectAllBox.setMinimumSize(QSize(0, 33)) - self.selectAllBox.setAutoFillBackground(False) - self.selectAllBox.setStyleSheet(u"") - - self.horizontalLayout_3.addWidget(self.selectAllBox) - - self.numSelectedLabel = QLabel(MultiModelSelection) - self.numSelectedLabel.setObjectName(u"numSelectedLabel") - self.numSelectedLabel.setStyleSheet(u"") - self.numSelectedLabel.setWordWrap(True) - - self.horizontalLayout_3.addWidget(self.numSelectedLabel) + self.radioAll = QRadioButton(MultiModelSelection) + self.radioAll.setObjectName(u"radioAll") + sizePolicy.setHeightForWidth(self.radioAll.sizePolicy().hasHeightForWidth()) + self.radioAll.setSizePolicy(sizePolicy) + self.radioAll.setMinimumSize(QSize(0, 33)) + self.radioAll.setAutoFillBackground(False) + self.radioAll.setStyleSheet(u"") + self.radioAll.setChecked(True) + + self.horizontalLayout_3.addWidget(self.radioAll) + + self.radioONNX = QRadioButton(MultiModelSelection) + self.radioONNX.setObjectName(u"radioONNX") + sizePolicy.setHeightForWidth(self.radioONNX.sizePolicy().hasHeightForWidth()) + self.radioONNX.setSizePolicy(sizePolicy) + self.radioONNX.setMinimumSize(QSize(0, 33)) + self.radioONNX.setAutoFillBackground(False) + self.radioONNX.setStyleSheet(u"") + + self.horizontalLayout_3.addWidget(self.radioONNX) + + self.radioReports = QRadioButton(MultiModelSelection) + self.radioReports.setObjectName(u"radioReports") + sizePolicy.setHeightForWidth(self.radioReports.sizePolicy().hasHeightForWidth()) + self.radioReports.setSizePolicy(sizePolicy) + self.radioReports.setMinimumSize(QSize(0, 33)) + self.radioReports.setAutoFillBackground(False) + self.radioReports.setStyleSheet(u"") + + self.horizontalLayout_3.addWidget(self.radioReports) self.duplicateLabel = QLabel(MultiModelSelection) self.duplicateLabel.setObjectName(u"duplicateLabel") + sizePolicy2 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) + sizePolicy2.setHorizontalStretch(0) + sizePolicy2.setVerticalStretch(0) + sizePolicy2.setHeightForWidth(self.duplicateLabel.sizePolicy().hasHeightForWidth()) + self.duplicateLabel.setSizePolicy(sizePolicy2) + self.duplicateLabel.setMinimumSize(QSize(550, 0)) self.duplicateLabel.setStyleSheet(u"") self.duplicateLabel.setWordWrap(True) self.horizontalLayout_3.addWidget(self.duplicateLabel) + self.horizontalSpacer_2 = QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum) + + self.horizontalLayout_3.addItem(self.horizontalSpacer_2) + self.verticalLayout.addLayout(self.horizontalLayout_3) + self.numSelectedLabel = QLabel(MultiModelSelection) + self.numSelectedLabel.setObjectName(u"numSelectedLabel") + self.numSelectedLabel.setStyleSheet(u"") + self.numSelectedLabel.setWordWrap(True) + + self.verticalLayout.addWidget(self.numSelectedLabel) + self.columnsLayout = QHBoxLayout() self.columnsLayout.setObjectName(u"columnsLayout") self.leftColumnLayout = QVBoxLayout() @@ -151,11 +182,11 @@ def setupUi(self, MultiModelSelection): self.rightColumnLayout.setObjectName(u"rightColumnLayout") self.duplicateListWidget = QListWidget(MultiModelSelection) self.duplicateListWidget.setObjectName(u"duplicateListWidget") - sizePolicy2 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) - sizePolicy2.setHorizontalStretch(0) - sizePolicy2.setVerticalStretch(0) - sizePolicy2.setHeightForWidth(self.duplicateListWidget.sizePolicy().hasHeightForWidth()) - self.duplicateListWidget.setSizePolicy(sizePolicy2) + sizePolicy3 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + sizePolicy3.setHorizontalStretch(0) + sizePolicy3.setVerticalStretch(0) + sizePolicy3.setHeightForWidth(self.duplicateListWidget.sizePolicy().hasHeightForWidth()) + self.duplicateListWidget.setSizePolicy(sizePolicy3) self.duplicateListWidget.setStyleSheet(u"") self.duplicateListWidget.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers) self.duplicateListWidget.setSelectionMode(QAbstractItemView.SelectionMode.MultiSelection) @@ -184,9 +215,11 @@ def retranslateUi(self, MultiModelSelection): self.selectFolderBtn.setText(QCoreApplication.translate("MultiModelSelection", u"Select Folder", None)) self.openAnalysisBtn.setText(QCoreApplication.translate("MultiModelSelection", u"Open Analysis", None)) self.infoLabel.setText("") - self.warningLabel.setText(QCoreApplication.translate("MultiModelSelection", u"Warning: The chosen folder contains more than MAX_ONNX_MODELS", None)) - self.selectAllBox.setText(QCoreApplication.translate("MultiModelSelection", u"Select All", None)) - self.numSelectedLabel.setText(QCoreApplication.translate("MultiModelSelection", u"0 selected models", None)) + self.warningLabel.setText(QCoreApplication.translate("MultiModelSelection", u"Warning", None)) + self.radioAll.setText(QCoreApplication.translate("MultiModelSelection", u"All", None)) + self.radioONNX.setText(QCoreApplication.translate("MultiModelSelection", u"ONNX", None)) + self.radioReports.setText(QCoreApplication.translate("MultiModelSelection", u"Reports", None)) self.duplicateLabel.setText(QCoreApplication.translate("MultiModelSelection", u"The following models were found to be duplicates and have been deselected from the list on the left.", None)) + self.numSelectedLabel.setText(QCoreApplication.translate("MultiModelSelection", u"0 selected models", None)) # retranslateUi diff --git a/src/digest/ui/nodessummary_ui.py b/src/digest/ui/nodessummary_ui.py index 7efc69d..e0e400c 100644 --- a/src/digest/ui/nodessummary_ui.py +++ b/src/digest/ui/nodessummary_ui.py @@ -3,7 +3,7 @@ ################################################################################ ## Form generated from reading UI file 'nodessummary.ui' ## -## Created by: Qt User Interface Compiler version 6.8.0 +## Created by: Qt User Interface Compiler version 6.8.1 ## ## WARNING! All changes made in this file will be lost when recompiling UI file! ################################################################################ diff --git a/src/utils/onnx_utils.py b/src/utils/onnx_utils.py index d8a6894..4d4b293 100644 --- a/src/utils/onnx_utils.py +++ b/src/utils/onnx_utils.py @@ -1,95 +1,19 @@ # Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. import os -import csv import tempfile -from uuid import uuid4 -from collections import Counter, OrderedDict, defaultdict -from typing import List, Dict, Optional, Any, Tuple, Union, cast -from datetime import datetime +from collections import Counter +from typing import List, Optional, Tuple, Union import numpy as np import onnx import onnxruntime as ort -from prettytable import PrettyTable - - -class NodeParsingException(Exception): - pass - - -# The classes are for type aliasing. Once python 3.10 is the minimum we can switch to TypeAlias -class NodeShapeCounts(defaultdict[str, Counter]): - def __init__(self): - super().__init__(Counter) # Initialize with the Counter factory - - -class NodeTypeCounts(Dict[str, int]): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - -class TensorInfo: - "Used to store node input and output tensor information" - - def __init__(self) -> None: - self.dtype: Optional[str] = None - self.dtype_bytes: Optional[int] = None - self.size_kbytes: Optional[float] = None - self.shape: List[Union[int, str]] = [] - - -class TensorData(OrderedDict[str, TensorInfo]): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - -class NodeInfo: - def __init__(self) -> None: - self.flops: Optional[int] = None - self.parameters: int = 0 - self.node_type: Optional[str] = None - self.attributes: OrderedDict[str, Any] = OrderedDict() - # We use an ordered dictionary because the order in which - # the inputs and outputs are listed in the node matter. - self.inputs = TensorData() - self.outputs = TensorData() - - def get_input(self, index: int) -> TensorInfo: - return list(self.inputs.values())[index] - - def get_output(self, index: int) -> TensorInfo: - return list(self.outputs.values())[index] - - def __str__(self): - """Provides a human-readable string representation of NodeInfo.""" - output = [ - f"Node Type: {self.node_type}", - f"FLOPs: {self.flops if self.flops is not None else 'N/A'}", - f"Parameters: {self.parameters}", - ] - - if self.attributes: - output.append("Attributes:") - for key, value in self.attributes.items(): - output.append(f" - {key}: {value}") - - if self.inputs: - output.append("Inputs:") - for name, tensor in self.inputs.items(): - output.append(f" - {name}: {tensor}") - - if self.outputs: - output.append("Outputs:") - for name, tensor in self.outputs.items(): - output.append(f" - {name}: {tensor}") - - return "\n".join(output) - - -# The classes are for type aliasing. Once python 3.10 is the minimum we can switch to TypeAlias -class NodeData(OrderedDict[str, NodeInfo]): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) +from digest.model_class.digest_model import ( + NodeTypeCounts, + NodeData, + NodeShapeCounts, + TensorData, + TensorInfo, +) # Convert tensor type to human-readable string and size in bytes @@ -117,706 +41,6 @@ def tensor_type_to_str_and_size(elem_type) -> Tuple[str, int]: return type_mapping.get(elem_type, ("unknown", 0)) -class DigestOnnxModel: - def __init__( - self, - onnx_model: onnx.ModelProto, - onnx_filepath: Optional[str] = None, - model_name: Optional[str] = None, - save_proto: bool = True, - ) -> None: - # Public members exposed to the API - self.unique_id: str = str(uuid4()) - self.filepath: Optional[str] = onnx_filepath - self.model_proto: Optional[onnx.ModelProto] = onnx_model if save_proto else None - self.model_name: Optional[str] = model_name - self.model_version: Optional[int] = None - self.graph_name: Optional[str] = None - self.producer_name: Optional[str] = None - self.producer_version: Optional[str] = None - self.ir_version: Optional[int] = None - self.opset: Optional[int] = None - self.imports: Dict[str, int] = {} - self.node_type_counts: NodeTypeCounts = NodeTypeCounts() - self.model_flops: Optional[int] = None - self.model_parameters: int = 0 - self.node_type_flops: Dict[str, int] = {} - self.node_type_parameters: Dict[str, int] = {} - self.per_node_info = NodeData() - self.model_inputs = TensorData() - self.model_outputs = TensorData() - - # Private members not intended to be exposed - self.input_tensors_: Dict[str, onnx.ValueInfoProto] = {} - self.output_tensors_: Dict[str, onnx.ValueInfoProto] = {} - self.value_tensors_: Dict[str, onnx.ValueInfoProto] = {} - self.init_tensors_: Dict[str, onnx.TensorProto] = {} - - self.update_state(onnx_model) - - def update_state(self, model_proto: onnx.ModelProto) -> None: - self.model_version = model_proto.model_version - self.graph_name = model_proto.graph.name - self.producer_name = model_proto.producer_name - self.producer_version = model_proto.producer_version - self.ir_version = model_proto.ir_version - self.opset = get_opset(model_proto) - self.imports = { - import_.domain: import_.version for import_ in model_proto.opset_import - } - - self.model_inputs = get_model_input_shapes_types(model_proto) - self.model_outputs = get_model_output_shapes_types(model_proto) - - self.node_type_counts = get_node_type_counts(model_proto) - self.parse_model_nodes(model_proto) - - def get_node_tensor_info_( - self, onnx_node: onnx.NodeProto - ) -> Tuple[TensorData, TensorData]: - """ - This function is set to private because it is not intended to be used - outside of the DigestOnnxModel class. - """ - - input_tensor_info = TensorData() - for node_input in onnx_node.input: - input_tensor_info[node_input] = TensorInfo() - if ( - node_input in self.input_tensors_ - or node_input in self.value_tensors_ - or node_input in self.output_tensors_ - ): - tensor = ( - self.input_tensors_.get(node_input) - or self.value_tensors_.get(node_input) - or self.output_tensors_.get(node_input) - ) - if tensor: - for dim in tensor.type.tensor_type.shape.dim: - if dim.HasField("dim_value"): - input_tensor_info[node_input].shape.append(dim.dim_value) - elif dim.HasField("dim_param"): - input_tensor_info[node_input].shape.append(dim.dim_param) - - dtype_str, dtype_bytes = tensor_type_to_str_and_size( - tensor.type.tensor_type.elem_type - ) - elif node_input in self.init_tensors_: - input_tensor_info[node_input].shape.extend( - [dim for dim in self.init_tensors_[node_input].dims] - ) - dtype_str, dtype_bytes = tensor_type_to_str_and_size( - self.init_tensors_[node_input].data_type - ) - else: - dtype_str = None - dtype_bytes = None - - input_tensor_info[node_input].dtype = dtype_str - input_tensor_info[node_input].dtype_bytes = dtype_bytes - - if ( - all(isinstance(s, int) for s in input_tensor_info[node_input].shape) - and dtype_bytes - ): - tensor_size = float( - np.prod(np.array(input_tensor_info[node_input].shape)) - ) - input_tensor_info[node_input].size_kbytes = ( - tensor_size * float(dtype_bytes) / 1024.0 - ) - - output_tensor_info = TensorData() - for node_output in onnx_node.output: - output_tensor_info[node_output] = TensorInfo() - if ( - node_output in self.input_tensors_ - or node_output in self.value_tensors_ - or node_output in self.output_tensors_ - ): - tensor = ( - self.input_tensors_.get(node_output) - or self.value_tensors_.get(node_output) - or self.output_tensors_.get(node_output) - ) - if tensor: - output_tensor_info[node_output].shape.extend( - [ - int(dim.dim_value) - for dim in tensor.type.tensor_type.shape.dim - ] - ) - dtype_str, dtype_bytes = tensor_type_to_str_and_size( - tensor.type.tensor_type.elem_type - ) - elif node_output in self.init_tensors_: - output_tensor_info[node_output].shape.extend( - [dim for dim in self.init_tensors_[node_output].dims] - ) - dtype_str, dtype_bytes = tensor_type_to_str_and_size( - self.init_tensors_[node_output].data_type - ) - - else: - dtype_str = None - dtype_bytes = None - - output_tensor_info[node_output].dtype = dtype_str - output_tensor_info[node_output].dtype_bytes = dtype_bytes - - if ( - all(isinstance(s, int) for s in output_tensor_info[node_output].shape) - and dtype_bytes - ): - tensor_size = float( - np.prod(np.array(output_tensor_info[node_output].shape)) - ) - output_tensor_info[node_output].size_kbytes = ( - tensor_size * float(dtype_bytes) / 1024.0 - ) - - return input_tensor_info, output_tensor_info - - def parse_model_nodes(self, onnx_model: onnx.ModelProto) -> None: - """ - Calculate total number of FLOPs found in the onnx model. - FLOP is defined as one floating-point operation. This distinguishes - from multiply-accumulates (MACs) where FLOPs == 2 * MACs. - """ - - # Initialze to zero so we can accumulate. Set to None during the - # model FLOPs calculation if it errors out. - self.model_flops = 0 - - # Check to see if the model inputs have any dynamic shapes - if get_dynamic_input_dims(onnx_model): - self.model_flops = None - - try: - onnx_model, _ = optimize_onnx_model(onnx_model) - - onnx_model = onnx.shape_inference.infer_shapes( - onnx_model, strict_mode=True, data_prop=True - ) - except Exception as e: # pylint: disable=broad-except - print(f"ONNX utils: {str(e)}") - self.model_flops = None - - # If the ONNX model contains one of the following unsupported ops, then this - # function will return None since the FLOP total is expected to be incorrect - unsupported_ops = [ - "Einsum", - "RNN", - "GRU", - "DeformConv", - ] - - if not self.input_tensors_: - self.input_tensors_ = { - tensor.name: tensor for tensor in onnx_model.graph.input - } - - if not self.output_tensors_: - self.output_tensors_ = { - tensor.name: tensor for tensor in onnx_model.graph.output - } - - if not self.value_tensors_: - self.value_tensors_ = { - tensor.name: tensor for tensor in onnx_model.graph.value_info - } - - if not self.init_tensors_: - self.init_tensors_ = { - tensor.name: tensor for tensor in onnx_model.graph.initializer - } - - for node in onnx_model.graph.node: # pylint: disable=E1101 - - node_info = NodeInfo() - - # TODO: I have encountered models containing nodes with no name. It would be a good idea - # to have this type of model info fed back to the user through a warnings section. - if not node.name: - node.name = f"{node.op_type}_{len(self.per_node_info)}" - - node_info.node_type = node.op_type - input_tensor_info, output_tensor_info = self.get_node_tensor_info_(node) - node_info.inputs = input_tensor_info - node_info.outputs = output_tensor_info - - # Check if this node has parameters through the init tensors - for input_name, input_tensor in node_info.inputs.items(): - if input_name in self.init_tensors_: - if all(isinstance(dim, int) for dim in input_tensor.shape): - input_parameters = int(np.prod(np.array(input_tensor.shape))) - node_info.parameters += input_parameters - self.model_parameters += input_parameters - self.node_type_parameters[node.op_type] = ( - self.node_type_parameters.get(node.op_type, 0) - + input_parameters - ) - else: - print(f"Tensor with params has unknown shape: {input_name}") - - for attribute in node.attribute: - node_info.attributes.update(attribute_to_dict(attribute)) - - # if node.name in self.per_node_info: - # print(f"Node name {node.name} is a duplicate.") - - self.per_node_info[node.name] = node_info - - if node.op_type in unsupported_ops: - self.model_flops = None - node_info.flops = None - - try: - - if ( - node.op_type == "MatMul" - or node.op_type == "MatMulInteger" - or node.op_type == "QLinearMatMul" - ): - - input_a = node_info.get_input(0).shape - if node.op_type == "QLinearMatMul": - input_b = node_info.get_input(3).shape - else: - input_b = node_info.get_input(1).shape - - if not all( - isinstance(dim, int) for dim in input_a - ) or not isinstance(input_b[-1], int): - node_info.flops = None - self.model_flops = None - continue - - node_info.flops = int( - 2 * np.prod(np.array(input_a), dtype=np.int64) * input_b[-1] - ) - - elif ( - node.op_type == "Mul" - or node.op_type == "Div" - or node.op_type == "Add" - ): - input_a = node_info.get_input(0).shape - input_b = node_info.get_input(1).shape - - if not all(isinstance(dim, int) for dim in input_a) or not all( - isinstance(dim, int) for dim in input_b - ): - node_info.flops = None - self.model_flops = None - continue - - node_info.flops = int( - np.prod(np.array(input_a), dtype=np.int64) - ) + int(np.prod(np.array(input_b), dtype=np.int64)) - - elif node.op_type == "Gemm" or node.op_type == "QGemm": - x_shape = node_info.get_input(0).shape - if node.op_type == "Gemm": - w_shape = node_info.get_input(1).shape - else: - w_shape = node_info.get_input(3).shape - - if not all(isinstance(dim, int) for dim in x_shape) or not all( - isinstance(dim, int) for dim in w_shape - ): - node_info.flops = None - self.model_flops = None - continue - - mm_dims = [ - ( - x_shape[0] - if not node_info.attributes.get("transA", 0) - else x_shape[1] - ), - ( - x_shape[1] - if not node_info.attributes.get("transA", 0) - else x_shape[0] - ), - ( - w_shape[1] - if not node_info.attributes.get("transB", 0) - else w_shape[0] - ), - ] - - node_info.flops = int( - 2 * np.prod(np.array(mm_dims), dtype=np.int64) - ) - - if len(mm_dims) == 3: # if there is a bias input - bias_shape = node_info.get_input(2).shape - node_info.flops += int(np.prod(np.array(bias_shape))) - - elif ( - node.op_type == "Conv" - or node.op_type == "ConvInteger" - or node.op_type == "QLinearConv" - or node.op_type == "ConvTranspose" - ): - # N, C, d1, ..., dn - x_shape = node_info.get_input(0).shape - - # M, C/group, k1, ..., kn. Note C and M are swapped for ConvTranspose - if node.op_type == "QLinearConv": - w_shape = node_info.get_input(3).shape - else: - w_shape = node_info.get_input(1).shape - - if not all(isinstance(dim, int) for dim in x_shape): - node_info.flops = None - self.model_flops = None - continue - - x_shape_ints = cast(List[int], x_shape) - w_shape_ints = cast(List[int], w_shape) - - has_bias = False # Note, ConvInteger has no bias - if node.op_type == "Conv" and len(node_info.inputs) == 3: - has_bias = True - elif node.op_type == "QLinearConv" and len(node_info.inputs) == 9: - has_bias = True - - num_dims = len(x_shape_ints) - 2 - strides = node_info.attributes.get( - "strides", [1] * num_dims - ) # type: List[int] - dilation = node_info.attributes.get( - "dilations", [1] * num_dims - ) # type: List[int] - kernel_shape = w_shape_ints[2:] - batch_size = x_shape_ints[0] - out_channels = w_shape_ints[0] - out_dims = [batch_size, out_channels] - output_shape = node_info.attributes.get( - "output_shape", [] - ) # type: List[int] - - # If output_shape is given then we do not need to compute it ourselves - # The output_shape attribute does not include batch_size or channels and - # is only valid for ConvTranspose - if output_shape: - out_dims.extend(output_shape) - else: - auto_pad = node_info.attributes.get( - "auto_pad", "NOTSET".encode() - ).decode() - # SAME expects padding so that the output_shape = CEIL(input_shape / stride) - if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": - out_dims.extend( - [x * s for x, s in zip(x_shape_ints[2:], strides)] - ) - else: - # NOTSET means just use pads attribute - if auto_pad == "NOTSET": - pads = node_info.attributes.get( - "pads", [0] * num_dims * 2 - ) - # VALID essentially means no padding - elif auto_pad == "VALID": - pads = [0] * num_dims * 2 - - for i in range(num_dims): - dim_in = x_shape_ints[i + 2] # type: int - - if node.op_type == "ConvTranspose": - out_dim = ( - strides[i] * (dim_in - 1) - + ((kernel_shape[i] - 1) * dilation[i] + 1) - - pads[i] - - pads[i + num_dims] - ) - else: - out_dim = ( - dim_in - + pads[i] - + pads[i + num_dims] - - dilation[i] * (kernel_shape[i] - 1) - - 1 - ) // strides[i] + 1 - - out_dims.append(out_dim) - - kernel_flops = int( - np.prod(np.array(kernel_shape)) * w_shape_ints[1] - ) - output_points = int(np.prod(np.array(out_dims))) - bias_ops = output_points if has_bias else int(0) - node_info.flops = 2 * kernel_flops * output_points + bias_ops - - elif node.op_type == "LSTM" or node.op_type == "DynamicQuantizeLSTM": - - x_shape = node_info.get_input( - 0 - ).shape # seq_length, batch_size, input_dim - - if not all(isinstance(dim, int) for dim in x_shape): - node_info.flops = None - self.model_flops = None - continue - - x_shape_ints = cast(List[int], x_shape) - hidden_size = node_info.attributes["hidden_size"] - direction = ( - 2 - if node_info.attributes.get("direction") - == "bidirectional".encode() - else 1 - ) - - has_bias = True if len(node_info.inputs) >= 4 else False - if has_bias: - bias_shape = node_info.get_input(3).shape - if isinstance(bias_shape[1], int): - bias_ops = bias_shape[1] - else: - bias_ops = 0 - else: - bias_ops = 0 - # seq_length, batch_size, input_dim = x_shape - if not isinstance(bias_ops, int): - bias_ops = int(0) - num_gates = int(4) - gate_input_flops = int(2 * x_shape_ints[2] * hidden_size) - gate_hid_flops = int(2 * hidden_size * hidden_size) - unit_flops = ( - num_gates * (gate_input_flops + gate_hid_flops) + bias_ops - ) - node_info.flops = ( - x_shape_ints[1] * x_shape_ints[0] * direction * unit_flops - ) - # In this case we just hit an op that doesn't have FLOPs - else: - node_info.flops = None - - except IndexError as err: - print(f"Error parsing node {node.name}: {err}") - node_info.flops = None - self.model_flops = None - continue - - # Update the model level flops count - if node_info.flops is not None and self.model_flops is not None: - self.model_flops += node_info.flops - - # Update the node type flops count - self.node_type_flops[node.op_type] = ( - self.node_type_flops.get(node.op_type, 0) + node_info.flops - ) - - def save_txt_report(self, filepath: str) -> None: - - parent_dir = os.path.dirname(os.path.abspath(filepath)) - if not os.path.exists(parent_dir): - raise FileNotFoundError(f"Directory {parent_dir} does not exist.") - - report_date = datetime.now().strftime("%B %d, %Y") - - with open(filepath, "w", encoding="utf-8") as f_p: - f_p.write(f"Report created on {report_date}\n") - if self.filepath: - f_p.write(f"ONNX file: {self.filepath}\n") - f_p.write(f"Name of the model: {self.model_name}\n") - f_p.write(f"Model version: {self.model_version}\n") - f_p.write(f"Name of the graph: {self.graph_name}\n") - f_p.write(f"Producer: {self.producer_name} {self.producer_version}\n") - f_p.write(f"Ir version: {self.ir_version}\n") - f_p.write(f"Opset: {self.opset}\n\n") - f_p.write("Import list\n") - for name, version in self.imports.items(): - f_p.write(f"\t{name}: {version}\n") - - f_p.write("\n") - f_p.write(f"Total graph nodes: {sum(self.node_type_counts.values())}\n") - f_p.write(f"Number of parameters: {self.model_parameters}\n") - if self.model_flops: - f_p.write(f"Number of FLOPs: {self.model_flops}\n") - f_p.write("\n") - - table_op_intensity = PrettyTable() - table_op_intensity.field_names = ["Operation", "FLOPs", "Intensity (%)"] - for op_type, count in self.node_type_flops.items(): - if count > 0: - table_op_intensity.add_row( - [ - op_type, - count, - 100.0 * float(count) / float(self.model_flops), - ] - ) - - f_p.write("Op intensity:\n") - f_p.write(table_op_intensity.get_string()) - f_p.write("\n\n") - - node_counts_table = PrettyTable() - node_counts_table.field_names = ["Node", "Occurrences"] - for op, count in self.node_type_counts.items(): - node_counts_table.add_row([op, count]) - f_p.write("Nodes and their occurrences:\n") - f_p.write(node_counts_table.get_string()) - f_p.write("\n\n") - - input_table = PrettyTable() - input_table.field_names = [ - "Input Name", - "Shape", - "Type", - "Tensor Size (KB)", - ] - for input_name, input_details in self.model_inputs.items(): - if input_details.size_kbytes: - kbytes = f"{input_details.size_kbytes:.2f}" - else: - kbytes = "" - - input_table.add_row( - [ - input_name, - input_details.shape, - input_details.dtype, - kbytes, - ] - ) - f_p.write("Input Tensor(s) Information:\n") - f_p.write(input_table.get_string()) - f_p.write("\n\n") - - output_table = PrettyTable() - output_table.field_names = [ - "Output Name", - "Shape", - "Type", - "Tensor Size (KB)", - ] - for output_name, output_details in self.model_outputs.items(): - if output_details.size_kbytes: - kbytes = f"{output_details.size_kbytes:.2f}" - else: - kbytes = "" - - output_table.add_row( - [ - output_name, - output_details.shape, - output_details.dtype, - kbytes, - ] - ) - f_p.write("Output Tensor(s) Information:\n") - f_p.write(output_table.get_string()) - f_p.write("\n\n") - - def save_nodes_csv_report(self, filepath: str) -> None: - save_nodes_csv_report(self.per_node_info, filepath) - - def get_node_type_counts(self) -> Union[NodeTypeCounts, None]: - if not self.node_type_counts and self.model_proto: - self.node_type_counts = get_node_type_counts(self.model_proto) - return self.node_type_counts if self.node_type_counts else None - - def get_node_shape_counts(self) -> NodeShapeCounts: - tensor_shape_counter = NodeShapeCounts() - for _, info in self.per_node_info.items(): - shape_hash = tuple([tuple(v.shape) for _, v in info.inputs.items()]) - if info.node_type: - tensor_shape_counter[info.node_type][shape_hash] += 1 - return tensor_shape_counter - - -def save_nodes_csv_report(node_data: NodeData, filepath: str) -> None: - - parent_dir = os.path.dirname(os.path.abspath(filepath)) - if not os.path.exists(parent_dir): - raise FileNotFoundError(f"Directory {parent_dir} does not exist.") - - flattened_data = [] - fieldnames = ["Node Name", "Node Type", "Parameters", "FLOPs", "Attributes"] - input_fieldnames = [] - output_fieldnames = [] - for name, node_info in node_data.items(): - row = OrderedDict() - row["Node Name"] = name - row["Node Type"] = str(node_info.node_type) - row["Parameters"] = str(node_info.parameters) - row["FLOPs"] = str(node_info.flops) - if node_info.attributes: - row["Attributes"] = str({k: v for k, v in node_info.attributes.items()}) - else: - row["Attributes"] = "" - - for i, (input_name, input_info) in enumerate(node_info.inputs.items()): - column_name = f"Input{i+1} (Shape, Dtype, Size (kB))" - row[column_name] = ( - f"{input_name} ({input_info.shape}, {input_info.dtype}, {input_info.size_kbytes})" - ) - - # Dynamically add input column names to fieldnames if not already present - if column_name not in input_fieldnames: - input_fieldnames.append(column_name) - - for i, (output_name, output_info) in enumerate(node_info.outputs.items()): - column_name = f"Output{i+1} (Shape, Dtype, Size (kB))" - row[column_name] = ( - f"{output_name} ({output_info.shape}, " - f"{output_info.dtype}, {output_info.size_kbytes})" - ) - - # Dynamically add input column names to fieldnames if not already present - if column_name not in output_fieldnames: - output_fieldnames.append(column_name) - - flattened_data.append(row) - - fieldnames = fieldnames + input_fieldnames + output_fieldnames - with open(filepath, "w", encoding="utf-8", newline="") as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator="\n") - writer.writeheader() - writer.writerows(flattened_data) - - -def save_node_type_counts_csv_report(node_data: NodeTypeCounts, filepath: str) -> None: - - parent_dir = os.path.dirname(os.path.abspath(filepath)) - if not os.path.exists(parent_dir): - raise FileNotFoundError(f"Directory {parent_dir} does not exist.") - - header = ["Node Type", "Count"] - - with open(filepath, "w", encoding="utf-8", newline="") as csvfile: - writer = csv.writer(csvfile, lineterminator="\n") - writer.writerow(header) - for node_type, node_count in node_data.items(): - writer.writerow([node_type, node_count]) - - -def save_node_shape_counts_csv_report( - node_data: NodeShapeCounts, filepath: str -) -> None: - - parent_dir = os.path.dirname(os.path.abspath(filepath)) - if not os.path.exists(parent_dir): - raise FileNotFoundError(f"Directory {parent_dir} does not exist.") - - header = ["Node Type", "Input Tensors Shapes", "Count"] - - with open(filepath, "w", encoding="utf-8", newline="") as csvfile: - writer = csv.writer(csvfile, dialect="excel", lineterminator="\n") - writer.writerow(header) - for node_type, node_info in node_data.items(): - info_iter = iter(node_info.items()) - for shape, count in info_iter: - writer.writerow([node_type, shape, count]) - - def load_onnx(onnx_path: str, load_external_data: bool = True) -> onnx.ModelProto: if os.path.exists(onnx_path): return onnx.load(onnx_path, load_external_data=load_external_data) diff --git a/test/resnet18_reports/resnet18_heatmap.png b/test/resnet18_reports/resnet18_heatmap.png new file mode 100644 index 0000000..1fb614e Binary files /dev/null and b/test/resnet18_reports/resnet18_heatmap.png differ diff --git a/test/resnet18_reports/resnet18_histogram.png b/test/resnet18_reports/resnet18_histogram.png new file mode 100644 index 0000000..eb13e01 Binary files /dev/null and b/test/resnet18_reports/resnet18_histogram.png differ diff --git a/test/resnet18_reports/resnet18_node_type_counts.csv b/test/resnet18_reports/resnet18_node_type_counts.csv new file mode 100644 index 0000000..29504ba --- /dev/null +++ b/test/resnet18_reports/resnet18_node_type_counts.csv @@ -0,0 +1,8 @@ +Node Type,Count +Conv,20 +Relu,17 +Add,8 +MaxPool,1 +GlobalAveragePool,1 +Flatten,1 +Gemm,1 diff --git a/test/resnet18_test_nodes.csv b/test/resnet18_reports/resnet18_nodes.csv similarity index 100% rename from test/resnet18_test_nodes.csv rename to test/resnet18_reports/resnet18_nodes.csv diff --git a/test/resnet18_test_summary.txt b/test/resnet18_reports/resnet18_report.txt similarity index 86% rename from test/resnet18_test_summary.txt rename to test/resnet18_reports/resnet18_report.txt index a5b4cfb..fdda0bf 100644 --- a/test/resnet18_test_summary.txt +++ b/test/resnet18_reports/resnet18_report.txt @@ -1,5 +1,5 @@ -Report created on June 02, 2024 -ONNX file: resnet18.onnx +Report created on December 06, 2024 +ONNX file: C:\Users\pcolange\Projects\digestai\test\resnet18.onnx Name of the model: resnet18 Model version: 0 Name of the graph: main_graph @@ -9,6 +9,13 @@ Opset: 17 Import list : 17 + ai.onnx.ml: 5 + ai.onnx.preview.training: 1 + ai.onnx.training: 1 + com.microsoft: 1 + com.microsoft.experimental: 1 + com.microsoft.nchwc: 1 + org.pytorch.aten: 1 Total graph nodes: 49 Number of parameters: 11684712 diff --git a/test/resnet18_reports/resnet18_report.yaml b/test/resnet18_reports/resnet18_report.yaml new file mode 100644 index 0000000..9df22be --- /dev/null +++ b/test/resnet18_reports/resnet18_report.yaml @@ -0,0 +1,56 @@ +report_date: December 06, 2024 +model_file: C:\Users\pcolange\Projects\digestai\test\resnet18.onnx +model_type: onnx +model_name: resnet18 +model_version: 0 +graph_name: main_graph +producer_name: pytorch +producer_version: 2.1.0 +ir_version: 8 +opset: 17 +import_list: + ? '' + : 17 + ai.onnx.ml: 5 + ai.onnx.preview.training: 1 + ai.onnx.training: 1 + com.microsoft: 1 + com.microsoft.experimental: 1 + com.microsoft.nchwc: 1 + org.pytorch.aten: 1 +graph_nodes: 49 +parameters: 11684712 +flops: 3632136680 +node_type_counts: + Conv: 20 + Relu: 17 + Add: 8 + MaxPool: 1 + GlobalAveragePool: 1 + Flatten: 1 + Gemm: 1 +node_type_flops: + Conv: 3629606400 + Add: 1505280 + Gemm: 1025000 +node_type_parameters: + Conv: 11171712 + Gemm: 513000 +input_tensors: + input.1: + dtype: float32 + dtype_bytes: 4 + size_kbytes: 588.0 + shape: + - 1 + - 3 + - 224 + - 224 +output_tensors: + '191': + dtype: float32 + dtype_bytes: 4 + size_kbytes: 3.90625 + shape: + - 1 + - 1000 diff --git a/test/test_gui.py b/test/test_gui.py index 0e1d351..59fbb8f 100644 --- a/test/test_gui.py +++ b/test/test_gui.py @@ -8,60 +8,97 @@ # pylint: disable=no-name-in-module from PySide6.QtTest import QTest -from PySide6.QtCore import Qt, QDeadlineTimer +from PySide6.QtCore import Qt from PySide6.QtWidgets import QApplication import digest.main from digest.node_summary import NodeSummary -ONNX_BASENAME = "resnet18" -TEST_DIR = os.path.abspath(os.path.dirname(__file__)) -ONNX_FILEPATH = os.path.normpath(os.path.join(TEST_DIR, f"{ONNX_BASENAME}.onnx")) - class DigestGuiTest(unittest.TestCase): + MODEL_BASENAME = "resnet18" + TEST_DIR = os.path.abspath(os.path.dirname(__file__)) + ONNX_FILEPATH = os.path.normpath(os.path.join(TEST_DIR, f"{MODEL_BASENAME}.onnx")) + YAML_FILEPATH = os.path.normpath( + os.path.join( + TEST_DIR, f"{MODEL_BASENAME}_reports", f"{MODEL_BASENAME}_report.yaml" + ) + ) @classmethod def setUpClass(cls): cls.app = QApplication(sys.argv) + return super().setUpClass() + + @classmethod + def tearDownClass(cls): + if isinstance(cls.app, QApplication): + cls.app.closeAllWindows() + cls.app = None def setUp(self): self.digest_app = digest.main.DigestApp() self.digest_app.show() def tearDown(self): - self.wait_all_threads() self.digest_app.close() - def wait_all_threads(self): + def wait_all_threads(self, timeout=10000) -> bool: + all_threads = list(self.digest_app.model_nodes_stats_thread.values()) + list( + self.digest_app.model_similarity_thread.values() + ) - for thread in self.digest_app.model_nodes_stats_thread.values(): - thread.wait(deadline=QDeadlineTimer.Forever) + for thread in all_threads: + thread.wait(timeout) - for thread in self.digest_app.model_similarity_thread.values(): - thread.wait(deadline=QDeadlineTimer.Forever) + # Return True if all threads finished, False if timed out + return all(thread.isFinished() for thread in all_threads) def test_open_valid_onnx(self): with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: mock_dialog.return_value = ( - ONNX_FILEPATH, + self.ONNX_FILEPATH, + "", + ) + + num_tabs_prior = self.digest_app.ui.tabWidget.count() + + QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton) + + self.assertTrue(self.wait_all_threads()) + + self.assertTrue( + self.digest_app.ui.tabWidget.count() == num_tabs_prior + 1 + ) # Check if a tab was added + + self.digest_app.closeTab(num_tabs_prior) + + def test_open_valid_yaml(self): + with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: + mock_dialog.return_value = ( + self.YAML_FILEPATH, "", ) + num_tabs_prior = self.digest_app.ui.tabWidget.count() + QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton) - self.wait_all_threads() + self.assertTrue(self.wait_all_threads()) self.assertTrue( - self.digest_app.ui.tabWidget.count() > 0 + self.digest_app.ui.tabWidget.count() == num_tabs_prior + 1 ) # Check if a tab was added + self.digest_app.closeTab(num_tabs_prior) + def test_open_invalid_file(self): with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: mock_dialog.return_value = ("invalid_file.txt", "") + num_tabs_prior = self.digest_app.ui.tabWidget.count() QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton) - self.wait_all_threads() - self.assertEqual(self.digest_app.ui.tabWidget.count(), 0) + self.assertTrue(self.wait_all_threads()) + self.assertEqual(self.digest_app.ui.tabWidget.count(), num_tabs_prior) def test_save_reports(self): with patch( @@ -70,7 +107,7 @@ def test_save_reports(self): "PySide6.QtWidgets.QFileDialog.getExistingDirectory" ) as mock_save_dialog: - mock_open_dialog.return_value = (ONNX_FILEPATH, "") + mock_open_dialog.return_value = (self.ONNX_FILEPATH, "") with tempfile.TemporaryDirectory() as tmpdirname: mock_save_dialog.return_value = tmpdirname @@ -79,45 +116,57 @@ def test_save_reports(self): Qt.MouseButton.LeftButton, ) - self.wait_all_threads() + self.assertTrue(self.wait_all_threads()) - # This is a slight hack but the issue is that model similarity takes - # a bit longer to complete and we must have it done before the save - # button is enabled guaranteeing all the artifacts are saved. - # wait_all_threads() above doesn't seem to work. The only thing that - # does is just waiting 5 seconds. - QTest.qWait(5000) + self.assertTrue( + self.digest_app.ui.saveBtn.isEnabled(), "Save button is disabled!" + ) QTest.mouseClick(self.digest_app.ui.saveBtn, Qt.MouseButton.LeftButton) mock_save_dialog.assert_called_once() - result_basepath = os.path.join(tmpdirname, f"{ONNX_BASENAME}_reports") + result_basepath = os.path.join( + tmpdirname, f"{self.MODEL_BASENAME}_reports" + ) # Text report test - txt_report_filepath = os.path.join( - result_basepath, f"{ONNX_BASENAME}_report.txt" + text_report_filepath = os.path.join( + result_basepath, f"{self.MODEL_BASENAME}_report.txt" ) - self.assertTrue(os.path.isfile(txt_report_filepath)) + self.assertTrue( + os.path.isfile(text_report_filepath), + f"{text_report_filepath} not found!", + ) + + # YAML report test + yaml_report_filepath = os.path.join( + result_basepath, f"{self.MODEL_BASENAME}_report.yaml" + ) + self.assertTrue(os.path.isfile(yaml_report_filepath)) # Nodes test nodes_csv_report_filepath = os.path.join( - result_basepath, f"{ONNX_BASENAME}_nodes.csv" + result_basepath, f"{self.MODEL_BASENAME}_nodes.csv" ) self.assertTrue(os.path.isfile(nodes_csv_report_filepath)) # Histogram test histogram_filepath = os.path.join( - result_basepath, f"{ONNX_BASENAME}_histogram.png" + result_basepath, f"{self.MODEL_BASENAME}_histogram.png" ) self.assertTrue(os.path.isfile(histogram_filepath)) # Heatmap test heatmap_filepath = os.path.join( - result_basepath, f"{ONNX_BASENAME}_heatmap.png" + result_basepath, f"{self.MODEL_BASENAME}_heatmap.png" ) self.assertTrue(os.path.isfile(heatmap_filepath)) + num_tabs = self.digest_app.ui.tabWidget.count() + self.assertTrue(num_tabs == 1) + self.digest_app.closeTab(0) + def test_save_tables(self): with patch( "PySide6.QtWidgets.QFileDialog.getOpenFileName" @@ -125,10 +174,10 @@ def test_save_tables(self): "PySide6.QtWidgets.QFileDialog.getSaveFileName" ) as mock_save_dialog: - mock_open_dialog.return_value = (ONNX_FILEPATH, "") + mock_open_dialog.return_value = (self.ONNX_FILEPATH, "") with tempfile.TemporaryDirectory() as tmpdirname: mock_save_dialog.return_value = ( - os.path.join(tmpdirname, f"{ONNX_BASENAME}_nodes.csv"), + os.path.join(tmpdirname, f"{self.MODEL_BASENAME}_nodes.csv"), "", ) @@ -136,17 +185,19 @@ def test_save_tables(self): self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton ) - self.wait_all_threads() + self.assertTrue(self.wait_all_threads()) QTest.mouseClick( self.digest_app.ui.nodesListBtn, Qt.MouseButton.LeftButton ) - # We assume there is only model loaded + # We assume there is only one model loaded _, node_window = self.digest_app.nodes_window.popitem() node_summary = node_window.main_window.centralWidget() self.assertIsInstance(node_summary, NodeSummary) + + # This line of code seems redundant but we do this to clean pylance if isinstance(node_summary, NodeSummary): QTest.mouseClick( node_summary.ui.saveCsvBtn, Qt.MouseButton.LeftButton @@ -156,11 +207,15 @@ def test_save_tables(self): self.assertTrue( os.path.exists( - os.path.join(tmpdirname, f"{ONNX_BASENAME}_nodes.csv") + os.path.join(tmpdirname, f"{self.MODEL_BASENAME}_nodes.csv") ), "Nodes csv file not found.", ) + num_tabs = self.digest_app.ui.tabWidget.count() + self.assertTrue(num_tabs == 1) + self.digest_app.closeTab(0) + if __name__ == "__main__": unittest.main() diff --git a/test/test_reports.py b/test/test_reports.py index a16c4d8..ae99ab9 100644 --- a/test/test_reports.py +++ b/test/test_reports.py @@ -1,17 +1,22 @@ # Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. -"""Unit tests for Vitis ONNX Model Analyzer """ - import os import unittest import tempfile import csv -from utils.onnx_utils import DigestOnnxModel, load_onnx +import utils.onnx_utils as onnx_utils +from digest.model_class.digest_onnx_model import DigestOnnxModel +from digest.model_class.digest_report_model import compare_yaml_files TEST_DIR = os.path.dirname(os.path.abspath(__file__)) TEST_ONNX = os.path.join(TEST_DIR, "resnet18.onnx") -TEST_SUMMARY_TXT_REPORT = os.path.join(TEST_DIR, "resnet18_test_summary.txt") -TEST_NODES_CSV_REPORT = os.path.join(TEST_DIR, "resnet18_test_nodes.csv") +TEST_SUMMARY_TEXT_REPORT = os.path.join( + TEST_DIR, "resnet18_reports/resnet18_report.txt" +) +TEST_SUMMARY_YAML_REPORT = os.path.join( + TEST_DIR, "resnet18_reports/resnet18_report.yaml" +) +TEST_NODES_CSV_REPORT = os.path.join(TEST_DIR, "resnet18_reports/resnet18_nodes.csv") class TestDigestReports(unittest.TestCase): @@ -46,27 +51,35 @@ def compare_csv_files(self, file1, file2, skip_lines=0): self.assertEqual(row1, row2, msg=f"Difference in row: {row1} vs {row2}") def test_against_example_reports(self): - model_proto = load_onnx(TEST_ONNX) + model_proto = onnx_utils.load_onnx(TEST_ONNX, load_external_data=False) model_name = os.path.splitext(os.path.basename(TEST_ONNX))[0] + opt_model, _ = onnx_utils.optimize_onnx_model(model_proto) digest_model = DigestOnnxModel( - model_proto, onnx_filepath=TEST_ONNX, model_name=model_name, save_proto=False, + opt_model, + onnx_filepath=TEST_ONNX, + model_name=model_name, + save_proto=False, ) with tempfile.TemporaryDirectory() as tmpdir: - # Model summary text report - summary_filepath = os.path.join(tmpdir, f"{model_name}_summary.txt") - digest_model.save_txt_report(summary_filepath) - - with self.subTest("Testing summary text file"): - self.compare_files_line_by_line( - TEST_SUMMARY_TXT_REPORT, - summary_filepath, - skip_lines=2, + # Model yaml report + yaml_report_filepath = os.path.join(tmpdir, f"{model_name}_report.yaml") + digest_model.save_yaml_report(yaml_report_filepath) + with self.subTest("Testing report yaml file"): + self.assertTrue( + compare_yaml_files( + TEST_SUMMARY_YAML_REPORT, + yaml_report_filepath, + skip_keys=["report_date", "model_file", "digest_version"], + ) ) # Save CSV containing node-level information nodes_filepath = os.path.join(tmpdir, f"{model_name}_nodes.csv") digest_model.save_nodes_csv_report(nodes_filepath) - with self.subTest("Testing nodes csv file"): self.compare_csv_files(TEST_NODES_CSV_REPORT, nodes_filepath) + + +if __name__ == "__main__": + unittest.main()