Skip to content

[6058841] Consistent types on If/Loop/Scan subgraphs during FP16/BF16 conversion#1628

Draft
ajrasane wants to merge 1 commit into
mainfrom
arasane/nvbug-6058841-onnx-fp16-if-subgraph
Draft

[6058841] Consistent types on If/Loop/Scan subgraphs during FP16/BF16 conversion#1628
ajrasane wants to merge 1 commit into
mainfrom
arasane/nvbug-6058841-onnx-fp16-if-subgraph

Conversation

@ajrasane
Copy link
Copy Markdown
Contributor

@ajrasane ajrasane commented Jun 4, 2026

What does this PR do?

Type of change: Bug fix

Fixes [6058841] — python -m modelopt.onnx.quantization --high_precision_dtype fp16 crashed with "Inconsistent type on If node" on models containing control-flow If/Loop/Scan subgraphs.

Root cause. The FP16/BF16 PrecisionConverter blindly converted every subgraph initializer to the parent control-flow node's precision, without the activation-bracketing casts it uses in the main graph. This left inconsistent tensor types that crash ONNX shape inference / TensorRT strongly-typed parsing:

  • A Gemm inside an If branch reading an outer-scope activation (fp32) ended up with fp16 weights → B has inconsistent type tensor(float16).
  • Resize scales (which must stay fp32 per the ONNX spec) was converted to fp16 → ParseData type mismatch ... Expected:float Actual:float16.

Fix.

  • A subgraph node is converted to low precision only when all of its float inputs are subgraph initializers eligible for low precision. Any node consuming a float activation / outer-scope tensor, or an input that must stay high precision (e.g. Resize scales), stays high precision — so each node's inputs share a single precision.
  • Float outer-scope captures and low→high precision boundaries inside subgraphs are reconciled with Cast nodes; a captured tensor's preserved subgraph value_info is synced to its real main-graph precision; control-flow node outputs are treated as high precision (their bodies are kept high precision).
  • ConstantCast folding in remove_redundant_casts now refreshes the constant's value_info so a same-type-constrained consumer (e.g. Greater) isn't left with a stale, conflicting type. (Pre-existing main-graph bug surfaced once the If models completed conversion.)

Usage

python -m modelopt.onnx.quantization \
  --quantize_mode int8 --high_precision_dtype fp16 \
  --onnx_path model.onnx \
  --output_path model_strongType_int8+fp16.onnx

Testing

Validated on both reported models:

graph pattern convert strict infer_shapes(check_type=True) ORT load numerics vs FP32
If branch with Gemm reading an outer-scope input bit-exact (Gemms kept fp32)
If branch with Resize (scales initializer) max abs err 9e-5 (fp16 tol)

Also verified with keep_io_types=False. Added 5 regression tests in tests/unit/onnx/autocast/test_precisionconverter.py (Gemm-with-outer-input, Resize-scales, chained-If capture, Constant-fold value_info) that fail without the fix and pass with it. Full tests/unit/onnx/autocast/ suite passes (214). pre-commit clean.

Before your PR is "Ready for review"

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ✅
  • Did you get Claude approval on this PR?: ❌ (draft)

Additional Information

Fixes bug [6058841]. Draft pending final review.

🤖 Generated with Claude Code

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Jun 4, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 4, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 5eecfb2e-1cbc-4f8b-9617-18f087aae561

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch arasane/nvbug-6058841-onnx-fp16-if-subgraph

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 4, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1628/

Built to branch gh-pages at 2026-06-04 19:28 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 4, 2026

Codecov Report

❌ Patch coverage is 78.91156% with 31 lines in your changes missing coverage. Please review.
✅ Project coverage is 77.39%. Comparing base (6b73e93) to head (80a7500).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/onnx/autocast/precisionconverter.py 78.16% 31 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff            @@
##             main    #1628    +/-   ##
========================================
  Coverage   77.38%   77.39%            
========================================
  Files         482      482            
  Lines       52960    53096   +136     
