Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Changelog

- In Megatron-Core only do EP amax sync for routed expert weights if ``sync_expert_weight_amax=True``. Previously EP amax sync would sync routed expert weights across EP ranks even when ``sync_expert_weight_amax`` was False.
- Fix Megatron-Core HF importer to load fused ``TELayerNormColumnParallelLinear.layer_norm_weight`` from HF for GPT-family models (Qwen3 etc.) under ``--export-default-te-spec``. Importer now prefers per-context keys ``fused_input_layernorm`` / ``fused_pre_mlp_layernorm`` (fallback ``fused_norm`` for Nemotron-H backward compatibility); ``mcore_qwen.py`` provides the new rules. Without this fix, post-prune MMLU sat at chance.
- Fix ONNX FP16/BF16 conversion (``--high_precision_dtype fp16``) producing inconsistent tensor types on models with control-flow ``If``/``Loop``/``Scan`` subgraphs (e.g. a ``Gemm`` reading an outer-scope activation alongside converted weights, or ``Resize`` ``scales`` that must stay FP32). Subgraph nodes now only run in low precision when all their float inputs are subgraph initializers; outer-scope captures and low-to-high-precision boundaries inside subgraphs are reconciled with ``Cast`` nodes, and ``Constant`` folding refreshes the constant's ``value_info`` so strongly-typed parsers (TensorRT) no longer reject the model.

**Deprecations**

Expand Down
284 changes: 257 additions & 27 deletions modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class InitializerConsumerTracker:

ONNX_TYPES = [t.onnx_type for t in PRECISION_MAP.values()]

# Reverse mapping from ONNX tensor type to its PrecisionTypes entry (e.g. TensorProto.FLOAT -> fp32).
ONNX_TYPE_TO_PRECISION = {t.onnx_type: t for t in PRECISION_MAP.values()}

OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Upsample", "NonMaxSuppression", "Celu"]

