Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support 5-input concat in XNNPACK delegate #7401

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions backends/xnnpack/operators/op_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import cast, Dict, List

import torch
Expand All @@ -17,6 +19,7 @@
XNNConcatenate2,
XNNConcatenate3,
XNNConcatenate4,
XNNConcatenate5,
XNNGraph,
XNode,
)
Expand Down Expand Up @@ -71,6 +74,7 @@ def define_node(
input2_id=vals_to_ids[list_of_tensors[1]],
input3_id=XNN_INVALID_VALUE_ID,
input4_id=XNN_INVALID_VALUE_ID,
input5_id=XNN_INVALID_VALUE_ID,
output_id=vals_to_ids[node],
flags=0,
)
Expand All @@ -81,6 +85,7 @@ def define_node(
input2_id=vals_to_ids[list_of_tensors[1]],
input3_id=vals_to_ids[list_of_tensors[2]],
input4_id=XNN_INVALID_VALUE_ID,
input5_id=XNN_INVALID_VALUE_ID,
output_id=vals_to_ids[node],
flags=0,
)
Expand All @@ -91,6 +96,18 @@ def define_node(
input2_id=vals_to_ids[list_of_tensors[1]],
input3_id=vals_to_ids[list_of_tensors[2]],
input4_id=vals_to_ids[list_of_tensors[3]],
input5_id=XNN_INVALID_VALUE_ID,
output_id=vals_to_ids[node],
flags=0,
)
elif num_tensors_to_cat == 5:
xnode = XNNConcatenate5(
axis=axis,
input1_id=vals_to_ids[list_of_tensors[0]],
input2_id=vals_to_ids[list_of_tensors[1]],
input3_id=vals_to_ids[list_of_tensors[2]],
input4_id=vals_to_ids[list_of_tensors[3]],
input5_id=vals_to_ids[list_of_tensors[4]],
output_id=vals_to_ids[node],
flags=0,
)
Expand Down
6 changes: 3 additions & 3 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,17 @@ class CatConfig(GenericNodePartitionerConfig):

def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
"""
Only support concatenation of 2 - 4 tensors
Only support concatenation of 2 - 5 tensors
"""
if not self.check_common_constraints(node, ep):
return False

num_tensors = len(node.all_input_nodes)

if not (num_tensors >= 2 and num_tensors <= 4):
if not (num_tensors >= 2 and num_tensors <= 5):
why(
node,
reason=f"only support concatenation of 2 - 4 tensors, got {num_tensors} tensors",
reason=f"only support concatenation of 2 - 5 tensors, got {num_tensors} tensors",
)
return False

Expand Down
40 changes: 38 additions & 2 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1600,7 +1600,7 @@ Error defineConcatenate2Node(
}

/*
Defines serialized concatenate2 node into the subgraph,
Defines serialized concatenate3 node into the subgraph,
using the remapped ids to map the serialized ids,
to the new ids generated when defining the tensor value
*/
Expand Down Expand Up @@ -1633,7 +1633,7 @@ Error defineConcatenate3Node(
}

/*
Defines serialized concatenate2 node into the subgraph,
Defines serialized concatenate4 node into the subgraph,
using the remapped ids to map the serialized ids,
to the new ids generated when defining the tensor value
*/
Expand Down Expand Up @@ -1666,6 +1666,41 @@ Error defineConcatenate4Node(
return Error::Ok;
}

/*
Defines serialized concatenate5 node into the subgraph,
using the remapped ids to map the serialized ids,
to the new ids generated when defining the tensor value
*/
Error defineConcatenate5Node(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* graph) noexcept {
MAYBE_UNUSED(graph);

auto graph_node = node->xnode_union_as_XNNConcatenate5();

xnn_status status = xnn_define_concatenate5(
subgraph_ptr,
graph_node->axis(),
remapped_ids.at(graph_node->input1_id()),
remapped_ids.at(graph_node->input2_id()),
remapped_ids.at(graph_node->input3_id()),
remapped_ids.at(graph_node->input4_id()),
remapped_ids.at(graph_node->input5_id()),
remapped_ids.at(graph_node->output_id()),
graph_node->flags());

ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create cat5 node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));

return Error::Ok;
}

