diff --git a/backends/xnnpack/operators/op_cat.py b/backends/xnnpack/operators/op_cat.py index 706073ef9b..1bf2854a81 100644 --- a/backends/xnnpack/operators/op_cat.py +++ b/backends/xnnpack/operators/op_cat.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from typing import cast, Dict, List import torch @@ -17,6 +19,7 @@ XNNConcatenate2, XNNConcatenate3, XNNConcatenate4, + XNNConcatenate5, XNNGraph, XNode, ) @@ -71,6 +74,7 @@ def define_node( input2_id=vals_to_ids[list_of_tensors[1]], input3_id=XNN_INVALID_VALUE_ID, input4_id=XNN_INVALID_VALUE_ID, + input5_id=XNN_INVALID_VALUE_ID, output_id=vals_to_ids[node], flags=0, ) @@ -81,6 +85,7 @@ def define_node( input2_id=vals_to_ids[list_of_tensors[1]], input3_id=vals_to_ids[list_of_tensors[2]], input4_id=XNN_INVALID_VALUE_ID, + input5_id=XNN_INVALID_VALUE_ID, output_id=vals_to_ids[node], flags=0, ) @@ -91,6 +96,18 @@ def define_node( input2_id=vals_to_ids[list_of_tensors[1]], input3_id=vals_to_ids[list_of_tensors[2]], input4_id=vals_to_ids[list_of_tensors[3]], + input5_id=XNN_INVALID_VALUE_ID, + output_id=vals_to_ids[node], + flags=0, + ) + elif num_tensors_to_cat == 5: + xnode = XNNConcatenate5( + axis=axis, + input1_id=vals_to_ids[list_of_tensors[0]], + input2_id=vals_to_ids[list_of_tensors[1]], + input3_id=vals_to_ids[list_of_tensors[2]], + input4_id=vals_to_ids[list_of_tensors[3]], + input5_id=vals_to_ids[list_of_tensors[4]], output_id=vals_to_ids[node], flags=0, ) diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index cb41c87ed2..c97f27700e 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -174,17 +174,17 @@ class CatConfig(GenericNodePartitionerConfig): def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: """ - Only support concatenation of 2 - 4 tensors + Only support concatenation of 2 - 5 tensors """ if not self.check_common_constraints(node, ep): return False num_tensors = len(node.all_input_nodes) - if not (num_tensors >= 2 and num_tensors <= 4): + if not (num_tensors >= 2 and num_tensors <= 5): why( node, - reason=f"only support concatenation of 2 - 4 tensors, got {num_tensors} tensors", + reason=f"only support concatenation of 2 - 5 tensors, got {num_tensors} tensors", ) return False diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index b948aa8623..3d4d2e6821 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -1600,7 +1600,7 @@ Error defineConcatenate2Node( } /* -Defines serialized concatenate2 node into the subgraph, +Defines serialized concatenate3 node into the subgraph, using the remapped ids to map the serialized ids, to the new ids generated when defining the tensor value */ @@ -1633,7 +1633,7 @@ Error defineConcatenate3Node( } /* -Defines serialized concatenate2 node into the subgraph, +Defines serialized concatenate4 node into the subgraph, using the remapped ids to map the serialized ids, to the new ids generated when defining the tensor value */ @@ -1666,6 +1666,41 @@ Error defineConcatenate4Node( return Error::Ok; } +/* +Defines serialized concatenate5 node into the subgraph, +using the remapped ids to map the serialized ids, +to the new ids generated when defining the tensor value +*/ +Error defineConcatenate5Node( + xnn_subgraph_t subgraph_ptr, + const std::unordered_map& remapped_ids, + const NodePtr node, + const fb_xnnpack::XNNGraph* graph) noexcept { + MAYBE_UNUSED(graph); + + auto graph_node = node->xnode_union_as_XNNConcatenate5(); + + xnn_status status = xnn_define_concatenate5( + subgraph_ptr, + graph_node->axis(), + remapped_ids.at(graph_node->input1_id()), + remapped_ids.at(graph_node->input2_id()), + remapped_ids.at(graph_node->input3_id()), + remapped_ids.at(graph_node->input4_id()), + remapped_ids.at(graph_node->input5_id()), + remapped_ids.at(graph_node->output_id()), + graph_node->flags()); + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Failed to create cat5 node %i with code: %s", + node->debug_handle(), + xnn_status_to_string(status)); + + return Error::Ok; +} + /* Defines serialized static_slice node into the subgraph, using the remapped ids to map the serialized ids, @@ -1832,6 +1867,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { _DEFINE(Concatenate2) _DEFINE(Concatenate3) _DEFINE(Concatenate4) + _DEFINE(Concatenate5) _DEFINE(StaticSlice) _DEFINE(ScaledDotProductAttention) _DEFINE(BatchMatrixMultiply) diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index efe717e085..0c6ee86912 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -136,6 +136,7 @@ union XNodeUnion { XNNStaticSlice, XNNScaledDotProductAttention, XNNBatchMatrixMultiply: _XNNNode2x1, + XNNConcatenate5: _XNNCat, } union XValueUnion { @@ -209,6 +210,7 @@ table _XNNCat { input4_id: uint; output_id: uint; flags: uint; + input5_id: uint; } table XNNELU { diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index 33571195d6..45f1248463 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -132,6 +132,7 @@ union XNodeUnion { XNNStaticSlice, XNNScaledDotProductAttention, XNNBatchMatrixMultiply: _XNNNode2x1, + XNNConcatenate5: _XNNCat, } union XValueUnion { @@ -205,6 +206,7 @@ table _XNNCat { input4_id: uint; output_id: uint; flags: uint; + input5_id: uint; } table XNNELU { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index e3e699c58f..ca0fc60bdc 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -42,6 +42,7 @@ class XNNCat: input4_id: int output_id: int flags: int + input5_id: int # Generic node data class for convolution type nodes @@ -177,6 +178,11 @@ class XNNConcatenate4(XNNCat): pass +@dataclass +class XNNConcatenate5(XNNCat): + pass + + @dataclass class XNNBatchMatrixMultiply(XNNNode2x1): pass @@ -357,6 +363,7 @@ class XNNScaledDotProductAttention: XNNConcatenate2, XNNConcatenate3, XNNConcatenate4, + XNNConcatenate5, XNNStaticSlice, XNNScaledDotProductAttention, XNNBatchMatrixMultiply, diff --git a/backends/xnnpack/test/ops/test_cat.py b/backends/xnnpack/test/ops/test_cat.py index 039da2c075..9a7adaeb0f 100644 --- a/backends/xnnpack/test/ops/test_cat.py +++ b/backends/xnnpack/test/ops/test_cat.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import unittest import torch @@ -11,27 +13,9 @@ class TestCat(unittest.TestCase): - class Cat2(torch.nn.Module): - def forward(self, arg1, arg2): - xs = [arg1, arg2] - x = torch.cat(xs) - return x + x # Quantize by propagation. - - class Cat3(torch.nn.Module): - def forward(self, arg1, arg2, arg3): - xs = [arg1, arg2, arg3] - x = torch.cat(xs) - return x + x # Quantize by propagation. - - class Cat4(torch.nn.Module): - def forward(self, arg1, arg2, arg3, arg4): - xs = [arg1, arg2, arg3, arg4] - x = torch.cat(xs) - return x + x # Quantize by propagation. - - class Cat5(torch.nn.Module): - def forward(self, arg1, arg2, arg3, arg4, arg5): - xs = [arg1, arg2, arg3, arg4, arg5] + class Cat(torch.nn.Module): + def forward(self, *args): + xs = [*args] x = torch.cat(xs) return x + x # Quantize by propagation. @@ -84,7 +68,7 @@ def test_fp16_cat2(self): torch.randn(1, 2, 3).to(torch.float16), torch.randn(3, 2, 3).to(torch.float16), ) - self._test_cat(self.Cat2(), inputs) + self._test_cat(self.Cat(), inputs) def test_fp16_cat3(self): """ @@ -95,7 +79,7 @@ def test_fp16_cat3(self): torch.randn(3, 2, 3).to(torch.float16), torch.randn(2, 2, 3).to(torch.float16), ) - self._test_cat(self.Cat3(), inputs) + self._test_cat(self.Cat(), inputs) def test_fp16_cat4(self): """ @@ -107,15 +91,15 @@ def test_fp16_cat4(self): torch.randn(2, 2, 3).to(torch.float16), torch.randn(5, 2, 3).to(torch.float16), ) - self._test_cat(self.Cat4(), inputs) + self._test_cat(self.Cat(), inputs) def test_fp32_cat2(self): inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3)) - self._test_cat(self.Cat2(), inputs) + self._test_cat(self.Cat(), inputs) def test_fp32_cat3(self): inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3)) - self._test_cat(self.Cat3(), inputs) + self._test_cat(self.Cat(), inputs) def test_fp32_cat4(self): inputs = ( @@ -124,15 +108,25 @@ def test_fp32_cat4(self): torch.randn(2, 2, 3), torch.randn(5, 2, 3), ) - self._test_cat(self.Cat4(), inputs) + self._test_cat(self.Cat(), inputs) + + def test_fp32_cat5(self): + inputs = ( + torch.randn(1, 2, 3), + torch.randn(3, 2, 3), + torch.randn(2, 2, 3), + torch.randn(5, 2, 3), + torch.randn(1, 2, 3), + ) + self._test_cat(self.Cat(), inputs) def test_qs8_cat2(self): inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3)) - self._test_cat(self.Cat2(), inputs, cat_num=2, quant=True) + self._test_cat(self.Cat(), inputs, cat_num=2, quant=True) def test_qs8_cat3(self): inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3)) - self._test_cat(self.Cat3(), inputs, cat_num=3, quant=True) + self._test_cat(self.Cat(), inputs, cat_num=3, quant=True) def test_qs8_cat4(self): inputs = ( @@ -141,7 +135,7 @@ def test_qs8_cat4(self): torch.randn(2, 2, 3), torch.randn(5, 2, 3), ) - self._test_cat(self.Cat4(), inputs, cat_num=4, quant=True) + self._test_cat(self.Cat(), inputs, cat_num=4, quant=True) def test_fp32_cat_unsupported(self): """ @@ -153,9 +147,10 @@ def test_fp32_cat_unsupported(self): torch.randn(2, 2, 3), torch.randn(5, 2, 3), torch.randn(1, 2, 3), + torch.randn(2, 2, 3), ) ( - Tester(self.Cat5(), inputs) + Tester(self.Cat(), inputs) .export() .check_count({"torch.ops.aten.cat": 1}) .to_edge_transform_and_lower() @@ -164,7 +159,7 @@ def test_fp32_cat_unsupported(self): def test_fp32_cat_unsupported_legacy_mode(self): """ - XNNPACK only supports concatenating up to 4 values, so it should not delegate here. + XNNPACK only supports concatenating up to 5 values, so it should not delegate here. """ inputs = ( torch.randn(1, 2, 3), @@ -172,9 +167,10 @@ def test_fp32_cat_unsupported_legacy_mode(self): torch.randn(2, 2, 3), torch.randn(5, 2, 3), torch.randn(1, 2, 3), + torch.randn(6, 2, 3), ) ( - Tester(self.Cat5(), inputs) + Tester(self.Cat(), inputs) .export() .check_count({"torch.ops.aten.cat": 1}) .to_edge()