diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index ccadd1e7a88..0b88e687224 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -498,14 +498,15 @@ def targets(self) -> list[EdgeOpOverload]: exir_ops.edge.cadence.conv1d.default, exir_ops.edge.cadence.conv2d.default, exir_ops.edge.cadence.conv3d.default, - exir_ops.edge.cadence.transposed_convolution, + exir_ops.edge.cadence.transposed_convolution.default, ] def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Check if this is a transposed convolution assert isinstance(node.target, EdgeOpOverload) - op_packet = get_edge_overload_packet(node.target) - is_transposed = op_packet == exir_ops.edge.cadence.transposed_convolution + is_transposed = ( + node.target == exir_ops.edge.cadence.transposed_convolution.default + ) num_expected_args = 9 if is_transposed else 7 assert len(node.args) == num_expected_args # Check if the bias is concrete @@ -515,13 +516,15 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # The bias length is the number of out channels. out_shape = node.meta["val"].shape bias_size = out_shape[1] - # Create a zero bias tensor (bias is not a constant tensor, + + # Create a zero bias tensor with node.graph.inserting_before(node): zero_bias = node.graph.call_function( exir_ops.edge.aten.full.default, args=([bias_size], 0.0), kwargs={"dtype": torch.float32}, ) + # Create proper metadata for the zero_bias node zero_bias.meta = node.meta new_args = list(node.args) new_args[2] = zero_bias