# Mapping of op types to indices of inputs that should not be converted to low precision.
Expand Down Expand Up @@ -748,11 +751,20 @@ def _convert_initializers(
def _convert_initializers_recursive(
self, low_precision_nodes: list[str], high_precision_nodes: list[str]
) -> None:
"""Convert initializers in main graph and all subgraphs to appropriate precision.
"""Convert initializers in the main graph and reconcile precision inside all subgraphs.

For the main graph, uses consumer tracking to determine each initializer's precision
(see :meth:`_convert_initializers`).

For the main graph, uses sophisticated consumer tracking to determine precision.
For subgraphs, inherits precision from the parent control flow node and converts
all initializers to that precision (no runtime casts).
Control-flow subgraphs (e.g. If/Loop/Scan bodies) do not get the activation-bracketing casts
that the main graph uses, so a subgraph node is only converted to low precision when *all* of
its float inputs are subgraph initializers that may be in low precision (i.e. the node consumes
no float activation/outer-scope tensor and none of its inputs must stay high precision per the
ONNX spec, such as ``Resize`` ``scales``). Such a node's initializers are converted to low
precision; every other subgraph node and its initializers are kept in high precision so that
each node's inputs share a single precision. Float tensors captured from the enclosing scope,
and the outputs of any low-precision subgraph node feeding a high-precision one, are reconciled
with a ``Cast`` inserted inside the subgraph.

Args:
low_precision_nodes: List of node names in main graph that are low precision.
Expand All @@ -761,45 +773,263 @@ def _convert_initializers_recursive(
# Convert main graph initializers with full consumer tracking
self._convert_initializers(low_precision_nodes, high_precision_nodes)

# Convert subgraph initializers - walk all subgraphs and convert based on parent node precision
# Precompute, for each main-graph activation, the (raw) precision it has after main-graph
# conversion: a subgraph that captures it sees this raw precision, because the main-graph
# cast-up only rewires main-graph consumers (not subgraph captures).
low_precision_nodes_set = set(low_precision_nodes)
main_producer_precision: dict[str, int] = {}
for node in self.model.graph.node:
# Control-flow nodes (If/Loop/Scan) execute a subgraph that is kept in high precision, so
# their outputs are high precision regardless of the node's low/high classification.
is_control_flow = any(
attr.type in (onnx.AttributeProto.GRAPH, onnx.AttributeProto.GRAPHS)
for attr in node.attribute
)
producer_type = (
self.low_precision_type.onnx_type
if node.name in low_precision_nodes_set and not is_control_flow
else self.high_precision_type.onnx_type
)
for output_name in node.output:
main_producer_precision[output_name] = producer_type

def _capture_type(name: str) -> int | None:
"""Return the float precision (ONNX type) of an outer-scope tensor seen in a subgraph.

Float-ness is read from the (pre-conversion) type info, since it is invariant under
precision conversion; the precision is then taken from the producing main-graph node.
Network inputs use their declared type in ``value_info_map`` (already updated in place
when ``keep_io_types`` is False). Returns None for non-float or unknown tensors.
"""
if name in self.initializer_map:
base_type = self.initializer_map[name].data_type
elif name in self.value_info_map:
base_type = self.value_info_map[name].type.tensor_type.elem_type
else:
return None
if base_type not in ONNX_TYPES:
return None
return main_producer_precision.get(name, base_type)

def _convert_subgraph_callback(
graph: onnx.GraphProto, parent: onnx.NodeProto, is_subgraph: bool
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Worth calling out in the docstring (and ideally CHANGELOG) that this is a behavioral change for subgraph initializer precision: previously a low-precision parent caused all subgraph float initializers to be converted, now conversion is gated on the consuming node having no float activation/outer-scope/unknown-typed inputs. Anyone relying on the old behavior to get fp16 weights inside an If branch with an outer-scope activation will see those weights stay fp32 after this PR (which is the correct fix, but the change in observable output deserves a note).

) -> None:
if not is_subgraph or parent is None:
return
parent_is_low_precision = parent.name in low_precision_nodes_set
self._convert_subgraph_precision(graph, parent_is_low_precision, _capture_type)

# Inherit precision from parent control flow node
target_type = (
self.low_precision_type
if parent.name in low_precision_nodes_set
else self.high_precision_type
utils.walk_subgraphs_recursive(self.model.graph, _convert_subgraph_callback)

def _convert_subgraph_precision(
self, subgraph: onnx.GraphProto, parent_is_low_precision: bool, capture_type_fn
) -> None:
"""Convert a single control-flow subgraph to a consistent precision.

See :meth:`_convert_initializers_recursive` for the conversion policy.

Args:
subgraph: The subgraph (e.g. an If branch or a Loop/Scan body) to convert.
parent_is_low_precision: Whether the parent control-flow node is low precision.
capture_type_fn: Maps an outer-scope tensor name to its float ONNX type (or None).
"""
target = self.low_precision_type
high = self.high_precision_type

local_inits = {init.name: init for init in subgraph.initializer}
local_produced = {out for node in subgraph.node for out in node.output}
formal_inputs = {inp.name for inp in subgraph.input}
# Original (pre-conversion) element types of subgraph-local tensors. Whether a tensor is
# float is invariant under fp32<->fp16 conversion, so this is a reliable "is this float?"
# source that prevents casting non-float tensors (e.g. int axes/indices/shapes).
local_elem_type = {vi.name: vi.type.tensor_type.elem_type for vi in subgraph.value_info}
local_elem_type.update({vi.name: vi.type.tensor_type.elem_type for vi in subgraph.output})

# Map of tensor name -> list of (node, input index) consumers within this subgraph.
consumers: dict[str, list[InputIndexTracker]] = defaultdict(list)
for node in subgraph.node:
for idx, input_name in enumerate(node.input):
if input_name:
consumers[input_name].append(InputIndexTracker(node=node, node_index=idx))

def _is_low_precision_eligible_init(node: onnx.NodeProto, input_name: str) -> bool:
init = local_inits.get(input_name)
return (
init is not None
and init.data_type in ONNX_TYPES
and not self._should_skip_low_precision_input_conversion(node, input_name)
)

# Convert all float initializers to target precision
for init in graph.initializer:
if init.data_type not in ONNX_TYPES or init.data_type == target_type.onnx_type:
continue
def _known_elem_type(input_name: str) -> int | None:
"""Best-effort (pre-conversion) element type of a tensor visible here, or None."""
if input_name in local_inits:
return local_inits[input_name].data_type
if input_name in local_elem_type:
return local_elem_type[input_name]
if input_name in self.value_info_map:
return self.value_info_map[input_name].type.tensor_type.elem_type
if input_name in self.initializer_map:
return self.initializer_map[input_name].data_type
return None

from_type = (
self.high_precision_type
if init.data_type == self.high_precision_type.onnx_type
else self.low_precision_type
if init.data_type == self.low_precision_type.onnx_type
else None
# 1. Classify each subgraph node. A node is converted to low precision only if it has at least
# one low-precision-eligible float initializer input and every other input is a tensor we
# know is non-float (e.g. int axes/indices). Any float activation/outer-scope input (or an
# input of unknown type) keeps the node in high precision, since no bracketing casts are
# inserted around subgraph activations.
node_is_low: dict[str, bool] = {}
for node in subgraph.node:
low = parent_is_low_precision and (
node.op_type not in self.op_types_not_supported_in_low_precision
)
has_low_precision_init = False
if low:
for input_name in node.input:
if not input_name:
continue
if _is_low_precision_eligible_init(node, input_name):
has_low_precision_init = True
continue
elem_type = _known_elem_type(input_name)
if elem_type is None or elem_type in ONNX_TYPES:
# Float or unknown non-initializer input: keep the node in high precision.
low = False
break
node_is_low[node.name] = low and has_low_precision_init

# 2. Convert the initializers consumed by low-precision nodes to low precision. An initializer
# shared by low- and high-precision nodes is duplicated so each consumer keeps one precision.
for init in list(subgraph.initializer):
if init.data_type not in ONNX_TYPES:
continue
from_type = ONNX_TYPE_TO_PRECISION.get(init.data_type)
if from_type is None:
continue
low_consumers: list[InputIndexTracker] = []
high_consumers: list[InputIndexTracker] = []
for c in consumers.get(init.name, []):
if node_is_low.get(c.node.name) and _is_low_precision_eligible_init(
c.node, init.name
):
low_consumers.append(c)
else:
high_consumers.append(c)

if not low_consumers:
# Keep in high precision (covers Resize-scales-style inputs and unused initializers).
if init.data_type != high.onnx_type:
init.CopyFrom(self._convert_initializer_data(init, from_type, high))
elif not high_consumers:
# Convert the single-precision initializer in place.
if init.data_type != target.onnx_type:
init.CopyFrom(self._convert_initializer_data(init, from_type, target))
else:
# Shared: keep the original high precision and add a low-precision duplicate.
low_init = self._convert_initializer_data(init, from_type, target)
low_init.name = f"{init.name}_{target.str_short}"
subgraph.initializer.extend([low_init])
for consumer in low_consumers:
consumer.node.input[consumer.node_index] = low_init.name
if init.data_type != high.onnx_type:
init.CopyFrom(self._convert_initializer_data(init, from_type, high))

# 3. Reconcile float tensors whose precision does not match the consuming node: outer-scope
# captures and the outputs of low-precision nodes feeding high-precision ones. Casts are
# collected first and inserted afterwards to avoid mutating the node list while iterating.
local_produced_low = {
out for node in subgraph.node if node_is_low.get(node.name) for out in node.output
}

def _current_float_type(input_name: str) -> int | None:
"""Float precision (ONNX type) of a subgraph input tensor, or None if it is non-float.

Non-float tensors (int indices/axes/shapes, bool conditions) must never be cast.
"""
if input_name in local_produced:
elem_type = local_elem_type.get(input_name)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

_current_float_type returns high.onnx_type unconditionally for any float formal_inputs. For Loop body state vars / Scan scan_inputs, the actual precision is whatever the main graph hands the parent Loop/Scan node — which can be low precision. The new node-classification rule mostly papers over this (a float formal input demotes the node to high precision), but it would be good to add at least one Loop-body regression test, since none of the new tests exercise Loop/Scan.

if elem_type is not None and elem_type not in ONNX_TYPES:
return None # known non-float subgraph activation
if input_name in local_produced_low:
return target.onnx_type # output of a converted low-precision node
return high.onnx_type if elem_type in ONNX_TYPES else None
if input_name in formal_inputs:
elem_type = local_elem_type.get(input_name)
return high.onnx_type if elem_type in ONNX_TYPES else None
return capture_type_fn(input_name) # outer-scope capture (None if non-float)

# (tensor name, target onnx type) -> (cast output name, producer node name or None)
casts_to_insert: dict[tuple[str, int], tuple[str, str | None]] = {}
rewrites: list[tuple[InputIndexTracker, str]] = []
# Float outer-scope captures and the (current) precision they have in the enclosing scope.
captured_types: dict[str, int] = {}
for node in subgraph.node:
node_low = node_is_low.get(node.name, False)
for idx, input_name in enumerate(node.input):
if not input_name or input_name in local_inits:
continue
current = _current_float_type(input_name)
if current is None:
continue # non-float tensor: never cast
if input_name not in local_produced and input_name not in formal_inputs:
captured_types[input_name] = current
desired = (
target.onnx_type
if node_low
and not self._should_skip_low_precision_input_conversion(node, input_name)
else high.onnx_type
)
if current == desired:
continue

if from_type is None:
logger.debug(
f"Skipping subgraph initializer {init.name} with unsupported type {init.data_type}"
key = (input_name, desired)
if key not in casts_to_insert:
short = ONNX_TYPE_TO_PRECISION[desired].str_short
producer = (
input_name
if input_name in local_produced
else None # outer-scope capture (produced outside this subgraph)
)
continue
casts_to_insert[key] = (
f"{input_name}_subgraph_cast_to_{short}",
producer,
)
rewrites.append(
(InputIndexTracker(node=node, node_index=idx), casts_to_insert[key][0])
)

new_init = self._convert_initializer_data(init, from_type, target_type)
init.CopyFrom(new_init)
# Sync any preserved outer-scope value_info inside the subgraph with the capture's current
# main-graph precision. Otherwise a stale type (e.g. fp32 for a tensor the main graph now
# produces in fp16) makes strongly-typed parsers (and ORT) reject the If subgraph.
for vi in subgraph.value_info:
if vi.name in captured_types and vi.type.HasField("tensor_type"):
vi.type.tensor_type.elem_type = captured_types[vi.name]

utils.walk_subgraphs_recursive(self.model.graph, _convert_subgraph_callback)
if not casts_to_insert:
return

for tracker, cast_output in rewrites:
tracker.node.input[tracker.node_index] = cast_output

# Build the cast nodes and re-assemble the subgraph node list so each cast appears after its
# producer (captures, produced outside the subgraph, are placed at the front).
cast_nodes_after_producer: dict[str, list[onnx.NodeProto]] = defaultdict(list)
leading_casts: list[onnx.NodeProto] = []
for (input_name, desired), (cast_output, producer) in casts_to_insert.items():
cast_node = helper.make_node(
"Cast", inputs=[input_name], outputs=[cast_output], to=desired, name=cast_output
)
if producer is not None:
producer_node_name = next(n.name for n in subgraph.node if input_name in n.output)
cast_nodes_after_producer[producer_node_name].append(cast_node)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Cheap perf cleanup: this next(...) re-scans subgraph.node once per inserted cast. You already have local_produced above and the new node list iteration; build a producer_by_output: dict[str, str] once before this loop (or include the producer node name when populating casts_to_insert) to avoid the O(casts × nodes) scan.

else:
leading_casts.append(cast_node)

new_nodes = list(leading_casts)
for node in subgraph.node:
new_nodes.append(node)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Consider asserting / enforcing topological order before reassembling. The cast is placed right after its producer, but a downstream consumer of this cast that happens to appear earlier in subgraph.node (possible if the subgraph was hand-built and not toposorted) would now see the rewritten input name without the Cast having been emitted yet. A gs.import_onnx/toposort pass on the rebuilt subgraph (or an explicit assertion that producer indices < consumer indices) would harden this.

new_nodes.extend(cast_nodes_after_producer.get(node.name, []))
del subgraph.node[:]
subgraph.node.extend(new_nodes)

def _convert_initializer_data(
self,
Expand Down
8 changes: 8 additions & 0 deletions modelopt/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,14 @@ def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
assert len(cast_producers) == 1 and cast_producers[0].op_type == "Constant"
constant_producer = cast_producers[0]
_convert_constant_values(constant_producer, node)
# Folding changes the Constant output's element type; refresh its value_info so
# downstream consumers (and strongly-typed parsers) don't see a stale type that
# conflicts with the now-converted constant value.
cast_to_type = get_cast_to_type(node)
for vi in onnx_model.graph.value_info:
if vi.name == constant_producer.output[0]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

This only updates the first value_info entry that matches and silently does nothing if no entry exists. Constants frequently lack a value_info entry until type inference runs. Fine for the regression test (which immediately re-runs strict shape inference), but if a caller of remove_redundant_casts doesn't follow up with type inference, downstream consumers may still see no/stale type for the folded constant. Consider creating a value_info entry when one is missing, or document that callers must re-infer types after folding.

vi.type.tensor_type.elem_type = cast_to_type
break
_bypass_cast_node(onnx_model, node)
logger.debug(f"Found foldable Constant->Cast pattern, removing {node.name}")

Expand Down
Loading