Skip to content

Commit e804065

Browse files
authored
Qualcomm AI Engine Direct - enable operator max_pool3d by decomposition (#15897)
Summary: Enable max_pool3d operator by using max_pool2d two times Test plan: python backends/qualcomm/tests/test_qnn_delegate.py TestQNNFloatingPointOperator.test_qnn_backend_max_pool3d -b build-android -H HOST -s DEVICE -m CHIPID python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperator.test_qnn_backend_max_pool3d -b build-android -H HOST -s DEVICE -m CHIPID
1 parent 0d61efc commit e804065

File tree

6 files changed

+195
-0
lines changed

6 files changed

+195
-0
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .decompose_floor_divide import DecomposeFloorDivide
2323
from .decompose_glu import DecomposeGlu
2424
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
25+
from .decompose_maxpool3d import DecomposeMaxPool3d
2526
from .decompose_minmaxdim import DecomposeMinMaxDim
2627
from .decompose_roll import DecomposeRoll
2728
from .decompose_silu import DecomposeSilu
@@ -68,6 +69,7 @@
6869
DecomposeFloorDivide,
6970
DecomposeGlu,
7071
DecomposeLinalgVectorNorm,
72+
DecomposeMaxPool3d,
7173
DecomposeMinMaxDim,
7274
DecomposeRoll,
7375
DecomposeSilu,
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import warnings
7+
from typing import cast, List
8+
9+
import torch
10+
import torch.nn as nn
11+
from executorch.exir import to_edge
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
14+
from .utils import merge_decomposed_graph
15+
16+
17+
class ModelMaxPool3D(torch.nn.Module):
18+
def __init__(
19+
self, filter_size, stride, padding, dilation, return_indices, ceil_mode
20+
):
21+
super().__init__()
22+
23+
self.pool2d_hw = nn.MaxPool2d(
24+
kernel_size=[1, filter_size[2]], # (H, W) part
25+
stride=[1, stride[2]],
26+
padding=[0, padding[2]],
27+
return_indices=return_indices,
28+
ceil_mode=ceil_mode,
29+
)
30+
self.pool2d_dh = nn.MaxPool2d(
31+
kernel_size=filter_size[:2], # (D, H) part
32+
stride=stride[:2],
33+
padding=padding[:2],
34+
return_indices=return_indices,
35+
ceil_mode=ceil_mode,
36+
)
37+
38+
def forward(self, x: torch.Tensor) -> torch.Tensor:
39+
N, C, D, H, W = x.shape
40+
x_ = x.permute(0, 1, 4, 2, 3)
41+
x1_1d = x_.reshape(N * C, W, D, H)
42+
# first pool over (D, H)
43+
out_pool1d_0 = self.pool2d_dh(x1_1d)
44+
D_out = out_pool1d_0.shape[2]
45+
# NC, W, D, H-> NC, D, H, W
46+
x1b = out_pool1d_0.permute(0, 2, 3, 1)
47+
# second pool over (H, W)
48+
out4d = self.pool2d_hw(x1b)
49+
H_out2 = out4d.shape[2]
50+
W_out = out4d.shape[3]
51+
out = out4d.reshape(N, C, D_out, H_out2, W_out)
52+
return out
53+
54+
55+
class DecomposeMaxPool3d(ExportPass):
56+
# The max_pool3d is not supported yet by QNN.
57+
# Decompose: input -> permute -> reshape -> max_pool2d -> permute -> max_pool2d -> reshape -> output
58+
59+
def __init__(self, quantization_capture=False) -> None:
60+
super().__init__()
61+
self.quantization_capture = quantization_capture
62+
63+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
64+
graph = graph_module.graph
65+
for node in graph.nodes:
66+
if node.op == "call_function" and "max_pool3d" in str(node.target):
67+
# kernel info
68+
filter_size = cast(List[int], node.args[1])
69+
if len(filter_size) == 1:
70+
filter_size *= 3
71+
72+
num_args = len(node.args)
73+
74+
# stride info
75+
stride = filter_size
76+
if num_args > 2:
77+
stride = cast(List[int], node.args[2])
78+
if len(stride) == 1:
79+
stride *= 3
80+
81+
# padding info
82+
padding = [0, 0, 0]
83+
if num_args > 3:
84+
padding = cast(List[int], node.args[3])
85+
if len(padding) == 1:
86+
padding *= 3
87+
88+
# dilation info
89+
dilation = [1, 1, 1]
90+
if num_args > 4:
91+
dilation = cast(List[int], node.args[4])
92+
if len(padding) == 1:
93+
dilation *= 3
94+
95+
ceil_mode = node.args[5] if num_args > 5 else False
96+
return_indices = node.args[6] if num_args > 6 else False
97+
if return_indices:
98+
warnings.warn(
99+
"[QNN Delegate Op Builder]: The case return_indices=True is not be support, fallback",
100+
stacklevel=1,
101+
)
102+
return
103+
104+
model = ModelMaxPool3D(
105+
filter_size, stride, padding, dilation, return_indices, ceil_mode
106+
)
107+
if self.quantization_capture:
108+
decomposed_module = torch.export.export(
109+
model, (node.args[0].meta["val"],), strict=True
110+
).module()
111+
else:
112+
edge_mgr = to_edge(
113+
torch.export.export(
114+
model, (node.args[0].meta["val"],), strict=True
115+
)
116+
)
117+
decomposed_module = edge_mgr.exported_program()
118+
119+
with graph.inserting_before(node):
120+
# remap is used to map original node values to new node values,
121+
# which ensures that reference to nodes are correctly updated in the new graph
122+
remap = {"x": node.args[0]}
123+
merge_decomposed_graph(
124+
remap=remap,
125+
target_node=node,
126+
target_graph=graph,
127+
decomposed_graph_module=decomposed_module,
128+
)
129+
graph.erase_node(node)
130+
131+
graph.eliminate_dead_code()
132+
graph_module.recompile()
133+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
DecomposeFloorDivide,
2828
DecomposeGlu,
2929
DecomposeLinalgVectorNorm,
30+
DecomposeMaxPool3d,
3031
DecomposeMinMaxDim,
3132
DecomposeRoll,
3233
DecomposeSilu,
@@ -98,6 +99,7 @@ def get_capture_program_passes():
9899
(FoldQDQ, True),
99100
(I64toI32, True),
100101
(LayoutTransform, True),
102+
(DecomposeMaxPool3d, True),
101103
(RecomposePixelUnshuffle, True),
102104
(RecomposeRmsNorm, True),
103105
(Remove0DTensor, True),
@@ -201,6 +203,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
201203
self.add_pass(ReplaceArangeArgs())
202204
self.add_pass(DecomposeBinaryAlpha())
203205
self.add_pass(DecomposeCDist())
206+
self.add_pass(DecomposeMaxPool3d(quantization_capture=True))
204207
self.add_pass(DecomposeScaledDotProductAttention())
205208
self.add_pass(DecomposeRoll())
206209
self.add_pass(DecomposeSilu())

backends/qualcomm/_passes/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def get_passes_dependency_for_capture_program():
6969
DecomposeAny,
7070
DecomposeColIm,
7171
DecomposeLinalgVectorNorm,
72+
DecomposeMaxPool3d,
7273
ExpandBroadcastTensorShape,
7374
FixedLinearKeepDim,
7475
FoldQDQ,
@@ -93,6 +94,7 @@ def get_passes_dependency_for_capture_program():
9394
DecomposeAny: [RemoveRedundancy],
9495
DecomposeColIm: [FoldQDQ],
9596
DecomposeLinalgVectorNorm: [RemoveRedundancy],
97+
DecomposeMaxPool3d: [RemoveRedundancy],
9698
ExpandBroadcastTensorShape: [FoldQDQ],
9799
FixedLinearKeepDim: [FoldQDQ],
98100
FoldQDQ: [AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind],

backends/qualcomm/tests/models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,6 +1466,24 @@ def forward(self, x):
14661466
return self.max_pool2d(x)
14671467

14681468

1469+
class MaxPool3d(torch.nn.Module):
1470+
def __init__(
1471+
self, kernel_size, stride, padding, dilation, ceil_mode, return_indices
1472+
):
1473+
super().__init__()
1474+
self.max_pool3d = torch.nn.MaxPool3d(
1475+
kernel_size=kernel_size,
1476+
stride=stride,
1477+
padding=padding,
1478+
dilation=dilation,
1479+
return_indices=return_indices,
1480+
ceil_mode=ceil_mode,
1481+
)
1482+
1483+
def forward(self, x):
1484+
return self.max_pool3d(x)
1485+
1486+
14691487
class Mean(torch.nn.Module):
14701488
def __init__(
14711489
self,

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,6 +1406,24 @@ def test_qnn_backend_max_pool2d(self):
14061406
sample_input = (torch.randn(4, 3, 24, 24),)
14071407
self.lower_module_and_test_output(module, sample_input)
14081408

1409+
def test_qnn_backend_max_pool3d(self):
1410+
# NOTE: The pad should be at most half of effective kernel size.
1411+
modules = [
1412+
MaxPool3d((3), (1), (1), (1), False, False), # noqa: F405
1413+
MaxPool3d((7), (1), (3), (1), False, False), # noqa: F405
1414+
MaxPool3d((7), (1), (3), (1), True, False), # noqa: F405
1415+
MaxPool3d( # noqa: F405
1416+
(7, 7, 7), (1, 1, 1), (3, 3, 3), (1, 1, 1), True, False
1417+
), # noqa: F405
1418+
MaxPool3d( # noqa: F405
1419+
(7, 9, 13), (1, 1, 1), (3, 4, 6), (1, 1, 1), False, False
1420+
), # noqa: F405
1421+
]
1422+
sample_input = (torch.randn(1, 7, 21, 35, 28),)
1423+
for i, module in enumerate(modules):
1424+
with self.subTest(i=i):
1425+
self.lower_module_and_test_output(module, sample_input)
1426+
14091427
def test_qnn_backend_mean(self):
14101428
test_comb = [
14111429
# Reduce over last two dims, keepdim=True
@@ -3636,6 +3654,25 @@ def test_qnn_backend_max_pool2d(self):
36363654
module = self.get_qdq_module(module, sample_input)
36373655
self.lower_module_and_test_output(module, sample_input)
36383656

3657+
def test_qnn_backend_max_pool3d(self):
3658+
# NOTE: The pad should be at most half of effective kernel size.
3659+
modules = [
3660+
MaxPool3d((3), (1), (1), (1), False, False), # noqa: F405
3661+
MaxPool3d((7), (1), (3), (1), False, False), # noqa: F405
3662+
MaxPool3d((7), (1), (3), (1), True, False), # noqa: F405
3663+
MaxPool3d( # noqa: F405
3664+
(7, 7, 7), (1, 1, 1), (3, 3, 3), (1, 1, 1), True, False
3665+
), # noqa: F405
3666+
MaxPool3d( # noqa: F405
3667+
(7, 9, 13), (1, 1, 1), (3, 4, 6), (1, 1, 1), False, False
3668+
), # noqa: F405
3669+
]
3670+
sample_input = (torch.randn(1, 7, 21, 35, 28),)
3671+
for i, module in enumerate(modules):
3672+
with self.subTest(i=i):
3673+
module = self.get_qdq_module(module, sample_input)
3674+
self.lower_module_and_test_output(module, sample_input)
3675+
36393676
def test_qnn_backend_mean(self):
36403677
test_comb = [
36413678
# Reduce over last two dims, keepdim=True

0 commit comments

Comments
 (0)