From e7201c337be05bd2f8bee26146eac17b8900cd24 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Thu, 4 Jun 2026 19:13:35 +0000 Subject: [PATCH] fix(onnx): consistent types on If/Loop/Scan subgraphs during FP16/BF16 conversion ONNX `--high_precision_dtype fp16` (and the autocast PrecisionConverter) used to blindly convert every control-flow subgraph initializer to the parent node's precision. This produced inconsistent types and broke shape inference / TensorRT strongly-typed parsing on models with `If`/`Loop`/`Scan` subgraphs: - A `Gemm` inside an `If` branch reading an outer-scope activation got fp16 weights but an fp32 activation -> "B has inconsistent type". - `Resize` `scales` (which must stay fp32) was converted to fp16 -> "ParseData type mismatch". The converter now keeps a subgraph node in low precision only when all of its float inputs are subgraph initializers eligible for low precision; any node with a float activation/outer-scope input, or an input that must remain high precision per the ONNX spec, stays in high precision so each node's inputs share one precision. Float outer-scope captures and low->high precision boundaries inside a subgraph are reconciled with `Cast` nodes, captured-tensor `value_info` is synced to the capturing tensor's real precision, control-flow node outputs are treated as high precision (their subgraph bodies are kept high precision), and `Constant` folding refreshes the constant's `value_info` so a same-type-constrained consumer (e.g. `Greater`) is not left with a stale, conflicting type. Validated on both reported models: both convert successfully, pass strict `onnx.shape_inference` type checking, load in ONNX Runtime, and match the FP32 reference outputs within FP16 tolerance. Fixes bug 6058841 Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- CHANGELOG.rst | 1 + modelopt/onnx/autocast/precisionconverter.py | 284 ++++++++++++++++-- modelopt/onnx/utils.py | 8 + .../onnx/autocast/test_precisionconverter.py | 234 +++++++++++++++ 4 files changed, 500 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0e4f5517818..9446f78980c 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -53,6 +53,7 @@ Changelog - In Megatron-Core only do EP amax sync for routed expert weights if ``sync_expert_weight_amax=True``. Previously EP amax sync would sync routed expert weights across EP ranks even when ``sync_expert_weight_amax`` was False. - Fix Megatron-Core HF importer to load fused ``TELayerNormColumnParallelLinear.layer_norm_weight`` from HF for GPT-family models (Qwen3 etc.) under ``--export-default-te-spec``. Importer now prefers per-context keys ``fused_input_layernorm`` / ``fused_pre_mlp_layernorm`` (fallback ``fused_norm`` for Nemotron-H backward compatibility); ``mcore_qwen.py`` provides the new rules. Without this fix, post-prune MMLU sat at chance. +- Fix ONNX FP16/BF16 conversion (``--high_precision_dtype fp16``) producing inconsistent tensor types on models with control-flow ``If``/``Loop``/``Scan`` subgraphs (e.g. a ``Gemm`` reading an outer-scope activation alongside converted weights, or ``Resize`` ``scales`` that must stay FP32). Subgraph nodes now only run in low precision when all their float inputs are subgraph initializers; outer-scope captures and low-to-high-precision boundaries inside subgraphs are reconciled with ``Cast`` nodes, and ``Constant`` folding refreshes the constant's ``value_info`` so strongly-typed parsers (TensorRT) no longer reject the model. **Deprecations** diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 3d6cb2a849e..2fef940aaed 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -65,6 +65,9 @@ class InitializerConsumerTracker: ONNX_TYPES = [t.onnx_type for t in PRECISION_MAP.values()] +# Reverse mapping from ONNX tensor type to its PrecisionTypes entry (e.g. TensorProto.FLOAT -> fp32). +ONNX_TYPE_TO_PRECISION = {t.onnx_type: t for t in PRECISION_MAP.values()} + OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Upsample", "NonMaxSuppression", "Celu"] # Mapping of op types to indices of inputs that should not be converted to low precision. @@ -748,11 +751,20 @@ def _convert_initializers( def _convert_initializers_recursive( self, low_precision_nodes: list[str], high_precision_nodes: list[str] ) -> None: - """Convert initializers in main graph and all subgraphs to appropriate precision. + """Convert initializers in the main graph and reconcile precision inside all subgraphs. + + For the main graph, uses consumer tracking to determine each initializer's precision + (see :meth:`_convert_initializers`). - For the main graph, uses sophisticated consumer tracking to determine precision. - For subgraphs, inherits precision from the parent control flow node and converts - all initializers to that precision (no runtime casts). + Control-flow subgraphs (e.g. If/Loop/Scan bodies) do not get the activation-bracketing casts + that the main graph uses, so a subgraph node is only converted to low precision when *all* of + its float inputs are subgraph initializers that may be in low precision (i.e. the node consumes + no float activation/outer-scope tensor and none of its inputs must stay high precision per the + ONNX spec, such as ``Resize`` ``scales``). Such a node's initializers are converted to low + precision; every other subgraph node and its initializers are kept in high precision so that + each node's inputs share a single precision. Float tensors captured from the enclosing scope, + and the outputs of any low-precision subgraph node feeding a high-precision one, are reconciled + with a ``Cast`` inserted inside the subgraph. Args: low_precision_nodes: List of node names in main graph that are low precision. @@ -761,45 +773,263 @@ def _convert_initializers_recursive( # Convert main graph initializers with full consumer tracking self._convert_initializers(low_precision_nodes, high_precision_nodes) - # Convert subgraph initializers - walk all subgraphs and convert based on parent node precision + # Precompute, for each main-graph activation, the (raw) precision it has after main-graph + # conversion: a subgraph that captures it sees this raw precision, because the main-graph + # cast-up only rewires main-graph consumers (not subgraph captures). low_precision_nodes_set = set(low_precision_nodes) + main_producer_precision: dict[str, int] = {} + for node in self.model.graph.node: + # Control-flow nodes (If/Loop/Scan) execute a subgraph that is kept in high precision, so + # their outputs are high precision regardless of the node's low/high classification. + is_control_flow = any( + attr.type in (onnx.AttributeProto.GRAPH, onnx.AttributeProto.GRAPHS) + for attr in node.attribute + ) + producer_type = ( + self.low_precision_type.onnx_type + if node.name in low_precision_nodes_set and not is_control_flow + else self.high_precision_type.onnx_type + ) + for output_name in node.output: + main_producer_precision[output_name] = producer_type + + def _capture_type(name: str) -> int | None: + """Return the float precision (ONNX type) of an outer-scope tensor seen in a subgraph. + + Float-ness is read from the (pre-conversion) type info, since it is invariant under + precision conversion; the precision is then taken from the producing main-graph node. + Network inputs use their declared type in ``value_info_map`` (already updated in place + when ``keep_io_types`` is False). Returns None for non-float or unknown tensors. + """ + if name in self.initializer_map: + base_type = self.initializer_map[name].data_type + elif name in self.value_info_map: + base_type = self.value_info_map[name].type.tensor_type.elem_type + else: + return None + if base_type not in ONNX_TYPES: + return None + return main_producer_precision.get(name, base_type) def _convert_subgraph_callback( graph: onnx.GraphProto, parent: onnx.NodeProto, is_subgraph: bool ) -> None: if not is_subgraph or parent is None: return + parent_is_low_precision = parent.name in low_precision_nodes_set + self._convert_subgraph_precision(graph, parent_is_low_precision, _capture_type) - # Inherit precision from parent control flow node - target_type = ( - self.low_precision_type - if parent.name in low_precision_nodes_set - else self.high_precision_type + utils.walk_subgraphs_recursive(self.model.graph, _convert_subgraph_callback) + + def _convert_subgraph_precision( + self, subgraph: onnx.GraphProto, parent_is_low_precision: bool, capture_type_fn + ) -> None: + """Convert a single control-flow subgraph to a consistent precision. + + See :meth:`_convert_initializers_recursive` for the conversion policy. + + Args: + subgraph: The subgraph (e.g. an If branch or a Loop/Scan body) to convert. + parent_is_low_precision: Whether the parent control-flow node is low precision. + capture_type_fn: Maps an outer-scope tensor name to its float ONNX type (or None). + """ + target = self.low_precision_type + high = self.high_precision_type + + local_inits = {init.name: init for init in subgraph.initializer} + local_produced = {out for node in subgraph.node for out in node.output} + formal_inputs = {inp.name for inp in subgraph.input} + # Original (pre-conversion) element types of subgraph-local tensors. Whether a tensor is + # float is invariant under fp32<->fp16 conversion, so this is a reliable "is this float?" + # source that prevents casting non-float tensors (e.g. int axes/indices/shapes). + local_elem_type = {vi.name: vi.type.tensor_type.elem_type for vi in subgraph.value_info} + local_elem_type.update({vi.name: vi.type.tensor_type.elem_type for vi in subgraph.output}) + + # Map of tensor name -> list of (node, input index) consumers within this subgraph. + consumers: dict[str, list[InputIndexTracker]] = defaultdict(list) + for node in subgraph.node: + for idx, input_name in enumerate(node.input): + if input_name: + consumers[input_name].append(InputIndexTracker(node=node, node_index=idx)) + + def _is_low_precision_eligible_init(node: onnx.NodeProto, input_name: str) -> bool: + init = local_inits.get(input_name) + return ( + init is not None + and init.data_type in ONNX_TYPES + and not self._should_skip_low_precision_input_conversion(node, input_name) ) - # Convert all float initializers to target precision - for init in graph.initializer: - if init.data_type not in ONNX_TYPES or init.data_type == target_type.onnx_type: - continue + def _known_elem_type(input_name: str) -> int | None: + """Best-effort (pre-conversion) element type of a tensor visible here, or None.""" + if input_name in local_inits: + return local_inits[input_name].data_type + if input_name in local_elem_type: + return local_elem_type[input_name] + if input_name in self.value_info_map: + return self.value_info_map[input_name].type.tensor_type.elem_type + if input_name in self.initializer_map: + return self.initializer_map[input_name].data_type + return None - from_type = ( - self.high_precision_type - if init.data_type == self.high_precision_type.onnx_type - else self.low_precision_type - if init.data_type == self.low_precision_type.onnx_type - else None + # 1. Classify each subgraph node. A node is converted to low precision only if it has at least + # one low-precision-eligible float initializer input and every other input is a tensor we + # know is non-float (e.g. int axes/indices). Any float activation/outer-scope input (or an + # input of unknown type) keeps the node in high precision, since no bracketing casts are + # inserted around subgraph activations. + node_is_low: dict[str, bool] = {} + for node in subgraph.node: + low = parent_is_low_precision and ( + node.op_type not in self.op_types_not_supported_in_low_precision + ) + has_low_precision_init = False + if low: + for input_name in node.input: + if not input_name: + continue + if _is_low_precision_eligible_init(node, input_name): + has_low_precision_init = True + continue + elem_type = _known_elem_type(input_name) + if elem_type is None or elem_type in ONNX_TYPES: + # Float or unknown non-initializer input: keep the node in high precision. + low = False + break + node_is_low[node.name] = low and has_low_precision_init + + # 2. Convert the initializers consumed by low-precision nodes to low precision. An initializer + # shared by low- and high-precision nodes is duplicated so each consumer keeps one precision. + for init in list(subgraph.initializer): + if init.data_type not in ONNX_TYPES: + continue + from_type = ONNX_TYPE_TO_PRECISION.get(init.data_type) + if from_type is None: + continue + low_consumers: list[InputIndexTracker] = [] + high_consumers: list[InputIndexTracker] = [] + for c in consumers.get(init.name, []): + if node_is_low.get(c.node.name) and _is_low_precision_eligible_init( + c.node, init.name + ): + low_consumers.append(c) + else: + high_consumers.append(c) + + if not low_consumers: + # Keep in high precision (covers Resize-scales-style inputs and unused initializers). + if init.data_type != high.onnx_type: + init.CopyFrom(self._convert_initializer_data(init, from_type, high)) + elif not high_consumers: + # Convert the single-precision initializer in place. + if init.data_type != target.onnx_type: + init.CopyFrom(self._convert_initializer_data(init, from_type, target)) + else: + # Shared: keep the original high precision and add a low-precision duplicate. + low_init = self._convert_initializer_data(init, from_type, target) + low_init.name = f"{init.name}_{target.str_short}" + subgraph.initializer.extend([low_init]) + for consumer in low_consumers: + consumer.node.input[consumer.node_index] = low_init.name + if init.data_type != high.onnx_type: + init.CopyFrom(self._convert_initializer_data(init, from_type, high)) + + # 3. Reconcile float tensors whose precision does not match the consuming node: outer-scope + # captures and the outputs of low-precision nodes feeding high-precision ones. Casts are + # collected first and inserted afterwards to avoid mutating the node list while iterating. + local_produced_low = { + out for node in subgraph.node if node_is_low.get(node.name) for out in node.output + } + + def _current_float_type(input_name: str) -> int | None: + """Float precision (ONNX type) of a subgraph input tensor, or None if it is non-float. + + Non-float tensors (int indices/axes/shapes, bool conditions) must never be cast. + """ + if input_name in local_produced: + elem_type = local_elem_type.get(input_name) + if elem_type is not None and elem_type not in ONNX_TYPES: + return None # known non-float subgraph activation + if input_name in local_produced_low: + return target.onnx_type # output of a converted low-precision node + return high.onnx_type if elem_type in ONNX_TYPES else None + if input_name in formal_inputs: + elem_type = local_elem_type.get(input_name) + return high.onnx_type if elem_type in ONNX_TYPES else None + return capture_type_fn(input_name) # outer-scope capture (None if non-float) + + # (tensor name, target onnx type) -> (cast output name, producer node name or None) + casts_to_insert: dict[tuple[str, int], tuple[str, str | None]] = {} + rewrites: list[tuple[InputIndexTracker, str]] = [] + # Float outer-scope captures and the (current) precision they have in the enclosing scope. + captured_types: dict[str, int] = {} + for node in subgraph.node: + node_low = node_is_low.get(node.name, False) + for idx, input_name in enumerate(node.input): + if not input_name or input_name in local_inits: + continue + current = _current_float_type(input_name) + if current is None: + continue # non-float tensor: never cast + if input_name not in local_produced and input_name not in formal_inputs: + captured_types[input_name] = current + desired = ( + target.onnx_type + if node_low + and not self._should_skip_low_precision_input_conversion(node, input_name) + else high.onnx_type ) + if current == desired: + continue - if from_type is None: - logger.debug( - f"Skipping subgraph initializer {init.name} with unsupported type {init.data_type}" + key = (input_name, desired) + if key not in casts_to_insert: + short = ONNX_TYPE_TO_PRECISION[desired].str_short + producer = ( + input_name + if input_name in local_produced + else None # outer-scope capture (produced outside this subgraph) ) - continue + casts_to_insert[key] = ( + f"{input_name}_subgraph_cast_to_{short}", + producer, + ) + rewrites.append( + (InputIndexTracker(node=node, node_index=idx), casts_to_insert[key][0]) + ) - new_init = self._convert_initializer_data(init, from_type, target_type) - init.CopyFrom(new_init) + # Sync any preserved outer-scope value_info inside the subgraph with the capture's current + # main-graph precision. Otherwise a stale type (e.g. fp32 for a tensor the main graph now + # produces in fp16) makes strongly-typed parsers (and ORT) reject the If subgraph. + for vi in subgraph.value_info: + if vi.name in captured_types and vi.type.HasField("tensor_type"): + vi.type.tensor_type.elem_type = captured_types[vi.name] - utils.walk_subgraphs_recursive(self.model.graph, _convert_subgraph_callback) + if not casts_to_insert: + return + + for tracker, cast_output in rewrites: + tracker.node.input[tracker.node_index] = cast_output + + # Build the cast nodes and re-assemble the subgraph node list so each cast appears after its + # producer (captures, produced outside the subgraph, are placed at the front). + cast_nodes_after_producer: dict[str, list[onnx.NodeProto]] = defaultdict(list) + leading_casts: list[onnx.NodeProto] = [] + for (input_name, desired), (cast_output, producer) in casts_to_insert.items(): + cast_node = helper.make_node( + "Cast", inputs=[input_name], outputs=[cast_output], to=desired, name=cast_output + ) + if producer is not None: + producer_node_name = next(n.name for n in subgraph.node if input_name in n.output) + cast_nodes_after_producer[producer_node_name].append(cast_node) + else: + leading_casts.append(cast_node) + + new_nodes = list(leading_casts) + for node in subgraph.node: + new_nodes.append(node) + new_nodes.extend(cast_nodes_after_producer.get(node.name, [])) + del subgraph.node[:] + subgraph.node.extend(new_nodes) def _convert_initializer_data( self, diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index 951ce2cc98c..bc6f6bcefc5 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -1558,6 +1558,14 @@ def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: assert len(cast_producers) == 1 and cast_producers[0].op_type == "Constant" constant_producer = cast_producers[0] _convert_constant_values(constant_producer, node) + # Folding changes the Constant output's element type; refresh its value_info so + # downstream consumers (and strongly-typed parsers) don't see a stale type that + # conflicts with the now-converted constant value. + cast_to_type = get_cast_to_type(node) + for vi in onnx_model.graph.value_info: + if vi.name == constant_producer.output[0]: + vi.type.tensor_type.elem_type = cast_to_type + break _bypass_cast_node(onnx_model, node) logger.debug(f"Found foldable Constant->Cast pattern, removing {node.name}") diff --git a/tests/unit/onnx/autocast/test_precisionconverter.py b/tests/unit/onnx/autocast/test_precisionconverter.py index c3e1a51db51..476fc8aebfd 100644 --- a/tests/unit/onnx/autocast/test_precisionconverter.py +++ b/tests/unit/onnx/autocast/test_precisionconverter.py @@ -1709,3 +1709,237 @@ def test_if_subgraph_outer_scope_type_preservation( assert len(else_x_info) > 0, "X value_info should be preserved in else branch" assert then_x_info[0].type.tensor_type.elem_type != onnx.TensorProto.UNDEFINED assert else_x_info[0].type.tensor_type.elem_type != onnx.TensorProto.UNDEFINED + + +#################################################################################################### +# Regression tests for bug 6058841: inconsistent tensor types on control-flow If nodes during +# ONNX FP16/BF16 conversion. +# +# Converting a model with control-flow subgraphs to FP16 used to blindly convert every subgraph +# initializer to the parent node's precision, which broke models where a subgraph node also consumes +# a float activation/outer-scope tensor (e.g. a Gemm reading a network input) or whose inputs must +# stay in high precision per the ONNX spec (e.g. Resize 'scales'). +#################################################################################################### +@pytest.fixture +def model_if_subgraph_gemm_outer_input(): + """If branches with a Gemm consuming an outer-scope input plus subgraph weight initializers.""" + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 4]) + condition = helper.make_tensor_value_info("condition", TensorProto.BOOL, []) + y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3]) + + def _branch(name): + w = numpy_helper.from_array(np.random.randn(4, 3).astype(np.float32), name=f"w_{name}") + b = numpy_helper.from_array(np.random.randn(3).astype(np.float32), name=f"b_{name}") + out = helper.make_tensor_value_info(f"{name}_out", TensorProto.FLOAT, [1, 3]) + gemm = helper.make_node( + "Gemm", ["X", f"w_{name}", f"b_{name}"], [f"{name}_out"], name=f"{name}_gemm" + ) + return helper.make_graph([gemm], f"{name}_branch", [], [out], [w, b]) + + if_node = helper.make_node( + "If", + ["condition"], + ["Y"], + name="if_node", + then_branch=_branch("then"), + else_branch=_branch("else"), + ) + main_graph = helper.make_graph([if_node], "model_if_gemm", [x, condition], [y]) + model = helper.make_model(main_graph, producer_name="model_if_gemm") + model.opset_import[0].version = 20 + model.ir_version = 10 + onnx.checker.check_model(model) + return setup_mappings(model) + + +@pytest.mark.parametrize("keep_io_types", [True, False]) +@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_if_subgraph_gemm_with_outer_scope_input( + model_if_subgraph_gemm_outer_input, + keep_io_types, + low_precision_type, + use_standalone_type_inference, +): + """A Gemm inside an If branch consuming an outer-scope input must not end up with fp16 weights + feeding alongside an fp32 activation (regression test for bug 6058841).""" + model, value_info_map, initializer_map, node_to_init_map = model_if_subgraph_gemm_outer_input + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=keep_io_types, + low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, + ) + converted_model = converter.convert(high_precision_nodes=[], low_precision_nodes=["if_node"]) + onnx.checker.check_model(converted_model) + # Strict type checking must pass; this is what failed before the fix. + onnx.shape_inference.infer_shapes(converted_model, strict_mode=True, check_type=True) + + +@pytest.fixture +def model_if_subgraph_resize(): + """If branches containing a Resize whose 'roi'/'scales' inputs must remain in high precision.""" + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 8, 8]) + condition = helper.make_tensor_value_info("condition", TensorProto.BOOL, []) + y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 16, 16]) + + def _branch(name): + roi = numpy_helper.from_array(np.array([], dtype=np.float32), name=f"roi_{name}") + scales = numpy_helper.from_array( + np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32), name=f"scales_{name}" + ) + out = helper.make_tensor_value_info(f"{name}_out", TensorProto.FLOAT, [1, 3, 16, 16]) + resize = helper.make_node( + "Resize", + ["X", f"roi_{name}", f"scales_{name}"], + [f"{name}_out"], + name=f"{name}_resize", + mode="nearest", + ) + return helper.make_graph([resize], f"{name}_branch", [], [out], [roi, scales]) + + if_node = helper.make_node( + "If", + ["condition"], + ["Y"], + name="if_node", + then_branch=_branch("then"), + else_branch=_branch("else"), + ) + main_graph = helper.make_graph([if_node], "model_if_resize", [x, condition], [y]) + model = helper.make_model(main_graph, producer_name="model_if_resize") + model.opset_import[0].version = 20 + model.ir_version = 10 + onnx.checker.check_model(model) + return setup_mappings(model) + + +@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_if_subgraph_resize_scales_stay_high_precision( + model_if_subgraph_resize, low_precision_type, use_standalone_type_inference +): + """Resize 'scales' inside an If branch must remain FP32 (regression test for bug 6058841).""" + model, value_info_map, initializer_map, node_to_init_map = model_if_subgraph_resize + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=True, + low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, + ) + converted_model = converter.convert(high_precision_nodes=[], low_precision_nodes=["if_node"]) + onnx.checker.check_model(converted_model) + onnx.shape_inference.infer_shapes(converted_model, strict_mode=True, check_type=True) + + # The 'scales' (and 'roi') initializers must stay FP32 in both branches. + if_node = next(n for n in converted_model.graph.node if n.op_type == "If") + for attr in if_node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + scales = [init for init in attr.g.initializer if init.name.startswith("scales_")] + assert scales, "scales initializer should be present in the branch" + for init in scales: + assert init.data_type == TensorProto.FLOAT, ( + f"Resize scales must remain FP32, but '{init.name}' is {init.data_type}" + ) + + +@pytest.fixture +def model_chained_if_capture(): + """Two chained If nodes; the second's subgraph captures the first If node's output.""" + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 4]) + cond1 = helper.make_tensor_value_info("cond1", TensorProto.BOOL, []) + cond2 = helper.make_tensor_value_info("cond2", TensorProto.BOOL, []) + y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3]) + + def _gemm_branch(name, data, k, n): + w = numpy_helper.from_array(np.random.randn(k, n).astype(np.float32), name=f"w_{name}") + out = helper.make_tensor_value_info(f"{name}_out", TensorProto.FLOAT, [1, n]) + gemm = helper.make_node("Gemm", [data, f"w_{name}"], [f"{name}_out"], name=f"{name}_gemm") + return helper.make_graph([gemm], f"{name}_branch", [], [out], [w]) + + if1 = helper.make_node( + "If", + ["cond1"], + ["mid"], + name="if1", + then_branch=_gemm_branch("then1", "X", 4, 3), + else_branch=_gemm_branch("else1", "X", 4, 3), + ) + if2 = helper.make_node( + "If", + ["cond2"], + ["Y"], + name="if2", + then_branch=_gemm_branch("then2", "mid", 3, 3), + else_branch=_gemm_branch("else2", "mid", 3, 3), + ) + main_graph = helper.make_graph([if1, if2], "chained_if", [x, cond1, cond2], [y]) + model = helper.make_model(main_graph, producer_name="chained_if") + model.opset_import[0].version = 20 + model.ir_version = 10 + onnx.checker.check_model(model) + return setup_mappings(model) + + +@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +def test_chained_if_subgraph_capture(model_chained_if_capture, low_precision_type): + """An If subgraph capturing another control-flow node's output must reconcile its precision + (regression test for bug 6058841; an If subgraph capturing another If node's output).""" + model, value_info_map, initializer_map, node_to_init_map = model_chained_if_capture + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=True, + low_precision_type=low_precision_type, + ) + converted_model = converter.convert(high_precision_nodes=[], low_precision_nodes=["if1", "if2"]) + onnx.checker.check_model(converted_model) + onnx.shape_inference.infer_shapes(converted_model, strict_mode=True, check_type=True) + + +@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +def test_constant_cast_fold_refreshes_value_info(low_precision_type): + """Folding a Constant->Cast must refresh the constant's value_info, otherwise a same-type + constrained consumer (e.g. Greater) sees a stale, conflicting type and strict type inference + fails (regression test for bug 6058841; Constant feeding a same-type-constrained Greater).""" + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4]) + y = helper.make_tensor_value_info("Y", TensorProto.BOOL, [4]) + w = numpy_helper.from_array(np.ones([4], dtype=np.float32), name="w") + nodes = [ + helper.make_node("Mul", ["X", "w"], ["m0"], name="mul0"), + helper.make_node( + "Constant", + [], + ["c0"], + name="const0", + value=numpy_helper.from_array(np.array([0.5, 0.5, 0.5, 0.5], dtype=np.float32), "cv"), + ), + helper.make_node("Greater", ["m0", "c0"], ["Y"], name="greater0"), + ] + graph = helper.make_graph(nodes, "const_greater", [x], [y], [w]) + model = helper.make_model(graph, producer_name="const_greater") + model.opset_import[0].version = 20 + model.ir_version = 10 + onnx.checker.check_model(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=True, + low_precision_type=low_precision_type, + ) + converted_model = converter.convert( + high_precision_nodes=[], low_precision_nodes=["mul0", "greater0"] + ) + onnx.checker.check_model(converted_model) + onnx.shape_inference.infer_shapes(converted_model, strict_mode=True, check_type=True)