-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR adds the `tcp.gather` op. The design and semantics of this op are defined in [docs/gather.md](https://github.com/cruise-automation/mlir-tcp/blob/main/docs/gather.md). More specifically, this PR includes: * The definition of the `tcp.gather` op. * Lowering from Torch to Tcp for the `torch.gather` op. * Lowering from Tcp to Linalg for the `tcp.gather` op. * Lit tests for the op and the lowerings. * AOT test the Tcp gather op. * Renames the op in the design document. This PR does not update the lowering we have from aten.gather to tcp.custom("aten.gather") pattern [here](https://github.com/cruise-automation/mlir-tcp/blob/b812717001121df6974687e2b2e7e82b900390dd/lib/Conversion/TorchToTcp/TcpCustomOp.cpp#L29-L45). We should be able to control how `aten.gather` gets lowered by fixing the order in which the conversion passes are called in the pipeline.
- Loading branch information
Showing
14 changed files
with
297 additions
and
159 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
//===------------------------------------------------------------*- C++ -*-===// | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir-tcp/Conversion/TcpToLinalg/TcpToLinalg.h" | ||
|
||
#include "mlir-tcp/Dialect/IR/TcpDialect.h" | ||
#include "mlir-tcp/Dialect/IR/TcpOps.h" | ||
|
||
#include "../PassDetail.h" | ||
#include "PopulatePatterns.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Dialect/Math/IR/Math.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
#include "llvm/Support/Debug.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::tcp; | ||
|
||
namespace { | ||
|
||
class ConvertGatherOp : public OpConversionPattern<GatherOp> { | ||
public: | ||
using OpConversionPattern::OpConversionPattern; | ||
|
||
LogicalResult | ||
matchAndRewrite(GatherOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
Location loc = op->getLoc(); | ||
auto resultTensorType = getTypeConverter() | ||
->convertType(op.getOut().getType()) | ||
.cast<RankedTensorType>(); | ||
|
||
auto inputTensor = adaptor.getInput(); | ||
auto indicesTensor = adaptor.getIndices(); | ||
int64_t gatherDim = adaptor.getDimAttr().getValue().getSExtValue(); | ||
|
||
auto resultRank = resultTensorType.getRank(); | ||
|
||
SmallVector<Value> resultDimSizes; | ||
for (int64_t i = 0; i < resultRank; ++i) { | ||
resultDimSizes.push_back( | ||
rewriter.createOrFold<tensor::DimOp>(loc, indicesTensor, i)); | ||
} | ||
|
||
SmallVector<AffineMap, 2> indexingMaps; | ||
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); | ||
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); | ||
|
||
SmallVector<utils::IteratorType> iteratorTypes( | ||
resultRank, utils::IteratorType::parallel); | ||
|
||
Value emptyTensor = | ||
rewriter.create<tensor::EmptyOp>(loc, getAsOpFoldResult(resultDimSizes), | ||
resultTensorType.getElementType()); | ||
|
||
auto bodyBuilder = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { | ||
SmallVector<Value> extractIndices; | ||
for (int64_t i = 0; i < resultRank; ++i) { | ||
if (i == gatherDim) { | ||
auto indexCast = b.create<arith::IndexCastOp>(loc, b.getIndexType(), | ||
payloadArgs[0]); | ||
extractIndices.push_back(indexCast); | ||
} else { | ||
auto iterIndex = b.create<linalg::IndexOp>(loc, b.getIndexType(), | ||
b.getI64IntegerAttr(i)); | ||
extractIndices.push_back(iterIndex); | ||
} | ||
} | ||
auto extract = b.create<tensor::ExtractOp>( | ||
loc, resultTensorType.getElementType(), inputTensor, extractIndices); | ||
b.create<linalg::YieldOp>(loc, extract.getResult()); | ||
}; | ||
Value generic = | ||
rewriter | ||
.create<linalg::GenericOp>(loc, emptyTensor.getType(), | ||
indicesTensor, emptyTensor, indexingMaps, | ||
iteratorTypes, bodyBuilder) | ||
.getResult(0); | ||
rewriter.replaceOp(op, generic); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void mlir::TcpToLinalg::populateDataMovementPatternsAndLegality( | ||
TypeConverter &typeConverter, RewritePatternSet &patterns, | ||
ConversionTarget &target) { | ||
MLIRContext *context = patterns.getContext(); | ||
|
||
target.addIllegalOp<GatherOp>(); | ||
patterns.add<ConvertGatherOp>(typeConverter, context); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.