diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 1e22504c6..eaba7c635 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -82,6 +82,10 @@ class PrecisionConverter: Public Methods: convert: Convert specified nodes to FP16/BF16 precision while keeping others in FP32. """ + def print_byte_size(self, label: str): + model_proto = self.model.SerializeToString() + model_size = len(model_proto) + print(f"GAGAM {label} ByteSize: {model_size}") def __init__( self, @@ -172,7 +176,7 @@ def convert( onnx.ModelProto: The converted mixed precision model. """ try: - self.model = onnx_utils.check_model(self.model) + onnx_utils.check_model(self.model) except onnx.checker.ValidationError as e: logger.error(f"Internal error: onnx.checker failed on input model {e}") raise Exception( @@ -1253,7 +1257,9 @@ def _fix_network_output_names(self): def _sanity_check(self): sanity_ok = True try: + self.print_byte_size("before check_model") onnx_utils.check_model(self.model) + self.print_byte_size("after check_model") except onnx.checker.ValidationError as e: logger.error(f"Internal error: onnx.checker failed: {e}") sanity_ok = False diff --git a/modelopt/onnx/autocast/referencerunner.py b/modelopt/onnx/autocast/referencerunner.py index 8dc91ff08..896066a73 100644 --- a/modelopt/onnx/autocast/referencerunner.py +++ b/modelopt/onnx/autocast/referencerunner.py @@ -24,11 +24,13 @@ import copy import io import sys +import tempfile from collections import OrderedDict import numpy as np import onnx +from modelopt.onnx import utils as onnx_utils from modelopt.onnx.autocast.logging_config import configure_logging, logger from modelopt.onnx.quantization.ort_utils import _prepare_ep_list @@ -118,13 +120,65 @@ def _load_inputs(self, inputs): return data_loader + def _get_ort_runner(self, model): + import onnxruntime as ort + from polygraphy.backend.onnx import BytesFromOnnx + from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx + + # Check if model has external data by checking: + # 1. If any initializer has data_location set to EXTERNAL (even if data is loaded) + # 2. If model size would exceed 2GB (indicating need for external data) + has_external_data = any( + init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL + for init in self.model.graph.initializer + ) + + # Also check if model would be too large (>2GB) for SerializeToString + # This handles cases where model was loaded with external data already loaded + if not has_external_data: + try: + # Try to estimate size by serializing the model + # If it fails or exceeds 2GB, we need file-based approach + model_size = len(self.model.SerializeToString()) + if model_size > 2 * (1024**3): # 2GB threshold + has_external_data = True + logger.debug( + f"Model size ({model_size / (1024**3):.2f} GB) exceeds 2GB, using file-based approach" + ) + except (ValueError, AttributeError) as e: + # SerializeToString failed (likely >2GB limit), use file-based approach + if "exceeds maximum protobuf size" in str(e) or "2GB" in str(e): + has_external_data = True + logger.debug("Model exceeds protobuf 2GB limit, using file-based approach") + + if has_external_data: + logger.debug("Model has external data, using file-based approach") + # Get the actual ONNX ModelProto from ModifyOutputs wrapper + modified_model = model() + + # Use a persistent temp file to handle external data files properly + tmp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) + tmp_file.close() + tmp_file_path = tmp_file.name + onnx_utils.save_onnx(modified_model, tmp_file_path, save_as_external_data=True) + logger.debug(f"Model with all outputs saved to {tmp_file_path}") + session = ort.InferenceSession(tmp_file_path, providers=self.providers) + runners = [OnnxrtRunner(lambda: session)] + + else: + # For models without external data, use the original BytesFromOnnx approach (no tmp files) + logger.debug("Model has no external data, using BytesFromOnnx approach") + serialize_onnx = BytesFromOnnx(model) + build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers) + runners = [OnnxrtRunner(build_onnxrt_session)] + + return runners + def run(self, inputs=None): """Run FP32 inference with provided or random inputs.""" import onnxruntime as ort from polygraphy import constants - from polygraphy.backend.onnx import BytesFromOnnx from polygraphy.backend.onnx import ModifyOutputs as ModifyOnnxOutputs - from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx from polygraphy.comparator import Comparator logger.info("Running ONNX Runtime to obtain reference outputs (this may take a while)...") @@ -133,9 +187,9 @@ def run(self, inputs=None): model_copy = copy.deepcopy(self.model) modify_outputs = ModifyOnnxOutputs(model_copy, outputs=constants.MARK_ALL) - serialize_onnx = BytesFromOnnx(modify_outputs) - build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers) - runners = [OnnxrtRunner(build_onnxrt_session)] + + # Load the modified model and create an inference session + runners = self._get_ort_runner(modify_outputs) # Comparator is used despite the fact that we are using ONNXRuntime # because it provides the ability to generate random inputs using DataLoader diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index a6b37758e..36d592432 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -552,19 +552,19 @@ def _get_unique_name(old_name): return onnx_model, is_modified -def check_model(model: onnx.ModelProto) -> onnx.ModelProto: +def check_model(model: onnx.ModelProto) -> None: """Checks if the given model is valid.""" if model.ByteSize() > (2 * (1024**3)): # 2GB limit - with tempfile.TemporaryDirectory() as temp_dir: - # ONNX also looks in CWD, so we need to use a unique id - unique_id = str(uuid.uuid4())[:8] - onnx_tmp_path = os.path.join(temp_dir, f"model_{unique_id}.onnx") - save_onnx(model, onnx_tmp_path, save_as_external_data=True) - onnx.checker.check_model(onnx_tmp_path) - return onnx.load(onnx_tmp_path) + logger.warning("Model exceeds 2GB limit, skipping check_model") + # with tempfile.TemporaryDirectory() as temp_dir: + # # ONNX also looks in CWD, so we need to use a unique id + # unique_id = str(uuid.uuid4())[:8] + # onnx_tmp_path = os.path.join(temp_dir, f"model_{unique_id}.onnx") + # save_onnx(model, onnx_tmp_path, save_as_external_data=True) + # onnx.checker.check_model(onnx_tmp_path) + else: onnx.checker.check_model(model) - return model def find_lowest_common_ancestor(node1: Node, node2: Node) -> tuple[str | None, int, int]: @@ -644,7 +644,7 @@ def save_onnx(model: onnx.ModelProto, onnx_path: str, save_as_external_data: boo model_proto = model.SerializeToString() model_size = len(model_proto) save_as_external_data = save_as_external_data or model_size > size_threshold - logger.debug( + logger.warning( f"Model size: {model_size} bytes, using external data: {save_as_external_data}" ) @@ -658,7 +658,7 @@ def save_onnx(model: onnx.ModelProto, onnx_path: str, save_as_external_data: boo # Set ir_version to 10, remove it once ORT supports ir_version 11 model.ir_version = 10 - + save_as_external_data = True # GAGAM: for debug if save_as_external_data: external_data_path = os.path.basename(onnx_path) + "_data" if os.path.exists(external_data_path):