[6058841] Consistent types on If/Loop/Scan subgraphs during FP16/BF16 conversion#1628
[6058841] Consistent types on If/Loop/Scan subgraphs during FP16/BF16 conversion#1628ajrasane wants to merge 1 commit into
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1628 +/- ##
========================================
Coverage 77.38% 77.39%
========================================
Files 482 482
Lines 52960 53096 +136
========================================
+ Hits 40984 41094 +110
- Misses 11976 12002 +26
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Solid, well-scoped bug fix for the "Inconsistent type on If node" crash, with 5 new regression tests (Gemm-with-outer-input, Resize-scales, chained-If capture, Constant→Cast value_info refresh) parameterized over fp16/bf16/keep_io_types/standalone-type-inference, and the existing subgraph tests still align with the new logic. Comments are clear and the policy is documented.
Worth a human reviewer's eye:
-
Semantic shift in subgraph init conversion. Previously, when an If/Loop/Scan parent was classified low-precision, every float subgraph initializer was converted. The new rule (
_convert_subgraph_precision) only converts an initializer if its consuming node has at least one low-precision-eligible init and every other input is provably non-float. Any unknown-type input (e.g. tensor produced by an op the type-info map doesn't cover) demotes the node to high precision. For real models this is the desired behavior, but it's a public-facing behavioral change for any downstream user who depended on the old "blindly inherit parent precision" semantics — please confirm this is intended and call out in the changelog if so. -
Loop/Scanformal-input precision._current_float_typereturnshigh.onnx_typefor any floatformal_inputs(subgraph inputs). ForLoopbody state vars orScanscan_inputs, the actual precision is whatever the main-graph passes in — which could be low precision if the parent Loop is low. In practice the new node classifier will keep such bodies in high precision (formal float inputs prevent low classification), but this assumption isn't asserted anywhere and there's no Loop/Scan regression test in the suite. Worth a quick sanity check on a Loop-body model. -
Topological assumption when assembling casts. In step 3 the code does
next(n.name for n in subgraph.node if input_name in n.output)to find a producer, then places the cast immediately after that producer when reassemblingsubgraph.node. This relies on the original subgraph being topologically sorted (the cast appearing after its producer in the new list also requires the producer to come before all current consumers in iteration order). Subgraphs fromhelper.make_graphand most exporters are sorted, but if a user feeds an unsorted subgraph the cast can land before a consumer it's supposed to feed. Cheap fix: do ags.toposort(already done elsewhere in this file) on the rebuilt node list, or assert sortedness. -
O(N²) producer lookup. The
next(n.name for n in subgraph.node if input_name in n.output)runs per cast and re-scanssubgraph.nodeeach time. Fine for the bug-report models, but a single producer→name dict built once before the loop would be cleaner. -
Constant→Cast
value_inforefresh inutils.remove_redundant_casts. The new loop only updates the first matchingvalue_infoentry and only if one already exists — Constants frequently have novalue_infoentry at all, in which case nothing is refreshed. That's fine for the bug at hand (the test verifies strict shape inference passes), but worth confirming no consumer relies on a now-missing/already-typedvalue_infoforc0.
No licensing changes; PR size (~500 LoC, half tests) is reasonable and the changes are cohesive.
| ) | ||
| if producer is not None: | ||
| producer_node_name = next(n.name for n in subgraph.node if input_name in n.output) | ||
| cast_nodes_after_producer[producer_node_name].append(cast_node) |
There was a problem hiding this comment.
Bot comment.
Cheap perf cleanup: this next(...) re-scans subgraph.node once per inserted cast. You already have local_produced above and the new node list iteration; build a producer_by_output: dict[str, str] once before this loop (or include the producer node name when populating casts_to_insert) to avoid the O(casts × nodes) scan.
|
|
||
| new_nodes = list(leading_casts) | ||
| for node in subgraph.node: | ||
| new_nodes.append(node) |
There was a problem hiding this comment.
Bot comment.
Consider asserting / enforcing topological order before reassembling. The cast is placed right after its producer, but a downstream consumer of this cast that happens to appear earlier in subgraph.node (possible if the subgraph was hand-built and not toposorted) would now see the rewritten input name without the Cast having been emitted yet. A gs.import_onnx/toposort pass on the rebuilt subgraph (or an explicit assertion that producer indices < consumer indices) would harden this.
| Non-float tensors (int indices/axes/shapes, bool conditions) must never be cast. | ||
| """ | ||
| if input_name in local_produced: | ||
| elem_type = local_elem_type.get(input_name) |
There was a problem hiding this comment.
Bot comment.
_current_float_type returns high.onnx_type unconditionally for any float formal_inputs. For Loop body state vars / Scan scan_inputs, the actual precision is whatever the main graph hands the parent Loop/Scan node — which can be low precision. The new node-classification rule mostly papers over this (a float formal input demotes the node to high precision), but it would be good to add at least one Loop-body regression test, since none of the new tests exercise Loop/Scan.
| return main_producer_precision.get(name, base_type) | ||
|
|
||
| def _convert_subgraph_callback( | ||
| graph: onnx.GraphProto, parent: onnx.NodeProto, is_subgraph: bool |
There was a problem hiding this comment.
Bot comment.
Worth calling out in the docstring (and ideally CHANGELOG) that this is a behavioral change for subgraph initializer precision: previously a low-precision parent caused all subgraph float initializers to be converted, now conversion is gated on the consuming node having no float activation/outer-scope/unknown-typed inputs. Anyone relying on the old behavior to get fp16 weights inside an If branch with an outer-scope activation will see those weights stay fp32 after this PR (which is the correct fix, but the change in observable output deserves a note).
| # conflicts with the now-converted constant value. | ||
| cast_to_type = get_cast_to_type(node) | ||
| for vi in onnx_model.graph.value_info: | ||
| if vi.name == constant_producer.output[0]: |
There was a problem hiding this comment.
Bot comment.
This only updates the first value_info entry that matches and silently does nothing if no entry exists. Constants frequently lack a value_info entry until type inference runs. Fine for the regression test (which immediately re-runs strict shape inference), but if a caller of remove_redundant_casts doesn't follow up with type inference, downstream consumers may still see no/stale type for the folded constant. Consider creating a value_info entry when one is missing, or document that callers must re-infer types after folding.
…6 conversion ONNX `--high_precision_dtype fp16` (and the autocast PrecisionConverter) used to blindly convert every control-flow subgraph initializer to the parent node's precision. This produced inconsistent types and broke shape inference / TensorRT strongly-typed parsing on models with `If`/`Loop`/`Scan` subgraphs: - A `Gemm` inside an `If` branch reading an outer-scope activation got fp16 weights but an fp32 activation -> "B has inconsistent type". - `Resize` `scales` (which must stay fp32) was converted to fp16 -> "ParseData type mismatch". The converter now keeps a subgraph node in low precision only when all of its float inputs are subgraph initializers eligible for low precision; any node with a float activation/outer-scope input, or an input that must remain high precision per the ONNX spec, stays in high precision so each node's inputs share one precision. Float outer-scope captures and low->high precision boundaries inside a subgraph are reconciled with `Cast` nodes, captured-tensor `value_info` is synced to the capturing tensor's real precision, control-flow node outputs are treated as high precision (their subgraph bodies are kept high precision), and `Constant` folding refreshes the constant's `value_info` so a same-type-constrained consumer (e.g. `Greater`) is not left with a stale, conflicting type. Validated on both reported models: both convert successfully, pass strict `onnx.shape_inference` type checking, load in ONNX Runtime, and match the FP32 reference outputs within FP16 tolerance. Fixes bug 6058841 Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
80a7500 to
e7201c3
Compare
What does this PR do?
Type of change: Bug fix
Fixes [6058841] —
python -m modelopt.onnx.quantization --high_precision_dtype fp16crashed with "Inconsistent type on If node" on models containing control-flowIf/Loop/Scansubgraphs.Root cause. The FP16/BF16
PrecisionConverterblindly converted every subgraph initializer to the parent control-flow node's precision, without the activation-bracketing casts it uses in the main graph. This left inconsistent tensor types that crash ONNX shape inference / TensorRT strongly-typed parsing:Gemminside anIfbranch reading an outer-scope activation (fp32) ended up with fp16 weights →B has inconsistent type tensor(float16).Resizescales(which must stay fp32 per the ONNX spec) was converted to fp16 →ParseData type mismatch ... Expected:float Actual:float16.Fix.
Resizescales), stays high precision — so each node's inputs share a single precision.Castnodes; a captured tensor's preserved subgraphvalue_infois synced to its real main-graph precision; control-flow node outputs are treated as high precision (their bodies are kept high precision).Constant→Castfolding inremove_redundant_castsnow refreshes the constant'svalue_infoso a same-type-constrained consumer (e.g.Greater) isn't left with a stale, conflicting type. (Pre-existing main-graph bug surfaced once theIfmodels completed conversion.)Usage
Testing
Validated on both reported models:
infer_shapes(check_type=True)Ifbranch withGemmreading an outer-scope inputIfbranch withResize(scalesinitializer)Also verified with
keep_io_types=False. Added 5 regression tests intests/unit/onnx/autocast/test_precisionconverter.py(Gemm-with-outer-input, Resize-scales, chained-If capture, Constant-foldvalue_info) that fail without the fix and pass with it. Fulltests/unit/onnx/autocast/suite passes (214).pre-commitclean.Before your PR is "Ready for review"
CONTRIBUTING.md: N/AAdditional Information
Fixes bug [6058841]. Draft pending final review.
🤖 Generated with Claude Code