@@ -498,14 +498,15 @@ def targets(self) -> list[EdgeOpOverload]:
498498 exir_ops .edge .cadence .conv1d .default ,
499499 exir_ops .edge .cadence .conv2d .default ,
500500 exir_ops .edge .cadence .conv3d .default ,
501- exir_ops .edge .cadence .transposed_convolution ,
501+ exir_ops .edge .cadence .transposed_convolution . default ,
502502 ]
503503
504504 def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
505505 # Check if this is a transposed convolution
506506 assert isinstance (node .target , EdgeOpOverload )
507- op_packet = get_edge_overload_packet (node .target )
508- is_transposed = op_packet == exir_ops .edge .cadence .transposed_convolution
507+ is_transposed = (
508+ node .target == exir_ops .edge .cadence .transposed_convolution .default
509+ )
509510 num_expected_args = 9 if is_transposed else 7
510511 assert len (node .args ) == num_expected_args
511512 # Check if the bias is concrete
@@ -515,13 +516,15 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
515516 # The bias length is the number of out channels.
516517 out_shape = node .meta ["val" ].shape
517518 bias_size = out_shape [1 ]
518- # Create a zero bias tensor (bias is not a constant tensor,
519+
520+ # Create a zero bias tensor
519521 with node .graph .inserting_before (node ):
520522 zero_bias = node .graph .call_function (
521523 exir_ops .edge .aten .full .default ,
522524 args = ([bias_size ], 0.0 ),
523525 kwargs = {"dtype" : torch .float32 },
524526 )
527+ # Create proper metadata for the zero_bias node
525528 zero_bias .meta = node .meta
526529 new_args = list (node .args )
527530 new_args [2 ] = zero_bias
0 commit comments