diff --git a/examples/torch_onnx/llm_export.py b/examples/torch_onnx/llm_export.py index 31954c057c..a9c7bba01c 100644 --- a/examples/torch_onnx/llm_export.py +++ b/examples/torch_onnx/llm_export.py @@ -363,7 +363,26 @@ def main(args): onnx_dir = args.output_dir + "_raw" if args.save_original else args.output_dir # Surgeon graph based on precision raw_onnx_path = f"{onnx_dir}/model.onnx" - extra_inputs, extra_dyn_axes = {}, {} + + batch_size = 1 + seq_len = 8 + dummy_input_ids = torch.randint(0, 1000, (batch_size, seq_len), dtype=torch.long) + dummy_attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long) + dummy_position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1) + + # Correct assignment — no trailing comma + extra_inputs = { + "input_ids": dummy_input_ids, + "attention_mask": dummy_attention_mask, + "position_ids": dummy_position_ids, + } + extra_dyn_axes = { + "input_ids": {0: "batch", 1: "seq_len"}, + "attention_mask": {0: "batch", 1: "seq_len"}, + "position_ids": {0: "batch", 1: "seq_len"}, + "logits": {0: "batch", 1: "seq_len"}, + } + export_raw_llm( model=model, output_dir=onnx_dir,