/*
Defines serialized static_slice node into the subgraph,
using the remapped ids to map the serialized ids,
Expand Down Expand Up @@ -1832,6 +1867,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
_DEFINE(Concatenate2)
_DEFINE(Concatenate3)
_DEFINE(Concatenate4)
_DEFINE(Concatenate5)
_DEFINE(StaticSlice)
_DEFINE(ScaledDotProductAttention)
_DEFINE(BatchMatrixMultiply)
Expand Down
2 changes: 2 additions & 0 deletions backends/xnnpack/serialization/runtime_schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ union XNodeUnion {
XNNStaticSlice,
XNNScaledDotProductAttention,
XNNBatchMatrixMultiply: _XNNNode2x1,
XNNConcatenate5: _XNNCat,
}

union XValueUnion {
Expand Down Expand Up @@ -209,6 +210,7 @@ table _XNNCat {
input4_id: uint;
output_id: uint;
flags: uint;
input5_id: uint;
}

table XNNELU {
Expand Down
2 changes: 2 additions & 0 deletions backends/xnnpack/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ union XNodeUnion {
XNNStaticSlice,
XNNScaledDotProductAttention,
XNNBatchMatrixMultiply: _XNNNode2x1,
XNNConcatenate5: _XNNCat,
}

union XValueUnion {
Expand Down Expand Up @@ -205,6 +206,7 @@ table _XNNCat {
input4_id: uint;
output_id: uint;
flags: uint;
input5_id: uint;
}

table XNNELU {
Expand Down
7 changes: 7 additions & 0 deletions backends/xnnpack/serialization/xnnpack_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class XNNCat:
input4_id: int
output_id: int
flags: int
input5_id: int


# Generic node data class for convolution type nodes
Expand Down Expand Up @@ -177,6 +178,11 @@ class XNNConcatenate4(XNNCat):
pass


@dataclass
class XNNConcatenate5(XNNCat):
pass


@dataclass
class XNNBatchMatrixMultiply(XNNNode2x1):
pass
Expand Down Expand Up @@ -357,6 +363,7 @@ class XNNScaledDotProductAttention:
XNNConcatenate2,
XNNConcatenate3,
XNNConcatenate4,
XNNConcatenate5,
XNNStaticSlice,
XNNScaledDotProductAttention,
XNNBatchMatrixMultiply,
Expand Down
62 changes: 29 additions & 33 deletions backends/xnnpack/test/ops/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,18 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import unittest

import torch
from executorch.backends.xnnpack.test.tester import Tester


class TestCat(unittest.TestCase):
class Cat2(torch.nn.Module):
def forward(self, arg1, arg2):
xs = [arg1, arg2]
x = torch.cat(xs)
return x + x # Quantize by propagation.

class Cat3(torch.nn.Module):
def forward(self, arg1, arg2, arg3):
xs = [arg1, arg2, arg3]
x = torch.cat(xs)
return x + x # Quantize by propagation.

class Cat4(torch.nn.Module):
def forward(self, arg1, arg2, arg3, arg4):
xs = [arg1, arg2, arg3, arg4]
x = torch.cat(xs)
return x + x # Quantize by propagation.

class Cat5(torch.nn.Module):
def forward(self, arg1, arg2, arg3, arg4, arg5):
xs = [arg1, arg2, arg3, arg4, arg5]
class Cat(torch.nn.Module):
def forward(self, *args):
xs = [*args]
x = torch.cat(xs)
return x + x # Quantize by propagation.

Expand Down Expand Up @@ -84,7 +68,7 @@ def test_fp16_cat2(self):
torch.randn(1, 2, 3).to(torch.float16),
torch.randn(3, 2, 3).to(torch.float16),
)
self._test_cat(self.Cat2(), inputs)
self._test_cat(self.Cat(), inputs)

def test_fp16_cat3(self):
"""
Expand All @@ -95,7 +79,7 @@ def test_fp16_cat3(self):
torch.randn(3, 2, 3).to(torch.float16),
torch.randn(2, 2, 3).to(torch.float16),
)
self._test_cat(self.Cat3(), inputs)
self._test_cat(self.Cat(), inputs)

def test_fp16_cat4(self):
"""
Expand All @@ -107,15 +91,15 @@ def test_fp16_cat4(self):
torch.randn(2, 2, 3).to(torch.float16),
torch.randn(5, 2, 3).to(torch.float16),
)
self._test_cat(self.Cat4(), inputs)
self._test_cat(self.Cat(), inputs)

def test_fp32_cat2(self):
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
self._test_cat(self.Cat2(), inputs)
self._test_cat(self.Cat(), inputs)

def test_fp32_cat3(self):
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3))
self._test_cat(self.Cat3(), inputs)
self._test_cat(self.Cat(), inputs)

def test_fp32_cat4(self):
inputs = (
Expand All @@ -124,15 +108,25 @@ def test_fp32_cat4(self):
torch.randn(2, 2, 3),
torch.randn(5, 2, 3),
)
self._test_cat(self.Cat4(), inputs)
self._test_cat(self.Cat(), inputs)

def test_fp32_cat5(self):
inputs = (
torch.randn(1, 2, 3),
torch.randn(3, 2, 3),
torch.randn(2, 2, 3),
torch.randn(5, 2, 3),
torch.randn(1, 2, 3),
)
self._test_cat(self.Cat(), inputs)

def test_qs8_cat2(self):
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
self._test_cat(self.Cat2(), inputs, cat_num=2, quant=True)
self._test_cat(self.Cat(), inputs, cat_num=2, quant=True)

def test_qs8_cat3(self):
inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3))
self._test_cat(self.Cat3(), inputs, cat_num=3, quant=True)
self._test_cat(self.Cat(), inputs, cat_num=3, quant=True)

def test_qs8_cat4(self):
inputs = (
Expand All @@ -141,7 +135,7 @@ def test_qs8_cat4(self):
torch.randn(2, 2, 3),
torch.randn(5, 2, 3),
)
self._test_cat(self.Cat4(), inputs, cat_num=4, quant=True)
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)

def test_fp32_cat_unsupported(self):
"""
Expand All @@ -153,9 +147,10 @@ def test_fp32_cat_unsupported(self):
torch.randn(2, 2, 3),
torch.randn(5, 2, 3),
torch.randn(1, 2, 3),
torch.randn(2, 2, 3),
)
(
Tester(self.Cat5(), inputs)
Tester(self.Cat(), inputs)
.export()
.check_count({"torch.ops.aten.cat": 1})
.to_edge_transform_and_lower()
Expand All @@ -164,17 +159,18 @@ def test_fp32_cat_unsupported(self):

def test_fp32_cat_unsupported_legacy_mode(self):
"""
XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
XNNPACK only supports concatenating up to 5 values, so it should not delegate here.
"""
inputs = (
torch.randn(1, 2, 3),
torch.randn(3, 2, 3),
torch.randn(2, 2, 3),
torch.randn(5, 2, 3),
torch.randn(1, 2, 3),
torch.randn(6, 2, 3),
)
(
Tester(self.Cat5(), inputs)
Tester(self.Cat(), inputs)
.export()
.check_count({"torch.ops.aten.cat": 1})
.to_edge()
Expand Down