Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Changelog

**Bug Fixes**

- Fix ``ShapeInferenceError`` during ONNX INT8 + FP16 quantization (``--high_precision_dtype fp16``) of weakly-typed models (e.g. TensorFlow exports) that carry stale rank-0 ``graph.output`` shapes or ops such as ``TopK`` that ONNX's static shape inference cannot resolve. ``clear_stale_value_info`` now reconciles stale output shapes via symbolic shape inference (keeping every output's shape field populated), and AutoCast runs ONNX shape inference in strict mode and falls back to schema-based standalone type inference when it fails, so unresolved ops no longer leave tensors untyped.
- In Megatron-Core only do EP amax sync for routed expert weights if ``sync_expert_weight_amax=True``. Previously EP amax sync would sync routed expert weights across EP ranks even when ``sync_expert_weight_amax`` was False.
- Fix Megatron-Core HF importer to load fused ``TELayerNormColumnParallelLinear.layer_norm_weight`` from HF for GPT-family models (Qwen3 etc.) under ``--export-default-te-spec``. Importer now prefers per-context keys ``fused_input_layernorm`` / ``fused_pre_mlp_layernorm`` (fallback ``fused_norm`` for Nemotron-H backward compatibility); ``mcore_qwen.py`` provides the new rules. Without this fix, post-prune MMLU sat at chance.

Expand Down
12 changes: 8 additions & 4 deletions modelopt/onnx/autocast/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,10 @@ def convert_to_mixed_precision(
graph_sanitizer.sanitize()
model = graph_sanitizer.model

# Setup internal mappings
model = onnx_utils.infer_types(model, use_standalone_type_inference)
# Setup internal mappings. Use strict shape inference so an op ONNX cannot resolve surfaces
# as an exception (triggering infer_types' standalone type-inference fallback) instead of
# silently leaving tensors untyped, which would break later type lookups.
model = onnx_utils.infer_types(model, use_standalone_type_inference, strict_mode=True)
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)

# Automatically add 'trt' to list of providers if custom ops are detected
Expand Down Expand Up @@ -267,8 +269,10 @@ def convert_to_f16(
sanitizer.convert_fp64_to_fp32()
model = sanitizer.model

# Setup internal mappings
model = onnx_utils.infer_types(model, use_standalone_type_inference)
# Setup internal mappings. Use strict shape inference so an op ONNX cannot resolve surfaces
# as an exception (triggering infer_types' standalone type-inference fallback) instead of
# silently leaving tensors untyped, which would break later type lookups.
model = onnx_utils.infer_types(model, use_standalone_type_inference, strict_mode=True)
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)

