Skip to content

Commit

Permalink
fixed critical threading issue with matplotlib
Browse files Browse the repository at this point in the history
- various quality updates to gui
  • Loading branch information
Philip Colangelo committed Dec 12, 2024
1 parent db35c40 commit 496ef19
Show file tree
Hide file tree
Showing 13 changed files with 409 additions and 254 deletions.
4 changes: 2 additions & 2 deletions src/digest/histogramchartwidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def __init__(self, *args, **kwargs):
self.bar_spacing = 25

def set_data(self, data: OrderedDict, model_name, y_max, title="", set_ticks=False):

title_color = "rgb(0,0,0)" if set_ticks else "rgb(200,200,200)"
self.plot_widget.setLabel(
"left",
Expand All @@ -173,7 +172,8 @@ def set_data(self, data: OrderedDict, model_name, y_max, title="", set_ticks=Fal
x_positions = list(range(len(op_count)))
total_count = sum(op_count)
width = 0.6
self.plot_widget.setFixedWidth(len(op_names) * self.bar_spacing)
self.plot_widget.setFixedWidth(500)

for count, x_pos, tick in zip(op_count, x_positions, op_names):
x0 = x_pos - width / 2
y0 = 0
Expand Down
98 changes: 59 additions & 39 deletions src/digest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Dict, Tuple, Optional, Union
import tempfile
from enum import IntEnum
import pandas as pd
import yaml

# This is a temporary workaround since the Qt designer generated files
Expand Down Expand Up @@ -37,7 +38,7 @@
from PySide6.QtCore import Qt, QSize

from digest.dialog import StatusDialog, InfoDialog, WarnDialog, ProgressDialog
from digest.thread import StatsThread, SimilarityThread
from digest.thread import StatsThread, SimilarityThread, post_process
from digest.popup_window import PopupWindow
from digest.huggingface_page import HuggingfacePage
from digest.multi_model_selection_page import MultiModelSelectionPage
Expand Down Expand Up @@ -214,15 +215,15 @@ def __init__(self, model_file: Optional[str] = None):

# Set up the HUGGINGFACE Page
huggingface_page = HuggingfacePage()
huggingface_page.model_signal.connect(self.load_onnx)
huggingface_page.model_signal.connect(self.load_model)
self.ui.stackedWidget.insertWidget(self.Page.HUGGINGFACE, huggingface_page)

# Set up the multi model page and relevant button
self.multimodelselection_page = MultiModelSelectionPage()
self.ui.stackedWidget.insertWidget(
self.Page.MULTIMODEL, self.multimodelselection_page
)
self.multimodelselection_page.model_signal.connect(self.load_onnx)
self.multimodelselection_page.model_signal.connect(self.load_model)

# Load model file if given as input to the executable
if model_file:
Expand Down Expand Up @@ -285,25 +286,14 @@ def closeTab(self, index):
self.ui.singleModelWidget.hide()

def openFile(self):
filename, _ = QFileDialog.getOpenFileName(
file_name, _ = QFileDialog.getOpenFileName(
self, "Open File", "", "ONNX and Report Files (*.onnx *.yaml)"
)

if not filename:
if not file_name:
return

file_ext = os.path.splitext(filename)[-1]

if file_ext == ".onnx":
self.load_onnx(filename)
elif file_ext == ".yaml":
self.load_report(filename)
else:
bad_ext_dialog = StatusDialog(
f"Digest does not support files with the extension {file_ext}",
parent=self,
)
bad_ext_dialog.show()
self.load_model(file_name)

def update_cards(
self,
Expand Down Expand Up @@ -374,7 +364,8 @@ def update_similarity_widget(
completed_successfully: bool,
model_id: str,
most_similar: str,
png_filepath: Union[str, None],
png_filepath: Optional[str] = None,
df_sorted: Optional[pd.DataFrame] = None,
):
widget = None
digest_model = None
Expand All @@ -390,9 +381,24 @@ def update_similarity_widget(
curr_index = index
break

if completed_successfully and isinstance(widget, modelSummary) and png_filepath:
# convert back to a List[str]
most_similar_list = most_similar.split(",")

if (
completed_successfully
and isinstance(widget, modelSummary)
and digest_model
and png_filepath
):

if df_sorted is not None:
post_process(
digest_model.model_name, most_similar_list, df_sorted, png_filepath
)

widget.load_gif.stop()
widget.ui.similarityImg.clear()
# We give the image a 10% haircut to fit it more aesthetically
widget_width = widget.ui.similarityImg.width()

pixmap = QPixmap(png_filepath)
Expand All @@ -411,30 +417,31 @@ def update_similarity_widget(
# Show most correlated models
widget.ui.similarityCorrelation.show()
widget.ui.similarityCorrelationStatic.show()

most_similar_list = most_similar_list[1:4]
if most_similar:
most_similar_models = most_similar.split(",")
text = (
"\n<span style='color:red;text-align:center;'>"
f"{most_similar_models[0]}, {most_similar_models[1]}, "
f"and {most_similar_models[2]}. "
f"{most_similar_list[0]}, {most_similar_list[1]}, "
f"and {most_similar_list[2]}. "
"</span>"
)
else:
# currently the similarity widget expects the most_similar_models
# to allows contains 3 models. For now we will just send three empty
# strings but at some point we should handle an arbitrary case.
most_similar_models = ["", "", ""]
text = ""
most_similar_list = ["", "", ""]
text = "NTD"

# Create option to click to enlarge image
widget.ui.similarityImg.mousePressEvent = (
lambda event: self.open_similarity_report(
model_id, png_filepath, most_similar_models
model_id, png_filepath, most_similar_list
)
)
# Create option to click to enlarge image
self.model_similarity_report[model_id] = SimilarityAnalysisReport(
png_filepath, most_similar_models
png_filepath, most_similar_list
)

widget.ui.similarityCorrelation.setText(text)
Expand Down Expand Up @@ -878,17 +885,38 @@ def load_report(self, filepath: str):
movie.start()

self.update_similarity_widget(
bool(digest_model.similarity_heatmap_path),
digest_model.unique_id,
"",
digest_model.similarity_heatmap_path,
completed_successfully=bool(digest_model.similarity_heatmap_path),
model_id=digest_model.unique_id,
most_similar="",
png_filepath=digest_model.similarity_heatmap_path,
)

progress.close()

except FileNotFoundError as e:
print(f"File not found: {e.filename}")

def load_model(self, file_path: str):

# Ensure the filepath follows a standard formatting:
file_path = os.path.normpath(file_path)

if not os.path.exists(file_path):
return

file_ext = os.path.splitext(file_path)[-1]

if file_ext == ".onnx":
self.load_onnx(file_path)
elif file_ext == ".yaml":
self.load_report(file_path)
else:
bad_ext_dialog = StatusDialog(
f"Digest does not support files with the extension {file_ext}",
parent=self,
)
bad_ext_dialog.show()

def dragEnterEvent(self, event: QDragEnterEvent):
if event.mimeData().hasUrls():
event.acceptProposedAction()
Expand All @@ -897,12 +925,7 @@ def dropEvent(self, event: QDropEvent):
if event.mimeData().hasUrls():
for url in event.mimeData().urls():
file_path = url.toLocalFile()
if file_path.endswith(".onnx"):
self.load_onnx(file_path)
break
elif file_path.endswith(".yaml"):
self.load_report(file_path)
break
self.load_model(file_path)

## functions for changing menu page
def logo_clicked(self):
Expand Down Expand Up @@ -950,9 +973,6 @@ def save_reports(self):
self, "Select Directory"
)

if not save_directory:
return

# Check if the directory exists and is writable
if not os.path.exists(save_directory) or not os.access(save_directory, os.W_OK):
self.show_warning_dialog(
Expand Down
44 changes: 36 additions & 8 deletions src/digest/multi_model_analysis.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.

import os
from datetime import datetime
import csv
from typing import List, Dict, Union
from collections import Counter, defaultdict, OrderedDict

# pylint: disable=no-name-in-module
from PySide6.QtWidgets import QWidget, QTableWidgetItem, QFileDialog
from PySide6.QtCore import Qt
from digest.dialog import ProgressDialog, StatusDialog
from digest.ui.multimodelanalysis_ui import Ui_multiModelAnalysis
from digest.histogramchartwidget import StackedHistogramWidget
Expand Down Expand Up @@ -42,6 +44,9 @@ def __init__(
self.ui.individualCheckBox.stateChanged.connect(self.check_box_changed)
self.ui.multiCheckBox.stateChanged.connect(self.check_box_changed)

# For some reason setting alignments in designer lead to bugs in *ui.py files
self.ui.opHistogramChart.layout().setAlignment(Qt.AlignmentFlag.AlignTop)

if not model_list:
return

Expand Down Expand Up @@ -80,7 +85,8 @@ def __init__(
if isinstance(model, DigestOnnxModel):
item = QTableWidgetItem(str(model.opset))
elif isinstance(model, DigestReportModel):
item = QTableWidgetItem(str(model.model_data.get("opset", "NA")))
item = QTableWidgetItem(str(model.model_data.get("opset", "")))

self.ui.dataTable.setItem(row, 2, item)

item = QTableWidgetItem(str(len(model.node_data)))
Expand Down Expand Up @@ -193,31 +199,53 @@ def __init__(
set_ticks=False,
)
frame_layout = self.ui.stackedHistogramFrame.layout()
frame_layout.addWidget(stacked_histogram_widget)
if frame_layout:
frame_layout.addWidget(stacked_histogram_widget)

# Add a "ghost" histogram to allow us to set the x axis label vertically
model_name = list(node_type_counter.keys())[0]
stacked_histogram_widget = StackedHistogramWidget()
ordered_dict = {key: 1 for key in top_ops}
ordered_dict = OrderedDict({key: 1 for key in top_ops})
stacked_histogram_widget.set_data(
ordered_dict,
model_name="_",
y_max=max_count,
set_ticks=True,
)
frame_layout = self.ui.stackedHistogramFrame.layout()
frame_layout.addWidget(stacked_histogram_widget)
if frame_layout:
frame_layout.addWidget(stacked_histogram_widget)

self.model_list = model_list

def save_reports(self):
# Model summary text report
save_directory = QFileDialog(self).getExistingDirectory(
"""This function saves all available reports for the models that are opened
in the multi-model analysis page."""

base_directory = QFileDialog(self).getExistingDirectory(
self, "Select Directory"
)

if not save_directory:
return
# Check if the directory exists and is writable
if not os.path.exists(base_directory) or not os.access(base_directory, os.W_OK):
bad_ext_dialog = StatusDialog(
f"The directory {base_directory} is not valid or writable.",
parent=self,
)
bad_ext_dialog.show()

# Append a subdirectory to the save_directory so that all reports are co-located
name_id = datetime.now().strftime("%Y%m%d%H%M%S")
sub_directory = f"multi_model_reports_{name_id}"
save_directory = os.path.join(base_directory, sub_directory)
try:
os.makedirs(save_directory)
except OSError as os_err:
bad_ext_dialog = StatusDialog(
f"Failed to create {save_directory} with error {os_err}",
parent=self,
)
bad_ext_dialog.show()

save_individual_reports = self.ui.individualCheckBox.isChecked()
save_multi_reports = self.ui.multiCheckBox.isChecked()
Expand Down
7 changes: 1 addition & 6 deletions src/digest/multi_model_selection_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,7 @@ def run(self):

self.close_progress.emit()

model_list = [
model
for model in self.model_dict.values()
if isinstance(model, DigestOnnxModel)
or isinstance(model, DigestReportModel)
]
model_list = [model for model in self.model_dict.values()]

self.completed.emit(model_list)

Expand Down
19 changes: 19 additions & 0 deletions src/digest/styles/darkstyle.qss
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,23 @@ QTreeView::item:selected:active {

QTreeView::item:selected:!active {
background-color: #949494;
}

QRadioButton {
spacing: 5px; /* Add spacing between the indicator and text */
color: white; /* Set text color to white */
}

QRadioButton::indicator {
/*width: 15px;
height: 15px;*/
border-radius: 7px; /* Make the indicator circular */
}

QRadioButton::indicator:unchecked {
border: 2px solid gray; /* Add a gray border when unchecked */
}

QRadioButton::indicator:checked {
background-color: lightblue; /* Fill with light blue when checked */
}
Loading

0 comments on commit 496ef19

Please sign in to comment.