Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Resolves] Support ingesting a digest model report/cache #3

Merged
merged 15 commits into from
Jan 24, 2025
11 changes: 0 additions & 11 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ enable =
expression-not-assigned,
pcolange marked this conversation as resolved.
Show resolved Hide resolved
confusing-with-statement,
unnecessary-lambda,
assign-to-new-keyword,
redeclared-assigned-name,
pointless-statement,
pointless-string-statement,
Expand Down Expand Up @@ -123,7 +122,6 @@ enable =
invalid-length-returned,
protected-access,
attribute-defined-outside-init,
no-init,
abstract-method,
invalid-overridden-method,
arguments-differ,
Expand Down Expand Up @@ -165,9 +163,7 @@ enable =
### format
# Line length, indentation, whitespace:
bad-indentation,
mixed-indentation,
unnecessary-semicolon,
bad-whitespace,
missing-final-newline,
line-too-long,
mixed-line-endings,
Expand All @@ -187,7 +183,6 @@ enable =
import-self,
preferred-module,
reimported,
relative-import,
deprecated-module,
wildcard-import,
misplaced-future,
Expand Down Expand Up @@ -282,12 +277,6 @@ indent-string = ' '
# black doesn't always obey its own limit. See pyproject.toml.
max-line-length = 100

# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check =

# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt = no
Expand Down
40 changes: 21 additions & 19 deletions examples/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
import csv
from collections import Counter, defaultdict
from tqdm import tqdm
from digest.model_class.digest_model import (
NodeShapeCounts,
NodeTypeCounts,
save_node_shape_counts_csv_report,
save_node_type_counts_csv_report,
)
from digest.model_class.digest_onnx_model import DigestOnnxModel
from utils.onnx_utils import (
get_dynamic_input_dims,
load_onnx,
DigestOnnxModel,
save_node_shape_counts_csv_report,
save_node_type_counts_csv_report,
NodeTypeCounts,
NodeShapeCounts,
)

GLOBAL_MODEL_HEADERS = [
Expand Down Expand Up @@ -82,46 +84,46 @@ def main(onnx_files: str, output_dir: str):

global_model_data[model_name] = {
"opset": digest_model.opset,
"parameters": digest_model.model_parameters,
"flops": digest_model.model_flops,
"parameters": digest_model.parameters,
"flops": digest_model.flops,
}

# Model summary text report
summary_filepath = os.path.join(output_dir, f"{model_name}_summary.txt")
digest_model.save_txt_report(summary_filepath)
digest_model.save_text_report(summary_filepath)

# Model summary yaml report
summary_filepath = os.path.join(output_dir, f"{model_name}_summary.yaml")
digest_model.save_yaml_report(summary_filepath)

# Save csv containing node-level information
nodes_filepath = os.path.join(output_dir, f"{model_name}_nodes.csv")
digest_model.save_nodes_csv_report(nodes_filepath)

# Save csv containing node type counter
node_type_counter = digest_model.get_node_type_counts()
node_type_filepath = os.path.join(
output_dir, f"{model_name}_node_type_counts.csv"
)
if node_type_counter:
save_node_type_counts_csv_report(node_type_counter, node_type_filepath)

digest_model.save_node_type_counts_csv_report(node_type_filepath)

# Update global data structure for node type counter
global_node_type_counter.update(node_type_counter)
global_node_type_counter.update(digest_model.node_type_counts)

# Save csv containing node shape counts per op_type
node_shape_counts = digest_model.get_node_shape_counts()
node_shape_filepath = os.path.join(
output_dir, f"{model_name}_node_shape_counts.csv"
)
save_node_shape_counts_csv_report(node_shape_counts, node_shape_filepath)
digest_model.save_node_shape_counts_csv_report(node_shape_filepath)

# Update global data structure for node shape counter
for node_type, shape_counts in node_shape_counts.items():
for node_type, shape_counts in digest_model.get_node_shape_counts().items():
global_node_shape_counter[node_type].update(shape_counts)

if len(onnx_file_list) > 1:
global_filepath = os.path.join(output_dir, "global_node_type_counts.csv")
global_node_type_counter = NodeTypeCounts(
global_node_type_counter.most_common()
)
save_node_type_counts_csv_report(global_node_type_counter, global_filepath)
global_node_type_counts = NodeTypeCounts(global_node_type_counter.most_common())
save_node_type_counts_csv_report(global_node_type_counts, global_filepath)

global_filepath = os.path.join(output_dir, "global_node_shape_counts.csv")
save_node_shape_counts_csv_report(global_node_shape_counter, global_filepath)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="digestai",
version="1.0.0",
version="1.1.0",
description="Model analysis toolkit",
author="Philip Colangelo, Daniel Holanda",
packages=find_packages(where="src"),
Expand Down
14 changes: 12 additions & 2 deletions src/digest/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,23 @@ class WarnDialog(QDialog):

def __init__(self, warning_message: str, parent=None):
super().__init__(parent)
self.setWindowTitle("Warning Message")

self.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg"))

self.setWindowTitle("Warning Message")
self.setWindowFlags(Qt.WindowType.Dialog)
self.setMinimumWidth(300)

self.setWindowModality(Qt.WindowModality.WindowModal)

layout = QVBoxLayout()

# Application Version
layout.addWidget(QLabel("<b>Something went wrong</b>"))
layout.addWidget(QLabel("<b>Warning</b>"))
layout.addWidget(QLabel(warning_message))

ok_button = QPushButton("OK")
ok_button.clicked.connect(self.accept) # Close dialog when clicked
layout.addWidget(ok_button)

self.setLayout(layout)
2 changes: 1 addition & 1 deletion src/digest/gui_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# For EXE releases we can block certain features e.g. to customers

modules:
huggingface: false
huggingface: true
6 changes: 3 additions & 3 deletions src/digest/histogramchartwidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(self, *args, **kwargs):
super(StackedHistogramWidget, self).__init__(*args, **kwargs)

self.plot_widget = pg.PlotWidget()
self.plot_widget.setMaximumHeight(150)
self.plot_widget.setMaximumHeight(200)
plot_item = self.plot_widget.getPlotItem()
if plot_item:
plot_item.setContentsMargins(0, 0, 0, 0)
Expand All @@ -157,7 +157,6 @@ def __init__(self, *args, **kwargs):
self.bar_spacing = 25

def set_data(self, data: OrderedDict, model_name, y_max, title="", set_ticks=False):

title_color = "rgb(0,0,0)" if set_ticks else "rgb(200,200,200)"
self.plot_widget.setLabel(
"left",
Expand All @@ -173,7 +172,8 @@ def set_data(self, data: OrderedDict, model_name, y_max, title="", set_ticks=Fal
x_positions = list(range(len(op_count)))
total_count = sum(op_count)
width = 0.6
self.plot_widget.setFixedWidth(len(op_names) * self.bar_spacing)
self.plot_widget.setFixedWidth(500)

for count, x_pos, tick in zip(op_count, x_positions, op_names):
x0 = x_pos - width / 2
y0 = 0
Expand Down
Loading
Loading