Torch PTQ ONNX export example add for windows + small fixes to existing torch onnx export path#1027
Torch PTQ ONNX export example add for windows + small fixes to existing torch onnx export path#1027hthadicherla wants to merge 4 commits intomainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughA new Windows-optimized LLM-to-ONNX export pipeline is introduced with NVFP4 quantization support, alongside infrastructure updates enabling INT8 SmoothQuant, dynamic onnx_quantizer_type parameter propagation, improved Cast node handling, and DynamicCache-based KV cache management for export workflows. Changes
Sequence Diagram(s)sequenceDiagram
actor User as User/CLI
participant Args as llm_arguments<br/>(Parser)
participant Config as get_config_path<br/>(Config Resolver)
participant Loader as ModelLoader
participant Quant as Quantizer
participant Export as export_raw_llm<br/>(ONNX Export)
participant Surgery as surgeon_llm<br/>(Graph Surgery)
participant Output as Output Files
User->>Args: Parse CLI args (dtype, model_path, etc.)
Args->>User: Return parsed arguments
User->>Config: Resolve config.json location
Config->>User: Return config path
User->>Loader: Load HF model (if hf_model_path)
Loader->>User: Return model instance
User->>Quant: Quantize model (FP8/INT4/INT8/NVFP4)
Quant->>User: Return quantized model
User->>Export: Export to raw ONNX
Export->>Export: Apply LLM export (fp16 or quantized)
Export->>User: Return raw ONNX path
User->>Surgery: Apply graph surgery (dtype fixes, GQA, opset updates)
Surgery->>Surgery: Quantize weights to NVFP4<br/>(if NVFP4 mode)
Surgery->>Surgery: Apply GQA surgery<br/>(if hf_model_path provided)
Surgery->>Surgery: Fix logits shape & external data
Surgery->>User: Return optimized ONNX path
User->>Output: Save config.json alongside ONNX
Output->>User: Export complete
Estimated code review effort🎯 4 (Complex) | ⏱️ ~65 minutes Important Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional. ❌ Failed checks (1 error)
✅ Passed checks (3 passed)
✨ Finishing Touches
🧪 Generate unit tests (beta)
Comment |
…e in windows and changed torch and onnx export related files which were broken Signed-off-by: Hrishith Thadicherla <hthadicherla@nvidia.com>
6180e38 to
0b83d06
Compare
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/tensor_quant.py (1)
405-408:⚠️ Potential issue | 🔴 CriticalBackward gradient count mismatch after adding
onnx_quantizer_type.The
forwardmethod now accepts 11 input arguments (excludingctx), butbackwardreturnsnum_args=10. This will cause the gradient tuple to have insufficient elements. The count should be updated to 11.🐛 Proposed fix
`@staticmethod` def backward(ctx, grad_outputs): """Implements straight through estimation with clipping.""" - return _fake_quant_backward_function(ctx, grad_outputs, num_args=10) + return _fake_quant_backward_function(ctx, grad_outputs, num_args=11)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/tensor_quant.py` around lines 405 - 408, The backward implementation of the custom autograd Function (method backward) still calls _fake_quant_backward_function with num_args=10 although forward now accepts an additional onnx_quantizer_type argument (total 11 inputs excluding ctx); update the backward call in backward(ctx, grad_outputs) to pass num_args=11 so the returned gradient tuple matches the forward inputs (reference symbols: backward, forward, _fake_quant_backward_function, onnx_quantizer_type).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/windows/torch_onnx/llm_export/llm_export.py`:
- Around line 740-746: The current ONNX-only branch renames the external data
file (pre_gqa_data -> model.onnx_data) but does not update the ONNX protobuf
(pre_gqa_onnx/final_onnx), so the model still references the old filename; to
fix, either keep the external data filename unchanged (stop renaming
output_dir/pre_gqa_data) when you rename pre_gqa_onnx -> final_onnx, or load the
protobuf at pre_gqa_onnx and re-save it to final_onnx while updating the
external data location to the new name (so the external_data_location in the
model matches the renamed file); locate the branch handling pre_gqa_onnx,
final_onnx, output_dir and pre_gqa_data and apply one of these two fixes.
- Around line 726-736: The call to replace_attention_with_gqa(...) is hardcoding
max_seq_len=4096; instead read the model config (e.g., load the HuggingFace
config from hf_model_path or the extracted config.json) and derive the
context/window length (e.g., config.max_position_embeddings,
config.context_length, or config.max_seq_len depending on model family) and pass
that value as max_seq_len to replace_attention_with_gqa; update the invocation
that uses pre_gqa_onnx, final_onnx and hf_model_path to compute max_seq_len
before the call and fall back to a sensible default if the config field is
missing.
- Around line 291-322: get_config_path may return None but downstream code calls
shutil.copy(config_path, ...) and os.path.exists(config_path) unconditionally;
update the callers to handle a None config_path (or alternatively make
get_config_path raise a clear FileNotFoundError). Specifically, either (A)
change get_config_path to raise a descriptive exception when no config is found
so callers can fail-fast, or (B) more minimally, guard every use of config_path
(the shutil.copy(config_path, ...) calls and any os.path.exists(config_path)
checks) with an explicit if config_path is not None: ... else: log a warning and
skip the copy/check so ONNX-only runs without a colocated config.json degrade
gracefully. Ensure references to get_config_path and the
shutil.copy(config_path, ...) and os.path.exists(config_path) sites are updated
accordingly.
- Around line 786-809: The current logic overwrites a provided --onnx_path
because the hf_model_path branch always re-exports; change the precedence so
that if args.onnx_path is set you do not re-export. Concretely, keep assigning
raw_onnx_path from args.onnx_path when present, and guard the
ModelLoader/load_model and export_raw_llm calls behind "if args.hf_model_path
and not args.onnx_path" (or equivalent) so ModelLoader, model =
model_loader.load_model(...) and export_raw_llm(...) only run when no onnx_path
was supplied; refer to symbols args.onnx_path, args.hf_model_path,
raw_onnx_path, ModelLoader, load_model, and export_raw_llm to locate and update
the conditional logic.
- Around line 387-423: Pre-quantized local models skip ONNX/checkpoint export
because the block guarded by model_needs_quantization (computed from
modelopt_state) contains llm_to_onnx and export_hf_checkpoint; move or duplicate
the export steps so they run for already-quantized models too. Concretely:
adjust the control flow around model_needs_quantization/modelopt_state so that
after loading or skipping quantize you still call llm_to_onnx (when dtype in
{"fp8","int4_awq","int8_sq","nvfp4"} or otherwise needed) and run
export_hf_checkpoint into quantized_model_dir; keep existing calls to
quantize(), _override_trt_high_precision_dtype, and the dtype-specific Linear
handling (preserve symbols quantize, _override_trt_high_precision_dtype,
llm_to_onnx, export_hf_checkpoint, quantized_model_dir, output_dir) but ensure
exports are performed even when model_needs_quantization is False so
surgeon_llm/main can find output_dir/model.onnx.
- Around line 457-459: infer_shapes_path() is being called with only
raw_onnx_path which mutates the source file; change the call to write to a
separate temp file and load that instead (e.g., create a temporary path like
temp_inferred_path = tempfile.NamedTemporaryFile(suffix=".onnx",
delete=False).name or construct raw_onnx_path + ".inferred.onnx"), call
onnx.shape_inference.infer_shapes_path(raw_onnx_path,
output_path=temp_inferred_path), and then pass temp_inferred_path to
gs.import_onnx(onnx.load(...)) so the original raw_onnx_path (and the saved
original in {output_dir}_raw/) are not modified.
In `@modelopt/onnx/llm_export_utils/export_utils.py`:
- Around line 79-118: The monkey-patches (DynamicLayer.update,
transformers.masking_utils.create_causal_mask, per-model create_causal_mask, and
sdpa_mod.use_gqa_in_sdpa) are applied at import time and never restored; move
these mutations into the export wrapper's forward() method and restore originals
in a finally block. Specifically, in the class that implements forward() (the
exporter wrapper), capture the original symbols (DynamicLayer.update,
transformers.masking_utils.create_causal_mask, the model-specific
create_causal_mask in transformers.models.{model_type}.modeling_{model_type},
and transformers.integrations.sdpa_attention.use_gqa_in_sdpa), apply the patched
lambdas/ functions at the start of forward(), run the export logic, and always
reassign the originals back inside a finally clause so the global state is not
permanently mutated. Ensure you reference and patch the same symbols shown
(DynamicLayer.update, create_causal_mask, and sdpa_mod.use_gqa_in_sdpa) and
handle ImportError/ModuleNotFoundError when restoring the model-specific
create_causal_mask just as in the original diff.
In `@modelopt/torch/quantization/export_onnx.py`:
- Around line 203-206: The code is casting `inputs` instead of the variable
returned (`out`), so the cast is dead; update the block in export_onnx.py to
cast `out` (use g.op("Cast", out, to_i=onnx_dtype_map[input_type])) when
trt_high_precision_dtype != input_type, or if the cast is unnecessary remove the
entire if-block; ensure references are to `out`, `input_type`,
`trt_high_precision_dtype`, `onnx_dtype_map`, and the g.op("Cast") call so the
returned tensor has the intended dtype.
---
Outside diff comments:
In `@modelopt/torch/quantization/tensor_quant.py`:
- Around line 405-408: The backward implementation of the custom autograd
Function (method backward) still calls _fake_quant_backward_function with
num_args=10 although forward now accepts an additional onnx_quantizer_type
argument (total 11 inputs excluding ctx); update the backward call in
backward(ctx, grad_outputs) to pass num_args=11 so the returned gradient tuple
matches the forward inputs (reference symbols: backward, forward,
_fake_quant_backward_function, onnx_quantizer_type).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: cf6b1f93-b833-49d6-97b1-94f01c4b66de
📒 Files selected for processing (10)
examples/windows/torch_onnx/llm_export/README.mdexamples/windows/torch_onnx/llm_export/llm_export.pyexamples/windows/torch_onnx/llm_export/requirements.txtmodelopt/onnx/export/int4_exporter.pymodelopt/onnx/graph_surgery/__init__.pymodelopt/onnx/llm_export_utils/export_utils.pymodelopt/onnx/llm_export_utils/quantization_utils.pymodelopt/torch/quantization/export_onnx.pymodelopt/torch/quantization/nn/modules/tensor_quantizer.pymodelopt/torch/quantization/tensor_quant.py
| def get_config_path(args): | ||
| """Get config.json file path from the arguments. | ||
|
|
||
| The default priority is: config_path > hf_model_path/config.json > onnx_path/../config.json | ||
| """ | ||
| if args.config_path and os.path.exists(args.config_path): | ||
| return args.config_path | ||
| if args.hf_model_path: | ||
| if os.path.isdir(args.hf_model_path): | ||
| torch_config = os.path.join(args.hf_model_path, "config.json") | ||
| if os.path.exists(torch_config): | ||
| return torch_config | ||
| else: | ||
| try: | ||
| config = AutoConfig.from_pretrained( | ||
| args.hf_model_path, trust_remote_code=args.trust_remote_code | ||
| ) | ||
| temp_config_path = os.path.join( | ||
| tempfile.gettempdir(), f"config_{args.hf_model_path.replace('/', '_')}.json" | ||
| ) | ||
| with open(temp_config_path, "w") as f: | ||
| json.dump(config.to_dict(), f, indent=2) | ||
| return temp_config_path | ||
| except Exception as e: | ||
| print(f"Warning: Could not download config for {args.hf_model_path}: {e}") | ||
|
|
||
| if args.onnx_path: | ||
| onnx_config = os.path.join(os.path.dirname(args.onnx_path), "config.json") | ||
| if os.path.exists(onnx_config): | ||
| return onnx_config | ||
| print("Warning: cannot find config.json. Please pass in --config_path.") | ||
| return None |
There was a problem hiding this comment.
config_path is optional in the CLI, but not in the implementation.
get_config_path() returns None on this warning path, yet Line 381 and Line 418 unconditionally shutil.copy(config_path, ...), and Line 706 later calls os.path.exists(config_path). An ONNX-only run without a colocated config.json, or a model directory missing that file, will crash instead of degrading gracefully. Either fail fast here or make the downstream copy/check paths handle a missing config explicitly.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/windows/torch_onnx/llm_export/llm_export.py` around lines 291 - 322,
get_config_path may return None but downstream code calls
shutil.copy(config_path, ...) and os.path.exists(config_path) unconditionally;
update the callers to handle a None config_path (or alternatively make
get_config_path raise a clear FileNotFoundError). Specifically, either (A)
change get_config_path to raise a descriptive exception when no config is found
so callers can fail-fast, or (B) more minimally, guard every use of config_path
(the shutil.copy(config_path, ...) calls and any os.path.exists(config_path)
checks) with an explicit if config_path is not None: ... else: log a warning and
skip the copy/check so ONNX-only runs without a colocated config.json degrade
gracefully. Ensure references to get_config_path and the
shutil.copy(config_path, ...) and os.path.exists(config_path) sites are updated
accordingly.
| if os.path.isdir(hf_model_path): | ||
| modelopt_state = os.path.join(hf_model_path, "modelopt_state.pth") | ||
| model_needs_quantization = not os.path.exists(modelopt_state) | ||
| else: | ||
| model_needs_quantization = True | ||
|
|
||
| if model_needs_quantization: | ||
| model = quantize( | ||
| model, tokenizer, dtype, lm_head_precision, dataset_dir, calib_size=calib_size | ||
| ) | ||
|
|
||
| _override_trt_high_precision_dtype(model, "Half") | ||
|
|
||
| if dtype == "nvfp4": | ||
| for module in model.modules(): | ||
| assert not isinstance(module, torch.nn.Linear) or is_quantized_linear(module) | ||
| if isinstance(module, torch.nn.Linear): | ||
| module.input_quantizer._trt_high_precision_dtype = "Half" | ||
| module.input_quantizer._onnx_quantizer_type = "dynamic" | ||
| module.weight_quantizer._onnx_quantizer_type = "static" | ||
|
|
||
| if dtype in {"fp8", "int4_awq", "int8_sq", "nvfp4"}: | ||
| print(f"Exporting {dtype} ONNX model from quantized PyTorch model...") | ||
| llm_to_onnx( | ||
| wrapper_cls( | ||
| model, | ||
| ), | ||
| output_dir, | ||
| extra_inputs=extra_inputs, | ||
| extra_dyn_axes=extra_dyn_axes, | ||
| ) | ||
| shutil.copy(config_path, os.path.join(output_dir, "config.json")) | ||
|
|
||
| quantized_model_dir = f"{output_dir}_{dtype}_quantized" | ||
| os.makedirs(quantized_model_dir, exist_ok=True) | ||
| with torch.inference_mode(): | ||
| export_hf_checkpoint(model, dtype=torch.float16, export_dir=quantized_model_dir) |
There was a problem hiding this comment.
Already-quantized local models never produce model.onnx.
When modelopt_state.pth exists, model_needs_quantization is false and this block skips both llm_to_onnx() and export_hf_checkpoint(). main() still points surgeon_llm() at <output_dir>/model.onnx, so the pre-quantized checkpoint path fails before surgery.
♻️ Suggested control-flow fix
if model_needs_quantization:
model = quantize(
model, tokenizer, dtype, lm_head_precision, dataset_dir, calib_size=calib_size
)
_override_trt_high_precision_dtype(model, "Half")
if dtype == "nvfp4":
for module in model.modules():
assert not isinstance(module, torch.nn.Linear) or is_quantized_linear(module)
if isinstance(module, torch.nn.Linear):
module.input_quantizer._trt_high_precision_dtype = "Half"
module.input_quantizer._onnx_quantizer_type = "dynamic"
module.weight_quantizer._onnx_quantizer_type = "static"
- if dtype in {"fp8", "int4_awq", "int8_sq", "nvfp4"}:
- print(f"Exporting {dtype} ONNX model from quantized PyTorch model...")
- llm_to_onnx(
- wrapper_cls(
- model,
- ),
- output_dir,
- extra_inputs=extra_inputs,
- extra_dyn_axes=extra_dyn_axes,
- )
- shutil.copy(config_path, os.path.join(output_dir, "config.json"))
-
+ print(f"Exporting {dtype} ONNX model from quantized PyTorch model...")
+ llm_to_onnx(
+ wrapper_cls(model),
+ output_dir,
+ extra_inputs=extra_inputs,
+ extra_dyn_axes=extra_dyn_axes,
+ )
+ if config_path:
+ shutil.copy(config_path, os.path.join(output_dir, "config.json"))
+
+ if model_needs_quantization:
quantized_model_dir = f"{output_dir}_{dtype}_quantized"
os.makedirs(quantized_model_dir, exist_ok=True)
with torch.inference_mode():
export_hf_checkpoint(model, dtype=torch.float16, export_dir=quantized_model_dir)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/windows/torch_onnx/llm_export/llm_export.py` around lines 387 - 423,
Pre-quantized local models skip ONNX/checkpoint export because the block guarded
by model_needs_quantization (computed from modelopt_state) contains llm_to_onnx
and export_hf_checkpoint; move or duplicate the export steps so they run for
already-quantized models too. Concretely: adjust the control flow around
model_needs_quantization/modelopt_state so that after loading or skipping
quantize you still call llm_to_onnx (when dtype in
{"fp8","int4_awq","int8_sq","nvfp4"} or otherwise needed) and run
export_hf_checkpoint into quantized_model_dir; keep existing calls to
quantize(), _override_trt_high_precision_dtype, and the dtype-specific Linear
handling (preserve symbols quantize, _override_trt_high_precision_dtype,
llm_to_onnx, export_hf_checkpoint, quantized_model_dir, output_dir) but ensure
exports are performed even when model_needs_quantization is False so
surgeon_llm/main can find output_dir/model.onnx.
| t0 = time.time() | ||
| onnx.shape_inference.infer_shapes_path(raw_onnx_path) | ||
| graph = gs.import_onnx(onnx.load(raw_onnx_path)) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, locate and examine the file
fd "llm_export.py" examples/ --type fRepository: NVIDIA/Model-Optimizer
Length of output: 152
🏁 Script executed:
# Once found, examine the relevant code section
wc -l examples/windows/torch_onnx/llm_export/llm_export.pyRepository: NVIDIA/Model-Optimizer
Length of output: 122
🏁 Script executed:
# Read the relevant section and surrounding context (lines 450-470)
sed -n '450,470p' examples/windows/torch_onnx/llm_export/llm_export.pyRepository: NVIDIA/Model-Optimizer
Length of output: 765
🏁 Script executed:
# Search for --save_original and --onnx_path to understand the intent
rg "save_original|onnx_path" examples/windows/torch_onnx/llm_export/llm_export.py -A 2 -B 2Repository: NVIDIA/Model-Optimizer
Length of output: 2022
🏁 Script executed:
# Check what happens with the output after surgery - where is it exported?
sed -n '450,480p' examples/windows/torch_onnx/llm_export/llm_export.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1057
🏁 Script executed:
# Look for where the final ONNX is exported/saved
sed -n '460,500p' examples/windows/torch_onnx/llm_export/llm_export.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1647
🌐 Web query:
ONNX shape_inference.infer_shapes_path output_path parameter behavior documentation
💡 Result:
In onnx.shape_inference.infer_shapes_path(model_path, output_path=...), output_path defaults to the empty string "". If you leave it as "", ONNX replaces it with model_path internally (if output_path == "": output_path = model_path) and then writes the inferred model to that path—i.e., it overwrites the input file. [1]
The API docs describe the same behavior: it “outputs the inferred model to the output_path” and “the original model path is used if not specified.” [2]
Sources
[1] ONNX 1.22.0 onnx.shape_inference source code (infer_shapes_path)
[2] ONNX 1.22.0 onnx.shape_inference.infer_shapes_path API documentation
Call infer_shapes_path() with an explicit output_path to avoid mutating the input file.
onnx.shape_inference.infer_shapes_path() without an output_path parameter overwrites the input file in-place. This violates the semantics of --save_original (the file in {output_dir}_raw/ should remain the original export, not the shape-inferred version) and modifies user-supplied files when --onnx_path is used, which users do not expect.
Specify a temporary or separate output path for shape inference, or use an alternative approach that does not mutate the source.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/windows/torch_onnx/llm_export/llm_export.py` around lines 457 - 459,
infer_shapes_path() is being called with only raw_onnx_path which mutates the
source file; change the call to write to a separate temp file and load that
instead (e.g., create a temporary path like temp_inferred_path =
tempfile.NamedTemporaryFile(suffix=".onnx", delete=False).name or construct
raw_onnx_path + ".inferred.onnx"), call
onnx.shape_inference.infer_shapes_path(raw_onnx_path,
output_path=temp_inferred_path), and then pass temp_inferred_path to
gs.import_onnx(onnx.load(...)) so the original raw_onnx_path (and the saved
original in {output_dir}_raw/) are not modified.
| replace_attention_with_gqa( | ||
| model_path=pre_gqa_onnx, | ||
| output_path=final_onnx, | ||
| hf_model_id=hf_model_path, | ||
| max_seq_len=4096, | ||
| io_dtype="float16", | ||
| use_external_data=True, | ||
| external_data_name="model.onnx_data", | ||
| ir_version=10, | ||
| trust_remote_code=trust_remote_code, | ||
| ) |
There was a problem hiding this comment.
Don’t hardcode max_seq_len=4096 for every model.
replace_attention_with_gqa() takes max_seq_len because the rewritten attention graph depends on model context length. Baking 4096 here will undersize models whose config advertises a different window and can misbuild the RoPE/cache state in the exported graph. Please derive this from config.json instead.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/windows/torch_onnx/llm_export/llm_export.py` around lines 726 - 736,
The call to replace_attention_with_gqa(...) is hardcoding max_seq_len=4096;
instead read the model config (e.g., load the HuggingFace config from
hf_model_path or the extracted config.json) and derive the context/window length
(e.g., config.max_position_embeddings, config.context_length, or
config.max_seq_len depending on model family) and pass that value as max_seq_len
to replace_attention_with_gqa; update the invocation that uses pre_gqa_onnx,
final_onnx and hf_model_path to compute max_seq_len before the call and fall
back to a sensible default if the config field is missing.
| if args.onnx_path: | ||
| raw_onnx_path = args.onnx_path | ||
|
|
||
| model_loader = ModelLoader(args.hf_model_path, args.config_path) | ||
|
|
||
| if args.hf_model_path: | ||
| model = model_loader.load_model(trust_remote_code=args.trust_remote_code) | ||
| onnx_dir = args.output_dir + "_raw" if args.save_original else args.output_dir | ||
| raw_onnx_path = f"{onnx_dir}/model.onnx" | ||
| extra_inputs, extra_dyn_axes = {}, {} | ||
| export_raw_llm( | ||
| model=model, | ||
| output_dir=onnx_dir, | ||
| dtype=args.dtype, | ||
| config_path=args.config_path, | ||
| hf_model_path=args.hf_model_path, | ||
| lm_head_precision=args.lm_head, | ||
| dataset_dir=args.dataset_dir, | ||
| wrapper_cls=WrapperModelForCausalLM, | ||
| extra_inputs=extra_inputs, | ||
| extra_dyn_axes=extra_dyn_axes, | ||
| calib_size=args.calib_size, | ||
| trust_remote_code=args.trust_remote_code, | ||
| ) |
There was a problem hiding this comment.
--onnx_path does not actually skip export.
The help text says this flag should reuse an existing ONNX, but as soon as hf_model_path is also present this branch loads the HF model, re-exports model.onnx, and overwrites raw_onnx_path. That makes it impossible to reuse an existing ONNX while still passing hf_model_path only for GQA metadata/config.
🐛 Suggested precedence fix
- if args.onnx_path:
- raw_onnx_path = args.onnx_path
-
- model_loader = ModelLoader(args.hf_model_path, args.config_path)
-
- if args.hf_model_path:
+ if args.onnx_path:
+ raw_onnx_path = args.onnx_path
+ elif args.hf_model_path:
+ model_loader = ModelLoader(args.hf_model_path, args.config_path)
model = model_loader.load_model(trust_remote_code=args.trust_remote_code)
onnx_dir = args.output_dir + "_raw" if args.save_original else args.output_dir
raw_onnx_path = f"{onnx_dir}/model.onnx"
extra_inputs, extra_dyn_axes = {}, {}
export_raw_llm(🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/windows/torch_onnx/llm_export/llm_export.py` around lines 786 - 809,
The current logic overwrites a provided --onnx_path because the hf_model_path
branch always re-exports; change the precedence so that if args.onnx_path is set
you do not re-export. Concretely, keep assigning raw_onnx_path from
args.onnx_path when present, and guard the ModelLoader/load_model and
export_raw_llm calls behind "if args.hf_model_path and not args.onnx_path" (or
equivalent) so ModelLoader, model = model_loader.load_model(...) and
export_raw_llm(...) only run when no onnx_path was supplied; refer to symbols
args.onnx_path, args.hf_model_path, raw_onnx_path, ModelLoader, load_model, and
export_raw_llm to locate and update the conditional logic.
| # Patch DynamicLayer.lazy_initialization so it does NOT create empty | ||
| # tensors (which torch.jit.trace bakes as constants). Instead, set | ||
| # keys/values to None; the update() cat path handles the rest. | ||
| from transformers.cache_utils import DynamicLayer | ||
|
|
||
| def _patched_update(self_layer, key_states, value_states, cache_kwargs=None): | ||
| if not self_layer.is_initialized: | ||
| self_layer.dtype = key_states.dtype | ||
| self_layer.device = key_states.device | ||
| self_layer.is_initialized = True | ||
| self_layer.keys = key_states | ||
| self_layer.values = value_states | ||
| return self_layer.keys, self_layer.values | ||
| self_layer.keys = torch.cat([self_layer.keys, key_states], dim=-2) | ||
| self_layer.values = torch.cat([self_layer.values, value_states], dim=-2) | ||
| return self_layer.keys, self_layer.values | ||
|
|
||
| DynamicLayer.update = _patched_update | ||
|
|
||
| # Monkey-patch create_causal_mask to return None during export. | ||
| # This avoids baking mask shapes as constants during JIT tracing. | ||
| # SDPA uses is_causal=True internally so the explicit mask is unnecessary. | ||
| import importlib | ||
|
|
||
| import transformers.masking_utils | ||
|
|
||
| setattr(transformers.masking_utils, "create_causal_mask", lambda *args, **kwargs: None) | ||
| model_type = getattr(self.config, "model_type", "llama") | ||
| try: | ||
| mod = importlib.import_module(f"transformers.models.{model_type}.modeling_{model_type}") | ||
| setattr(mod, "create_causal_mask", lambda *args, **kwargs: None) | ||
| except (ImportError, ModuleNotFoundError): | ||
| pass | ||
|
|
||
| # Force use_gqa_in_sdpa to return False so SDPA does manual repeat_kv | ||
| # instead of using enable_gqa=True (which torch.onnx.export doesn't support). | ||
| # With attention_mask=None and enable_gqa=False, SDPA uses is_causal=True. | ||
| import transformers.integrations.sdpa_attention as sdpa_mod | ||
|
|
||
| sdpa_mod.use_gqa_in_sdpa = lambda *args, **kwargs: False |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's see the full context of the file
cat -n modelopt/onnx/llm_export_utils/export_utils.py | head -200Repository: NVIDIA/Model-Optimizer
Length of output: 9430
🏁 Script executed:
# Also check the referenced pattern file for comparison
head -250 modelopt/torch/quantization/plugins/huggingface.py | tail -100Repository: NVIDIA/Model-Optimizer
Length of output: 4828
🏁 Script executed:
# Let's see if there's any cleanup mechanism in the WrapperModelForCausalLM class
grep -A 50 "class WrapperModelForCausalLM" modelopt/onnx/llm_export_utils/export_utils.py | head -80Repository: NVIDIA/Model-Optimizer
Length of output: 2491
Move these monkey patches into forward() and restore them in a finally block.
Lines 96, 105, 109, and 118 mutate transformers globals from __init__, and no restoration mechanism exists. This causes the patches to persist indefinitely at the process level—once one wrapper is instantiated, all subsequent inference/export in the same process uses the patched behavior, and if export fails after construction, the process remains permanently patched. Move these patches into the forward() method and use try/finally to restore the originals, following the pattern already established in modelopt/torch/quantization/plugins/huggingface.py (lines 166–214).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/onnx/llm_export_utils/export_utils.py` around lines 79 - 118, The
monkey-patches (DynamicLayer.update,
transformers.masking_utils.create_causal_mask, per-model create_causal_mask, and
sdpa_mod.use_gqa_in_sdpa) are applied at import time and never restored; move
these mutations into the export wrapper's forward() method and restore originals
in a finally block. Specifically, in the class that implements forward() (the
exporter wrapper), capture the original symbols (DynamicLayer.update,
transformers.masking_utils.create_causal_mask, the model-specific
create_causal_mask in transformers.models.{model_type}.modeling_{model_type},
and transformers.integrations.sdpa_attention.use_gqa_in_sdpa), apply the patched
lambdas/ functions at the start of forward(), run the export logic, and always
reassign the originals back inside a finally clause so the global state is not
permanently mutated. Ensure you reference and patch the same symbols shown
(DynamicLayer.update, create_causal_mask, and sdpa_mod.use_gqa_in_sdpa) and
handle ImportError/ModuleNotFoundError when restoring the model-specific
create_causal_mask just as in the original diff.
| # Cast the output back if needed. | ||
| if trt_high_precision_dtype != input_type: | ||
| inputs = g.op("Cast", inputs, to_i=onnx_dtype_map[input_type]) | ||
|
|
There was a problem hiding this comment.
Dead code: casting inputs instead of out has no effect.
The comment says "Cast the output back if needed" but the code casts inputs (line 205), not out. Since out is what's returned (line 207), this cast has no effect. Either remove this dead code or fix it to cast out if the intent was to restore the original dtype.
🔧 Option 1: Remove dead code
- # Cast the output back if needed.
- if trt_high_precision_dtype != input_type:
- inputs = g.op("Cast", inputs, to_i=onnx_dtype_map[input_type])
-
return out🔧 Option 2: Cast output if needed
# Cast the output back if needed.
if trt_high_precision_dtype != input_type:
- inputs = g.op("Cast", inputs, to_i=onnx_dtype_map[input_type])
+ out = g.op("Cast", out, to_i=onnx_dtype_map[input_type])
return out📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Cast the output back if needed. | |
| if trt_high_precision_dtype != input_type: | |
| inputs = g.op("Cast", inputs, to_i=onnx_dtype_map[input_type]) | |
| return out |
| # Cast the output back if needed. | |
| if trt_high_precision_dtype != input_type: | |
| inputs = g.op("Cast", inputs, to_i=onnx_dtype_map[input_type]) | |
| # Cast the output back if needed. | |
| if trt_high_precision_dtype != input_type: | |
| out = g.op("Cast", out, to_i=onnx_dtype_map[input_type]) | |
| return out |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/quantization/export_onnx.py` around lines 203 - 206, The code
is casting `inputs` instead of the variable returned (`out`), so the cast is
dead; update the block in export_onnx.py to cast `out` (use g.op("Cast", out,
to_i=onnx_dtype_map[input_type])) when trt_high_precision_dtype != input_type,
or if the cast is unnecessary remove the entire if-block; ensure references are
to `out`, `input_type`, `trt_high_precision_dtype`, `onnx_dtype_map`, and the
g.op("Cast") call so the returned tensor has the intended dtype.
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/torch/quantization/tensor_quant.py (1)
404-408:⚠️ Potential issue | 🟠 Major
backwardreturns incorrectnum_argsafter addingonnx_quantizer_type.The
forwardmethod now has 11 parameters afterctx(inputs, amax, bias, num_bits, unsigned, narrow_range, trt_high_precision_dtype, pass_through_bwd, block_size, axis, onnx_quantizer_type), butbackwardreturnsnum_args=10. This mismatch will cause incorrect gradient propagation.🐛 Proposed fix
`@staticmethod` def backward(ctx, grad_outputs): """Implements straight through estimation with clipping.""" - return _fake_quant_backward_function(ctx, grad_outputs, num_args=10) + return _fake_quant_backward_function(ctx, grad_outputs, num_args=11)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/tensor_quant.py` around lines 404 - 408, The backward staticmethod is calling _fake_quant_backward_function with num_args=10 but forward now accepts 11 args after ctx (inputs, amax, bias, num_bits, unsigned, narrow_range, trt_high_precision_dtype, pass_through_bwd, block_size, axis, onnx_quantizer_type); update the call in backward to use num_args=11 (or compute it dynamically from the forward signature) so the number of saved tensor/argument slots passed to _fake_quant_backward_function matches forward; modify the call to _fake_quant_backward_function(ctx, grad_outputs, num_args=11) in the backward method.modelopt/onnx/llm_export_utils/quantization_utils.py (1)
113-114:⚠️ Potential issue | 🔴 Critical
# nosec B105comment violates coding guidelines.Per the coding guidelines, any use of
# noseccomments to bypass Bandit security checks is not allowed. If this security-sensitive pattern is genuinely necessary, the PR must be reviewed and approved by@NVIDIA/modelopt-setup-codeownerswith an explicit justification in the PR description.The comparison
tokenizer.pad_token != "<unk>"triggers B105 (hardcoded password string), but since this is clearly not a password but a token comparison, consider refactoring to avoid the Bandit false positive:🔧 Suggested refactor to avoid nosec
- if tokenizer.pad_token != "<unk>": # nosec B105 + UNK_TOKEN = "<unk>" + if tokenizer.pad_token != UNK_TOKEN: tokenizer.pad_token = tokenizer.eos_tokenAs per coding guidelines: "Any use of '# nosec' comments to bypass Bandit security checks is not allowed."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/onnx/llm_export_utils/quantization_utils.py` around lines 113 - 114, Remove the '# nosec B105' bypass and eliminate the hardcoded string comparison by replacing the literal "<unk>" with a well-named constant or tokenizer attribute; specifically update the conditional that references tokenizer.pad_token (in quantization_utils.py) to compare against a constant like UNK_TOKEN or, better, against tokenizer.unk_token if that attribute exists, then set tokenizer.pad_token = tokenizer.eos_token as before—this removes the Bandit false positive without suppressing the check and preserves the existing behavior in the block that assigns tokenizer.pad_token from tokenizer.eos_token.
♻️ Duplicate comments (1)
examples/windows/torch_onnx/llm_export/llm_export.py (1)
789-789:⚠️ Potential issue | 🟠 Major
ModelLoaderinstantiation fails when only--onnx_pathis provided.Line 789 unconditionally creates
ModelLoader(args.hf_model_path, args.config_path), but when--onnx_pathis provided without--hf_model_path,args.hf_model_pathisNone. This will causeModelLoader.get_model_type()to fail when trying to openself.config_path(which may also beNone).🐛 Proposed fix
if args.onnx_path: raw_onnx_path = args.onnx_path - model_loader = ModelLoader(args.hf_model_path, args.config_path) - if args.hf_model_path: + model_loader = ModelLoader(args.hf_model_path, args.config_path) model = model_loader.load_model(trust_remote_code=args.trust_remote_code)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/windows/torch_onnx/llm_export/llm_export.py` at line 789, The code always constructs ModelLoader(args.hf_model_path, args.config_path) even when only args.onnx_path is supplied, causing ModelLoader.get_model_type() to open a None config; fix by making the ModelLoader instantiation conditional: if args.hf_model_path or args.config_path is provided instantiate ModelLoader(hf_model_path, config_path) and use its get_model_type(), otherwise skip creating ModelLoader and branch to the ONNX-only path (using args.onnx_path) or pass a safe default to downstream logic; reference ModelLoader, get_model_type, args.hf_model_path, args.config_path, and args.onnx_path when updating the control flow.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/onnx/llm_export_utils/export_utils.py`:
- Around line 122-124: The current call to
DynamicCache(ddp_cache_data=past_key_values) is invalid; instantiate the cache
as DynamicCache(config=self.config) (e.g., cache =
DynamicCache(config=self.config>) and then seed it with existing past_key_values
using the correct API (for example a factory like
DynamicCache.from_past_key_values(past_key_values, config=self.config) or a
setter such as cache.set_past_key_values(past_key_values) /
cache.update_from_past_key_values(...) depending on the available methods); then
pass that cache into self.model(input_ids=..., past_key_values=cache,
use_cache=True). Ensure you convert past_key_values into the cache's expected
internal format if needed and prefer the library-provided constructor/factory
rather than the invalid ddp_cache_data parameter.
---
Outside diff comments:
In `@modelopt/onnx/llm_export_utils/quantization_utils.py`:
- Around line 113-114: Remove the '# nosec B105' bypass and eliminate the
hardcoded string comparison by replacing the literal "<unk>" with a well-named
constant or tokenizer attribute; specifically update the conditional that
references tokenizer.pad_token (in quantization_utils.py) to compare against a
constant like UNK_TOKEN or, better, against tokenizer.unk_token if that
attribute exists, then set tokenizer.pad_token = tokenizer.eos_token as
before—this removes the Bandit false positive without suppressing the check and
preserves the existing behavior in the block that assigns tokenizer.pad_token
from tokenizer.eos_token.
In `@modelopt/torch/quantization/tensor_quant.py`:
- Around line 404-408: The backward staticmethod is calling
_fake_quant_backward_function with num_args=10 but forward now accepts 11 args
after ctx (inputs, amax, bias, num_bits, unsigned, narrow_range,
trt_high_precision_dtype, pass_through_bwd, block_size, axis,
onnx_quantizer_type); update the call in backward to use num_args=11 (or compute
it dynamically from the forward signature) so the number of saved
tensor/argument slots passed to _fake_quant_backward_function matches forward;
modify the call to _fake_quant_backward_function(ctx, grad_outputs, num_args=11)
in the backward method.
---
Duplicate comments:
In `@examples/windows/torch_onnx/llm_export/llm_export.py`:
- Line 789: The code always constructs ModelLoader(args.hf_model_path,
args.config_path) even when only args.onnx_path is supplied, causing
ModelLoader.get_model_type() to open a None config; fix by making the
ModelLoader instantiation conditional: if args.hf_model_path or args.config_path
is provided instantiate ModelLoader(hf_model_path, config_path) and use its
get_model_type(), otherwise skip creating ModelLoader and branch to the
ONNX-only path (using args.onnx_path) or pass a safe default to downstream
logic; reference ModelLoader, get_model_type, args.hf_model_path,
args.config_path, and args.onnx_path when updating the control flow.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: de0b28f6-9c77-4067-8539-e7e0acb2e7a9
📒 Files selected for processing (10)
examples/windows/torch_onnx/llm_export/README.mdexamples/windows/torch_onnx/llm_export/llm_export.pyexamples/windows/torch_onnx/llm_export/requirements.txtmodelopt/onnx/export/int4_exporter.pymodelopt/onnx/graph_surgery/__init__.pymodelopt/onnx/llm_export_utils/export_utils.pymodelopt/onnx/llm_export_utils/quantization_utils.pymodelopt/torch/quantization/export_onnx.pymodelopt/torch/quantization/nn/modules/tensor_quantizer.pymodelopt/torch/quantization/tensor_quant.py
✅ Files skipped from review due to trivial changes (1)
- examples/windows/torch_onnx/llm_export/README.md
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/quantization/nn/modules/tensor_quantizer.py
| # (inserted when trt_high_precision_dtype is set to "Half") | ||
| if matmul_node.op_type == "Cast": | ||
| cast_after_transpose = matmul_node | ||
| nodes_to_remove.append(cast_after_transpose.name) |
There was a problem hiding this comment.
Don't key node removal off NodeProto.name.
These new append(...name) calls make the existing deletion logic easier to break on valid unnamed ONNX nodes. If either Cast has name == "", the later node.name not in nodes_to_remove filter will drop every unnamed node in the graph, not just this cast. Remove by object identity or by a unique output tensor name instead.
Also applies to: 273-273
| cache = DynamicCache(ddp_cache_data=past_key_values, config=self.config) | ||
|
|
||
| outputs = self.model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True) | ||
| outputs = self.model(input_ids=input_ids, past_key_values=cache, use_cache=True) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Search for DynamicCache constructor signature in transformers
pip show transformers | grep Version
# Check the DynamicCache class definition
python3 -c "from transformers import DynamicCache; import inspect; print(inspect.signature(DynamicCache.__init__))" 2>/dev/null || echo "Could not inspect DynamicCache"Repository: NVIDIA/Model-Optimizer
Length of output: 157
🏁 Script executed:
# First, find the file mentioned in the review
find . -name "export_utils.py" -path "*/onnx/llm_export_utils/*" | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 114
🏁 Script executed:
# Search for DynamicCache imports and definitions in the repository
rg "DynamicCache" --type py -B 2 -A 2 | head -50Repository: NVIDIA/Model-Optimizer
Length of output: 2982
🏁 Script executed:
# Check if there's a custom DynamicCache implementation in the repo
fd "*.py" | xargs rg "class DynamicCache" 2>/dev/nullRepository: NVIDIA/Model-Optimizer
Length of output: 351
🏁 Script executed:
# Look at the specific file context around lines 122-124
if [ -f "modelopt/onnx/llm_export_utils/export_utils.py" ]; then
sed -n '110,135p' "modelopt/onnx/llm_export_utils/export_utils.py"
fiRepository: NVIDIA/Model-Optimizer
Length of output: 1133
🏁 Script executed:
# Check DynamicCache documentation or usage in transformers version info
rg "class DynamicCache" --type py -A 20 2>/dev/null | head -100Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Search for how cache is initialized or used with past_key_values elsewhere
rg "past_key_values" modelopt/onnx/llm_export_utils/export_utils.py -B 3 -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 1606
🏁 Script executed:
# Check if there are any comments or hints about the cache initialization
cat -n modelopt/onnx/llm_export_utils/export_utils.py | sed -n '100,150p'Repository: NVIDIA/Model-Optimizer
Length of output: 2683
🏁 Script executed:
# Search for any git history or comments about this ddp_cache_data parameter
rg "ddp_cache_data" --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 193
🏁 Script executed:
# Check if there's documentation or tests for WrapperModelForCausalLM
rg "WrapperModelForCausalLM" --type py -B 2 -A 10 | head -80Repository: NVIDIA/Model-Optimizer
Length of output: 5775
🏁 Script executed:
# Look for how DynamicCache is properly used with cache data in transformers patterns
rg "DynamicCache\(" --type py -A 1 | head -40Repository: NVIDIA/Model-Optimizer
Length of output: 532
Fix invalid DynamicCache parameter ddp_cache_data.
The ddp_cache_data parameter is not a valid parameter for the transformers DynamicCache constructor. Standard usage in this codebase and across transformers shows DynamicCache(config=config). Verify the intended cache initialization pattern—this may require converting past_key_values to the proper cache format or using a different initialization approach.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/onnx/llm_export_utils/export_utils.py` around lines 122 - 124, The
current call to DynamicCache(ddp_cache_data=past_key_values) is invalid;
instantiate the cache as DynamicCache(config=self.config) (e.g., cache =
DynamicCache(config=self.config>) and then seed it with existing past_key_values
using the correct API (for example a factory like
DynamicCache.from_past_key_values(past_key_values, config=self.config) or a
setter such as cache.set_past_key_values(past_key_values) /
cache.update_from_past_key_values(...) depending on the available methods); then
pass that cache into self.model(input_ids=..., past_key_values=cache,
use_cache=True). Ensure you convert past_key_values into the cache's expected
internal format if needed and prefer the library-provided constructor/factory
rather than the invalid ddp_cache_data parameter.
Signed-off-by: Hrishith Thadicherla <hthadicherla@nvidia.com>
Signed-off-by: Hrishith Thadicherla <hthadicherla@nvidia.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1027 +/- ##
==========================================
+ Coverage 70.09% 70.10% +0.01%
==========================================
Files 221 221
Lines 25459 25463 +4
==========================================
+ Hits 17845 17852 +7
+ Misses 7614 7611 -3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…y use hidden_size // num_attention_heads when head_dim is not specified Signed-off-by: Hrishith Thadicherla <hthadicherla@nvidia.com>
What does this PR do?
Type of change: new example and bug fix
Added example for torch ptq followed by onnx export followed by GQA graph surgery (for replacing the entire attention subgraph with one single custom node in graph) in windows .
Fixed some issues with the existing export path like the past key values were not reflected as inputs in final model because it was given in a different format during export, and also int8 smooth quant export path where the both the activations and weights are quantized to QDQ instead of DQ only for weights and QDQ for attention.
Usage
NVFP4:
INT4_AWQ:
INT8 Smooth Quant:
Summary by CodeRabbit