diff --git a/BUILD b/BUILD index 3d0174e..0c9401b 100644 --- a/BUILD +++ b/BUILD @@ -246,6 +246,7 @@ cc_library( name = "TcpToLinalg", srcs = [ "lib/Conversion/PassDetail.h", + "lib/Conversion/TcpToLinalg/DataMovement.cpp", "lib/Conversion/TcpToLinalg/Elementwise.cpp", "lib/Conversion/TcpToLinalg/Misc.cpp", "lib/Conversion/TcpToLinalg/PopulatePatterns.h", diff --git a/docs/gather.md b/docs/gather.md index 318385e..f0ad65b 100644 --- a/docs/gather.md +++ b/docs/gather.md @@ -2,11 +2,11 @@ ## Gather elements along a given dim -`tcp.gather_elements` op gathers elements from a given tensor based on indices that index along a given dim. +`tcp.gather` op gathers elements from a given tensor based on indices that index along a given dim. Syntax: - operation ::= `tcp.gather_elements` $input `,` $indices attr-dict `:` + operation ::= `tcp.gather` $input `,` $indices attr-dict `:` type($input) `,` type($indices) `->` type($out) Attributes: @@ -43,11 +43,11 @@ This op is similar to `torch.gather` [[1]](https://pytorch.org/docs/stable/gener indices = torch.tensor([[0, 1], [2, 0], [2, 3]]) # Shape is [3, 2] x = torch.gather(input, 1, indices) # Result shape is [3, 2] - This will get mapped to `tcp.gather_elements` as follows: + This will get mapped to `tcp.gather` as follows: %input = ... %indices = ... - %x = tcp.gather_elements %input, %indices { gather_dim = 1 } : + %x = tcp.gather %input, %indices { dim = 1 } : (tensor<3x4xf32>, tensor<3x2xi64>) -> tensor<3x2xf32> 2. Modeling `onnx.GatherElements` @@ -56,11 +56,11 @@ This op is similar to `torch.gather` [[1]](https://pytorch.org/docs/stable/gener indices = ... # Shape is [2, 3] x = onnx.GatherElements(input, indices, 0) # Result shape is [2, 3] - This will get mapped to `tcp.gather_elements` as follows: + This will get mapped to `tcp.gather` as follows: %input = ... %indices = ... - %x = tcp.gather_elements %input, %indices { gather_dim = 0 } : + %x = tcp.gather %input, %indices { dim = 0 } : (tensor<3x3xf32>, tensor<2x3xi64>) -> tensor<2x3xf32> @@ -68,10 +68,10 @@ This op is similar to `torch.gather` [[1]](https://pytorch.org/docs/stable/gener This requires gathering slices from a given tensor based on indices that index along a given dim. -Our design is to use `tcp.gather_elements` op for these cases as follows. Suppose that the `input` has shape `[a, b, c]`, `indices` has shape `[x, y]` and `dim = 0`. Shape of `output` in this case will be `[x, y, b, c]`. +Our design is to use `tcp.gather` op for these cases as follows. Suppose that the `input` has shape `[a, b, c]`, `indices` has shape `[x, y]` and `dim = 0`. Shape of `output` in this case will be `[x, y, b, c]`. * Broadcast `input` from `[a, b, c]` to `[a, y, b, c]` by introducing `y` dim. * Broadcast `indices` from `[x, y]` to `[x, y, b, c]` by introducing `b` and `c` dims. -* Perform `tcp.gather_elements` on these broadcasted `input` and `indices`, whose `output` will now have the shape `[x, y, b, c]`. +* Perform `tcp.gather` on these broadcasted `input` and `indices`, whose `output` will now have the shape `[x, y, b, c]`. This approach can be used to represent ops like `torch.index_select` [[3]](https://pytorch.org/docs/stable/generated/torch.index_select.html), `tf.gather` [[4]](https://www.tensorflow.org/api_docs/python/tf/gather), and `onnx.Gather` [[5]](https://onnx.ai/onnx/operators/onnx__Gather.html#l-onnx-doc-gather). @@ -84,7 +84,7 @@ This approach can be used to represent ops like `torch.index_select` [[3]](https indices = torch.tensor([0, 2]) # Shape is [2] x = torch.index_select(input, 0, indices) # Result shape is [2, 4] - This will get mapped to `tcp.gather_elements` as follows: + This will get mapped to `tcp.gather` as follows: %input = ... # Shape is [3, 4] %indices = ... # Shape is [2] @@ -93,7 +93,7 @@ This approach can be used to represent ops like `torch.index_select` [[3]](https %cst4 = arith.constant 4 : index %indices_bcast = tcp.broadcast_to %indices_2d, %cst4 { axes = [1] } : (tensor<2x1xi64>, index) -> tensor<2x4xi64> - %x = tcp.gather_elements %input, %indices_bcast { gather_dim = 0 } : + %x = tcp.gather %input, %indices_bcast { dim = 0 } : (tensor<3x4xf32>, tensor<2x4xi64>) -> tensor<2x4xf32> 2. Modeling `tf.gather` @@ -102,7 +102,7 @@ This approach can be used to represent ops like `torch.index_select` [[3]](https indices = ... # Shape is [3] x = tf.gather(input, indices, axis=1) # Result shape is [3, 3, 5] - This will get mapped to `tcp.gather_elements` as follows: + This will get mapped to `tcp.gather` as follows: %input = ... # Shape is [3, 4, 5] %indices = ... # Shape is [3] @@ -110,7 +110,7 @@ This approach can be used to represent ops like `torch.index_select` [[3]](https (tensor<3xi64>) -> tensor<1x3x1xi64> %indices_bcast = tcp.broadcast_to %indices_3d, %cst3, %cst5 { axes = [0, 2] } : (tensor<1x3x1xi64>, index, index) -> tensor<3x3x5xi64> - %x = tcp.gather_elements %input, %indices_bcast { gather_dim = 1 } : + %x = tcp.gather %input, %indices_bcast { dim = 1 } : (tensor<3x4x5xf32>, tensor<3x3x5xi64>) -> tensor<3x3x5xf32> 3. Modeling `onnx.Gather` @@ -119,11 +119,11 @@ This approach can be used to represent ops like `torch.index_select` [[3]](https ### Alternative considered -We considered a separate `tcp.gather` op for this particular case with the following design. +We considered a separate `tcp.gather_slice` op for this particular case with the following design. Syntax: - operation ::= `tcp.gather` $input `,` $indices attr-dict `:` + operation ::= `tcp.gather_slice` $input `,` $indices attr-dict `:` type($input) `,` type($indices) `->` type($out) Attributes: @@ -150,7 +150,7 @@ Semantics: out[i][j][k][m] = input[i][indices[j][k]][m] # if dim == 1 out[i][j][k][m] = input[i][j][indices[k][m]] # if dim == 2 -The above approach of reusing `tcp.gather_elements` is preferred to avoid adding a new op here. +The above approach of reusing `tcp.gather` is preferred to avoid adding a new op here. ## Gather slices along N dims @@ -158,7 +158,7 @@ The above approach of reusing `tcp.gather_elements` is preferred to avoid adding Syntax: - operation ::= `tcp.gather` $input `,` $indices attr-dict `:` + operation ::= `tcp.gather_nd` $input `,` $indices attr-dict `:` type($input) `,` type($indices) `->` type($out) Inputs: diff --git a/include/mlir-tcp/Dialect/IR/TcpOps.td b/include/mlir-tcp/Dialect/IR/TcpOps.td index 10c3e65..cc4196b 100644 --- a/include/mlir-tcp/Dialect/IR/TcpOps.td +++ b/include/mlir-tcp/Dialect/IR/TcpOps.td @@ -595,4 +595,27 @@ def Tcp_CastOp: Tcp_Op<"cast", [Pure, Elementwise, SameOperandsAndResultShape]> let hasVerifier = 1; } +def Tcp_GatherOp : Tcp_Op<"gather", [Pure, AllElementTypesMatch<["input", "out"]>]> { + + let summary = "Gather elements from input based on indices"; + + let description = [{ + Gathers elements from a given tensor based on indices that index along a given dim. + + More details regarding this op: docs/gather.md + }]; + + let arguments = (ins + Tcp_Tensor:$input, + Tcp_IntTensor:$indices, + IndexAttr:$dim + ); + + let results = (outs + Tcp_Tensor:$out + ); + + let assemblyFormat = "$input `,` $indices attr-dict `:` type($input) `,` type($indices) `->` type($out)"; +} + #endif // TCP_OPS diff --git a/include/mlir-tcp/Dialect/IR/TcpTypes.td b/include/mlir-tcp/Dialect/IR/TcpTypes.td index 4d0b124..e91efb3 100644 --- a/include/mlir-tcp/Dialect/IR/TcpTypes.td +++ b/include/mlir-tcp/Dialect/IR/TcpTypes.td @@ -55,6 +55,7 @@ def Tcp_Tensor : RankedTensorOf<[Tcp_Scalar]>; def Tcp_TensorOrScalar : AnyTypeOf<[Tcp_Tensor, Tcp_Scalar]>; def Tcp_FloatTensor : RankedTensorOf<[AnyFloat]>; +def Tcp_IntTensor : RankedTensorOf<[AnySignlessInteger]>; def Tcp_FloatOrIntTensor : RankedTensorOf<[AnyFloat, AnySignlessInteger]>; diff --git a/lib/Conversion/TcpToLinalg/DataMovement.cpp b/lib/Conversion/TcpToLinalg/DataMovement.cpp new file mode 100644 index 0000000..f3ba789 --- /dev/null +++ b/lib/Conversion/TcpToLinalg/DataMovement.cpp @@ -0,0 +1,103 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "mlir-tcp/Conversion/TcpToLinalg/TcpToLinalg.h" + +#include "mlir-tcp/Dialect/IR/TcpDialect.h" +#include "mlir-tcp/Dialect/IR/TcpOps.h" + +#include "../PassDetail.h" +#include "PopulatePatterns.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/Support/Debug.h" + +using namespace mlir; +using namespace mlir::tcp; + +namespace { + +class ConvertGatherOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTensorType = getTypeConverter() + ->convertType(op.getOut().getType()) + .cast(); + + auto inputTensor = adaptor.getInput(); + auto indicesTensor = adaptor.getIndices(); + int64_t gatherDim = adaptor.getDimAttr().getValue().getSExtValue(); + + auto resultRank = resultTensorType.getRank(); + + SmallVector resultDimSizes; + for (int64_t i = 0; i < resultRank; ++i) { + resultDimSizes.push_back( + rewriter.createOrFold(loc, indicesTensor, i)); + } + + SmallVector indexingMaps; + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); + + SmallVector iteratorTypes( + resultRank, utils::IteratorType::parallel); + + Value emptyTensor = + rewriter.create(loc, getAsOpFoldResult(resultDimSizes), + resultTensorType.getElementType()); + + auto bodyBuilder = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { + SmallVector extractIndices; + for (int64_t i = 0; i < resultRank; ++i) { + if (i == gatherDim) { + auto indexCast = b.create(loc, b.getIndexType(), + payloadArgs[0]); + extractIndices.push_back(indexCast); + } else { + auto iterIndex = b.create(loc, b.getIndexType(), + b.getI64IntegerAttr(i)); + extractIndices.push_back(iterIndex); + } + } + auto extract = b.create( + loc, resultTensorType.getElementType(), inputTensor, extractIndices); + b.create(loc, extract.getResult()); + }; + Value generic = + rewriter + .create(loc, emptyTensor.getType(), + indicesTensor, emptyTensor, indexingMaps, + iteratorTypes, bodyBuilder) + .getResult(0); + rewriter.replaceOp(op, generic); + return success(); + } +}; + +} // namespace + +void mlir::TcpToLinalg::populateDataMovementPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target) { + MLIRContext *context = patterns.getContext(); + + target.addIllegalOp(); + patterns.add(typeConverter, context); +} diff --git a/lib/Conversion/TcpToLinalg/PopulatePatterns.h b/lib/Conversion/TcpToLinalg/PopulatePatterns.h index b49ef87..afd9f09 100644 --- a/lib/Conversion/TcpToLinalg/PopulatePatterns.h +++ b/lib/Conversion/TcpToLinalg/PopulatePatterns.h @@ -18,6 +18,9 @@ void populateElementwisePatternsAndLegality(TypeConverter &typeConverter, void populateMiscPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target); +void populateDataMovementPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target); } // namespace TcpToLinalg } // namespace mlir diff --git a/lib/Conversion/TcpToLinalg/TcpToLinalg.cpp b/lib/Conversion/TcpToLinalg/TcpToLinalg.cpp index 5b239cf..fd276c6 100644 --- a/lib/Conversion/TcpToLinalg/TcpToLinalg.cpp +++ b/lib/Conversion/TcpToLinalg/TcpToLinalg.cpp @@ -50,6 +50,8 @@ class ConvertTcpToLinalg : public ConvertTcpToLinalgBase { target); TcpToLinalg::populateMiscPatternsAndLegality(typeConverter, patterns, target); + TcpToLinalg::populateDataMovementPatternsAndLegality(typeConverter, + patterns, target); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Conversion/TorchToTcp/DataMovement.cpp b/lib/Conversion/TorchToTcp/DataMovement.cpp index bf0d0d6..b37332d 100644 --- a/lib/Conversion/TorchToTcp/DataMovement.cpp +++ b/lib/Conversion/TorchToTcp/DataMovement.cpp @@ -22,6 +22,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" using namespace mlir; using namespace mlir::tcp; @@ -207,6 +208,39 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern { return success(); } }; + +class ConvertAtenGatherOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto input = adaptor.getSelf(); + auto indices = adaptor.getIndex(); + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .template cast(); + + int64_t dim = 0; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return op.emitError("dim on torch.gather must be an int constant"); + auto inputType = input.getType().cast(); + dim = Torch::toPositiveDim(dim, inputType.getRank()); + + bool sparseGrad = false; + if (!matchPattern(op.getSparseGrad(), m_TorchConstantBool(&sparseGrad))) + return op.emitError( + "sparse_grad on torch.gather must be a bool constant"); + if (sparseGrad) + return op.emitError("unimplemented: sparse_grad is not supported yet"); + + rewriter.replaceOpWithNewOp(op, resultType, input, indices, + rewriter.getIndexAttr(dim)); + return success(); + } +}; + } // namespace void torch_to_tcp::populateDataMovementPatternsAndLegality( @@ -217,4 +251,7 @@ void torch_to_tcp::populateDataMovementPatternsAndLegality( torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( typeConverter, patterns, target, convertTorchOpsSet); + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( + typeConverter, patterns, target, convertTorchOpsSet); } diff --git a/test/AotCompile/BUILD b/test/AotCompile/BUILD index c64d1a4..f542e7b 100644 --- a/test/AotCompile/BUILD +++ b/test/AotCompile/BUILD @@ -34,6 +34,7 @@ AOT_TEST_SUITE = [ ("broadcast_unit_dim_to_dynamic_with_unchanged_dim_dynamic", False), ("broadcast_unit_dim_to_static_with_rank_increase", False), ("broadcast_unit_dim_to_dynamic_with_rank_increase", False), + ("gather_elements", False), ] py_library( diff --git a/test/AotCompile/model_loader_lib.py b/test/AotCompile/model_loader_lib.py index 913dedd..647498b 100644 --- a/test/AotCompile/model_loader_lib.py +++ b/test/AotCompile/model_loader_lib.py @@ -36,9 +36,7 @@ def forward( } return TorchLoaderOutput( - model=AddMulSingleOutput(), - inputs=(x, y, z), - dynamic_shapes=dynamic_shapes, + model=AddMulSingleOutput(), inputs=(x, y, z), dynamic_shapes=dynamic_shapes ) @@ -69,9 +67,7 @@ def forward( } return TorchLoaderOutput( - model=AddMulMultiOutput(), - inputs=(x, y, z), - dynamic_shapes=dynamic_shapes, + model=AddMulMultiOutput(), inputs=(x, y, z), dynamic_shapes=dynamic_shapes ) @@ -90,15 +86,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # Dynamic dim constraints batch = Dim("batch") - dynamic_shapes = { - "x": None, - "y": {0: batch}, - } + dynamic_shapes = {"x": None, "y": {0: batch}} return TorchLoaderOutput( - model=AddTensorMixedRanks(), - inputs=(x, y), - dynamic_shapes=dynamic_shapes, + model=AddTensorMixedRanks(), inputs=(x, y), dynamic_shapes=dynamic_shapes ) @@ -117,15 +108,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # Dynamic dim constraints batch = Dim("batch") - dynamic_shapes = { - "x": {0: batch}, - "y": {0: batch}, - } + dynamic_shapes = {"x": {0: batch}, "y": {0: batch}} return TorchLoaderOutput( - model=AddTensorWithAlpha(), - inputs=(x, y), - dynamic_shapes=dynamic_shapes, + model=AddTensorWithAlpha(), inputs=(x, y), dynamic_shapes=dynamic_shapes ) @@ -144,15 +130,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # Dynamic dim constraints batch = Dim("batch") - dynamic_shapes = { - "x": {0: batch}, - "y": {0: batch}, - } + dynamic_shapes = {"x": {0: batch}, "y": {0: batch}} return TorchLoaderOutput( - model=SubTensorWithAlpha(), - inputs=(x, y), - dynamic_shapes=dynamic_shapes, + model=SubTensorWithAlpha(), inputs=(x, y), dynamic_shapes=dynamic_shapes ) @@ -171,15 +152,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # Dynamic dim constraints batch = Dim("batch") - dynamic_shapes = { - "x": None, - "y": {0: batch}, - } + dynamic_shapes = {"x": None, "y": {0: batch}} return TorchLoaderOutput( - model=DivTensorMixedRanks(), - inputs=(x, y), - dynamic_shapes=dynamic_shapes, + model=DivTensorMixedRanks(), inputs=(x, y), dynamic_shapes=dynamic_shapes ) @@ -204,14 +180,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Dynamic dim constraints batch = Dim("batch") - dynamic_shapes = { - "x": {0: batch}, - } + dynamic_shapes = {"x": {0: batch}} return TorchLoaderOutput( - model=AddSubMulDivScalar(), - inputs=(x,), - dynamic_shapes=dynamic_shapes, + model=AddSubMulDivScalar(), inputs=(x,), dynamic_shapes=dynamic_shapes ) @@ -228,14 +200,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Dynamic dim constraints batch = Dim("batch") - dynamic_shapes = { - "x": {0: batch}, - } + dynamic_shapes = {"x": {0: batch}} return TorchLoaderOutput( - model=Sigmoid(), - inputs=(x,), - dynamic_shapes=dynamic_shapes, + model=Sigmoid(), inputs=(x,), dynamic_shapes=dynamic_shapes ) @@ -252,15 +220,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Dynamic dim constraints batch = Dim("batch") - dynamic_shapes = { - "x": {0: batch}, - } + dynamic_shapes = {"x": {0: batch}} - return TorchLoaderOutput( - model=Tanh(), - inputs=(x,), - dynamic_shapes=dynamic_shapes, - ) + return TorchLoaderOutput(model=Tanh(), inputs=(x,), dynamic_shapes=dynamic_shapes) def clamp_loader() -> TorchLoaderOutput: @@ -278,15 +240,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Dynamic dim constraints batch = Dim("batch") - dynamic_shapes = { - "x": {0: batch}, - } + dynamic_shapes = {"x": {0: batch}} - return TorchLoaderOutput( - model=Clamp(), - inputs=(x,), - dynamic_shapes=dynamic_shapes, - ) + return TorchLoaderOutput(model=Clamp(), inputs=(x,), dynamic_shapes=dynamic_shapes) def relu_loader() -> TorchLoaderOutput: @@ -302,15 +258,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Dynamic dim constraints batch = Dim("batch") - dynamic_shapes = { - "x": {0: batch}, - } + dynamic_shapes = {"x": {0: batch}} - return TorchLoaderOutput( - model=Relu(), - inputs=(x,), - dynamic_shapes=dynamic_shapes, - ) + return TorchLoaderOutput(model=Relu(), inputs=(x,), dynamic_shapes=dynamic_shapes) def round_even_loader() -> TorchLoaderOutput: @@ -326,14 +276,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Dynamic dim constraints batch = Dim("batch") - dynamic_shapes = { - "x": {0: batch}, - } + dynamic_shapes = {"x": {0: batch}} return TorchLoaderOutput( - model=RoundEven(), - inputs=(x,), - dynamic_shapes=dynamic_shapes, + model=RoundEven(), inputs=(x,), dynamic_shapes=dynamic_shapes ) @@ -350,14 +296,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Dynamic dim constraints batch = Dim("batch") - dynamic_shapes = { - "x": {0: batch}, - } + dynamic_shapes = {"x": {0: batch}} return TorchLoaderOutput( - model=SqrtFloat(), - inputs=(x,), - dynamic_shapes=dynamic_shapes, + model=SqrtFloat(), inputs=(x,), dynamic_shapes=dynamic_shapes ) @@ -374,14 +316,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Dynamic dim constraints batch = Dim("batch") - dynamic_shapes = { - "x": {0: batch}, - } + dynamic_shapes = {"x": {0: batch}} return TorchLoaderOutput( - model=SqrtInt(), - inputs=(x,), - dynamic_shapes=dynamic_shapes, + model=SqrtInt(), inputs=(x,), dynamic_shapes=dynamic_shapes ) @@ -400,15 +338,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # Dynamic dim constraints batch_x = Dim("batch_x") batch_y = Dim("batch_y") - dynamic_shapes = { - "x": {0: batch_x}, - "y": {0: batch_y}, - } + dynamic_shapes = {"x": {0: batch_x}, "y": {0: batch_y}} return TorchLoaderOutput( - model=ConcatFloatTensors(), - inputs=(x, y), - dynamic_shapes=dynamic_shapes, + model=ConcatFloatTensors(), inputs=(x, y), dynamic_shapes=dynamic_shapes ) @@ -427,15 +360,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # Dynamic dim constraints batch_x = Dim("batch_x") batch_y = Dim("batch_y") - dynamic_shapes = { - "x": {0: batch_x}, - "y": {0: batch_y}, - } + dynamic_shapes = {"x": {0: batch_x}, "y": {0: batch_y}} return TorchLoaderOutput( - model=ConcatIntTensors(), - inputs=(x, y), - dynamic_shapes=dynamic_shapes, + model=ConcatIntTensors(), inputs=(x, y), dynamic_shapes=dynamic_shapes ) @@ -452,14 +380,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Dynamic dim constraints batch = Dim("batch", min=3) - dynamic_shapes = { - "x": {0: batch}, - } + dynamic_shapes = {"x": {0: batch}} return TorchLoaderOutput( - model=SliceTensor(), - inputs=(x,), - dynamic_shapes=dynamic_shapes, + model=SliceTensor(), inputs=(x,), dynamic_shapes=dynamic_shapes ) @@ -475,8 +399,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.randn(1, 2) return TorchLoaderOutput( - model=BroadcastUnitDimToStaticWithExplicitDimStatic(), - inputs=(x,), + model=BroadcastUnitDimToStaticWithExplicitDimStatic(), inputs=(x,) ) @@ -494,8 +417,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.randn(1, 2) return TorchLoaderOutput( - model=BroadcastUnitDimToStaticWithUnchangedDimStatic(), - inputs=(x,), + model=BroadcastUnitDimToStaticWithUnchangedDimStatic(), inputs=(x,) ) @@ -513,9 +435,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.randn(1, 2) dim_1 = Dim("dim_1") - dynamic_shapes = { - "x": {1: dim_1}, - } + dynamic_shapes = {"x": {1: dim_1}} return TorchLoaderOutput( model=BroadcastUnitDimToStaticWithUnchangedDimDynamic(), @@ -539,10 +459,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: y = torch.randn(10) dim_0 = Dim("dim_0") - dynamic_shapes = { - "x": {}, - "y": {0: dim_0}, - } + dynamic_shapes = {"x": {}, "y": {0: dim_0}} return TorchLoaderOutput( model=BroadcastUnitDimToDynamicWithUnchangedDimStatic(), @@ -567,10 +484,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: dim_0 = Dim("dim_0") dim_1 = Dim("dim_1") - dynamic_shapes = { - "x": {1: dim_1}, - "y": {0: dim_0}, - } + dynamic_shapes = {"x": {1: dim_1}, "y": {0: dim_0}} return TorchLoaderOutput( model=BroadcastUnitDimToDynamicWithUnchangedDimDynamic(), @@ -592,8 +506,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: y = torch.randn(4, 3, 2) return TorchLoaderOutput( - model=BroadcastUnitDimToStaticWithRankIncrease(), - inputs=(x, y), + model=BroadcastUnitDimToStaticWithRankIncrease(), inputs=(x, y) ) @@ -610,13 +523,31 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: y = torch.randn(4, 3, 2) dim_0 = Dim("dim_0") - dynamic_shapes = { - "x": {}, - "y": {0: dim_0}, - } + dynamic_shapes = {"x": {}, "y": {0: dim_0}} return TorchLoaderOutput( model=BroadcastUnitDimToDynamicWithRankIncrease(), inputs=(x, y), dynamic_shapes=dynamic_shapes, ) + + +def gather_elements_loader() -> TorchLoaderOutput: + class GatherElements(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.gather(x, 0, y) + + # Sample inputs + x = torch.randn(4, 3) + y = torch.tensor([[0, 0, 0], [1, 1, 1]]) + + # Dynamic dim constraints + batch = Dim("batch", min=3) + dynamic_shapes = {"x": {0: batch}, "y": {}} + + return TorchLoaderOutput( + model=GatherElements(), inputs=(x, y), dynamic_shapes=dynamic_shapes + ) diff --git a/test/Conversion/TcpToLinalg/data_movement.mlir b/test/Conversion/TcpToLinalg/data_movement.mlir new file mode 100644 index 0000000..ee85853 --- /dev/null +++ b/test/Conversion/TcpToLinalg/data_movement.mlir @@ -0,0 +1,20 @@ +// RUN: tcp-opt %s -convert-tcp-to-linalg -split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @gather +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x4x3xf32>, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x4x2xi64>) -> tensor<1x4x2xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x2xf32> +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG1]] : tensor<1x4x2xi64>) outs(%[[EMPTY]] : tensor<1x4x2xf32>) +// CHECK: ^bb0(%[[IN:.+]]: i64, %[[OUT:.+]]: f32): +// CHECK: %[[I0:.+]] = linalg.index 0 : index +// CHECK: %[[I1:.+]] = linalg.index 1 : index +// CHECK: %[[I2:.+]] = arith.index_cast %[[IN]] : i64 to index +// CHECK: %[[EXTRACT:.+]] = tensor.extract %[[ARG0]][%[[I0]], %[[I1]], %[[I2]]] : tensor<1x4x3xf32> +// CHECK: linalg.yield %[[EXTRACT]] : f32 +// CHECK: } -> tensor<1x4x2xf32> +// CHECK: return %[[GENERIC]] : tensor<1x4x2xf32> +func.func @gather(%arg0 : tensor<1x4x3xf32>, %arg1 : tensor<1x4x2xi64>) -> tensor<1x4x2xf32> { + %0 = "tcp.gather"(%arg0, %arg1) {dim = 2 : index} : (tensor<1x4x3xf32>, tensor<1x4x2xi64>) -> tensor<1x4x2xf32> + return %0 : tensor<1x4x2xf32> +} diff --git a/test/Conversion/TorchToTcp/data_movement.mlir b/test/Conversion/TorchToTcp/data_movement.mlir index 070aba6..4856ace 100644 --- a/test/Conversion/TorchToTcp/data_movement.mlir +++ b/test/Conversion/TorchToTcp/data_movement.mlir @@ -29,3 +29,21 @@ func.func @torch.aten.slice.Tensor(%arg0: !torch.vtensor<[1,56,?,?],f32>) -> !to %0 = torch.aten.slice.Tensor %arg0, %int1, %int0, %int100, %int2 : !torch.vtensor<[1,56,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,28,?,?],f32> return %0 : !torch.vtensor<[1,28,?,?],f32> } + +// ----- + +// CHECK-LABEL: @torch.aten.gather +// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[1,4,3],f32>, +// CHECK-SAME: %[[ARG1:.+]]: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],f32> +// CHECK: %[[V1:.+]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,4,3],f32> -> tensor<1x4x3xf32> +// CHECK: %[[V2:.+]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[1,4,2],si64> -> tensor<1x4x2xi64> +// CHECK: %[[GATHER:.+]] = tcp.gather %[[V1]], %[[V2]] {dim = 2 : index} : +// CHECK-SAME: tensor<1x4x3xf32>, tensor<1x4x2xi64> -> tensor<1x4x2xf32> +// CHECK: %[[V3:.+]] = torch_c.from_builtin_tensor %[[GATHER]] : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32> +// CHECK: return %[[V3]] : !torch.vtensor<[1,4,2],f32> +func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],f32> { + %int-1 = torch.constant.int -1 + %false = torch.constant.bool false + %0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[1,4,3],f32>, !torch.int, !torch.vtensor<[1,4,2],si64>, !torch.bool -> !torch.vtensor<[1,4,2],f32> + return %0 : !torch.vtensor<[1,4,2],f32> +} diff --git a/test/Dialect/data_movement.mlir b/test/Dialect/data_movement.mlir new file mode 100644 index 0000000..53bf3b0 --- /dev/null +++ b/test/Dialect/data_movement.mlir @@ -0,0 +1,12 @@ +// RUN: tcp-opt %s -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @test_gather( +// CHECK-SAME: %[[ARG0:.*]]: tensor, +// CHECK-SAME: %[[ARG1:.*]]: tensor<2x?xi64>) -> tensor<2x?xf32> +// CHECK: %[[GATHER:.*]] = tcp.gather %[[ARG0]], %[[ARG1]] {dim = 0 : index} : tensor, tensor<2x?xi64> -> tensor<2x?xf32> +// CHECK: return %[[GATHER]] : tensor<2x?xf32> +func.func @test_gather(%arg0 : tensor, %arg1 : tensor<2x?xi64>) -> tensor<2x?xf32> { + %0 = tcp.gather %arg0, %arg1 { dim = 0 : index } : + tensor, tensor<2x?xi64> -> tensor<2x?xf32> + return %0 : tensor<2x?xf32> +} diff --git a/test/Pipeline/torch_to_tcp_pipeline.mlir b/test/Pipeline/torch_to_tcp_pipeline.mlir index aa51293..ae2296d 100644 --- a/test/Pipeline/torch_to_tcp_pipeline.mlir +++ b/test/Pipeline/torch_to_tcp_pipeline.mlir @@ -113,17 +113,3 @@ func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si1 %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32> return %0 : !torch.vtensor<[?, ?],si32> } - -// ----- - -// CHECK-LABEL: torch.aten.gather_op -// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x2xi64>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x2xf32> -// CHECK: %[[VAL_2:.*]] = tcp.custom_op("torch.aten.gather") %[[VAL_1]], %[[VAL_0]] {axis = 1 : i64, torch_operand_names = ["self", "index"]} : tensor<2x2xf32>, tensor<2x2xi64> -> tensor<2x2xf32> -// CHECK: return %[[VAL_2]] : tensor<2x2xf32> -func.func @torch.aten.gather_op(%arg0: !torch.vtensor<[2,2],si64>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> { - %false = torch.constant.bool false - %int1 = torch.constant.int 1 - %0 = torch.aten.gather %arg1, %int1, %arg0, %false : !torch.vtensor<[2,2],f32>, !torch.int, !torch.vtensor<[2,2],si64>, !torch.bool -> !torch.vtensor<[2,2],f32> - return %0 : !torch.vtensor<[2,2],f32> -}