Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion examples/torch_onnx/llm_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,26 @@ def main(args):
onnx_dir = args.output_dir + "_raw" if args.save_original else args.output_dir
# Surgeon graph based on precision
raw_onnx_path = f"{onnx_dir}/model.onnx"
extra_inputs, extra_dyn_axes = {}, {}

batch_size = 1
seq_len = 8
dummy_input_ids = torch.randint(0, 1000, (batch_size, seq_len), dtype=torch.long)
dummy_attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)
dummy_position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1)

# Correct assignment — no trailing comma
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a non-UTF-8 character at this line

extra_inputs = {
"input_ids": dummy_input_ids,
"attention_mask": dummy_attention_mask,
"position_ids": dummy_position_ids,
}
extra_dyn_axes = {
"input_ids": {0: "batch", 1: "seq_len"},
"attention_mask": {0: "batch", 1: "seq_len"},
"position_ids": {0: "batch", 1: "seq_len"},
"logits": {0: "batch", 1: "seq_len"},
}
Comment on lines +374 to +384
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Read-only verification of signature/call-path mismatch
rg -n -C3 'def forward\(self, input_ids.*past_key_values' modelopt/onnx/llm_export_utils/export_utils.py
rg -n -C8 'torch_to_onnx\(' modelopt/onnx/llm_export_utils/export_utils.py
rg -n -C8 'extra_inputs = \{|extra_dyn_axes = \{' examples/torch_onnx/llm_export.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2743


extra_inputs breaks the export call contract and causes runtime failure.

The wrapper's forward signature only accepts input_ids and past_key_values, but extra_inputs passes input_ids, attention_mask, and position_ids as kwargs. When expanded via **extra_inputs in the torch_to_onnx call at line 134, the model's forward receives unexpected keyword arguments (attention_mask, position_ids), causing a TypeError. Additionally, input_ids would be passed twice — once positionally and once as a kwarg, which is invalid.

Suggested minimal fix
-        extra_inputs = {
-            "input_ids": dummy_input_ids,
-            "attention_mask": dummy_attention_mask,
-            "position_ids": dummy_position_ids,
-        }
-        extra_dyn_axes = {
-            "input_ids": {0: "batch", 1: "seq_len"},
-            "attention_mask": {0: "batch", 1: "seq_len"},
-            "position_ids": {0: "batch", 1: "seq_len"},
-            "logits": {0: "batch", 1: "seq_len"},
-        }
+        extra_inputs = {}
+        extra_dyn_axes = {"logits": {0: "batch_size", 1: "seq_len"}}

If attention_mask and position_ids are required as ONNX inputs, coordinate updates in WrapperModelForCausalLM.forward signature and llm_to_onnx plumbing first.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/torch_onnx/llm_export.py` around lines 374 - 384, extra_inputs
currently supplies attention_mask and position_ids which don't exist in
WrapperModelForCausalLM.forward (only input_ids and past_key_values), causing
unexpected kwargs and duplicate input_ids; either update
WrapperModelForCausalLM.forward and the llm_to_onnx/torch_to_onnx plumbing to
accept attention_mask/position_ids, or minimally fix extra_inputs to only
include input_ids (and adjust extra_dyn_axes to remove
attention_mask/position_ids) so the call that expands **extra_inputs matches the
wrapper signature and does not pass input_ids twice; update references where
extra_inputs and extra_dyn_axes are used in the torch_to_onnx/llm_to_onnx flow
accordingly.


export_raw_llm(
model=model,
output_dir=onnx_dir,
Expand Down