Skip to content

Commit c998f8a

Browse files
authored
[ET-VK][ez] Apply quantize op replacement to all argument nodes (#15746)
Title says it all! With the way the pass is currently written only the first arg will be inspected for q/dq node replacement. As a consequence, the second arg for i.e. binary ops may not have the quantized op be replaced. Differential Revision: [D86674169](https://our.internmc.facebook.com/intern/diff/D86674169/)
1 parent 069b455 commit c998f8a

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

backends/vulkan/_passes/replace_qdq.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,23 @@ def call(self, graph_module: torch.fx.GraphModule):
3232
exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default,
3333
exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default,
3434
]:
35-
# Replace quantize op feeding into conv2d (first argument is the quantized input)
36-
quantized_input_node = node.args[0]
37-
if isinstance(
38-
quantized_input_node, torch.fx.Node
39-
) and utils.is_quant_node(quantized_input_node):
40-
# Get the arguments from the original quantize node
41-
input_tensor = quantized_input_node.args[0]
42-
scale = quantized_input_node.args[1]
43-
zero_point = quantized_input_node.args[2]
35+
for quantized_input_node in node.args:
36+
if isinstance(
37+
quantized_input_node, torch.fx.Node
38+
) and utils.is_quant_node(quantized_input_node):
39+
# Get the arguments from the original quantize node
40+
input_tensor = quantized_input_node.args[0]
41+
scale = quantized_input_node.args[1]
42+
zero_point = quantized_input_node.args[2]
4443

45-
nodes_to_replace.append(
46-
{
47-
"old_node": quantized_input_node,
48-
"new_target": exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default,
49-
"args": (input_tensor, scale, zero_point),
50-
"node_type": "quantize_input",
51-
}
52-
)
44+
nodes_to_replace.append(
45+
{
46+
"old_node": quantized_input_node,
47+
"new_target": exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default,
48+
"args": (input_tensor, scale, zero_point),
49+
"node_type": "quantize_input",
50+
}
51+
)
5352

5453
# Find dequantize ops that consume the output of this conv2d
5554
for user in node.users:

0 commit comments

Comments
 (0)