precision_converter = PrecisionConverter(
Expand Down
141 changes: 132 additions & 9 deletions modelopt/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,19 +1205,32 @@ def infer_types(
When use_standalone_type_inference is True, uses a standalone type inference implementation
that only infers types. Otherwise, uses ONNX's infer_shapes which infers both types and shapes.

ONNX's ``infer_shapes`` can fail on weakly-typed models -- with ``strict_mode=True`` it raises
on an op it cannot resolve (e.g. a ``TopK`` whose axis it resolves to a stale dimension)
instead of silently leaving that node's outputs untyped. On any shape-inference failure this
falls back to the standalone type inferencer, which derives types from operator schemas
regardless of shapes, so downstream type lookups (e.g. in AutoCast) do not fail. Callers that
need a fully typed graph should pass ``strict_mode=True`` so incomplete inference surfaces as
an exception that triggers the fallback.

Args:
model: ONNX model to infer types/shapes for.
use_standalone_type_inference: If True, use standalone type inference (_infer_types_only).
If False, use ONNX's shape inference (infer_shapes).
**kwargs: Additional arguments passed to infer_shapes when not using standalone type inference.
**kwargs: Additional arguments passed to infer_shapes when not using standalone type
inference (e.g. ``strict_mode``, ``check_type``, ``data_prop``).

Returns:
onnx.ModelProto: Model with inferred types (and shapes if not using standalone type inference).
"""
if use_standalone_type_inference:
return _infer_types_only(model)
else:

try:
return infer_shapes(model, **kwargs)
except Exception as e:
logger.debug("ONNX shape inference failed (%s); using standalone type inference.", e)
return _infer_types_only(model)


def onnx_type_str_to_enum(dtype: str) -> int:
Expand Down Expand Up @@ -1862,14 +1875,119 @@ def change_casts_to_fp16(model: onnx.ModelProto, target_op_types: list[str]) ->
return model


def _reconcile_stale_output_shapes(model: onnx.ModelProto) -> int:
"""Re-derive stale ``graph.output`` shapes from the operator graph.
Comment on lines +1878 to +1879
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ajrasane There's similar logic in PrecisionConverter::_clear_types_and_shapes_recursive. In which we don't check what's stale and what's not, just clear everything since we're going to infer types (and optionally shapes) from the graph.
Perhaps we can extract it to utils + add the fallback to standalone type inference on shape inference expcetion?

def _clear_types_and_shapes_recursive(
self, graph: onnx.GraphProto, is_subgraph: bool = False
) -> None:
"""Recursively clear type/shape information for a graph and all its subgraphs.
If use_standalone_type_inference is True, we clear only types, not shapes.
For subgraphs, input types/shapes are cleared, so that the input types/shapes are propagated
from the main graph.
Args:
graph: The ONNX graph to clear types and shapes for.
is_subgraph: Whether this is a subgraph (True) or the main graph (False).
"""
def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> None:
logger.debug(
f"Clearing types/shapes in {'subgraph' if is_sub else 'main graph'}: {g.name}"
)
# Clear type/shape information for inputs (only for subgraphs, not main graph inputs)
if is_sub:
for inp in g.input:
if inp.type.HasField("tensor_type"):
inp.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
if not self.use_standalone_type_inference:
for idx, d in enumerate(inp.type.tensor_type.shape.dim):
if d.dim_value:
inp.type.tensor_type.shape.dim[idx].dim_param = "unk"
if is_sub:
# Identify which tensors are produced by nodes in this subgraph
subgraph_outputs = set()
for node in g.node:
subgraph_outputs.update(node.output)
# Clear value_info only for intermediates produced by nodes in this subgraph
for vi in g.value_info:
if vi.name in subgraph_outputs:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
if not self.use_standalone_type_inference:
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
if d.dim_value:
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
else:
for vi in g.value_info:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
if d.dim_value:
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
# Clear outputs for both main graph and subgraphs
for out in g.output:
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
if not self.use_standalone_type_inference:
for idx, d in enumerate(out.type.tensor_type.shape.dim):
if d.dim_value:
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
utils.walk_subgraphs_recursive(graph, _clear_callback, is_subgraph=is_subgraph)


Weakly-typed models (e.g. exported from TensorFlow) can declare an output rank
that conflicts with the graph topology -- most commonly a leftover rank-0
(scalar) annotation on a tensor that is really rank-2+. Such a stale rank poisons
downstream shape inference: ORT fails while augmenting the model for INT8
calibration (``axis must be in [-rank, rank-1]. Input rank was 0``), and
``onnx.shape_inference`` with ``strict_mode=True`` raises ``Inferred shape and
existing shape differ in rank`` during fp16 autocast.

Strategy: snapshot the declared output shapes, clear them, and re-derive them from
the operator graph -- preferring ORT's symbolic shape inference (it resolves ops
such as ``TopK`` that ONNX's static inference gives up on) and falling back to the
size-aware ``infer_shapes`` wrapper. A declared shape is only overwritten when it is
genuinely stale -- a rank mismatch (the rank-0-vs-rank-N bug) or a conflicting
concrete dimension. Outputs that merely differ in symbolic ``dim_param`` names (e.g.
a re-derived ``unk__0`` vs a declared ``batch``) keep their original declaration, so
healthy models -- including dynamic batch/sequence dims -- are left untouched. A
graph output is never left without a shape (``onnx.checker`` requires the field).

Args:
model: Loaded in-memory onnx ModelProto, ideally with ``value_info`` already
cleared so re-inference derives shapes from the operator graph.

Returns:
Number of graph outputs whose shape was changed.
"""
outputs = model.graph.output
if not outputs:
return 0

def _sig(shape: onnx.TensorShapeProto | None) -> bytes | None:
return None if shape is None else shape.SerializeToString()

def _is_stale(declared: onnx.TensorShapeProto | None, inferred: onnx.TensorShapeProto | None):
# Only treat a declaration as stale when inference contradicts it: a different
# rank, or a concrete dim that disagrees with an inferred concrete dim. A missing
# declaration is "stale" (adopt whatever was inferred); a missing inference is not
# (keep the declaration). Symbolic dim_param renames are intentionally ignored.
if inferred is None:
return False
if declared is None:
return True
if len(declared.dim) != len(inferred.dim):
return True
return any(
d.HasField("dim_value") and i.HasField("dim_value") and d.dim_value != i.dim_value
for d, i in zip(declared.dim, inferred.dim)
)

# Snapshot declared shapes, then clear them so re-inference starts from the
# topology instead of being biased by the stale annotations.
declared: dict[str, onnx.TensorShapeProto | None] = {}
for o in outputs:
tt = o.type.tensor_type
if tt.HasField("shape"):
snapshot = onnx.TensorShapeProto()
snapshot.CopyFrom(tt.shape)
declared[o.name] = snapshot
else:
declared[o.name] = None
tt.ClearField("shape")

# Re-derive output shapes from the cleared model (neither inference call mutates it):
# prefer ORT symbolic shape inference, then fall back to the size-aware infer_shapes
# wrapper if it is unavailable or yields nothing.
inferred: dict[str, onnx.TensorShapeProto] = {}
try:
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference

inferred_model = SymbolicShapeInference.infer_shapes(model, auto_merge=True)
inferred = {
o.name: o.type.tensor_type.shape
for o in inferred_model.graph.output
if o.type.tensor_type.HasField("shape")
}
except Exception as e:
logger.debug("Symbolic shape inference unavailable/failed: %s", e)
if not inferred:
try:
inferred_model = infer_shapes(model, strict_mode=False, data_prop=True)
inferred = {
o.name: o.type.tensor_type.shape
for o in inferred_model.graph.output
if o.type.tensor_type.HasField("shape")
}
except Exception as e:
logger.debug("ONNX shape inference for output reconciliation failed: %s", e)

changed = 0
for o in outputs:
decl = declared[o.name]
# Adopt the inferred shape only when the declaration is genuinely stale; otherwise
# restore the declared shape (never leaving a graph output shapeless).
new_shape = inferred.get(o.name) if _is_stale(decl, inferred.get(o.name)) else decl
if new_shape is not None:
o.type.tensor_type.shape.CopyFrom(new_shape)
if _sig(decl) != _sig(new_shape):
changed += 1
return changed


def clear_stale_value_info(model: onnx.ModelProto) -> int:
"""Clear stale type metadata that would otherwise trip ORT's type checker.
"""Clear stale type/shape metadata that would otherwise trip ORT's type checker.

Walks every ``Cast`` node and forces the ``elem_type`` of any
``graph.output`` entry produced by that Cast to match the Cast's ``to``
attribute (the spec-defined contract for a Cast's output dtype). Then
clears ``value_info`` wholesale so ORT/shape-inference re-derives
intermediate-tensor types from the operator graph during session setup.
Walks every ``Cast`` node and forces the ``elem_type`` of any ``graph.output``
entry produced by that Cast to match the Cast's ``to`` attribute (the spec-defined
contract for a Cast's output dtype). Clears ``value_info`` wholesale so
ORT/shape-inference re-derives intermediate-tensor types from the operator graph
during session setup. Finally, reconciles stale ``graph.output`` *shapes* (e.g. a
leftover rank-0 scalar on a tensor that is really rank-2+) which would otherwise
propagate a wrong rank into downstream shape inference.

Args:
model: Loaded in-memory onnx ModelProto.
Expand All @@ -1893,4 +2011,9 @@ def clear_stale_value_info(model: onnx.ModelProto) -> int:
n_vi = len(model.graph.value_info)
if n_vi:
del model.graph.value_info[:]
return fixed_outputs + n_vi

# Reconcile output shapes after value_info is cleared so the re-inference inside
# the helper derives shapes cleanly from the operator graph.
fixed_shapes = _reconcile_stale_output_shapes(model)

return fixed_outputs + fixed_shapes + n_vi
34 changes: 34 additions & 0 deletions tests/unit/onnx/autocast/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import modelopt.onnx.utils as onnx_utils
from modelopt.onnx.autocast import convert_to_mixed_precision
from modelopt.onnx.autocast.__main__ import get_parser, main
from modelopt.onnx.autocast.convert import convert_to_f16
from modelopt.onnx.autocast.logging_config import configure_logging

configure_logging("DEBUG")
Expand Down Expand Up @@ -321,3 +322,36 @@ def test_opset_parser_argument():
# Test parsing without opset (should be None)
args = parser.parse_args(["--onnx_path", "test.onnx"])
assert args.opset is None


@pytest.fixture
def weakly_typed_topk_model():
# TopK k (5) exceeds the static axis size (3), so ONNX shape inference cannot resolve it.
x = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1, 3])
out = onnx.helper.make_tensor_value_info("out", onnx.TensorProto.FLOAT, [1, 5])
weight = onnx.numpy_helper.from_array(np.ones((1, 3), dtype=np.float32), name="weight")
k = onnx.numpy_helper.from_array(np.array([5], dtype=np.int64), name="k")
nodes = [
onnx.helper.make_node("Add", ["X", "weight"], ["a"], name="add"),
onnx.helper.make_node("TopK", ["a", "k"], ["vals", "inds"], axis=1, name="topk"),
onnx.helper.make_node("Cast", ["inds"], ["out"], to=onnx.TensorProto.FLOAT, name="cast"),
]
graph = onnx.helper.make_graph(nodes, "weakly_typed_topk", [x], [out], [weight, k])
model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 17)])
model.ir_version = 10
return model


def test_convert_to_f16_falls_back_on_unresolvable_op(weakly_typed_topk_model):
"""A weakly-typed graph ONNX shape inference cannot resolve must still convert.

The TopK ``k`` (5) exceeds the static axis size (3), so ONNX shape inference raises
in strict mode (the same failure class as NVBug 6058907). ``convert_to_f16`` -- the
path used by INT8 + ``--high_precision_dtype fp16`` quantization -- runs infer_types
in strict mode and must fall back to standalone type inference instead of crashing,
typing the TopK's int64 indices output that feeds the downstream Cast.
"""
converted_model = convert_to_f16(weakly_typed_topk_model, keep_io_types=True)

onnx.checker.check_model(converted_model)
assert any(n.op_type == "TopK" for n in converted_model.graph.node)
92 changes: 92 additions & 0 deletions tests/unit/onnx/test_onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
clear_stale_value_info,
get_input_names_from_bytes,
get_output_names_from_bytes,
infer_types,
randomize_weights_onnx_bytes,
remove_node_training_mode,
remove_weights_data,
Expand Down Expand Up @@ -364,3 +365,94 @@ def test_clear_stale_value_info(output_elem_type, with_value_info, expected_coun
assert model.graph.output[0].type.tensor_type.elem_type == onnx.TensorProto.FLOAT
assert len(model.graph.value_info) == 0
assert count == expected_count


def _make_matmul_model(output_shape):
"""Build an X[3,4] @ W[4,5] -> Y model with Y declared using ``output_shape``."""
weights = make_tensor("W", onnx.TensorProto.FLOAT, [4, 5], np.zeros(20, dtype=np.float32))
nodes = [make_node("MatMul", ["X", "W"], ["Y"], name="matmul")]
inputs = [make_tensor_value_info("X", onnx.TensorProto.FLOAT, [3, 4])]
outputs = [make_tensor_value_info("Y", onnx.TensorProto.FLOAT, output_shape)]
graph = make_graph(nodes, "matmul_graph", inputs, outputs, initializer=[weights])
return make_model(graph, producer_name="modelopt test", opset_imports=[make_opsetid("", 17)])


def test_clear_stale_value_info_reconciles_stale_rank0_output():
# Y is really rank-2 [3, 5] but the model declares it as a rank-0 scalar (stale
# metadata typical of weakly-typed exports). This is the rank-(N)-vs-(0) class of
# conflict that crashes downstream shape inference (NVBug 6058907).
model = _make_matmul_model(output_shape=[])
assert len(model.graph.output[0].type.tensor_type.shape.dim) == 0 # stale rank-0

clear_stale_value_info(model)

out_type = model.graph.output[0].type.tensor_type
assert out_type.HasField("shape") # shape field must remain (onnx.checker requires it)
assert [d.dim_value for d in out_type.shape.dim] == [3, 5] # reconciled to the real shape
onnx.checker.check_model(model)


def test_clear_stale_value_info_preserves_valid_output_shape():
# A correct output shape must be left untouched (no-op for healthy models).
model = _make_matmul_model(output_shape=[3, 5])

clear_stale_value_info(model)

out_type = model.graph.output[0].type.tensor_type
assert [d.dim_value for d in out_type.shape.dim] == [3, 5]


def _make_dynamic_dim_model():
"""Build an X[batch,4] -> Relu -> Y[my_batch,4] model (output declares a different dim_param)."""
nodes = [make_node("Relu", ["X"], ["Y"], name="relu")]
inputs = [make_tensor_value_info("X", onnx.TensorProto.FLOAT, ["batch", 4])]
outputs = [make_tensor_value_info("Y", onnx.TensorProto.FLOAT, ["my_batch", 4])]
graph = make_graph(nodes, "dyn_graph", inputs, outputs)
return make_model(graph, producer_name="modelopt test", opset_imports=[make_opsetid("", 17)])


def test_clear_stale_value_info_preserves_dynamic_dim_names():
# A healthy output with a named dynamic dim must not be rewritten just because
# symbolic shape inference re-derives a different dim_param. Y is declared with
# "my_batch" while the graph would infer "batch" from the input: same rank, no
# concrete-dim conflict, so the declaration (incl. its dim_param) must be preserved.
model = _make_dynamic_dim_model()

clear_stale_value_info(model)

out_dims = model.graph.output[0].type.tensor_type.shape.dim
assert [d.dim_param or d.dim_value for d in out_dims] == ["my_batch", 4]
onnx.checker.check_model(model)


def _make_topk_overflow_model():
"""Build a model whose TopK ``k`` (5) exceeds the static axis dim (3).

ONNX shape inference raises "Axis has less than the requested k elements" on this
model (the same failure class seen in NVBug 6058907), while standalone type
inference can still derive the output types (values float, indices int64).
"""
k = make_tensor("k", onnx.TensorProto.INT64, [1], [5])
nodes = [
make_node("TopK", ["X", "k"], ["vals", "inds"], axis=1, name="topk"),
make_node("Cast", ["inds"], ["out"], to=onnx.TensorProto.FLOAT, name="cast_inds"),
]
inputs = [make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1, 3])]
outputs = [make_tensor_value_info("out", onnx.TensorProto.FLOAT, [1, 5])]
graph = make_graph(nodes, "topk_overflow", inputs, outputs, initializer=[k])
return make_model(graph, producer_name="modelopt test", opset_imports=[make_opsetid("", 17)])


def test_infer_types_falls_back_to_standalone_when_onnx_fails():
# ONNX shape inference cannot resolve this model's TopK. With strict_mode=True it raises
# (instead of silently leaving the TopK outputs untyped), so infer_types catches the
# error and falls back to standalone type inference, which still types every tensor.
model = _make_topk_overflow_model()

inferred = infer_types(model, strict_mode=True)

value_info_types = {vi.name: vi.type.tensor_type.elem_type for vi in inferred.graph.value_info}
output_types = {o.name: o.type.tensor_type.elem_type for o in inferred.graph.output}
assert value_info_types.get("vals") == onnx.TensorProto.FLOAT
assert value_info_types.get("inds") == onnx.TensorProto.INT64 # TopK indices
assert output_types.get("out") == onnx.TensorProto.FLOAT
Loading