Skip to content

Commit

Permalink
initial commit for ingesting pytorch models
Browse files Browse the repository at this point in the history
  • Loading branch information
Philip Colangelo committed Jan 3, 2025
1 parent 06a3467 commit 1573a71
Show file tree
Hide file tree
Showing 16 changed files with 1,782 additions and 177 deletions.
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="digestai",
version="1.1.0",
version="1.2.0",
description="Model analysis toolkit",
author="Philip Colangelo, Daniel Holanda",
packages=find_packages(where="src"),
Expand All @@ -25,6 +25,8 @@
"platformdirs>=4.2.2",
"pyyaml>=6.0.1",
"psutil>=6.0.0",
"torch",
"transformers",
],
classifiers=[],
entry_points={"console_scripts": ["digest = digest.main:main"]},
Expand Down
172 changes: 102 additions & 70 deletions src/digest/main.py

Large diffs are not rendered by default.

43 changes: 23 additions & 20 deletions src/digest/model_class/digest_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
class SupportedModelTypes(Enum):
ONNX = "onnx"
REPORT = "report"
PYTORCH = "pytorch"


class NodeParsingException(Exception):
Expand Down Expand Up @@ -94,10 +95,12 @@ def __init__(self, *args, **kwargs):


class DigestModel(ABC):
def __init__(self, filepath: str, model_name: str, model_type: SupportedModelTypes):
def __init__(
self, file_path: str, model_name: str, model_type: SupportedModelTypes
):
# Public members exposed to the API
self.unique_id: str = str(uuid4())
self.filepath: Optional[str] = filepath
self.file_path: Optional[str] = os.path.abspath(file_path)
self.model_name: str = model_name
self.model_type: SupportedModelTypes = model_type
self.node_type_counts: NodeTypeCounts = NodeTypeCounts()
Expand All @@ -122,27 +125,27 @@ def parse_model_nodes(self, *args, **kwargs) -> None:
pass

@abstractmethod
def save_yaml_report(self, filepath: str) -> None:
def save_yaml_report(self, file_path: str) -> None:
pass

@abstractmethod
def save_text_report(self, filepath: str) -> None:
def save_text_report(self, file_path: str) -> None:
pass

def save_nodes_csv_report(self, filepath: str) -> None:
save_nodes_csv_report(self.node_data, filepath)
def save_nodes_csv_report(self, file_path: str) -> None:
save_nodes_csv_report(self.node_data, file_path)

def save_node_type_counts_csv_report(self, filepath: str) -> None:
def save_node_type_counts_csv_report(self, file_path: str) -> None:
if self.node_type_counts:
save_node_type_counts_csv_report(self.node_type_counts, filepath)
save_node_type_counts_csv_report(self.node_type_counts, file_path)

def save_node_shape_counts_csv_report(self, filepath: str) -> None:
save_node_shape_counts_csv_report(self.get_node_shape_counts(), filepath)
def save_node_shape_counts_csv_report(self, file_path: str) -> None:
save_node_shape_counts_csv_report(self.get_node_shape_counts(), file_path)


def save_nodes_csv_report(node_data: NodeData, filepath: str) -> None:
def save_nodes_csv_report(node_data: NodeData, file_path: str) -> None:

parent_dir = os.path.dirname(os.path.abspath(filepath))
parent_dir = os.path.dirname(os.path.abspath(file_path))
if not os.path.exists(parent_dir):
raise FileNotFoundError(f"Directory {parent_dir} does not exist.")

Expand Down Expand Up @@ -186,44 +189,44 @@ def save_nodes_csv_report(node_data: NodeData, filepath: str) -> None:

fieldnames = fieldnames + input_fieldnames + output_fieldnames
try:
with open(filepath, "w", encoding="utf-8", newline="") as csvfile:
with open(file_path, "w", encoding="utf-8", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator="\n")
writer.writeheader()
writer.writerows(flattened_data)
except PermissionError as exception:
raise PermissionError(
f"Saving reports to {filepath} failed with error {exception}"
f"Saving reports to {file_path} failed with error {exception}"
)


def save_node_type_counts_csv_report(
node_type_counts: NodeTypeCounts, filepath: str
node_type_counts: NodeTypeCounts, file_path: str
) -> None:

parent_dir = os.path.dirname(os.path.abspath(filepath))
parent_dir = os.path.dirname(os.path.abspath(file_path))
if not os.path.exists(parent_dir):
raise FileNotFoundError(f"Directory {parent_dir} does not exist.")

header = ["Node Type", "Count"]

with open(filepath, "w", encoding="utf-8", newline="") as csvfile:
with open(file_path, "w", encoding="utf-8", newline="") as csvfile:
writer = csv.writer(csvfile, lineterminator="\n")
writer.writerow(header)
for node_type, node_count in node_type_counts.items():
writer.writerow([node_type, node_count])


