Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class PrecisionConverter:
Public Methods:
convert: Convert specified nodes to FP16/BF16 precision while keeping others in FP32.
"""
def print_byte_size(self, label: str):
model_proto = self.model.SerializeToString()
model_size = len(model_proto)
print(f"GAGAM {label} ByteSize: {model_size}")

def __init__(
self,
Expand Down Expand Up @@ -172,7 +176,7 @@ def convert(
onnx.ModelProto: The converted mixed precision model.
"""
try:
self.model = onnx_utils.check_model(self.model)
onnx_utils.check_model(self.model)
except onnx.checker.ValidationError as e:
logger.error(f"Internal error: onnx.checker failed on input model {e}")
raise Exception(
Expand Down Expand Up @@ -1253,7 +1257,9 @@ def _fix_network_output_names(self):
def _sanity_check(self):
sanity_ok = True
try:
self.print_byte_size("before check_model")
onnx_utils.check_model(self.model)
self.print_byte_size("after check_model")
except onnx.checker.ValidationError as e:
logger.error(f"Internal error: onnx.checker failed: {e}")
sanity_ok = False
Expand Down
64 changes: 59 additions & 5 deletions modelopt/onnx/autocast/referencerunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
import copy
import io
import sys
import tempfile
from collections import OrderedDict

import numpy as np
import onnx

from modelopt.onnx import utils as onnx_utils
from modelopt.onnx.autocast.logging_config import configure_logging, logger
from modelopt.onnx.quantization.ort_utils import _prepare_ep_list

Expand Down Expand Up @@ -118,13 +120,65 @@ def _load_inputs(self, inputs):

return data_loader

def _get_ort_runner(self, model):
import onnxruntime as ort
from polygraphy.backend.onnx import BytesFromOnnx
from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx

# Check if model has external data by checking:
# 1. If any initializer has data_location set to EXTERNAL (even if data is loaded)
# 2. If model size would exceed 2GB (indicating need for external data)
has_external_data = any(
init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL
for init in self.model.graph.initializer
)

# Also check if model would be too large (>2GB) for SerializeToString
# This handles cases where model was loaded with external data already loaded
if not has_external_data:
try:
# Try to estimate size by serializing the model
# If it fails or exceeds 2GB, we need file-based approach
model_size = len(self.model.SerializeToString())
if model_size > 2 * (1024**3): # 2GB threshold
has_external_data = True
logger.debug(
f"Model size ({model_size / (1024**3):.2f} GB) exceeds 2GB, using file-based approach"
)
except (ValueError, AttributeError) as e:
# SerializeToString failed (likely >2GB limit), use file-based approach
if "exceeds maximum protobuf size" in str(e) or "2GB" in str(e):
has_external_data = True
logger.debug("Model exceeds protobuf 2GB limit, using file-based approach")

if has_external_data:
logger.debug("Model has external data, using file-based approach")
# Get the actual ONNX ModelProto from ModifyOutputs wrapper
modified_model = model()

# Use a persistent temp file to handle external data files properly
tmp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
tmp_file.close()
tmp_file_path = tmp_file.name
onnx_utils.save_onnx(modified_model, tmp_file_path, save_as_external_data=True)
logger.debug(f"Model with all outputs saved to {tmp_file_path}")
session = ort.InferenceSession(tmp_file_path, providers=self.providers)
runners = [OnnxrtRunner(lambda: session)]

else:
# For models without external data, use the original BytesFromOnnx approach (no tmp files)
logger.debug("Model has no external data, using BytesFromOnnx approach")
serialize_onnx = BytesFromOnnx(model)
build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers)
runners = [OnnxrtRunner(build_onnxrt_session)]

return runners

def run(self, inputs=None):
"""Run FP32 inference with provided or random inputs."""
import onnxruntime as ort
from polygraphy import constants
from polygraphy.backend.onnx import BytesFromOnnx
from polygraphy.backend.onnx import ModifyOutputs as ModifyOnnxOutputs
from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx
from polygraphy.comparator import Comparator

logger.info("Running ONNX Runtime to obtain reference outputs (this may take a while)...")
Expand All @@ -133,9 +187,9 @@ def run(self, inputs=None):

model_copy = copy.deepcopy(self.model)
modify_outputs = ModifyOnnxOutputs(model_copy, outputs=constants.MARK_ALL)
serialize_onnx = BytesFromOnnx(modify_outputs)
build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers)
runners = [OnnxrtRunner(build_onnxrt_session)]

# Load the modified model and create an inference session
runners = self._get_ort_runner(modify_outputs)

# Comparator is used despite the fact that we are using ONNXRuntime
# because it provides the ability to generate random inputs using DataLoader
Expand Down
22 changes: 11 additions & 11 deletions modelopt/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,19 +552,19 @@ def _get_unique_name(old_name):
return onnx_model, is_modified


def check_model(model: onnx.ModelProto) -> onnx.ModelProto:
def check_model(model: onnx.ModelProto) -> None:
"""Checks if the given model is valid."""
if model.ByteSize() > (2 * (1024**3)): # 2GB limit
with tempfile.TemporaryDirectory() as temp_dir:
# ONNX also looks in CWD, so we need to use a unique id
unique_id = str(uuid.uuid4())[:8]
onnx_tmp_path = os.path.join(temp_dir, f"model_{unique_id}.onnx")
save_onnx(model, onnx_tmp_path, save_as_external_data=True)
onnx.checker.check_model(onnx_tmp_path)
return onnx.load(onnx_tmp_path)
logger.warning("Model exceeds 2GB limit, skipping check_model")
# with tempfile.TemporaryDirectory() as temp_dir:
# # ONNX also looks in CWD, so we need to use a unique id
# unique_id = str(uuid.uuid4())[:8]
# onnx_tmp_path = os.path.join(temp_dir, f"model_{unique_id}.onnx")
# save_onnx(model, onnx_tmp_path, save_as_external_data=True)
# onnx.checker.check_model(onnx_tmp_path)

else:
onnx.checker.check_model(model)
return model


def find_lowest_common_ancestor(node1: Node, node2: Node) -> tuple[str | None, int, int]:
Expand Down Expand Up @@ -644,7 +644,7 @@ def save_onnx(model: onnx.ModelProto, onnx_path: str, save_as_external_data: boo
model_proto = model.SerializeToString()
model_size = len(model_proto)
save_as_external_data = save_as_external_data or model_size > size_threshold
logger.debug(
logger.warning(
f"Model size: {model_size} bytes, using external data: {save_as_external_data}"
)

Expand All @@ -658,7 +658,7 @@ def save_onnx(model: onnx.ModelProto, onnx_path: str, save_as_external_data: boo

# Set ir_version to 10, remove it once ORT supports ir_version 11
model.ir_version = 10

save_as_external_data = True # GAGAM: for debug
if save_as_external_data:
external_data_path = os.path.basename(onnx_path) + "_data"
if os.path.exists(external_data_path):
Expand Down
Loading