@@ -498,14 +498,13 @@ def targets(self) -> list[EdgeOpOverload]:
498498 exir_ops .edge .cadence .conv1d .default ,
499499 exir_ops .edge .cadence .conv2d .default ,
500500 exir_ops .edge .cadence .conv3d .default ,
501- exir_ops .edge .cadence .transposed_convolution ,
501+ exir_ops .edge .cadence .transposed_convolution . default ,
502502 ]
503503
504504 def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
505505 # Check if this is a transposed convolution
506506 assert isinstance (node .target , EdgeOpOverload )
507- op_packet = get_edge_overload_packet (node .target )
508- is_transposed = op_packet == exir_ops .edge .cadence .transposed_convolution
507+ is_transposed = node .target == exir_ops .edge .cadence .transposed_convolution .default
509508 num_expected_args = 9 if is_transposed else 7
510509 assert len (node .args ) == num_expected_args
511510 # Check if the bias is concrete
@@ -515,13 +514,15 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
515514 # The bias length is the number of out channels.
516515 out_shape = node .meta ["val" ].shape
517516 bias_size = out_shape [1 ]
518- # Create a zero bias tensor (bias is not a constant tensor,
517+
518+ # Create a zero bias tensor
519519 with node .graph .inserting_before (node ):
520520 zero_bias = node .graph .call_function (
521521 exir_ops .edge .aten .full .default ,
522522 args = ([bias_size ], 0.0 ),
523523 kwargs = {"dtype" : torch .float32 },
524524 )
525+ # Create proper metadata for the zero_bias node
525526 zero_bias .meta = node .meta
526527 new_args = list (node .args )
527528 new_args [2 ] = zero_bias
@@ -539,6 +540,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
539540 return True
540541
541542
543+
542544@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
543545class ReplaceRepeatWithCatPass (RemoveOrReplacePassInterface ):
544546 """
0 commit comments