def save_node_shape_counts_csv_report(
node_shape_counts: NodeShapeCounts, filepath: str
node_shape_counts: NodeShapeCounts, file_path: str
) -> None:

parent_dir = os.path.dirname(os.path.abspath(filepath))
parent_dir = os.path.dirname(os.path.abspath(file_path))
if not os.path.exists(parent_dir):
raise FileNotFoundError(f"Directory {parent_dir} does not exist.")

header = ["Node Type", "Input Tensors Shapes", "Count"]

with open(filepath, "w", encoding="utf-8", newline="") as csvfile:
with open(file_path, "w", encoding="utf-8", newline="") as csvfile:
writer = csv.writer(csvfile, dialect="excel", lineterminator="\n")
writer.writerow(header)
for node_type, node_info in node_shape_counts.items():
Expand Down
24 changes: 11 additions & 13 deletions src/digest/model_class/digest_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ class DigestOnnxModel(DigestModel):
def __init__(
self,
onnx_model: onnx.ModelProto,
onnx_filepath: str = "",
onnx_file_path: str = "",
model_name: str = "",
save_proto: bool = True,
) -> None:
super().__init__(onnx_filepath, model_name, SupportedModelTypes.ONNX)

self.model_type = SupportedModelTypes.ONNX
super().__init__(onnx_file_path, model_name, SupportedModelTypes.ONNX)

# Public members exposed to the API
self.model_proto: Optional[onnx.ModelProto] = onnx_model if save_proto else None
Expand Down Expand Up @@ -510,9 +508,9 @@ def parse_model_nodes(self, onnx_model: onnx.ModelProto) -> None:
self.node_type_flops.get(node.op_type, 0) + node_info.flops
)

def save_yaml_report(self, filepath: str) -> None:
def save_yaml_report(self, file_path: str) -> None:

parent_dir = os.path.dirname(os.path.abspath(filepath))
parent_dir = os.path.dirname(os.path.abspath(file_path))
if not os.path.exists(parent_dir):
raise FileNotFoundError(f"Directory {parent_dir} does not exist.")

Expand All @@ -526,7 +524,7 @@ def save_yaml_report(self, filepath: str) -> None:
"report_date": report_date,
"digest_version": digest_version,
"model_type": self.model_type.value,
"model_file": self.filepath,
"model_file": self.file_path,
"model_name": self.model_name,
"model_version": self.model_version,
"graph_name": self.graph_name,
Expand All @@ -545,25 +543,25 @@ def save_yaml_report(self, filepath: str) -> None:
"output_tensors": output_tensors,
}

with open(filepath, "w", encoding="utf-8") as f_p:
with open(file_path, "w", encoding="utf-8") as f_p:
yaml.dump(yaml_data, f_p, sort_keys=False)

def save_text_report(self, filepath: str) -> None:
def save_text_report(self, file_path: str) -> None:

parent_dir = os.path.dirname(os.path.abspath(filepath))
parent_dir = os.path.dirname(os.path.abspath(file_path))
if not os.path.exists(parent_dir):
raise FileNotFoundError(f"Directory {parent_dir} does not exist.")

report_date = datetime.now().strftime("%B %d, %Y")

digest_version = importlib.metadata.version("digestai")

with open(filepath, "w", encoding="utf-8") as f_p:
with open(file_path, "w", encoding="utf-8") as f_p:
f_p.write(f"Report created on {report_date}\n")
f_p.write(f"Digest version: {digest_version}\n")
f_p.write(f"Model type: {self.model_type.name}\n")
if self.filepath:
f_p.write(f"ONNX file: {self.filepath}\n")
if self.file_path:
f_p.write(f"ONNX file: {self.file_path}\n")
f_p.write(f"Name of the model: {self.model_name}\n")
f_p.write(f"Model version: {self.model_version}\n")
f_p.write(f"Name of the graph: {self.graph_name}\n")
Expand Down
102 changes: 102 additions & 0 deletions src/digest/model_class/digest_pytorch_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.

import os
from collections import OrderedDict
from typing import List, Tuple, Optional, Any, Union
import inspect
import onnx
import torch
from digest.model_class.digest_onnx_model import DigestOnnxModel
from digest.model_class.digest_model import (
DigestModel,
SupportedModelTypes,
)


class DigestPyTorchModel(DigestModel):
"""The idea of this class is to first support PyTorch models by converting them to ONNX
Eventually, we will want to support a PyTorch specific interface that has a custom GUI.
To facilitate this process, it makes the most sense to use this class as helper class
to convert the PyTorch model to ONNX and store the ONNX info in a member DigestOnnxModel
object. We can also store various PyTorch specific details in this class as well.
"""

def __init__(
self,
pytorch_file_path: str = "",
model_name: str = "",
) -> None:
super().__init__(pytorch_file_path, model_name, SupportedModelTypes.PYTORCH)

