@@ -32,24 +32,23 @@ def call(self, graph_module: torch.fx.GraphModule):
3232 exir_ops .edge .et_vk .conv2d_q8ta_q8csw_q8to_dw .default ,
3333 exir_ops .edge .et_vk .add_q8ta_q8ta_q8to .default ,
3434 ]:
35- # Replace quantize op feeding into conv2d (first argument is the quantized input)
36- quantized_input_node = node .args [0 ]
37- if isinstance (
38- quantized_input_node , torch .fx .Node
39- ) and utils .is_quant_node (quantized_input_node ):
40- # Get the arguments from the original quantize node
41- input_tensor = quantized_input_node .args [0 ]
42- scale = quantized_input_node .args [1 ]
43- zero_point = quantized_input_node .args [2 ]
35+ for quantized_input_node in node .args :
36+ if isinstance (
37+ quantized_input_node , torch .fx .Node
38+ ) and utils .is_quant_node (quantized_input_node ):
39+ # Get the arguments from the original quantize node
40+ input_tensor = quantized_input_node .args [0 ]
41+ scale = quantized_input_node .args [1 ]
42+ zero_point = quantized_input_node .args [2 ]
4443
45- nodes_to_replace .append (
46- {
47- "old_node" : quantized_input_node ,
48- "new_target" : exir_ops .edge .et_vk .quantize_q8ta_for_conv2d .default ,
49- "args" : (input_tensor , scale , zero_point ),
50- "node_type" : "quantize_input" ,
51- }
52- )
44+ nodes_to_replace .append (
45+ {
46+ "old_node" : quantized_input_node ,
47+ "new_target" : exir_ops .edge .et_vk .quantize_q8ta_for_conv2d .default ,
48+ "args" : (input_tensor , scale , zero_point ),
49+ "node_type" : "quantize_input" ,
50+ }
51+ )
5352
5453 # Find dequantize ops that consume the output of this conv2d
5554 for user in node .users :
0 commit comments