-
Notifications
You must be signed in to change notification settings - Fork 427
[6058841] Consistent types on If/Loop/Scan subgraphs during FP16/BF16 conversion #1628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,6 +65,9 @@ class InitializerConsumerTracker: | |
|
|
||
| ONNX_TYPES = [t.onnx_type for t in PRECISION_MAP.values()] | ||
|
|
||
| # Reverse mapping from ONNX tensor type to its PrecisionTypes entry (e.g. TensorProto.FLOAT -> fp32). | ||
| ONNX_TYPE_TO_PRECISION = {t.onnx_type: t for t in PRECISION_MAP.values()} | ||
|
|
||
| OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Upsample", "NonMaxSuppression", "Celu"] | ||
|
|
||
| # Mapping of op types to indices of inputs that should not be converted to low precision. | ||
|
|
@@ -748,11 +751,20 @@ def _convert_initializers( | |
| def _convert_initializers_recursive( | ||
| self, low_precision_nodes: list[str], high_precision_nodes: list[str] | ||
| ) -> None: | ||
| """Convert initializers in main graph and all subgraphs to appropriate precision. | ||
| """Convert initializers in the main graph and reconcile precision inside all subgraphs. | ||
|
|
||
| For the main graph, uses consumer tracking to determine each initializer's precision | ||
| (see :meth:`_convert_initializers`). | ||
|
|
||
| For the main graph, uses sophisticated consumer tracking to determine precision. | ||
| For subgraphs, inherits precision from the parent control flow node and converts | ||
| all initializers to that precision (no runtime casts). | ||
| Control-flow subgraphs (e.g. If/Loop/Scan bodies) do not get the activation-bracketing casts | ||
| that the main graph uses, so a subgraph node is only converted to low precision when *all* of | ||
| its float inputs are subgraph initializers that may be in low precision (i.e. the node consumes | ||
| no float activation/outer-scope tensor and none of its inputs must stay high precision per the | ||
| ONNX spec, such as ``Resize`` ``scales``). Such a node's initializers are converted to low | ||
| precision; every other subgraph node and its initializers are kept in high precision so that | ||
| each node's inputs share a single precision. Float tensors captured from the enclosing scope, | ||
| and the outputs of any low-precision subgraph node feeding a high-precision one, are reconciled | ||
| with a ``Cast`` inserted inside the subgraph. | ||
|
|
||
| Args: | ||
| low_precision_nodes: List of node names in main graph that are low precision. | ||
|
|
@@ -761,45 +773,263 @@ def _convert_initializers_recursive( | |
| # Convert main graph initializers with full consumer tracking | ||
| self._convert_initializers(low_precision_nodes, high_precision_nodes) | ||
|
|
||
| # Convert subgraph initializers - walk all subgraphs and convert based on parent node precision | ||
| # Precompute, for each main-graph activation, the (raw) precision it has after main-graph | ||
| # conversion: a subgraph that captures it sees this raw precision, because the main-graph | ||
| # cast-up only rewires main-graph consumers (not subgraph captures). | ||
| low_precision_nodes_set = set(low_precision_nodes) | ||
| main_producer_precision: dict[str, int] = {} | ||
| for node in self.model.graph.node: | ||
| # Control-flow nodes (If/Loop/Scan) execute a subgraph that is kept in high precision, so | ||
| # their outputs are high precision regardless of the node's low/high classification. | ||
| is_control_flow = any( | ||
| attr.type in (onnx.AttributeProto.GRAPH, onnx.AttributeProto.GRAPHS) | ||
| for attr in node.attribute | ||
| ) | ||
| producer_type = ( | ||
| self.low_precision_type.onnx_type | ||
| if node.name in low_precision_nodes_set and not is_control_flow | ||
| else self.high_precision_type.onnx_type | ||
| ) | ||
| for output_name in node.output: | ||
| main_producer_precision[output_name] = producer_type | ||
|
|
||
| def _capture_type(name: str) -> int | None: | ||
| """Return the float precision (ONNX type) of an outer-scope tensor seen in a subgraph. | ||
|
|
||
| Float-ness is read from the (pre-conversion) type info, since it is invariant under | ||
| precision conversion; the precision is then taken from the producing main-graph node. | ||
| Network inputs use their declared type in ``value_info_map`` (already updated in place | ||
| when ``keep_io_types`` is False). Returns None for non-float or unknown tensors. | ||
| """ | ||
| if name in self.initializer_map: | ||
| base_type = self.initializer_map[name].data_type | ||
| elif name in self.value_info_map: | ||
| base_type = self.value_info_map[name].type.tensor_type.elem_type | ||
| else: | ||
| return None | ||
| if base_type not in ONNX_TYPES: | ||
| return None | ||
| return main_producer_precision.get(name, base_type) | ||
|
|
||
| def _convert_subgraph_callback( | ||
| graph: onnx.GraphProto, parent: onnx.NodeProto, is_subgraph: bool | ||
| ) -> None: | ||
| if not is_subgraph or parent is None: | ||
| return | ||
| parent_is_low_precision = parent.name in low_precision_nodes_set | ||
| self._convert_subgraph_precision(graph, parent_is_low_precision, _capture_type) | ||
|
|
||
| # Inherit precision from parent control flow node | ||
| target_type = ( | ||
| self.low_precision_type | ||
| if parent.name in low_precision_nodes_set | ||
| else self.high_precision_type | ||
| utils.walk_subgraphs_recursive(self.model.graph, _convert_subgraph_callback) | ||
|
|
||
| def _convert_subgraph_precision( | ||
| self, subgraph: onnx.GraphProto, parent_is_low_precision: bool, capture_type_fn | ||
| ) -> None: | ||
| """Convert a single control-flow subgraph to a consistent precision. | ||
|
|
||
| See :meth:`_convert_initializers_recursive` for the conversion policy. | ||
|
|
||
| Args: | ||
| subgraph: The subgraph (e.g. an If branch or a Loop/Scan body) to convert. | ||
| parent_is_low_precision: Whether the parent control-flow node is low precision. | ||
| capture_type_fn: Maps an outer-scope tensor name to its float ONNX type (or None). | ||
| """ | ||
| target = self.low_precision_type | ||
| high = self.high_precision_type | ||
|
|
||
| local_inits = {init.name: init for init in subgraph.initializer} | ||
| local_produced = {out for node in subgraph.node for out in node.output} | ||
| formal_inputs = {inp.name for inp in subgraph.input} | ||
| # Original (pre-conversion) element types of subgraph-local tensors. Whether a tensor is | ||
| # float is invariant under fp32<->fp16 conversion, so this is a reliable "is this float?" | ||
| # source that prevents casting non-float tensors (e.g. int axes/indices/shapes). | ||
| local_elem_type = {vi.name: vi.type.tensor_type.elem_type for vi in subgraph.value_info} | ||
| local_elem_type.update({vi.name: vi.type.tensor_type.elem_type for vi in subgraph.output}) | ||
|
|
||
| # Map of tensor name -> list of (node, input index) consumers within this subgraph. | ||
| consumers: dict[str, list[InputIndexTracker]] = defaultdict(list) | ||
| for node in subgraph.node: | ||
| for idx, input_name in enumerate(node.input): | ||
| if input_name: | ||
| consumers[input_name].append(InputIndexTracker(node=node, node_index=idx)) | ||
|
|
||
| def _is_low_precision_eligible_init(node: onnx.NodeProto, input_name: str) -> bool: | ||
| init = local_inits.get(input_name) | ||
| return ( | ||
| init is not None | ||
| and init.data_type in ONNX_TYPES | ||
| and not self._should_skip_low_precision_input_conversion(node, input_name) | ||
| ) | ||
|
|
||
| # Convert all float initializers to target precision | ||
| for init in graph.initializer: | ||
| if init.data_type not in ONNX_TYPES or init.data_type == target_type.onnx_type: | ||
| continue | ||
| def _known_elem_type(input_name: str) -> int | None: | ||
| """Best-effort (pre-conversion) element type of a tensor visible here, or None.""" | ||
| if input_name in local_inits: | ||
| return local_inits[input_name].data_type | ||
| if input_name in local_elem_type: | ||
| return local_elem_type[input_name] | ||
| if input_name in self.value_info_map: | ||
| return self.value_info_map[input_name].type.tensor_type.elem_type | ||
| if input_name in self.initializer_map: | ||
| return self.initializer_map[input_name].data_type | ||
| return None | ||
|
|
||
| from_type = ( | ||
| self.high_precision_type | ||
| if init.data_type == self.high_precision_type.onnx_type | ||
| else self.low_precision_type | ||
| if init.data_type == self.low_precision_type.onnx_type | ||
| else None | ||
| # 1. Classify each subgraph node. A node is converted to low precision only if it has at least | ||
| # one low-precision-eligible float initializer input and every other input is a tensor we | ||
| # know is non-float (e.g. int axes/indices). Any float activation/outer-scope input (or an | ||
| # input of unknown type) keeps the node in high precision, since no bracketing casts are | ||
| # inserted around subgraph activations. | ||
| node_is_low: dict[str, bool] = {} | ||
| for node in subgraph.node: | ||
| low = parent_is_low_precision and ( | ||
| node.op_type not in self.op_types_not_supported_in_low_precision | ||
| ) | ||
| has_low_precision_init = False | ||
| if low: | ||
| for input_name in node.input: | ||
| if not input_name: | ||
| continue | ||
| if _is_low_precision_eligible_init(node, input_name): | ||
| has_low_precision_init = True | ||
| continue | ||
| elem_type = _known_elem_type(input_name) | ||
| if elem_type is None or elem_type in ONNX_TYPES: | ||
| # Float or unknown non-initializer input: keep the node in high precision. | ||
| low = False | ||
| break | ||
| node_is_low[node.name] = low and has_low_precision_init | ||
|
|
||
| # 2. Convert the initializers consumed by low-precision nodes to low precision. An initializer | ||
| # shared by low- and high-precision nodes is duplicated so each consumer keeps one precision. | ||
| for init in list(subgraph.initializer): | ||
| if init.data_type not in ONNX_TYPES: | ||
| continue | ||
| from_type = ONNX_TYPE_TO_PRECISION.get(init.data_type) | ||
| if from_type is None: | ||
| continue | ||
| low_consumers: list[InputIndexTracker] = [] | ||
| high_consumers: list[InputIndexTracker] = [] | ||
| for c in consumers.get(init.name, []): | ||
| if node_is_low.get(c.node.name) and _is_low_precision_eligible_init( | ||
| c.node, init.name | ||
| ): | ||
| low_consumers.append(c) | ||
| else: | ||
| high_consumers.append(c) | ||
|
|
||
| if not low_consumers: | ||
| # Keep in high precision (covers Resize-scales-style inputs and unused initializers). | ||
| if init.data_type != high.onnx_type: | ||
| init.CopyFrom(self._convert_initializer_data(init, from_type, high)) | ||
| elif not high_consumers: | ||
| # Convert the single-precision initializer in place. | ||
| if init.data_type != target.onnx_type: | ||
| init.CopyFrom(self._convert_initializer_data(init, from_type, target)) | ||
| else: | ||
| # Shared: keep the original high precision and add a low-precision duplicate. | ||
| low_init = self._convert_initializer_data(init, from_type, target) | ||
| low_init.name = f"{init.name}_{target.str_short}" | ||
| subgraph.initializer.extend([low_init]) | ||
| for consumer in low_consumers: | ||
| consumer.node.input[consumer.node_index] = low_init.name | ||
| if init.data_type != high.onnx_type: | ||
| init.CopyFrom(self._convert_initializer_data(init, from_type, high)) | ||
|
|
||
| # 3. Reconcile float tensors whose precision does not match the consuming node: outer-scope | ||
| # captures and the outputs of low-precision nodes feeding high-precision ones. Casts are | ||
| # collected first and inserted afterwards to avoid mutating the node list while iterating. | ||
| local_produced_low = { | ||
| out for node in subgraph.node if node_is_low.get(node.name) for out in node.output | ||
| } | ||
|
|
||
| def _current_float_type(input_name: str) -> int | None: | ||
| """Float precision (ONNX type) of a subgraph input tensor, or None if it is non-float. | ||
|
|
||
| Non-float tensors (int indices/axes/shapes, bool conditions) must never be cast. | ||
| """ | ||
| if input_name in local_produced: | ||
| elem_type = local_elem_type.get(input_name) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| if elem_type is not None and elem_type not in ONNX_TYPES: | ||
| return None # known non-float subgraph activation | ||
| if input_name in local_produced_low: | ||
| return target.onnx_type # output of a converted low-precision node | ||
| return high.onnx_type if elem_type in ONNX_TYPES else None | ||
| if input_name in formal_inputs: | ||
| elem_type = local_elem_type.get(input_name) | ||
| return high.onnx_type if elem_type in ONNX_TYPES else None | ||
| return capture_type_fn(input_name) # outer-scope capture (None if non-float) | ||
|
|
||
| # (tensor name, target onnx type) -> (cast output name, producer node name or None) | ||
| casts_to_insert: dict[tuple[str, int], tuple[str, str | None]] = {} | ||
| rewrites: list[tuple[InputIndexTracker, str]] = [] | ||
| # Float outer-scope captures and the (current) precision they have in the enclosing scope. | ||
| captured_types: dict[str, int] = {} | ||
| for node in subgraph.node: | ||
| node_low = node_is_low.get(node.name, False) | ||
| for idx, input_name in enumerate(node.input): | ||
| if not input_name or input_name in local_inits: | ||
| continue | ||
| current = _current_float_type(input_name) | ||
| if current is None: | ||
| continue # non-float tensor: never cast | ||
| if input_name not in local_produced and input_name not in formal_inputs: | ||
| captured_types[input_name] = current | ||
| desired = ( | ||
| target.onnx_type | ||
| if node_low | ||
| and not self._should_skip_low_precision_input_conversion(node, input_name) | ||
| else high.onnx_type | ||
| ) | ||
| if current == desired: | ||
| continue | ||
|
|
||
| if from_type is None: | ||
| logger.debug( | ||
| f"Skipping subgraph initializer {init.name} with unsupported type {init.data_type}" | ||
| key = (input_name, desired) | ||
| if key not in casts_to_insert: | ||
| short = ONNX_TYPE_TO_PRECISION[desired].str_short | ||
| producer = ( | ||
| input_name | ||
| if input_name in local_produced | ||
| else None # outer-scope capture (produced outside this subgraph) | ||
| ) | ||
| continue | ||
| casts_to_insert[key] = ( | ||
| f"{input_name}_subgraph_cast_to_{short}", | ||
| producer, | ||
| ) | ||
| rewrites.append( | ||
| (InputIndexTracker(node=node, node_index=idx), casts_to_insert[key][0]) | ||
| ) | ||
|
|
||
| new_init = self._convert_initializer_data(init, from_type, target_type) | ||
| init.CopyFrom(new_init) | ||
| # Sync any preserved outer-scope value_info inside the subgraph with the capture's current | ||
| # main-graph precision. Otherwise a stale type (e.g. fp32 for a tensor the main graph now | ||
| # produces in fp16) makes strongly-typed parsers (and ORT) reject the If subgraph. | ||
| for vi in subgraph.value_info: | ||
| if vi.name in captured_types and vi.type.HasField("tensor_type"): | ||
| vi.type.tensor_type.elem_type = captured_types[vi.name] | ||
|
|
||
| utils.walk_subgraphs_recursive(self.model.graph, _convert_subgraph_callback) | ||
| if not casts_to_insert: | ||
| return | ||
|
|
||
| for tracker, cast_output in rewrites: | ||
| tracker.node.input[tracker.node_index] = cast_output | ||
|
|
||
| # Build the cast nodes and re-assemble the subgraph node list so each cast appears after its | ||
| # producer (captures, produced outside the subgraph, are placed at the front). | ||
| cast_nodes_after_producer: dict[str, list[onnx.NodeProto]] = defaultdict(list) | ||
| leading_casts: list[onnx.NodeProto] = [] | ||
| for (input_name, desired), (cast_output, producer) in casts_to_insert.items(): | ||
| cast_node = helper.make_node( | ||
| "Cast", inputs=[input_name], outputs=[cast_output], to=desired, name=cast_output | ||
| ) | ||
| if producer is not None: | ||
| producer_node_name = next(n.name for n in subgraph.node if input_name in n.output) | ||
| cast_nodes_after_producer[producer_node_name].append(cast_node) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Cheap perf cleanup: this |
||
| else: | ||
| leading_casts.append(cast_node) | ||
|
|
||
| new_nodes = list(leading_casts) | ||
| for node in subgraph.node: | ||
| new_nodes.append(node) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Consider asserting / enforcing topological order before reassembling. The cast is placed right after its producer, but a downstream consumer of this cast that happens to appear earlier in |
||
| new_nodes.extend(cast_nodes_after_producer.get(node.name, [])) | ||
| del subgraph.node[:] | ||
| subgraph.node.extend(new_nodes) | ||
|
|
||
| def _convert_initializer_data( | ||
| self, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1558,6 +1558,14 @@ def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | |
| assert len(cast_producers) == 1 and cast_producers[0].op_type == "Constant" | ||
| constant_producer = cast_producers[0] | ||
| _convert_constant_values(constant_producer, node) | ||
| # Folding changes the Constant output's element type; refresh its value_info so | ||
| # downstream consumers (and strongly-typed parsers) don't see a stale type that | ||
| # conflicts with the now-converted constant value. | ||
| cast_to_type = get_cast_to_type(node) | ||
| for vi in onnx_model.graph.value_info: | ||
| if vi.name == constant_producer.output[0]: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This only updates the first |
||
| vi.type.tensor_type.elem_type = cast_to_type | ||
| break | ||
| _bypass_cast_node(onnx_model, node) | ||
| logger.debug(f"Found foldable Constant->Cast pattern, removing {node.name}") | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth calling out in the docstring (and ideally CHANGELOG) that this is a behavioral change for subgraph initializer precision: previously a low-precision parent caused all subgraph float initializers to be converted, now conversion is gated on the consuming node having no float activation/outer-scope/unknown-typed inputs. Anyone relying on the old behavior to get fp16 weights inside an If branch with an outer-scope activation will see those weights stay fp32 after this PR (which is the correct fix, but the change in observable output deserves a note).