assert os.path.exists(
pytorch_file_path
), f"PyTorch file {pytorch_file_path} does not exist."

# Default opset value
self.opset = 17

# Input dictionary to contain the names and shapes
# required for exporting the ONNX model
self.input_tensor_info: OrderedDict[str, List[Any]] = OrderedDict()

self.pytorch_model = torch.load(pytorch_file_path)

# Data needed for exporting to ONNX
self.do_constant_folding = True
self.export_params = True

self.onnx_file_path: Optional[str] = None

self.digest_onnx_model: Optional[DigestOnnxModel] = None

def parse_model_nodes(self) -> None:
"""This will be done in the DigestOnnxModel"""

def save_yaml_report(self, file_path: str) -> None:
"""This will be done in the DigestOnnxModel"""

def save_text_report(self, file_path: str) -> None:
"""This will be done in the DigestOnnxModel"""

def generate_random_tensor(self, shape: List[Union[str, int]]):
static_shape = [dim if isinstance(dim, int) else 1 for dim in shape]
return torch.rand(static_shape)

def export_to_onnx(self, output_onnx_path: str) -> Union[onnx.ModelProto, None]:

dummy_input_names: List[str] = list(self.input_tensor_info.keys())
dummy_inputs: List[torch.Tensor] = []

for shape in self.input_tensor_info.values():
dummy_inputs.append(self.generate_random_tensor(shape))

dynamic_axes = {
name: {i: dim for i, dim in enumerate(shape) if isinstance(dim, str)}
for name, shape in self.input_tensor_info.items()
}

try:
torch.onnx.export(
self.pytorch_model,
tuple(dummy_inputs),
output_onnx_path,
input_names=dummy_input_names,
do_constant_folding=self.do_constant_folding,
export_params=self.export_params,
opset_version=self.opset,
dynamic_axes=dynamic_axes,
verbose=False,
)

self.onnx_file_path = output_onnx_path

return onnx.load(output_onnx_path)

except (TypeError, RuntimeError) as err:
print(f"Failed to export ONNX: {err}")
raise


def get_model_fwd_parameters(torch_file_path):
torch_model = torch.load(torch_file_path)
return inspect.signature(torch_model.forward).parameters
22 changes: 12 additions & 10 deletions src/digest/model_class/digest_report_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def parse_tensor_info(
csv_tensor_cell_value,
) -> Tuple[str, list, str, Union[str, float]]:
) -> Tuple[str, list, str, Optional[float]]:
"""This is a helper function that expects the input to come from parsing
the nodes csv and extracting either an input or output tensor."""

Expand All @@ -40,7 +40,9 @@ def parse_tensor_info(
if not isinstance(shape, list):
shape = list(shape)

if size != "None":
if size == "None":
size = None
else:
size = float(size.split()[0])

return name.strip(), shape, dtype.strip(), size
Expand All @@ -49,30 +51,30 @@ def parse_tensor_info(
class DigestReportModel(DigestModel):
def __init__(
self,
report_filepath: str,
report_file_path: str,
) -> None:

self.model_type = SupportedModelTypes.REPORT

self.is_valid = validate_yaml(report_filepath)
self.is_valid = validate_yaml(report_file_path)

if not self.is_valid:
print(f"The yaml file {report_filepath} is not a valid digest report.")
print(f"The yaml file {report_file_path} is not a valid digest report.")
return

self.model_data = OrderedDict()
with open(report_filepath, "r", encoding="utf-8") as yaml_f:
with open(report_file_path, "r", encoding="utf-8") as yaml_f:
self.model_data = yaml.safe_load(yaml_f)

model_name = self.model_data["model_name"]
super().__init__(report_filepath, model_name, SupportedModelTypes.REPORT)
super().__init__(report_file_path, model_name, SupportedModelTypes.REPORT)

self.similarity_heatmap_path: Optional[str] = None
self.node_data = NodeData()

# Given the path to the digest report, let's check if its a complete cache
# and we can grab the nodes csv data and the similarity heatmap
cache_dir = os.path.dirname(os.path.abspath(report_filepath))
cache_dir = os.path.dirname(os.path.abspath(report_file_path))
expected_heatmap_file = os.path.join(cache_dir, f"{model_name}_heatmap.png")
if os.path.exists(expected_heatmap_file):
self.similarity_heatmap_path = expected_heatmap_file
Expand Down Expand Up @@ -139,10 +141,10 @@ def __init__(
def parse_model_nodes(self) -> None:
"""There are no model nodes to parse"""

def save_yaml_report(self, filepath: str) -> None:
def save_yaml_report(self, file_path: str) -> None:
"""Report models are not intended to be saved"""

def save_text_report(self, filepath: str) -> None:
def save_text_report(self, file_path: str) -> None:
"""Report models are not intended to be saved"""


Expand Down
Loading

0 comments on commit 1573a71

Please sign in to comment.