========================================
+ Hits        40984    41094   +110     
- Misses      11976    12002    +26     
Flag Coverage Δ
unit 53.99% <78.91%> (+0.06%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Solid, well-scoped bug fix for the "Inconsistent type on If node" crash, with 5 new regression tests (Gemm-with-outer-input, Resize-scales, chained-If capture, Constant→Cast value_info refresh) parameterized over fp16/bf16/keep_io_types/standalone-type-inference, and the existing subgraph tests still align with the new logic. Comments are clear and the policy is documented.

Worth a human reviewer's eye:

  • Semantic shift in subgraph init conversion. Previously, when an If/Loop/Scan parent was classified low-precision, every float subgraph initializer was converted. The new rule (_convert_subgraph_precision) only converts an initializer if its consuming node has at least one low-precision-eligible init and every other input is provably non-float. Any unknown-type input (e.g. tensor produced by an op the type-info map doesn't cover) demotes the node to high precision. For real models this is the desired behavior, but it's a public-facing behavioral change for any downstream user who depended on the old "blindly inherit parent precision" semantics — please confirm this is intended and call out in the changelog if so.

  • Loop/Scan formal-input precision. _current_float_type returns high.onnx_type for any float formal_inputs (subgraph inputs). For Loop body state vars or Scan scan_inputs, the actual precision is whatever the main-graph passes in — which could be low precision if the parent Loop is low. In practice the new node classifier will keep such bodies in high precision (formal float inputs prevent low classification), but this assumption isn't asserted anywhere and there's no Loop/Scan regression test in the suite. Worth a quick sanity check on a Loop-body model.

  • Topological assumption when assembling casts. In step 3 the code does next(n.name for n in subgraph.node if input_name in n.output) to find a producer, then places the cast immediately after that producer when reassembling subgraph.node. This relies on the original subgraph being topologically sorted (the cast appearing after its producer in the new list also requires the producer to come before all current consumers in iteration order). Subgraphs from helper.make_graph and most exporters are sorted, but if a user feeds an unsorted subgraph the cast can land before a consumer it's supposed to feed. Cheap fix: do a gs.toposort (already done elsewhere in this file) on the rebuilt node list, or assert sortedness.

  • O(N²) producer lookup. The next(n.name for n in subgraph.node if input_name in n.output) runs per cast and re-scans subgraph.node each time. Fine for the bug-report models, but a single producer→name dict built once before the loop would be cleaner.

  • Constant→Cast value_info refresh in utils.remove_redundant_casts. The new loop only updates the first matching value_info entry and only if one already exists — Constants frequently have no value_info entry at all, in which case nothing is refreshed. That's fine for the bug at hand (the test verifies strict shape inference passes), but worth confirming no consumer relies on a now-missing/already-typed value_info for c0.

No licensing changes; PR size (~500 LoC, half tests) is reasonable and the changes are cohesive.

)
if producer is not None:
producer_node_name = next(n.name for n in subgraph.node if input_name in n.output)
cast_nodes_after_producer[producer_node_name].append(cast_node)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Cheap perf cleanup: this next(...) re-scans subgraph.node once per inserted cast. You already have local_produced above and the new node list iteration; build a producer_by_output: dict[str, str] once before this loop (or include the producer node name when populating casts_to_insert) to avoid the O(casts × nodes) scan.


new_nodes = list(leading_casts)
for node in subgraph.node:
new_nodes.append(node)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Consider asserting / enforcing topological order before reassembling. The cast is placed right after its producer, but a downstream consumer of this cast that happens to appear earlier in subgraph.node (possible if the subgraph was hand-built and not toposorted) would now see the rewritten input name without the Cast having been emitted yet. A gs.import_onnx/toposort pass on the rebuilt subgraph (or an explicit assertion that producer indices < consumer indices) would harden this.

Non-float tensors (int indices/axes/shapes, bool conditions) must never be cast.
"""
if input_name in local_produced:
elem_type = local_elem_type.get(input_name)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

_current_float_type returns high.onnx_type unconditionally for any float formal_inputs. For Loop body state vars / Scan scan_inputs, the actual precision is whatever the main graph hands the parent Loop/Scan node — which can be low precision. The new node-classification rule mostly papers over this (a float formal input demotes the node to high precision), but it would be good to add at least one Loop-body regression test, since none of the new tests exercise Loop/Scan.

return main_producer_precision.get(name, base_type)

def _convert_subgraph_callback(
graph: onnx.GraphProto, parent: onnx.NodeProto, is_subgraph: bool
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Worth calling out in the docstring (and ideally CHANGELOG) that this is a behavioral change for subgraph initializer precision: previously a low-precision parent caused all subgraph float initializers to be converted, now conversion is gated on the consuming node having no float activation/outer-scope/unknown-typed inputs. Anyone relying on the old behavior to get fp16 weights inside an If branch with an outer-scope activation will see those weights stay fp32 after this PR (which is the correct fix, but the change in observable output deserves a note).

Comment thread modelopt/onnx/utils.py
# conflicts with the now-converted constant value.
cast_to_type = get_cast_to_type(node)
for vi in onnx_model.graph.value_info:
if vi.name == constant_producer.output[0]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

This only updates the first value_info entry that matches and silently does nothing if no entry exists. Constants frequently lack a value_info entry until type inference runs. Fine for the regression test (which immediately re-runs strict shape inference), but if a caller of remove_redundant_casts doesn't follow up with type inference, downstream consumers may still see no/stale type for the folded constant. Consider creating a value_info entry when one is missing, or document that callers must re-infer types after folding.

@ajrasane ajrasane changed the title fix(onnx): consistent types on If/Loop/Scan subgraphs during FP16/BF16 conversion [6058841] Consistent types on If/Loop/Scan subgraphs during FP16/BF16 conversion Jun 4, 2026
…6 conversion

ONNX `--high_precision_dtype fp16` (and the autocast PrecisionConverter) used to
blindly convert every control-flow subgraph initializer to the parent node's
precision. This produced inconsistent types and broke shape inference / TensorRT
strongly-typed parsing on models with `If`/`Loop`/`Scan` subgraphs:

- A `Gemm` inside an `If` branch reading an outer-scope activation got fp16
  weights but an fp32 activation -> "B has inconsistent type".
- `Resize` `scales` (which must stay fp32) was converted to fp16 -> "ParseData
  type mismatch".

The converter now keeps a subgraph node in low precision only when all of its
float inputs are subgraph initializers eligible for low precision; any node with
a float activation/outer-scope input, or an input that must remain high precision
per the ONNX spec, stays in high precision so each node's inputs share one
precision. Float outer-scope captures and low->high precision boundaries inside a
subgraph are reconciled with `Cast` nodes, captured-tensor `value_info` is synced
to the capturing tensor's real precision, control-flow node outputs are treated as
high precision (their subgraph bodies are kept high precision), and `Constant`
folding refreshes the constant's `value_info` so a same-type-constrained consumer
(e.g. `Greater`) is not left with a stale, conflicting type.

Validated on both reported models: both convert successfully, pass strict
`onnx.shape_inference` type checking, load in ONNX Runtime, and match the FP32
reference outputs within FP16 tolerance.

Fixes bug 6058841

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
@ajrasane ajrasane force-pushed the arasane/nvbug-6058841-onnx-fp16-if-subgraph branch from 80a7500 to e7201c3 Compare June 4, 2026 22:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants