Skip to content

Commit 6524013

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Fix ReplaceConvolutionOptionalArgsWithConcreteArgsPass
Summary: Was missing default variant for transposed_convolution, which resulted in us skipping them as targets. Differential Revision: D88705767
1 parent 9156fff commit 6524013

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -498,14 +498,13 @@ def targets(self) -> list[EdgeOpOverload]:
498498
exir_ops.edge.cadence.conv1d.default,
499499
exir_ops.edge.cadence.conv2d.default,
500500
exir_ops.edge.cadence.conv3d.default,
501-
exir_ops.edge.cadence.transposed_convolution,
501+
exir_ops.edge.cadence.transposed_convolution.default,
502502
]
503503

504504
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
505505
# Check if this is a transposed convolution
506506
assert isinstance(node.target, EdgeOpOverload)
507-
op_packet = get_edge_overload_packet(node.target)
508-
is_transposed = op_packet == exir_ops.edge.cadence.transposed_convolution
507+
is_transposed = node.target == exir_ops.edge.cadence.transposed_convolution.default
509508
num_expected_args = 9 if is_transposed else 7
510509
assert len(node.args) == num_expected_args
511510
# Check if the bias is concrete
@@ -515,13 +514,15 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
515514
# The bias length is the number of out channels.
516515
out_shape = node.meta["val"].shape
517516
bias_size = out_shape[1]
518-
# Create a zero bias tensor (bias is not a constant tensor,
517+
518+
# Create a zero bias tensor
519519
with node.graph.inserting_before(node):
520520
zero_bias = node.graph.call_function(
521521
exir_ops.edge.aten.full.default,
522522
args=([bias_size], 0.0),
523523
kwargs={"dtype": torch.float32},
524524
)
525+
# Create proper metadata for the zero_bias node
525526
zero_bias.meta = node.meta
526527
new_args = list(node.args)
527528
new_args[2] = zero_bias
@@ -539,6 +540,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
539540
return True
540541

541542

543+
542544
@register_cadence_pass(CadencePassAttribute(opt_level=0))
543545
class ReplaceRepeatWithCatPass(RemoveOrReplacePassInterface):
544546
"""

0 commit comments

Comments
 (0)