Skip to content

Commit

Permalink
Add op for Gather (#68)
Browse files Browse the repository at this point in the history
This PR adds the `tcp.gather` op. The design and semantics of this op
are defined in
[docs/gather.md](https://github.com/cruise-automation/mlir-tcp/blob/main/docs/gather.md).

More specifically, this PR includes:
* The definition of the `tcp.gather` op.
* Lowering from Torch to Tcp for the `torch.gather` op.
* Lowering from Tcp to Linalg for the `tcp.gather` op.
* Lit tests for the op and the lowerings.
* AOT test the Tcp gather op.
* Renames the op in the design document.

This PR does not update the lowering we have from aten.gather to
tcp.custom("aten.gather") pattern
[here](https://github.com/cruise-automation/mlir-tcp/blob/b812717001121df6974687e2b2e7e82b900390dd/lib/Conversion/TorchToTcp/TcpCustomOp.cpp#L29-L45).
We should be able to control how `aten.gather` gets lowered by fixing
the order in which the conversion passes are called in the pipeline.
  • Loading branch information
navahgar authored Apr 28, 2024
1 parent 04f327c commit 553b127
Show file tree
Hide file tree
Showing 14 changed files with 297 additions and 159 deletions.
1 change: 1 addition & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ cc_library(
name = "TcpToLinalg",
srcs = [
"lib/Conversion/PassDetail.h",
"lib/Conversion/TcpToLinalg/DataMovement.cpp",
"lib/Conversion/TcpToLinalg/Elementwise.cpp",
"lib/Conversion/TcpToLinalg/Misc.cpp",
"lib/Conversion/TcpToLinalg/PopulatePatterns.h",
Expand Down
32 changes: 16 additions & 16 deletions docs/gather.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

## Gather elements along a given dim

`tcp.gather_elements` op gathers elements from a given tensor based on indices that index along a given dim.
`tcp.gather` op gathers elements from a given tensor based on indices that index along a given dim.

Syntax:

operation ::= `tcp.gather_elements` $input `,` $indices attr-dict `:`
operation ::= `tcp.gather` $input `,` $indices attr-dict `:`
type($input) `,` type($indices) `->` type($out)

Attributes:
Expand Down Expand Up @@ -43,11 +43,11 @@ This op is similar to `torch.gather` [[1]](https://pytorch.org/docs/stable/gener
indices = torch.tensor([[0, 1], [2, 0], [2, 3]]) # Shape is [3, 2]
x = torch.gather(input, 1, indices) # Result shape is [3, 2]

This will get mapped to `tcp.gather_elements` as follows:
This will get mapped to `tcp.gather` as follows:

%input = ...
%indices = ...
%x = tcp.gather_elements %input, %indices { gather_dim = 1 } :
%x = tcp.gather %input, %indices { dim = 1 } :
(tensor<3x4xf32>, tensor<3x2xi64>) -> tensor<3x2xf32>

2. Modeling `onnx.GatherElements`
Expand All @@ -56,22 +56,22 @@ This op is similar to `torch.gather` [[1]](https://pytorch.org/docs/stable/gener
indices = ... # Shape is [2, 3]
x = onnx.GatherElements(input, indices, 0) # Result shape is [2, 3]

This will get mapped to `tcp.gather_elements` as follows:
This will get mapped to `tcp.gather` as follows:

%input = ...
%indices = ...
%x = tcp.gather_elements %input, %indices { gather_dim = 0 } :
%x = tcp.gather %input, %indices { dim = 0 } :
(tensor<3x3xf32>, tensor<2x3xi64>) -> tensor<2x3xf32>


## Gather slices along a given dim

This requires gathering slices from a given tensor based on indices that index along a given dim.

Our design is to use `tcp.gather_elements` op for these cases as follows. Suppose that the `input` has shape `[a, b, c]`, `indices` has shape `[x, y]` and `dim = 0`. Shape of `output` in this case will be `[x, y, b, c]`.
Our design is to use `tcp.gather` op for these cases as follows. Suppose that the `input` has shape `[a, b, c]`, `indices` has shape `[x, y]` and `dim = 0`. Shape of `output` in this case will be `[x, y, b, c]`.
* Broadcast `input` from `[a, b, c]` to `[a, y, b, c]` by introducing `y` dim.
* Broadcast `indices` from `[x, y]` to `[x, y, b, c]` by introducing `b` and `c` dims.
* Perform `tcp.gather_elements` on these broadcasted `input` and `indices`, whose `output` will now have the shape `[x, y, b, c]`.
* Perform `tcp.gather` on these broadcasted `input` and `indices`, whose `output` will now have the shape `[x, y, b, c]`.


This approach can be used to represent ops like `torch.index_select` [[3]](https://pytorch.org/docs/stable/generated/torch.index_select.html), `tf.gather` [[4]](https://www.tensorflow.org/api_docs/python/tf/gather), and `onnx.Gather` [[5]](https://onnx.ai/onnx/operators/onnx__Gather.html#l-onnx-doc-gather).
Expand All @@ -84,7 +84,7 @@ This approach can be used to represent ops like `torch.index_select` [[3]](https
indices = torch.tensor([0, 2]) # Shape is [2]
x = torch.index_select(input, 0, indices) # Result shape is [2, 4]

This will get mapped to `tcp.gather_elements` as follows:
This will get mapped to `tcp.gather` as follows:

%input = ... # Shape is [3, 4]
%indices = ... # Shape is [2]
Expand All @@ -93,7 +93,7 @@ This approach can be used to represent ops like `torch.index_select` [[3]](https
%cst4 = arith.constant 4 : index
%indices_bcast = tcp.broadcast_to %indices_2d, %cst4 { axes = [1] } :
(tensor<2x1xi64>, index) -> tensor<2x4xi64>
%x = tcp.gather_elements %input, %indices_bcast { gather_dim = 0 } :
%x = tcp.gather %input, %indices_bcast { dim = 0 } :
(tensor<3x4xf32>, tensor<2x4xi64>) -> tensor<2x4xf32>

2. Modeling `tf.gather`
Expand All @@ -102,15 +102,15 @@ This approach can be used to represent ops like `torch.index_select` [[3]](https
indices = ... # Shape is [3]
x = tf.gather(input, indices, axis=1) # Result shape is [3, 3, 5]

This will get mapped to `tcp.gather_elements` as follows:
This will get mapped to `tcp.gather` as follows:

%input = ... # Shape is [3, 4, 5]
%indices = ... # Shape is [3]
%indices_3d = tensor.expand_shape %indices [[0, 1, 2]] :
(tensor<3xi64>) -> tensor<1x3x1xi64>
%indices_bcast = tcp.broadcast_to %indices_3d, %cst3, %cst5 { axes = [0, 2] } :
(tensor<1x3x1xi64>, index, index) -> tensor<3x3x5xi64>
%x = tcp.gather_elements %input, %indices_bcast { gather_dim = 1 } :
%x = tcp.gather %input, %indices_bcast { dim = 1 } :
(tensor<3x4x5xf32>, tensor<3x3x5xi64>) -> tensor<3x3x5xf32>

3. Modeling `onnx.Gather`
Expand All @@ -119,11 +119,11 @@ This approach can be used to represent ops like `torch.index_select` [[3]](https

### Alternative considered

We considered a separate `tcp.gather` op for this particular case with the following design.
We considered a separate `tcp.gather_slice` op for this particular case with the following design.

Syntax:

operation ::= `tcp.gather` $input `,` $indices attr-dict `:`
operation ::= `tcp.gather_slice` $input `,` $indices attr-dict `:`
type($input) `,` type($indices) `->` type($out)

Attributes:
Expand All @@ -150,15 +150,15 @@ Semantics:
out[i][j][k][m] = input[i][indices[j][k]][m] # if dim == 1
out[i][j][k][m] = input[i][j][indices[k][m]] # if dim == 2

The above approach of reusing `tcp.gather_elements` is preferred to avoid adding a new op here.
The above approach of reusing `tcp.gather` is preferred to avoid adding a new op here.

## Gather slices along N dims

`tcp.gather_nd` op gathers slices from a given tensor based on indices that index along the first `n` dims.

Syntax:

operation ::= `tcp.gather` $input `,` $indices attr-dict `:`
operation ::= `tcp.gather_nd` $input `,` $indices attr-dict `:`
type($input) `,` type($indices) `->` type($out)

Inputs:
Expand Down
23 changes: 23 additions & 0 deletions include/mlir-tcp/Dialect/IR/TcpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -595,4 +595,27 @@ def Tcp_CastOp: Tcp_Op<"cast", [Pure, Elementwise, SameOperandsAndResultShape]>
let hasVerifier = 1;
}

def Tcp_GatherOp : Tcp_Op<"gather", [Pure, AllElementTypesMatch<["input", "out"]>]> {

let summary = "Gather elements from input based on indices";

let description = [{
Gathers elements from a given tensor based on indices that index along a given dim.

More details regarding this op: docs/gather.md
}];

let arguments = (ins
Tcp_Tensor:$input,
Tcp_IntTensor:$indices,
IndexAttr:$dim
);

let results = (outs
Tcp_Tensor:$out
);

let assemblyFormat = "$input `,` $indices attr-dict `:` type($input) `,` type($indices) `->` type($out)";
}

#endif // TCP_OPS
1 change: 1 addition & 0 deletions include/mlir-tcp/Dialect/IR/TcpTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def Tcp_Tensor : RankedTensorOf<[Tcp_Scalar]>;
def Tcp_TensorOrScalar : AnyTypeOf<[Tcp_Tensor, Tcp_Scalar]>;

def Tcp_FloatTensor : RankedTensorOf<[AnyFloat]>;
def Tcp_IntTensor : RankedTensorOf<[AnySignlessInteger]>;
def Tcp_FloatOrIntTensor : RankedTensorOf<[AnyFloat, AnySignlessInteger]>;


Expand Down
103 changes: 103 additions & 0 deletions lib/Conversion/TcpToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "mlir-tcp/Conversion/TcpToLinalg/TcpToLinalg.h"

#include "mlir-tcp/Dialect/IR/TcpDialect.h"
#include "mlir-tcp/Dialect/IR/TcpOps.h"

#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"

#include "llvm/Support/Debug.h"

using namespace mlir;
using namespace mlir::tcp;

namespace {

class ConvertGatherOp : public OpConversionPattern<GatherOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(GatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTensorType = getTypeConverter()
->convertType(op.getOut().getType())
.cast<RankedTensorType>();

auto inputTensor = adaptor.getInput();
auto indicesTensor = adaptor.getIndices();
int64_t gatherDim = adaptor.getDimAttr().getValue().getSExtValue();

auto resultRank = resultTensorType.getRank();

SmallVector<Value> resultDimSizes;
for (int64_t i = 0; i < resultRank; ++i) {
resultDimSizes.push_back(
rewriter.createOrFold<tensor::DimOp>(loc, indicesTensor, i));
}

SmallVector<AffineMap, 2> indexingMaps;
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));

SmallVector<utils::IteratorType> iteratorTypes(
resultRank, utils::IteratorType::parallel);

Value emptyTensor =
rewriter.create<tensor::EmptyOp>(loc, getAsOpFoldResult(resultDimSizes),
resultTensorType.getElementType());

auto bodyBuilder = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
SmallVector<Value> extractIndices;
for (int64_t i = 0; i < resultRank; ++i) {
if (i == gatherDim) {
auto indexCast = b.create<arith::IndexCastOp>(loc, b.getIndexType(),
payloadArgs[0]);
extractIndices.push_back(indexCast);
} else {
auto iterIndex = b.create<linalg::IndexOp>(loc, b.getIndexType(),
b.getI64IntegerAttr(i));
extractIndices.push_back(iterIndex);
}
}
auto extract = b.create<tensor::ExtractOp>(
loc, resultTensorType.getElementType(), inputTensor, extractIndices);
b.create<linalg::YieldOp>(loc, extract.getResult());
};
Value generic =
rewriter
.create<linalg::GenericOp>(loc, emptyTensor.getType(),
indicesTensor, emptyTensor, indexingMaps,
iteratorTypes, bodyBuilder)
.getResult(0);
rewriter.replaceOp(op, generic);
return success();
}
};

} // namespace

void mlir::TcpToLinalg::populateDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();

target.addIllegalOp<GatherOp>();
patterns.add<ConvertGatherOp>(typeConverter, context);
}
3 changes: 3 additions & 0 deletions lib/Conversion/TcpToLinalg/PopulatePatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ void populateElementwisePatternsAndLegality(TypeConverter &typeConverter,
void populateMiscPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
void populateDataMovementPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);

} // namespace TcpToLinalg
} // namespace mlir
2 changes: 2 additions & 0 deletions lib/Conversion/TcpToLinalg/TcpToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class ConvertTcpToLinalg : public ConvertTcpToLinalgBase<ConvertTcpToLinalg> {
target);
TcpToLinalg::populateMiscPatternsAndLegality(typeConverter, patterns,
target);
TcpToLinalg::populateDataMovementPatternsAndLegality(typeConverter,
patterns, target);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
Expand Down
37 changes: 37 additions & 0 deletions lib/Conversion/TorchToTcp/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"

using namespace mlir;
using namespace mlir::tcp;
Expand Down Expand Up @@ -207,6 +208,39 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern<AtenSliceTensorOp> {
return success();
}
};

class ConvertAtenGatherOp : public OpConversionPattern<AtenGatherOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto input = adaptor.getSelf();
auto indices = adaptor.getIndex();
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.template cast<RankedTensorType>();

int64_t dim = 0;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return op.emitError("dim on torch.gather must be an int constant");
auto inputType = input.getType().cast<RankedTensorType>();
dim = Torch::toPositiveDim(dim, inputType.getRank());

bool sparseGrad = false;
if (!matchPattern(op.getSparseGrad(), m_TorchConstantBool(&sparseGrad)))
return op.emitError(
"sparse_grad on torch.gather must be a bool constant");
if (sparseGrad)
return op.emitError("unimplemented: sparse_grad is not supported yet");

rewriter.replaceOpWithNewOp<tcp::GatherOp>(op, resultType, input, indices,
rewriter.getIndexAttr(dim));
return success();
}
};

} // namespace

void torch_to_tcp::populateDataMovementPatternsAndLegality(
Expand All @@ -217,4 +251,7 @@ void torch_to_tcp::populateDataMovementPatternsAndLegality(
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<ConvertAtenSliceTensorOp,
AtenSliceTensorOp>(
typeConverter, patterns, target, convertTorchOpsSet);
torch_to_tcp::addPatternIfOpInConvertTorchOpsSet<ConvertAtenGatherOp,
AtenGatherOp>(
typeConverter, patterns, target, convertTorchOpsSet);
}
1 change: 1 addition & 0 deletions test/AotCompile/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ AOT_TEST_SUITE = [
("broadcast_unit_dim_to_dynamic_with_unchanged_dim_dynamic", False),
("broadcast_unit_dim_to_static_with_rank_increase", False),
("broadcast_unit_dim_to_dynamic_with_rank_increase", False),
("gather_elements", False),
]

py_library(
Expand Down
Loading

0 comments on commit 553b127

Please sign in to comment.