From a7d0a6b60b225a7b96763c5ecd24bd1d76791496 Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Fri, 1 Aug 2025 13:06:26 -0700 Subject: [PATCH] [mlir-tensorrt] Transpose Reshape Elimination pass. This is a new pass that is designed to replace the Transpose and Reshape Elemination passes. This pass adds a lot of new rewrite rules which enable pushing the transposes and reshapes around so that they can be combined and then eliminated. The motivation for this pass is that there are some cases where shuffles can get inserted around matrix multiplications and element wise ops which break various fusions inside of TensorRT. To accomplish this, this pass uses several rewrite rules that push transposes and reshapes around to combine them into identity transposes and reshapes which can be eliminated from the program. The rewrite rules are as follows: 1. "canonicalize" the network into simpler ops - `shuffle(x)` -> `reshape(transpose(reshape(x)))` - `matrix_multiply(x, y)` -> `einsum("ij,jk->ik", x, y)` - `expand_rank(x)` -> `reshape(x)` - `collapse_rank(x)` -> `reshape(x)` 2. Push down `reshape` and `transpose` ops as much as possible. Merging and eliminating when possible - `einsum(transpose(x), ...)` -> `einsum(x, ...)` Merge transpose into einsum - `einsum(...)` -> `transpose(einsum(...))` Pull transpose out of einsum (to try to match matrix multiply pattern) - `einsum(reshape(x), y, ...)` -> `transpose(reshape(einsum(x, reshape(transpose(y)), ...)))` Push reshape down. Possibly add reshape and transposes to other inputs as needed. Conditioned on heuristic checking if "better" - `unary(transpose(x))` -> `transpose(unary(x))` - `activation(transpose(x))` -> `transpose(activation(x))` - `identity_op(transpose(x))` -> `transpose(identity_op(x))` - `activation(reshape(x))` -> `reshape(activation(x))` - `unary(reshape(x))` -> `reshape(unary(x))` - `identity_op(reshape(x))` -> `reshape(identity_op(x))` - `reshape(transpose(x))` -> `transpose(reshape(x))` if possible put reshape before transpose - `qdq(transpose(x))` -> `transpose(qdq(x))` if the scale is 0-dim - `qdq(reshape(x))` -> `reshape(qdq(x))` if the scale is 0-dim - `reshape(reshape(x))` -> `reshape(x)` - `transpose(transpose(x))` -> `transpose(x)` - `reshape(x)` -> `x` if `reshape` is identity - `transpose(x)` -> `x` if `transpose` is identity - `elementwise(reshape(a), b)` -> `reshape(elementwise(a, reshape(b)))` conditioned on heuristic - `elementwise(transpose(a), b)` -> `transpose(elementwise(a, transpose(b)))` - `softmax(transpose(x))` -> `transpose(softmax(x))` - `softmax(reshape(x))` -> `reshape(softmax(x))` 3. Push up `reshape` and `transpose` ops as much as possible. Merging and eliminating when possible - `transpose(einsum(...))` -> `einsum(...)`. Merge transpose into einsum - `einsum(...)` -> `einsum(transpose(x), ...)`. Pull transposes out of einsum (to try to match matrix multiply pattern) - `reshape(einsum(...))` -> `einsum(reshape(transpose(x)), ...)` Push reshapes up through einsum. Adding transposes as needed - `transpose(activation(x))` -> `activation(transpose(x))` - `transpose(unary(x))` -> `unary(transpose(x))` - `transpose(identity_op(x))` -> `identity_op(transpose(x))` - `reshape(activation(x))` -> `activation(reshape(x))` - `reshape(unary(x))` -> `unary(reshape(x))` - `reshape(identity_op(x))` -> `identity_op(reshape(x))` - `reshape(reshape(x))` -> `reshape(x)` - `transpose(transpose(x))` -> `transpose(x)` - `reshape(x)` -> `x` if `reshape` is identity - `transpose(x)` -> `x` if `transpose` is identity - `transpose(reshape(x))` -> `reshape(transpose(x))` if possible put transpose before reshape - `transpose(qdq(x))` -> `qdq(transpose(x))` if the scale is 0-dim - `reshape(qdq(x))` -> `qdq(reshape(x))` if the scale is 0-dim - `reshape(elementwise(a, b))` -> `elementwise(reshape(a), reshape(b))` - `transpose(elementwise(a, b))` -> `elementwise(transpose(a), transpose(b))` - `transpose(softmax(x))` -> `softmax(transpose(x))` - `reshape(softmax(x))` -> `softmax(reshape(x))` 4. Convert back to matrix multiplication form to assist with TRT's pattern matching - `einsum(x, y)` -> `matrix_multiply(x, y)` if einsum matches a matrix multiply pattern - `matrix_multiply(transpose(x), y)` -> `matrix_multiply(x, y)` merge transpose if possible 5. Final clean ups, additional merging of transpose/reshapes into leftover einsums - `einsum(x, y)` -> `matrix_multiply(x, y)` if einsum matches a matrix multiply pattern - `matrix_multiply(transpose(x), y)` -> `matrix_multiply(x, y)` merge transpose if possible - `transpose(einsum(...))` -> `einsum(...)` - `einsum(tranpose(x), ...)` -> `einsum(...)` - `einsum(collapse_rank(x), ...)` -> `einsum(...)` - `expand_rank(einsum(...))` -> `einsum(...)` --- .../TensorRT/IR/TensorRTOps.td | 5 + .../TensorRT/Transforms/Passes.td | 143 +- .../tensorrt/lib/TensorRT/IR/EinsumHelper.cpp | 11 +- .../tensorrt/lib/TensorRT/IR/TensorRT.cpp | 7 + .../lib/TensorRT/Transforms/CMakeLists.txt | 3 +- .../lib/TensorRT/Transforms/Passes.cpp | 4 +- .../Transforms/ReshapeElimination.cpp | 237 -- .../Transforms/TransposeElimination.cpp | 469 --- .../TransposeReshapeElimination.cpp | 3093 +++++++++++++++++ .../Dialect/TensorRT/reshape-elimination.mlir | 66 +- .../TensorRT/transpose-elimination.mlir | 35 +- .../transpose-reshape-elimination.mlir | 368 ++ 12 files changed, 3611 insertions(+), 830 deletions(-) delete mode 100644 mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/ReshapeElimination.cpp delete mode 100644 mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeElimination.cpp create mode 100644 mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp create mode 100644 mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-reshape-elimination.mlir diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td index e11ef94e6..17b22c688 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td @@ -3875,6 +3875,11 @@ def TensorRT_ReshapeOp : TensorRT_Op<"reshape", let extraClassDeclaration = [{ /// Returns true if created op is valid for TensorRT major version. bool isValidForTensorRTVersion(int64_t trtMajorVersion); + + /// Get canonicalization patterns which rewrite as ReshapeOp and + /// do NOT include rewrites which do not get to a different kind of Op + /// (e.g. ExpandRankOp, CollapseRankOp). + static void getCanonicalizationPatternsSameOp(RewritePatternSet &results, MLIRContext *context); }] # baseClassDeclaration; } diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/Transforms/Passes.td b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/Transforms/Passes.td index efd53bf8a..9d1b80c7a 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/Transforms/Passes.td +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/Transforms/Passes.td @@ -176,84 +176,87 @@ def LegalizeInt8Pass : Pass<"tensorrt-legalize-int8", "func::FuncOp"> { } //===----------------------------------------------------------------------===// -// TransposeEliminationPass +// TransposeReshapeEliminationPass //===----------------------------------------------------------------------===// -def TransposeEliminationPass : Pass<"tensorrt-transpose-elimination"> { - let summary = "try to eliminate tensorrt.transpose operations"; +def TransposeReshapeEliminationPass : Pass<"tensorrt-transpose-reshape-elimination"> { + let summary = "try to eliminate tensorrt.transpose, tensorrt.reshape, and tensorrt.shuffle operations"; let description = [{ - - It is well-known that excessive number of transpose ops (either - "tensorrt.transpose" or "tensorrt.shuffle" operations with identity reshape) - can cause performance issues with TensorRT. This commonly occurs when the - input source being converted represents convolutions in "NHWC" format vs. - TensorRT's preferred "NCHW" format. In the conversion of these types of + It is well-known that excessive number of transpose or reshapes ops (either + "tensorrt.transpose", "tensorrt.reshape" or "tensorrt.shuffle") + can cause performance issues with TensorRT. For example, this commonly occurs + when the input source being converted represents convolutions in "NHWC" format + vs. TensorRT's preferred "NCHW" format. In the conversion of these types of convolutions, a number of transpose operations must be inserted. These transpose operations can prevent fusions. For example, a transpose operation between a convolution and a pointwise addition can prevent convolution-bias - fusion. - - This pass tries to eliminate transpose operations by applying the following - patterns in a greedy manner: - - 1) rotating `tensorrt.transpose` "forward" certain computational operations, - especially `tensorrt.element_wise` ops. This means that the transpose will - be applied to the result of the elementwise operation as well as the other - branch of the operation. To avoid an infinite ping-pong application of this - pattern certain heuristics are applied to determine whether or not this is - beneficial. For example: - - ``` - func.func @transpose_pushdown_switch(%arg0: tensor<2x2xf32>, %arg1: tensor<1x2xf32>) - -> tensor<2x2xf32> { - %1 = tensorrt.transpose { - permutation = affine_map<(d0, d1)->(d1, d0)> - } %arg0 : tensor<2x2xf32> to tensor<2x2xf32> - %2 = tensorrt.element_wise ( - %1, %arg1: tensor<2x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> - return %2 : tensor<2x2xf32> - } - ``` - - becomes - - ``` - func.func @transpose_pushdown_switch(%arg0: tensor<2x2xf32>, - %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> { - %0 = tensorrt.transpose {permutation = #map} - %arg1 : tensor<1x2xf32> to tensor<2x1xf32> - %1 = tensorrt.element_wise - (%arg0, %0 : tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32> - %2 = tensorrt.transpose {permutation = #map} %1 : tensor<2x2xf32> to tensor<2x2xf32> - return %2 : tensor<2x2xf32> - } - ``` - - In this case, moving the transpose to the other branch results in lower - memory cost on the inputs, but higher total memory cost (because a transpose - on the result is also added). However, we always prefer to push transpose - operations as far forward as possible in this transformation. - - 2) Const-folding transpose operations. Often, it is undesirable to let - weights be transposed at runtime. Instead, weights should be pre-transformed - to put them into a form that is suitable for TensorRT convolutions. - Therefore, we apply global transpose-const folding. This can be quite - expensive for large weights but is important to reduce runtime transpose - costs. - + fusion. Fusions can also be blocked if a reshape is placed between a + matrix multiplication and an activation. + + The pass tries to eliminate transposes and reshapes by pushing the transposes + and reshapes around to combine them into identity transposes and reshapes which + can be eliminated from the program. To accomplish this, this pass uses several + rewrite rules that push transposes and reshapes from the output of an Operation + to the operation's input, and vice-versa. The rewrites currently included in this + pass handle common cases, though currently does not handle every possible scenario---one + may wish to extend this pass in the future as needed. + + The process is as follows: + 1) Normalize transpose, reshape, shuffle and matrix multiply into a common set of ops. + - shuffle(x) -> reshape(transpose(reshape(x))) + - matrix_multiply(x, y) -> einsum("ij,jk->ik", x, y) + - expand_rank(x) -> reshape(x) + - collapse_rank(x) -> reshape(x) + 2) Push down reshape and transpose, eliminating when possible. E.g. op(transpose(x)) -> transpose(op(x)) + - einsum(transpose(x), ...) -> einsum(x, ...) + - einsum(...) -> transpose(einsum(...)) Pull transposes out of einsum (to try to match matrix multiply pattern) + - einsum(reshape(x), y, ...) -> transpose(reshape(einsum(x, reshape(transpose(y)), ...)) + - unary(transpose(x)) -> transpose(unary(x)) + - activation(transpose(x)) -> transpose(activation(x)) + - identity_op(transpose(x)) -> transpose(identity_op(x)) + - activation(reshape(x)) -> reshape(activation(x)) + - unary(reshape(x)) -> reshape(unary(x)) + - identity_op(reshape(x)) -> reshape(identity_op(x)) + - reshape(transpose(x)) -> transpose(reshape(x)) if possible put reshape before transpose + - dequantize(quantize(transpose(x))) -> transpose(dequantize(quantize((x))) if the scale is 0-dim + - dequantize(quantize(reshape(x))) -> reshape(dequantize(quantize(x))) if the scale is 0-dim + - reshape(reshape(x)) -> reshape(x) + - transpose(transpose(x)) -> transpose(x) + - reshape(x) -> x if reshape is identity + - transpose(x) -> x if transpose is identity + - elementwise(reshape(a), b) -> reshape(elementwise(a, reshape(b))) conditioned on heuristic + - elementwise(transpose(a), b) -> transpose(elementwise(a, transpose(b))) + 3) Push up reshape and transpose, eliminating when possible. E.g. transpose(op(x)) -> op(transpose(x)) + - transpose(einsum(...)) -> einsum(...) + - einsum(...) -> einsum(transpose(x), ...) Pull transposes out of einsum (to try to match matrix multiply pattern) + - reshape(einsum(...)) -> einsum(reshape(transpose(x)), ...) + - transpose(activation(x)) -> activation(transpose(x)) + - transpose(unary(x)) -> unary(transpose(x)) + - transpose(identity_op(x)) -> identity_op(transpose(x)) + - reshape(activation(x)) -> activation(reshape(x)) + - reshape(unary(x)) -> unary(reshape(x)) + - reshape(identity_op(x)) -> identity_op(reshape(x)) + - reshape(reshape(x)) -> reshape(x) + - transpose(transpose(x)) -> transpose(x) + - reshape(x) -> x if reshape is identity + - transpose(x) -> x if transpose is identity + - transpose(reshape(x)) -> reshape(transpose(x)) if possible put transpose before reshape + - transpose(dequantize(quantize(x))) -> dequantize(quantize(transpose(x))) if the scale is 0-dim + - reshape(dequantize(quantize(x))) -> dequantize(quantize(reshape(x))) if the scale is 0-dim + - reshape(elementwise(a, b)) -> elementwise(reshape(a), reshape(b)) + - transpose(elementwise(a, b)) -> elementwise(transpose(a), transpose(b)) + 4) Final clean up. Fuse leftover transposes and reshapes with other ops. + - einsum("ij,jk->ik", x, y) -> matrix_multiply(x, y) if einsum matches a matrix multiply pattern + - matrix_multiply(transpose(x), y) -> matrix_multiply(x, y) merge transpose if possible + - transpose(einsum(...)) -> einsum(...) + - einsum(transpose(x), ...) -> einsum(...) + - einsum(collapse_rank(x), ...) -> einsum(...) + - expand_rank(einsum(...)) -> einsum(...) + + To avoid an infinite ping-pong application of these patterns, heuristics are + applied to determine when a pattern is beneficial. }]; } -//===----------------------------------------------------------------------===// -// ReshapeEliminationPass -//===----------------------------------------------------------------------===// -def ReshapeEliminationPass : Pass<"tensorrt-reshape-elimination"> { - let summary = "try to eliminate tensorrt.reshape operations"; - - let description = [{ - Reshape elimination pass captures pattern with un-necessary reshape and - simplifies it by eliminating reshape operations. - }]; -} #endif // MLIR_TENSORRT_DIALECT_TENSORRT_TRANSFORMS_PASSES diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/EinsumHelper.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/EinsumHelper.cpp index ac1b89c0a..33da7a599 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/EinsumHelper.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/EinsumHelper.cpp @@ -143,10 +143,11 @@ static LogicalResult validateInputsSubscript(const IOSubscripts &subscripts, // match. for example, ('ij,jk->ik', a, b) is valid for a = // tensor<4x5xf32>, b = tensor<5x6xf32> but invalid for a = // tensor<4x6xf32>, b = tensor<5x6xf32> - if (allLabelDims.count(label) == 0) { - allLabelDims.insert(std::pair(label, dimension)); + // Einsum also supports broadcasting + if (allLabelDims.count(label) == 0 || allLabelDims[label] == 1) { + allLabelDims[label] = dimension; } else { - if (allLabelDims[label] != dimension) + if (allLabelDims[label] != dimension && dimension != 1) return emitErrorFn(loc, Twine("label `") + Twine(label) + Twine("` is repeated between inputs but " "dimensions are not same")); @@ -203,8 +204,8 @@ static LogicalResult inferOutputShapeImpl(const IOSubscripts &ioSubscripts, llvm::zip((ioSubscripts).inputs, inputOperands)) { for (const auto &[label, dims] : llvm::zip(subscript, cast(operand).getShape())) - if (inputLabelsDims.count(label) == 0) - inputLabelsDims.insert(std::pair(label, dims)); + if (inputLabelsDims.count(label) == 0 || inputLabelsDims[label] == 1) + inputLabelsDims[label] = dims; } for (const auto &label : (ioSubscripts).outputs) { diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp index ee9c3de0b..c5c0782d8 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp @@ -1850,6 +1850,13 @@ void tensorrt::ReshapeOp::getCanonicalizationPatterns( SimplifyReshapeToRankExpandCollapse>(context); } +void tensorrt::ReshapeOp::getCanonicalizationPatternsSameOp( + RewritePatternSet &results, MLIRContext *context) { + results.insert, SimplifyReshapeReshape + // NOT INCLUDED: SimplifyReshapeToRankExpandCollapse + >(context); +} + void tensorrt::ReshapeOp::build(OpBuilder &builder, OperationState &state, Type result, Value input) { ReshapeOp::build(builder, state, result, input, Value()); diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/CMakeLists.txt b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/CMakeLists.txt index e77665668..5f4734708 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/CMakeLists.txt +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/CMakeLists.txt @@ -19,8 +19,7 @@ add_mtrtd_library(MLIRTensorRTTransforms Passes.cpp RaiseActivations.cpp RaiseNormalizations.cpp - ReshapeElimination.cpp - TransposeElimination.cpp + TransposeReshapeElimination.cpp DEPENDS MLIRTensorRTTransformsActivationsPdllGen diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/Passes.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/Passes.cpp index f54401b85..22e319431 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/Passes.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/Passes.cpp @@ -86,9 +86,7 @@ void tensorrt::buildTensorRTModuleSimplificationPipeline(OpPassManager &pm) { // Try to eliminate as many `tensorrt.broadcast` ops as possible. pm.addPass(tensorrt::createBroadcastEliminationPass()); addCleanupPasses(pm); - pm.addPass(tensorrt::createTransposeEliminationPass()); - addCleanupPasses(pm); - pm.addPass(tensorrt::createReshapeEliminationPass()); + pm.addPass(tensorrt::createTransposeReshapeEliminationPass()); addCleanupPasses(pm); pm.addPass(tensorrt::createRaiseNormalizationsPass()); addCleanupPasses(pm); diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/ReshapeElimination.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/ReshapeElimination.cpp deleted file mode 100644 index 56189d50e..000000000 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/ReshapeElimination.cpp +++ /dev/null @@ -1,237 +0,0 @@ -//===- ReshapeElimination.cpp ---------------------------------------------===// -// -// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. -// All rights reserved. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// -/// -/// Definition of reshape elimination pass. -/// -//===----------------------------------------------------------------------===// -#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h" -#include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/STLExtras.h" - -namespace mlir { -namespace tensorrt { -#define GEN_PASS_DEF_RESHAPEELIMINATIONPASS -#include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h.inc" -} // namespace tensorrt -} // namespace mlir - -#define DEBUG_TYPE "tensorrt-reshape-elimination" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") - -using namespace mlir; -using namespace mlir::tensorrt; - -// Checks if reduction of an input tensor by a reshape `op` is valid, -// considering output of the reshape `op` is fed to a matmul as the LHS input. -// Reduction is valid if K is not reduced and M is reduced. LHS input to a -// matmul has shape [batch ...]xMxK. -static bool isLhsReductionValid(ReshapeOp op) { - RankedTensorType originalType = op.getInput().getType(); - RankedTensorType collapsedType = op.getResult().getType(); - return (originalType.getShape()[originalType.getRank() - 1] == - collapsedType.getShape()[collapsedType.getRank() - 1]) && - (originalType.getShape()[originalType.getRank() - 2] != - collapsedType.getShape()[collapsedType.getRank() - 2]); -} - -// Checks if reduction of an input tensor by a reshape `op` is valid, -// considering output of the reshape `op` is fed to a matmul as the RHS input. -// Reduction is valid if K is not reduced. RHS input to a matmul has shape -// [batch ...]xKxN. -static bool isRhsReductionValid(ReshapeOp op) { - RankedTensorType originalType = op.getInput().getType(); - RankedTensorType collapsedType = op.getResult().getType(); - return originalType.getShape()[originalType.getRank() - 2] == - collapsedType.getShape()[collapsedType.getRank() - 2]; -} - -namespace { -/// Tries to eliminate reshape ops before and after matmul in the following -/// case. -/// %1 = reshape(%0) // collapse batch dims and M -/// %2 = matmul(%1, %k) -/// %3 = reshape(%2) // expand to previously collapsed batch dims and M -/// to -/// %1 = expand(%k) // expand by adding 1's -/// %2 = matmul(%0, %1) -/// -/// In short, we try to keep original shape of %1 (i.e. %0) by expanding rank of -/// %k. LHS operand to the matmul has shape [batch ...]xMxK. Matmul reduction -/// dimension i.e. K should not be collapsed. Matrix operation for both operands -/// is set to `MatrixOperation::kNONE`. This pattern does not apply when input -/// to the parent reshape op of matmul is dynamic. -struct EliminateReshapeBeforeAndAfterMatmulLHS : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ReshapeOp op, - PatternRewriter &rewriter) const override { - // Check if expansion happens with this reshape - if (op.getInput().getType().getRank() >= op.getResult().getType().getRank()) - return failure(); - - auto matmul = op.getInput().getDefiningOp(); - if (!matmul || matmul.getOp0() != MatrixOperation::kNONE || - matmul.getOp1() != MatrixOperation::kNONE) - return failure(); - - // Get expanded shape (by inserting unit dimensions in batch dimensions) of - // matmul input lhs/rhs tensor to match the rank of the other side. - auto getExpandedShapeOfRhs = [](RankedTensorType tensorToExpandType, - int64_t otherSideRank) { - assert(otherSideRank > tensorToExpandType.getRank() && - "operand to expand should have a smaller rank"); - // -2 for contraction dims - int64_t numOnesToAdd = otherSideRank - 2; - SmallVector expandedShape; - if (tensorToExpandType.getRank() > 2) { - ArrayRef tensorToExpandBatchDims = - tensorToExpandType.getShape().drop_back(2); - llvm::append_range(expandedShape, tensorToExpandBatchDims); - numOnesToAdd -= tensorToExpandBatchDims.size(); - } - expandedShape.insert(expandedShape.end(), numOnesToAdd, 1); - // Finally add contraction dims - llvm::append_range(expandedShape, - tensorToExpandType.getShape().take_back(2)); - return RankedTensorType::get(expandedShape, - tensorToExpandType.getElementType()); - }; - - // If lhsParentReshape exists, we need to make sure same dimensions as that - // of collapsed by `lhsParentReshape` are expanded by `op` and input to the - // `lhsParentReshape` is not dynamic. - auto lhsParentReshape = matmul.getInput0().getDefiningOp(); - if (!lhsParentReshape || - !lhsParentReshape.getInput().getType().getShape().drop_back(1).equals( - op.getResult().getType().getShape().drop_back(1)) || - !isLhsReductionValid(lhsParentReshape) || - !lhsParentReshape.getInput().getType().hasStaticShape()) - return failure(); - - auto rhsExpandedType = - getExpandedShapeOfRhs(matmul.getInput1().getType(), - lhsParentReshape.getInput().getType().getRank()); - auto rhsExpanded = rewriter - .create(op->getLoc(), - /*result=*/rhsExpandedType, - /*input=*/matmul.getInput1()) - .getResult(); - - rewriter.replaceOpWithNewOp( - op, - /*input0=*/lhsParentReshape.getInput(), - /*input1=*/rhsExpanded, - /*op0=*/matmul.getOp0(), - /*op1=*/matmul.getOp1()); - return success(); - } -}; - -/// Simplify reshape ops by eliminating one reshape in the following case. -/// %1 = reshape(%0) // collapse batch dims but not K -/// %2 = matmul(%k, %1) -/// %3 = reshape(%2) // expand to previously collapsed batch dims -/// to -/// %1 = reshape(%k) -/// %2 = matmul(%1, %0) -/// -/// LHS operand to the matmul has shape [batch ...]xMxK and RHS operand has -/// shape [batch ...]xKxN. Unlike `EliminateReshapeBeforeAndAfterMatmulLHS` -/// case, we can't expand `k` and keep original shape of %1 input (i.e. %0). -/// This is because for RHS input, reshape can only strictly change batch -/// dimension and not any of contracting dims (e.g. for LHS, M could be -/// collapsed). It means few or all only batch dims are collapsed for RHS. LHS -/// %k needs to be reshaped to have same batch dims as %0 (collapsed batch dims -/// of %0 are same as %k thus reshape works). This rewriter works only when -/// matrix operation for both operands is set to `MatrixOperation::kNONE`. This -/// pattern does not apply when input to the parent reshape op of matmul is -/// dynamic. -struct SimplifyReshapeBeforeAndAfterMatmulRHS : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ReshapeOp op, - PatternRewriter &rewriter) const override { - // Check if expansion happens with this reshape - if (op.getInput().getType().getRank() >= op.getResult().getType().getRank()) - return failure(); - - auto matmul = op.getInput().getDefiningOp(); - if (!matmul || matmul.getOp0() != MatrixOperation::kNONE || - matmul.getOp1() != MatrixOperation::kNONE) - return failure(); - - auto rhsParentReshape = matmul.getInput1().getDefiningOp(); - // If rhsParentReshape exists, we need to make sure same dimensions as that - // of collapsed by `rhsParentReshape` are expanded by `op` and input to the - // `rhsParentShape` is not dynamic. - if (!rhsParentReshape || - !rhsParentReshape.getInput().getType().getShape().drop_back(2).equals( - op.getResult().getType().getShape().drop_back(2)) || - !isRhsReductionValid(rhsParentReshape) || - !rhsParentReshape.getInput().getType().hasStaticShape()) - return failure(); - - SmallVector lhsReshapedShape( - rhsParentReshape.getInput().getType().getShape().drop_back(2)); - SmallVector lhsReductionDims( - matmul.getInput0().getType().getShape().drop_front( - matmul.getInput0().getType().getRank() - 2)); - llvm::append_range(lhsReshapedShape, lhsReductionDims); - auto lhsReshapedType = RankedTensorType::get( - lhsReshapedShape, matmul.getInput0().getType().getElementType()); - auto lhsReshaped = rewriter - .create(op->getLoc(), - /*result=*/lhsReshapedType, - /*input=*/matmul.getInput0()) - .getResult(); - rewriter.replaceOpWithNewOp( - op, - /*input0=*/lhsReshaped, - /*input1=*/rhsParentReshape.getInput(), - /*op0=*/matmul.getOp0(), - /*op1=*/matmul.getOp1()); - return success(); - } -}; -} // namespace - -namespace { -class ReshapeEliminationPass - : public tensorrt::impl::ReshapeEliminationPassBase< - ReshapeEliminationPass> { -public: - using Base::Base; - void runOnOperation() override { - MLIRContext *ctx = &getContext(); - Operation *op = getOperation(); - - RewritePatternSet patterns(ctx); - patterns.insert(ctx); - ReshapeOp::getCanonicalizationPatterns(patterns, ctx); - if (failed(applyPatternsGreedily(op, std::move(patterns)))) { - emitError(op->getLoc()) - << "failed to apply patterns in " << getArgument(); - return signalPassFailure(); - } - } -}; -} // namespace diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeElimination.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeElimination.cpp deleted file mode 100644 index 684e587eb..000000000 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeElimination.cpp +++ /dev/null @@ -1,469 +0,0 @@ -//===- TransposeElimination.cpp -------------------------------------------===// -// -// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. -// All rights reserved. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// -/// -/// Definition of transpose elimination pass. -/// -//===----------------------------------------------------------------------===// -#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h" -#include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h" -#include "mlir-tensorrt-dialect/Utils/ConstantFoldUtils.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/IR/Matchers.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/Support/Debug.h" -#include - -namespace mlir { -namespace tensorrt { -#define GEN_PASS_DEF_TRANSPOSEELIMINATIONPASS -#include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h.inc" -} // namespace tensorrt -} // namespace mlir - -#define DEBUG_TYPE "tensorrt-transpose-elimination" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") - -using namespace mlir; -using namespace mlir::tensorrt; - -// Set the max size of tensors which can be constant-folded to 131072 (0.5 MB -// for f32 constants). -constexpr int64_t kFoldOpEltLimit = 1 << 17; - -static int64_t memoryCost(RankedTensorType type) { - // If the type is dynamic, then return max. - if (!type.hasStaticShape()) - return std::numeric_limits::max(); - ArrayRef shape = type.getShape(); - return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); -} - -static TransposeOp getLowestTransposeCost(ElementWiseOp consumer, - TransposeOp op1, TransposeOp op2) { - // If there's only one transpose, then return it. - if (!op1 || !op2) - return op1 ? op1 : op2; - // Otherwise, compute the lowest cost. - int64_t cost1 = memoryCost(consumer.getType()) + memoryCost(op2.getType()); - int64_t cost2 = memoryCost(consumer.getType()) + memoryCost(op1.getType()); - LLVM_DEBUG(DBGS() << "cost1=" << cost1 << ", cost2=" << cost2 << "\n"); - return cost1 <= cost2 ? op1 : op2; -} - -static std::pair -getTransposeProducers(ElementWiseOp op) { - auto producer1 = op.getInput1().getDefiningOp(); - auto producer2 = op.getInput2().getDefiningOp(); - return std::make_pair(producer1, producer2); -} - -static TensorValue getOtherEwiseInput(ElementWiseOp op, Operation *producer) { - assert(producer != nullptr && "expected valid producer"); - Operation *producer1 = op.getInput1().getDefiningOp(); - return producer1 == producer ? op.getInput2() : op.getInput1(); -} - -// If there is only one ewise branch with a transpose, the below pushdown -// pattern may not terminate by repeatedly ping-ponging. We avoid this by having -// a set of conditions on which to move transpose to the other branch. We do -// this if we know doing so will result in additional elimination patterns or a -// smaller transpose cost. -bool pushDownTransposePrecondition(ElementWiseOp op, - TransposeOp transposeToPushdown) { - TensorValue otherInput = getOtherEwiseInput(op, transposeToPushdown); - Operation *otherProducer = otherInput.getDefiningOp(); - bool otherBranchHasSmallerCost = - memoryCost(otherInput.getType()) < - memoryCost(transposeToPushdown.getResult().getType()); - if (otherBranchHasSmallerCost) - return true; - // Even if the other branch has higher memory cost: - // 1. If its a constant, then we can fold. - // 2. If its a transpose, then we can combine transpose ops. - return isa_and_nonnull(otherProducer); -} - -namespace { -/// This rewrite tries to eliminate transpose operations by rotating them -/// "forward" (e.g. pushing them past their user(s)) when they have a single -/// user and when that user is some sort of computation operation (e.g. -/// elementwise, unary, or convolution). -struct PushdownTransposeEwise : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ElementWiseOp op, - PatternRewriter &rewriter) const override { - auto [transpose1, transpose2] = getTransposeProducers(op); - // Exit if there are not transpose producers. - if (!transpose1 && !transpose2) - return failure(); - // Choose which transpose to push down. The one to push down should be the - // one that results in the lowest memory cost. In all cases there will be - // two transpose produced. Choose the one with the lowest cost. If they are - // equal, choose either one. - TransposeOp transposeToPushdown = - getLowestTransposeCost(op, transpose1, transpose2); - - // If the other branch has smaller cost, we always move the transpose. - if (!pushDownTransposePrecondition(op, transposeToPushdown)) - return rewriter.notifyMatchFailure( - op, "does not meet transpose pushdown preconditions"); - - LLVM_DEBUG(DBGS() << "pushing down transpose " << transposeToPushdown - << "\n"); - - // Execute the transformation. - AffineMap permutation = transposeToPushdown.getPermutation(); - AffineMap inversePerm = inversePermutation(permutation); - Value otherInput = getOtherEwiseInput(op, transposeToPushdown); - Value transposedOther = - rewriter.create(op.getLoc(), otherInput, inversePerm); - bool pushdownIsInput1 = op.getInput1() == transposeToPushdown.getResult(); - Value ewiseOutput = rewriter.create( - op.getLoc(), - pushdownIsInput1 ? transposeToPushdown.getInput() : transposedOther, - !pushdownIsInput1 ? transposeToPushdown.getInput() : transposedOther, - op.getElementwiseOperation()); - rewriter.replaceOpWithNewOp(op, ewiseOutput, permutation); - return success(); - } -}; - -/// Rewrites `act(transpose(x))` to `transpose(act(x))`. -struct PushDownTransposeActivationRewriter - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ActivationOp op, - PatternRewriter &rewriter) const override { - auto producer = op.getInput().getDefiningOp(); - if (!producer) - return failure(); - AffineMap permutation = producer.getPermutation(); - auto activationOp = rewriter.create( - op.getLoc(), producer.getInput(), op.getActivationType(), - op.getAlphaAttr(), op.getBetaAttr()); - rewriter.replaceOpWithNewOp(op, activationOp.getResult(), - permutation); - return success(); - } -}; - -/// Push transpose below `tensorrt.identity` if the identity consumer is an -/// elementwise op. After the "pushdown" phase, the "push up" phase will restore -/// the transpose above the identity if it could not be eliminated and the -/// source has a smaller memory cost. -struct PushdownTransposeIdentity : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(IdentityOp op, - PatternRewriter &rewriter) const override { - TransposeOp transposeProducer = op.getInput().getDefiningOp(); - if (!transposeProducer) - return failure(); - Type newIdentityType = - RankedTensorType::Builder( - cast(transposeProducer.getInput().getType())) - .setElementType(op.getType().getElementType()); - Value newIdentityResult = rewriter.create( - op.getLoc(), newIdentityType, transposeProducer.getInput()); - rewriter.replaceOpWithNewOp( - op, newIdentityResult, transposeProducer.getPermutation()); - return success(); - } -}; - -/// Rewrite transpose(unary(x)) to unary(transpose(x)). -template -struct PushUpTransposeUnary : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TransposeOp op, - PatternRewriter &rewriter) const override { - auto unaryOp = op.getInput().getDefiningOp(); - if (!unaryOp) - return failure(); - Value newUnaryResult = rewriter.create( - op.getLoc(), unaryOp.getInput(), op.getPermutation()); - rewriter.replaceOpWithNewOp(op, op.getType(), newUnaryResult, - unaryOp->getAttrs()); - return success(); - } -}; - -/// The cost of a `reshape` is always zero, because the implementation of a -/// reshape is always just rearrangement of metadata and it does not involve -/// launching a CUDA kernel. The cost of a `transpose` will be zero if its input -/// is a `ConstantOp`. However, such case doesn't occur since constant folding -/// eliminates transpose op even before this pass runs. -static int64_t -reshapeOrTransposeOrConstantCost(Operation *reshapeOrTransposeOrConstant) { - if (isa(reshapeOrTransposeOrConstant)) - return 0; - if (isa(reshapeOrTransposeOrConstant)) - return 0; - TransposeOp op = dyn_cast(reshapeOrTransposeOrConstant); - assert(op); - return memoryCost(op.getResult().getType()); -} - -/// Computes benefit of pushing transpose above elementwise and after -/// reshape/transpose producer. -/// Benefit = Original cast - New cost -/// Original cost = memoryCost(original reshape/transpose/constant) + -/// memoryCost(transpose after elementwise) -/// New cost = memoryCost(original reshape/transpose/constant) + -/// memoryCost(pushed up transpose) + memoryCost(newly added transpose on the -/// other side) -static int64_t pushUpBenefit(ElementWiseOp elementwise, - Operation *reshapeOrTransposeOrConstantProducer) { - int64_t originalTransposeReshapeCost = - reshapeOrTransposeOrConstantCost(reshapeOrTransposeOrConstantProducer); - int64_t originalCost = originalTransposeReshapeCost + - memoryCost(elementwise.getResult().getType()); - // The cost of the pushed up transpose will be 0 in the following cases, - // a. The parent of the elementwise input is a `reshape` op with a - // `ConstantOp` input. - // b. The parent of the elementwise input is a `ConstantOp`. - // Such case won't happens for `transpose` op since its folded away even - // before this pass. - int64_t pushedUpTransposeCost = - (matchPattern(reshapeOrTransposeOrConstantProducer, - m_Op(m_Op())) || - isa(reshapeOrTransposeOrConstantProducer)) - ? 0 - : memoryCost(elementwise.getResult().getType()); - // If other side input to elementwise is coming from `ConstantOp`, cost of the - // newly added transpose will be zero. - TensorValue otherSideValue = - getOtherEwiseInput(elementwise, reshapeOrTransposeOrConstantProducer); - int64_t otherSideNewlyAddedTransposeCost = - otherSideValue.getDefiningOp() - ? 0 - : memoryCost(otherSideValue.getType()); - - int64_t newCost = originalTransposeReshapeCost + pushedUpTransposeCost + - otherSideNewlyAddedTransposeCost; - return originalCost - newCost; -} - -/// Returns true if pushing transpose op above given `elementwise` op is -/// beneficial. -static bool shouldPushUpTransposeElementwise( - ElementWiseOp elementwise, - Operation *reshapeOrTransposeOrConstantProducer) { - return (pushUpBenefit(elementwise, reshapeOrTransposeOrConstantProducer) > 0); -} - -/// Push transpose above elementwise if both of the following conditions hold -/// true, -/// 1. Input to the transpose is the output of an elementwise op. -/// 2. One or both of the inputs to the elementwise op is an output of transpose -/// or reshape op. -/// Idea is, if transpose is pushed up in this case, transpose/transpose or -/// reshape/transpose pair at the input of elementwise op will result in a -/// single shuffle op (both reshape and transpose are canonicalized to the -/// shuffle op, later in the pipeline). -/// -/// For example, subgraph in the form -/// -/// %2 = [reshape, transpose](%0) -/// %3 = [reshape, transpose](%1) -/// %4 = elementwise(%2, %3) -/// %5 = transpose(%4) -/// -/// where [reshape, transpose](%k) means if (%k can be passed without any -/// transformation as well) transformation is applied to %k, it will be either -/// reshape or transpose. However, if both %0 and %1 does not have any -/// transformation, this pattern doesn't apply. Both or any one of %0 and %1 can -/// be an output of constant op. -/// -/// is rewritten to -/// -/// %2 = [reshape, transpose](%0) -/// %4 = transpose1(%2) -/// %3 = [reshape, transpose](%1) -/// %5 = transpose2(%3) -/// %6 = elementwise(%4, %5) -/// -/// In short, the transpose (called `pushed-up` transpose, hereafter) after -/// elementwise is pushed above elementwise, before lhs or rhs input, if input's -/// parent is reshape/transpose. A `new` transpose is added on the other side of -/// `pushed-up` transpose. Benefit of applying this pattern is computed based on -/// `memoryCost` difference in the original and modified subgraph. - -struct PushUpTransposeElementwise : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TransposeOp op, - PatternRewriter &rewriter) const override { - auto elementwiseOp = op.getInput().getDefiningOp(); - if (!elementwiseOp) - return failure(); - - Location loc = op->getLoc(); - Operation *lhsParent = elementwiseOp.getInput1().getDefiningOp(); - Operation *rhsParent = elementwiseOp.getInput2().getDefiningOp(); - bool isLhsParentReshapeOrTransposeOrConstant = - lhsParent && isa(lhsParent); - bool isRhsParentReshapeOrTransposeOrConstant = - rhsParent && isa(rhsParent); - if (!isLhsParentReshapeOrTransposeOrConstant && - !isRhsParentReshapeOrTransposeOrConstant) - return failure(); - - auto addTransposeOp = [&](Value input, AffineMap perm) { - return rewriter - .create(loc, - /*input=*/input, - /*permutation=*/perm) - .getResult(); - }; - - auto rewriteElementwiseOp = [&](Value lhs, Value rhs) { - return rewriter - .create(loc, - /*input1=*/lhs, - /*input2=*/rhs, - /*op=*/elementwiseOp.getElementwiseOperation()) - .getResult(); - }; - - auto pushUpTransposeOnLhs = [&]() { - // Add a transpose before RHS - auto rhsTranspose = - addTransposeOp(elementwiseOp.getInput2(), op.getPermutation()); - // Push up transpose from after elementwise to before elementwise on LHS - auto pushedUpTransposeLhs = - addTransposeOp(elementwiseOp.getInput1(), op.getPermutation()); - // Update elementwise op - auto updatedElementwiseOp = - rewriteElementwiseOp(pushedUpTransposeLhs, rhsTranspose); - // Replace uses of old transpose - rewriter.replaceAllUsesWith(op.getResult(), updatedElementwiseOp); - }; - - auto pushUpTransposeOnRhs = [&]() { - // Add a transpose before LHS - auto lhsTranspose = - addTransposeOp(elementwiseOp.getInput1(), op.getPermutation()); - // Push up transpose from after elementwise to before elementwise on RHS - auto pushedUpTransposeRhs = - addTransposeOp(elementwiseOp.getInput2(), op.getPermutation()); - // Update elementwise op - auto updatedElementwiseOp = - rewriteElementwiseOp(lhsTranspose, pushedUpTransposeRhs); - // Replace uses of old transpose - rewriter.replaceAllUsesWith(op.getResult(), updatedElementwiseOp); - }; - - if (isLhsParentReshapeOrTransposeOrConstant && - shouldPushUpTransposeElementwise(elementwiseOp, lhsParent)) { - pushUpTransposeOnLhs(); - return success(); - } - - if (isRhsParentReshapeOrTransposeOrConstant && - shouldPushUpTransposeElementwise(elementwiseOp, rhsParent)) { - pushUpTransposeOnRhs(); - return success(); - } - return failure(); - } -}; -} // namespace - -namespace { -/// Constant fold transpose -struct TransposeConstantFold : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TransposeOp op, - PatternRewriter &rewriter) const override { - TensorType inputType = op.getInput().getType(); - - // Fold the input to a constant if possible, otherwise return. - ElementsAttr inputConst; - if (!matchPattern(op.getInput(), m_Constant(&inputConst))) - return failure(); - assert(inputType.hasStaticShape() && "constants should have static shape"); - - // Don't fold transpose if input has > 1 user and input is non-splat - // constant. - if (!inputConst.isSplat() && - (!op.getInput().hasOneUse() || - inputConst.getNumElements() > kFoldOpEltLimit)) - return failure(); - - ElementsAttr result = - constantFoldTranspose(inputConst, op.getPermutation()); - if (!result) - return failure(); - rewriter.replaceOpWithNewOp(op, result); - return success(); - } -}; -} // namespace - -namespace { -class TransposeEliminationPass - : public tensorrt::impl::TransposeEliminationPassBase< - TransposeEliminationPass> { -public: - using Base::Base; - void runOnOperation() override { - MLIRContext *ctx = &getContext(); - Operation *op = getOperation(); - - // First, we try to eliminate transpose operations by "pushing down" the - // transpose operations. This involves performing rewrites of the form - // "op(transpose(y))->transpose(op(y))". Often, this will eliminate most - // transpose operations in CNN networks produced by frameworks that use NHWC - // conventions (e.g. Tensorflow and often JAX/Flax models). - { - RewritePatternSet patterns(ctx); - patterns.insert(ctx); - TransposeOp::getCanonicalizationPatterns(patterns, ctx); - ExpandRankOp::getCanonicalizationPatterns(patterns, ctx); - if (failed(applyPatternsGreedily(op, std::move(patterns)))) { - emitError(op->getLoc()) - << "failed to apply pushdown patterns in " << getArgument(); - return signalPassFailure(); - } - } - - // Second, we try to eliminate transpose operations by "pushing up" (commute - // in the reverse direction). This can possible eliminate additional - // transpose ops. - { - RewritePatternSet patterns(ctx); - patterns.insert, - PushUpTransposeUnary, - PushUpTransposeUnary, - PushUpTransposeElementwise>(ctx); - TransposeOp::getCanonicalizationPatterns(patterns, ctx); - ExpandRankOp::getCanonicalizationPatterns(patterns, ctx); - if (failed(applyPatternsGreedily(op, std::move(patterns)))) { - emitError(op->getLoc()) - << "failed to apply pushup patterns in " << getArgument(); - return signalPassFailure(); - } - } - } -}; -} // namespace diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp new file mode 100644 index 000000000..459abb204 --- /dev/null +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp @@ -0,0 +1,3093 @@ +//===- TransposeElimination.cpp -------------------------------------------===// +// +// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// Definition of transpose elimination pass. +/// +//===----------------------------------------------------------------------===// +#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h" +#include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h" +#include "mlir-tensorrt-dialect/Utils/ConstantFoldUtils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" +#include +#include + +namespace mlir { +namespace tensorrt { +#define GEN_PASS_DEF_TRANSPOSERESHAPEELIMINATIONPASS +#include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h.inc" +} // namespace tensorrt +} // namespace mlir + +#define DEBUG_TYPE "tensorrt-transpose-reshape-elimination" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +using namespace mlir; +using namespace mlir::tensorrt; + +// Set the max size of tensors which can be constant-folded to 131072 (0.5 MB +// for f32 constants). +constexpr int64_t kFoldOpEltLimit = 1 << 17; + +static int64_t memoryCost(RankedTensorType type) { + // If the type is dynamic, then return max. + if (!type.hasStaticShape()) + return std::numeric_limits::max(); + return type.getNumElements(); +} + +static TransposeOp getLowestTransposeCost(ElementWiseOp consumer, + TransposeOp op1, TransposeOp op2) { + // If there's only one transpose, then return it. + if (!op1 || !op2) + return op1 ? op1 : op2; + // Otherwise, compute the lowest cost. + int64_t cost1 = memoryCost(consumer.getType()) + memoryCost(op2.getType()); + int64_t cost2 = memoryCost(consumer.getType()) + memoryCost(op1.getType()); + LLVM_DEBUG(DBGS() << "cost1=" << cost1 << ", cost2=" << cost2 << "\n"); + if (cost1 == 0 && cost2 == 0) + return {}; + return cost1 <= cost2 ? op1 : op2; +} + +static std::pair +getTransposeProducers(ElementWiseOp op) { + auto producer1 = op.getInput1().getDefiningOp(); + auto producer2 = op.getInput2().getDefiningOp(); + if (producer1 && producer1.getInput().getDefiningOp()) + producer1 = {}; + if (producer2 && producer2.getInput().getDefiningOp()) + producer2 = {}; + return std::make_pair(producer1, producer2); +} + +static TensorValue getOtherEwiseInput(ElementWiseOp op, Operation *producer) { + assert(producer != nullptr && "expected valid producer"); + Operation *producer1 = op.getInput1().getDefiningOp(); + return producer1 == producer ? op.getInput2() : op.getInput1(); +} + +// If there is only one ewise branch with a transpose, the below pushdown +// pattern may not terminate by repeatedly ping-ponging. We avoid this by having +// a set of conditions on which to move transpose to the other branch. We do +// this if we know doing so will result in additional elimination patterns or a +// smaller transpose cost. +static bool pushDownTransposePrecondition(ElementWiseOp op, + TransposeOp transposeToPushdown) { + TensorValue otherInput = getOtherEwiseInput(op, transposeToPushdown); + Operation *otherProducer = otherInput.getDefiningOp(); + bool otherBranchHasSmallerCost = + memoryCost(otherInput.getType()) < + memoryCost(transposeToPushdown.getResult().getType()); + if (otherBranchHasSmallerCost) + return true; + // Even if the other branch has higher memory cost: + // 1. If its a constant, then we can fold. + // 2. If its a transpose, then we can combine transpose ops. + return isa_and_nonnull(otherProducer); +} + +namespace { +/// This rewrite tries to eliminate transpose operations by rotating them +/// "forward" (e.g. pushing them past their user(s)) when they have a single +/// user and when that user is some sort of computation operation (e.g. +/// elementwise, unary, or convolution). +struct PushdownTransposeEwise : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ElementWiseOp op, + PatternRewriter &rewriter) const override { + auto [transpose1, transpose2] = getTransposeProducers(op); + // Exit if there are not transpose producers. + if (!transpose1 && !transpose2) + return failure(); + // Choose which transpose to push down. The one to push down should be the + // one that results in the lowest memory cost. In all cases there will be + // two transpose produced. Choose the one with the lowest cost. If they are + // equal, choose either one. + TransposeOp transposeToPushdown = + getLowestTransposeCost(op, transpose1, transpose2); + + // If the other branch has smaller cost, we always move the transpose. + if (!pushDownTransposePrecondition(op, transposeToPushdown)) + return rewriter.notifyMatchFailure( + op, "does not meet transpose pushdown preconditions"); + + LLVM_DEBUG(DBGS() << "pushing down transpose " << transposeToPushdown + << "\n"); + + // Execute the transformation. + AffineMap permutation = transposeToPushdown.getPermutation(); + AffineMap inversePerm = inversePermutation(permutation); + Value otherInput = getOtherEwiseInput(op, transposeToPushdown); + Value transposedOther = + rewriter.create(op.getLoc(), otherInput, inversePerm); + bool pushdownIsInput1 = op.getInput1() == transposeToPushdown.getResult(); + Value ewiseOutput = rewriter.create( + op.getLoc(), + pushdownIsInput1 ? transposeToPushdown.getInput() : transposedOther, + !pushdownIsInput1 ? transposeToPushdown.getInput() : transposedOther, + op.getElementwiseOperation()); + rewriter.replaceOpWithNewOp(op, ewiseOutput, permutation); + return success(); + } +}; + +/// Rewrites `act(transpose(x))` to `transpose(act(x))`. +struct PushDownTransposeActivationRewriter + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ActivationOp op, + PatternRewriter &rewriter) const override { + auto producer = op.getInput().getDefiningOp(); + if (!producer) + return failure(); + AffineMap permutation = producer.getPermutation(); + auto activationOp = rewriter.create( + op.getLoc(), producer.getInput(), op.getActivationType(), + op.getAlphaAttr(), op.getBetaAttr()); + auto newTranspose = rewriter.create( + producer.getLoc(), activationOp.getResult(), permutation); + rewriter.replaceOp(op, newTranspose.getResult()); + return success(); + } +}; + +// Rewrites unary(transpose(x)) to transpose(unary(x)) +struct PushDownTransposeUnary : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(UnaryOp op, + PatternRewriter &rewriter) const override { + auto producer = op.getInput().getDefiningOp(); + if (!producer) + return failure(); + AffineMap permutation = producer.getPermutation(); + auto unary = rewriter.create(op.getLoc(), producer.getInput(), + op.getUnaryOperationAttr()); + auto newTranspose = rewriter.create( + producer.getLoc(), unary.getResult(), permutation); + rewriter.replaceOp(op, newTranspose.getResult()); + return success(); + } +}; + +/// Push transpose below `tensorrt.identity` if the identity consumer is an +/// elementwise op. After the "pushdown" phase, the "push up" phase will restore +/// the transpose above the identity if it could not be eliminated and the +/// source has a smaller memory cost. +struct PushdownTransposeIdentity : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(IdentityOp op, + PatternRewriter &rewriter) const override { + TransposeOp transposeProducer = op.getInput().getDefiningOp(); + if (!transposeProducer) + return failure(); + Type newIdentityType = + RankedTensorType::Builder( + cast(transposeProducer.getInput().getType())) + .setElementType(op.getType().getElementType()); + Value newIdentityResult = rewriter.create( + op.getLoc(), newIdentityType, transposeProducer.getInput()); + rewriter.replaceOpWithNewOp( + op, newIdentityResult, transposeProducer.getPermutation()); + return success(); + } +}; + +/// Rewrite transpose(unary(x)) to unary(transpose(x)). +template +struct PushUpTransposeUnary : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TransposeOp op, + PatternRewriter &rewriter) const override { + auto unaryOp = op.getInput().getDefiningOp(); + if (!unaryOp) + return failure(); + Value newUnaryResult = rewriter.create( + op.getLoc(), unaryOp.getInput(), op.getPermutation()); + rewriter.replaceOpWithNewOp(op, op.getType(), newUnaryResult, + unaryOp->getAttrs()); + return success(); + } +}; + +/// The cost of a `reshape` is always zero, because the implementation of a +/// reshape is always just rearrangement of metadata and it does not involve +/// launching a CUDA kernel. The cost of a `transpose` will be zero if its input +/// is a `ConstantOp`. However, such case doesn't occur since constant folding +/// eliminates transpose op even before this pass runs. +static int64_t +reshapeOrTransposeOrConstantCost(Operation *reshapeOrTransposeOrConstant) { + if (isa(reshapeOrTransposeOrConstant)) + return 0; + if (isa(reshapeOrTransposeOrConstant)) + return 0; + TransposeOp op = dyn_cast(reshapeOrTransposeOrConstant); + assert(op); + return memoryCost(op.getResult().getType()); +} + +/// Computes benefit of pushing transpose above elementwise and after +/// reshape/transpose producer. +/// Benefit = Original cast - New cost +/// Original cost = memoryCost(original reshape/transpose/constant) + +/// memoryCost(transpose after elementwise) +/// New cost = memoryCost(original reshape/transpose/constant) + +/// memoryCost(pushed up transpose) + memoryCost(newly added transpose on the +/// other side) +static int64_t pushUpBenefit(ElementWiseOp elementwise, + Operation *reshapeOrTransposeOrConstantProducer) { + int64_t originalTransposeReshapeCost = + reshapeOrTransposeOrConstantCost(reshapeOrTransposeOrConstantProducer); + int64_t originalCost = originalTransposeReshapeCost + + memoryCost(elementwise.getResult().getType()); + // The cost of the pushed up transpose will be 0 in the following cases, + // a. The parent of the elementwise input is a `reshape` op with a + // `ConstantOp` input. + // b. The parent of the elementwise input is a `ConstantOp`. + // Such case won't happens for `transpose` op since its folded away even + // before this pass. + int64_t pushedUpTransposeCost = + (matchPattern(reshapeOrTransposeOrConstantProducer, + m_Op(m_Op())) || + isa(reshapeOrTransposeOrConstantProducer)) + ? 0 + : memoryCost(elementwise.getResult().getType()); + // If other side input to elementwise is coming from `ConstantOp`, cost of the + // newly added transpose will be zero. + TensorValue otherSideValue = + getOtherEwiseInput(elementwise, reshapeOrTransposeOrConstantProducer); + int64_t otherSideNewlyAddedTransposeCost = + otherSideValue.getDefiningOp() + ? 0 + : memoryCost(otherSideValue.getType()); + + int64_t newCost = originalTransposeReshapeCost + pushedUpTransposeCost + + otherSideNewlyAddedTransposeCost; + return originalCost - newCost; +} + +/// Returns true if pushing transpose op above given `elementwise` op is +/// beneficial. +static bool shouldPushUpTransposeElementwise( + ElementWiseOp elementwise, + Operation *reshapeOrTransposeOrConstantProducer) { + return (pushUpBenefit(elementwise, reshapeOrTransposeOrConstantProducer) > 0); +} + +/// Push transpose above elementwise if both of the following conditions hold +/// true, +/// 1. Input to the transpose is the output of an elementwise op. +/// 2. One or both of the inputs to the elementwise op is an output of transpose +/// or reshape op. +/// Idea is, if transpose is pushed up in this case, transpose/transpose or +/// reshape/transpose pair at the input of elementwise op will result in a +/// single shuffle op (both reshape and transpose are canonicalized to the +/// shuffle op, later in the pipeline). +/// +/// For example, subgraph in the form +/// +/// %2 = [reshape, transpose](%0) +/// %3 = [reshape, transpose](%1) +/// %4 = elementwise(%2, %3) +/// %5 = transpose(%4) +/// +/// where [reshape, transpose](%k) means if (%k can be passed without any +/// transformation as well) transformation is applied to %k, it will be either +/// reshape or transpose. However, if both %0 and %1 does not have any +/// transformation, this pattern doesn't apply. Both or any one of %0 and %1 can +/// be an output of constant op. +/// +/// is rewritten to +/// +/// %2 = [reshape, transpose](%0) +/// %4 = transpose1(%2) +/// %3 = [reshape, transpose](%1) +/// %5 = transpose2(%3) +/// %6 = elementwise(%4, %5) +/// +/// In short, the transpose (called `pushed-up` transpose, hereafter) after +/// elementwise is pushed above elementwise, before lhs or rhs input, if input's +/// parent is reshape/transpose. A `new` transpose is added on the other side of +/// `pushed-up` transpose. Benefit of applying this pattern is computed based on +/// `memoryCost` difference in the original and modified subgraph. + +struct PushUpTransposeElementwise : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TransposeOp op, + PatternRewriter &rewriter) const override { + auto elementwiseOp = op.getInput().getDefiningOp(); + if (!elementwiseOp) + return failure(); + + Location loc = op->getLoc(); + Operation *lhsParent = elementwiseOp.getInput1().getDefiningOp(); + Operation *rhsParent = elementwiseOp.getInput2().getDefiningOp(); + bool isLhsParentReshapeOrTransposeOrConstant = + lhsParent && isa(lhsParent); + bool isRhsParentReshapeOrTransposeOrConstant = + rhsParent && isa(rhsParent); + if (!isLhsParentReshapeOrTransposeOrConstant && + !isRhsParentReshapeOrTransposeOrConstant) + return failure(); + + auto addTransposeOp = [&](Value input, AffineMap perm) { + return rewriter + .create(loc, + /*input=*/input, + /*permutation=*/perm) + .getResult(); + }; + + auto rewriteElementwiseOp = [&](Value lhs, Value rhs) { + return rewriter + .create(loc, + /*input1=*/lhs, + /*input2=*/rhs, + /*op=*/elementwiseOp.getElementwiseOperation()) + .getResult(); + }; + + auto pushUpTransposeOnLhs = [&]() { + // Add a transpose before RHS + auto rhsTranspose = + addTransposeOp(elementwiseOp.getInput2(), op.getPermutation()); + // Push up transpose from after elementwise to before elementwise on LHS + auto pushedUpTransposeLhs = + addTransposeOp(elementwiseOp.getInput1(), op.getPermutation()); + // Update elementwise op + auto updatedElementwiseOp = + rewriteElementwiseOp(pushedUpTransposeLhs, rhsTranspose); + // Replace uses of old transpose + rewriter.replaceAllUsesWith(op.getResult(), updatedElementwiseOp); + }; + + auto pushUpTransposeOnRhs = [&]() { + // Add a transpose before LHS + auto lhsTranspose = + addTransposeOp(elementwiseOp.getInput1(), op.getPermutation()); + // Push up transpose from after elementwise to before elementwise on RHS + auto pushedUpTransposeRhs = + addTransposeOp(elementwiseOp.getInput2(), op.getPermutation()); + // Update elementwise op + auto updatedElementwiseOp = + rewriteElementwiseOp(lhsTranspose, pushedUpTransposeRhs); + // Replace uses of old transpose + rewriter.replaceAllUsesWith(op.getResult(), updatedElementwiseOp); + }; + + if (isLhsParentReshapeOrTransposeOrConstant && + shouldPushUpTransposeElementwise(elementwiseOp, lhsParent)) { + pushUpTransposeOnLhs(); + return success(); + } + + if (isRhsParentReshapeOrTransposeOrConstant && + shouldPushUpTransposeElementwise(elementwiseOp, rhsParent)) { + pushUpTransposeOnRhs(); + return success(); + } + return failure(); + } +}; +} // namespace + +namespace { +/// Constant fold transpose +struct TransposeConstantFold : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TransposeOp op, + PatternRewriter &rewriter) const override { + TensorType inputType = op.getInput().getType(); + + // Fold the input to a constant if possible, otherwise return. + ElementsAttr inputConst; + if (!matchPattern(op.getInput(), m_Constant(&inputConst))) + return failure(); + assert(inputType.hasStaticShape() && "constants should have static shape"); + + // Don't fold transpose if input has > 1 user and input is non-splat + // constant. + if (!inputConst.isSplat() && + (!op.getInput().hasOneUse() || + inputConst.getNumElements() > kFoldOpEltLimit)) + return failure(); + + ElementsAttr result = + constantFoldTranspose(inputConst, op.getPermutation()); + if (!result) + return failure(); + rewriter.replaceOpWithNewOp(op, result); + return success(); + } +}; +} // namespace + +namespace { +// Convert a matrix multiply to an einsum +// einsum allows is more flexible with the inputs, braodcasting dimensions and +// transposing. Hence, we can easily implement rewrites that merge transpose +// into einsum and push reshape through an einsum +class MatmulToEinsum : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::MatrixMultiplyOp op, + PatternRewriter &rewriter) const override { + using tensorrt::MatrixOperation; + + int numBatchDims = op.getCollectionRank(0); + if (numBatchDims != op.getCollectionRank(1)) + return failure(/* unknown number of batch dimensions */); + + std::string arg0Pattern = "", arg1Pattern = "", outPattern = ""; + char nextChar = 'a'; + for (int i = 0; i < numBatchDims; i++) { + // einsum supports broadcasting, so we just add the batch dims to the + // pattern + arg0Pattern += nextChar; + arg1Pattern += nextChar; + outPattern += nextChar++; + } + + char matrix0A, matrix0B, matrix1A, matrix1B, multiplyLetter; + if (op.getOp0() == MatrixOperation::kVECTOR) { + matrix0A = 0; + multiplyLetter = matrix0B = nextChar++; + arg0Pattern += matrix0B; + } else if (op.getOp0() == MatrixOperation::kNONE) { // normal matrix + matrix0A = nextChar++; + multiplyLetter = matrix0B = nextChar++; + arg0Pattern += matrix0A; + arg0Pattern += matrix0B; + outPattern += matrix0A; + } else if (op.getOp0() == MatrixOperation::kTRANSPOSE) { + multiplyLetter = matrix0A = nextChar++; + matrix0B = nextChar++; + arg0Pattern += matrix0A; + arg0Pattern += matrix0B; + outPattern += matrix0B; + } else { + return failure(/* unknown matrix operation */); + } + + if (op.getOp1() == MatrixOperation::kVECTOR) { + matrix1A = multiplyLetter; + matrix1B = 0; + arg1Pattern += matrix1A; + } else if (op.getOp1() == MatrixOperation::kNONE) { // normal matrix + matrix1A = multiplyLetter; + matrix1B = nextChar++; + arg1Pattern += matrix1A; + arg1Pattern += matrix1B; + outPattern += matrix1B; + } else if (op.getOp1() == MatrixOperation::kTRANSPOSE) { + matrix1A = nextChar++; + matrix1B = multiplyLetter; + arg1Pattern += matrix1A; + arg1Pattern += matrix1B; + outPattern += matrix1A; + } else { + return failure(/* unknown matrix operation */); + } + + SmallVector args{op.getInput0(), op.getInput1()}; + std::string einsum = arg0Pattern + "," + arg1Pattern + "->" + outPattern; + rewriter.replaceOpWithNewOp(op, op.getType(), args, + einsum); + return success(); + } +}; +} // namespace + +namespace { +// convert tensorrt.shuffle to tensorrt.transpose and tensorrt.reshape +// tensorrt.shuffle is the "lower level" op that eventually gets converted to +// INetwork layers it is possible that tensorrt.shuffle already exist in the +// network, hence, convert it back to the "simpler" reshape and transpose ops +// shuffle -> transpose(reshape(transpose(x))) +class ShuffleToTransposeAndReshape + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::ShuffleOp op, + PatternRewriter &rewriter) const override { + Value input = op.getInput(); + + if (op.getZeroIsPlaceholder()) + return failure(); + + input = rewriter.createOrFold( + op.getLoc(), input, + AffineMap::getPermutationMap(op.getFirstTranspose(), op.getContext())); + if (op.getReshape()) { + input = rewriter.createOrFold( + op.getLoc(), + cast(input.getType()).clone(*op.getReshape()), + input); + } else if (op.getDynamicReshape()) { + SmallVector shape(ShapedType::kDynamic, + op.getResult().getType().getRank()); + input = rewriter.createOrFold( + op.getLoc(), cast(input.getType()).clone(shape), + input, op.getDynamicReshape()); + } + input = rewriter.createOrFold( + op.getLoc(), input, + AffineMap::getPermutationMap(op.getSecondTranspose(), op.getContext())); + rewriter.replaceOp(op, input); + return success(); + } +}; +} // namespace + +namespace { +template +class RankChangeToReshape : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpType op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getInput()); + return success(); + } +}; +} // namespace + +namespace { +struct EinsumEquation { + std::string equation; + SmallVector lhsParts; + std::string lhs; + std::string rhs; + + LogicalResult parse(llvm::StringRef einsumEquation) { + std::string e{einsumEquation}; + return parse(e); + } + + LogicalResult parse(const std::string &einsumEquation) { + size_t pos = einsumEquation.find("->"); + if (pos == std::string::npos) + return failure(); + equation = einsumEquation; + lhs = einsumEquation.substr(0, pos); + rhs = einsumEquation.substr(pos + 2); + std::istringstream lhsStream(lhs); + std::string currentPart; + while (std::getline(lhsStream, currentPart, ',')) { + lhsParts.push_back(currentPart); + for (char c : currentPart) + if (!(c >= 'a' && c <= 'z')) + return failure(); + } + return success(); + } + + std::string generateEquation() const { + std::string ret = lhsParts[0]; + for (size_t i = 1; i < lhsParts.size(); i++) + ret += "," + lhsParts[i]; + ret += "->" + rhs; + return ret; + } +}; +} // namespace + +namespace { +// Control when fusing a transpose into another op. +// Currently always fuse +static bool shouldFuseTranspose(tensorrt::TransposeOp transposeOp, + mlir::Operation *targetFusion) { + return true; +} +} // namespace + +namespace { +// Push down transpose to into an einsum, rearranging the axes of the input +// tensors in the einsum as needed einsum(x1, transpose(x2), ...) -> einsum(x1, +// x2, ...) +class PushDownTransposeToEinsum : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::EinsumOp op, + PatternRewriter &rewriter) const override { + EinsumEquation einsumEquation; + if (failed(einsumEquation.parse(op.getEquation()))) + return failure(); + + bool hasTransposeInput = false; + SmallVector newInputs; + for (size_t i = 0; i < op.getInputs().size(); i++) { + auto input = op.getInputs()[i]; + tensorrt::TransposeOp transpose = + input.getDefiningOp(); + if (transpose && shouldFuseTranspose(transpose, op)) { + AffineMap perm = transpose.getPermutation(); + if (!perm.isPermutation()) + return failure(/* Transpose is not a permutation */); + SmallVector equation; + for (char c : einsumEquation.lhsParts[i]) + equation.push_back(c); + + equation = inversePermutation(perm).compose(equation); + einsumEquation.lhsParts[i] = ""; + for (size_t j = 0; j < equation.size(); j++) + einsumEquation.lhsParts[i] += (char)equation[j]; + newInputs.push_back(transpose.getInput()); + hasTransposeInput = true; + } else { + newInputs.push_back(input); + } + } + + if (!hasTransposeInput) + return failure(); + + std::string newEinsumEquation = einsumEquation.generateEquation(); + assert(einsumEquation.rhs.size() == op.getType().getShape().size()); + rewriter.replaceOpWithNewOp(op, op.getType(), newInputs, + newEinsumEquation); + return success(); + } +}; +} // namespace + +namespace { +// Push up transpose from an einsum, rearranging the axes of the output tensor +// in the einsum as needed transpose(einsum(x1, x2, ...)) -> einsum(x1, x2, ...) +class PushUpTransposeToEinsum : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::TransposeOp op, + PatternRewriter &rewriter) const override { + AffineMap perm = op.getPermutation(); + if (!perm.isPermutation()) + return failure(); + + auto einsum = op.getInput().getDefiningOp(); + if (!einsum) + return failure(); + + if (!einsum->hasOneUse()) + return failure(); + + if (!shouldFuseTranspose(op, einsum)) + return failure(); + + EinsumEquation einsumEquation; + if (failed(einsumEquation.parse(einsum.getEquation()))) + return failure(); + + SmallVector einsumRhs; + for (char c : einsumEquation.rhs) + einsumRhs.push_back(c); + einsumRhs = perm.compose(einsumRhs); + einsumEquation.rhs = ""; + for (size_t i = 0; i < einsumRhs.size(); i++) + einsumEquation.rhs += (char)einsumRhs[i]; + + std::string newEinsumEquation = einsumEquation.generateEquation(); + + auto newEinsum = rewriter.create( + op.getLoc(), op.getType(), einsum.getInputs(), newEinsumEquation); + assert(einsumEquation.rhs.size() == newEinsum.getType().getShape().size()); + rewriter.replaceOp(op, newEinsum.getResult()); + return success(); + } +}; +} // namespace + +namespace { +// Create an new transpose op from an einsum. Rearrange the output axes to +// match the ordering of the input axes This should enable converting the einsum +// back to a matmul einsum(x1, x2, ...) -> transpose(einsum(x1, x2, ...)) +class EinsumPushDownTranspose : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::EinsumOp op, + PatternRewriter &rewriter) const override { + for (auto input : op.getInputs()) { + if (input.getDefiningOp()) + return failure(); // Wait until the transpose is pushed into the einsum + // first + } + // determine the "best" order. + // Ideally, we want the einsum to be reducable to a matmul. So the batch + // elements should appear first in the output + + EinsumEquation equation; + if (failed(equation.parse(op.getEquation()))) + return failure(); + + SmallVector> outputAxes; + for (size_t i = 0; i < equation.rhs.size(); i++) + outputAxes.push_back(std::make_pair(equation.rhs[i], i)); + std::sort(outputAxes.begin(), outputAxes.end(), + [&](const std::pair &a, + const std::pair &b) { + for (std::string &eqLhs : equation.lhsParts) { + if (eqLhs.find(a.first) != std::string::npos) { + if (eqLhs.find(b.first) != std::string::npos) { + return eqLhs.find(a.first) < eqLhs.find(b.first); + } else { + return true; + } + } else if (eqLhs.find(b.first) != std::string::npos) { + return false; + } + } + return a.first < b.first; + }); + + LLVM_DEBUG({ + std::stringstream out; + out << "outputAxes: ["; + for (auto x : outputAxes) + out << x.first << "(" << x.second << ") "; + out << "]\n"; + DBGS() << out.str(); + }); + + SmallVector newEinsumShape; + SmallVector forwardPerm; + std::string newEinsumRhs = ""; + for (auto &[c, i] : outputAxes) { + newEinsumRhs += c; + newEinsumShape.push_back(op.getType().getDimSize(i)); + forwardPerm.push_back(i); + } + if (newEinsumRhs == equation.rhs) + return failure(); // no change + + equation.rhs = newEinsumRhs; + std::string newEinsumEquation = equation.generateEquation(); + + auto newEinsum = rewriter.create( + op.getLoc(), op.getType().clone(newEinsumShape), op.getInputs(), + newEinsumEquation); + assert(equation.rhs.size() == newEinsum.getType().getShape().size()); + + auto forwardMap = + AffineMap::getPermutationMap(forwardPerm, op.getLoc().getContext()); + + auto newTranspose = rewriter.create( + op.getLoc(), newEinsum.getResult(), inversePermutation(forwardMap)); + + assert(op.getType() == newTranspose.getType()); + rewriter.replaceOp(op, newTranspose.getResult()); + + return success(); + } +}; +} // namespace + +namespace { +// Create an new transpose op from an einsum. Rearrange the input axes to match +// the ordering of the output axes. This should enable converting the einsum +// back to a matmul using the `EinsumToMatrixMultiply` pattern. einsum(x1, x2, +// ...) -> einsum(x1, transpose(x2), ...) +class EinsumPushUpTranspose : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::EinsumOp op, + PatternRewriter &rewriter) const override { + EinsumEquation equation; + if (failed(equation.parse(op.getEquation()))) + return failure(); + + llvm::SmallSetVector multipliedAxes; + llvm::SmallSetVector uniqueAxes; + for (size_t i = 0; i < equation.lhsParts.size(); i++) { + for (size_t j = 0; j < equation.lhsParts[i].size(); j++) { + if (equation.rhs.find(equation.lhsParts[i][j]) == std::string::npos) { + multipliedAxes.insert(equation.lhsParts[i][j]); + } else { + // contained in rhs + bool found = false; + for (size_t k = 0; k < equation.lhsParts.size(); k++) { + if (k != i) { + if (equation.lhsParts[k].find(equation.lhsParts[i][j]) != + std::string::npos) { + found = true; + break; + } + } + } + if (!found) { + // contained in the i-th lhs, contained in the rhs and not contained + // in any other lhs this is not a batch axis of the multiplication. + // E.g in "abc,acd->abd" this would be "bd" + uniqueAxes.insert(equation.lhsParts[i][j]); + } + } + } + } + + bool didChange = false; + SmallVector newInputs; + for (size_t i = 0; i < op.getInputs().size(); i++) { + auto input = cast>(op.getInputs()[i]); + RankedTensorType inputType = input.getType(); + SmallVector> inputAxes; + for (int j = 0; j < inputType.getRank(); j++) + inputAxes.push_back(std::make_pair(equation.lhsParts[i][j], j)); + std::sort(inputAxes.begin(), inputAxes.end(), + [&](const std::pair &a, + const std::pair &b) { + size_t posA = equation.rhs.find(a.first); + size_t posB = equation.rhs.find(b.first); + if (posA != std::string::npos && posB != std::string::npos) { + // both letters are in the rhs, meaning that these are + // either batch or dims of the matrix try to match the order + // of the output so that these can become batch dims later + return posA < posB; + } else if (posA == std::string::npos && + posB == std::string::npos) { + return a.second < b.second; // preserve the order if neither + // is found in output + } else { + // one is in the output, and the other is not in the output + // if the character is one of the last two outputs, then we + // would rather preserve the ordering as the transpose + // property on the matrix multiply can be used to handle + if ((i == 0 && (posA == equation.rhs.size() - 2 || + posB == equation.rhs.size() - 2)) || + (i == 1 && (posA == equation.rhs.size() - 1 || + posB == equation.rhs.size() - 1))) { + return a.second < b.second; // preserve ordering + } + // does not match expected pattern, put the ordering so that + // the one in the output is first + return posA != std::string::npos; + } + }); + std::string newEquation = ""; + for (size_t j = 0; j < inputAxes.size(); j++) + newEquation += inputAxes[j].first; + if (newEquation != equation.lhsParts[i]) { + equation.lhsParts[i] = newEquation; + didChange = true; + SmallVector perm; + for (size_t j = 0; j < inputAxes.size(); j++) + perm.push_back(inputAxes[j].second); + auto newTranspose = rewriter.create( + op.getLoc(), input, + AffineMap::getPermutationMap(perm, op.getContext())); + newInputs.push_back(newTranspose.getResult()); + } else { + newInputs.push_back(input); + } + } + + if (!didChange) + return failure(); + + std::string newEquation = equation.generateEquation(); + assert(equation.rhs.size() == op.getType().getShape().size()); + rewriter.replaceOpWithNewOp(op, op.getType(), newInputs, + newEquation); + return success(); + } +}; +} // namespace + +namespace { +// When one of the input axes are 1, then we can push that up as a reshape. +// For example, +// %3 = tensorrt.einsum("abc,cd->abd", %1: tensor<1x2x3xf32>, %2: +// tensor<3x4xf32>) -> tensor<1x2x4xf32> +// will become +// %r1 = tensorrt.reshape %1 : tensor<1x2x3xf32> to tensor<2x3xf32> reshape +// the input to remove the 1 dim %e1 = tensorrt.einsum("bc,cd->bd", %r1, %2) +// -> tensor<2x4xf32> %3 = tensorrt.reshape %e1 : tensor<2x4xf32> to +// tensor<1x2x4xf32> reshape the output to add the 1 dim back +class EinsumEliminate1Axis : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::EinsumOp op, + PatternRewriter &rewriter) const override { + EinsumEquation einsumEquation; + if (failed(einsumEquation.parse(op.getEquation()))) + return failure(); + + bool madeChange = false; + SmallVector newInputs; + + for (size_t i = 0; i < op.getInputs().size(); i++) { + auto input = cast>(op.getInputs()[i]); + RankedTensorType inputType = input.getType(); + std::string equation = ""; + bool change = false; + SmallVector newInputShape; + for (int j = 0; j < inputType.getRank(); j++) { + if (inputType.getDimSize(j) == 1) { + // this axis is size 1, and not used in the multiplication, we can + // remove it from the einsum + madeChange = change = true; + } else { + equation += einsumEquation.lhsParts[i][j]; + newInputShape.push_back(inputType.getDimSize(j)); + } + } + if (change) { + auto newInput = + rewriter + .create( + op.getLoc(), inputType.clone(newInputShape), input) + .getResult(); + newInputs.push_back(newInput); + einsumEquation.lhsParts[i] = equation; + } else { + newInputs.push_back(input); + } + } + + if (!madeChange) + return failure(); + + RankedTensorType outputType = op.getType(); + EinsumEquation newEinsumEquation = einsumEquation; + newEinsumEquation.rhs = ""; + SmallVector newOutputShape; + bool changeOutput = false; + for (int i = 0; i < outputType.getRank(); i++) { + if (outputType.getDimSize(i) == 1) { + // this axis is size 1, and not used in the multiplication, we can + // remove it from the einsum + changeOutput = true; + } else { + newEinsumEquation.rhs += einsumEquation.rhs[i]; + newOutputShape.push_back(outputType.getDimSize(i)); + } + } + std::string newEquation = newEinsumEquation.generateEquation(); + + if (changeOutput) { + auto newEinsum = rewriter.create( + op.getLoc(), outputType.clone(newOutputShape), newInputs, + newEquation); + assert(newEinsumEquation.rhs.size() == + newEinsum.getType().getShape().size()); + auto outReshape = + rewriter + .create(op.getLoc(), op.getType(), + newEinsum.getResult()) + .getResult(); + assert(op.getType() == outReshape.getType()); + rewriter.replaceOp(op, outReshape); + return success(); + } else { + assert(newEinsumEquation.rhs.size() == op.getType().getShape().size()); + rewriter.replaceOpWithNewOp(op, op.getType(), + newInputs, newEquation); + return success(); + } + } +}; +} // namespace + +namespace { + +// When one of the input axes has a 1 shaped input and there is a reshape on the +// input, then the reshape can be merged with the einsum. E.g. +// %1 = tensorrt.reshape %0 : tensor<1x2x3xf32> to tensor<2x3xf32> +// %2 = tensorrt.einsum("bc,cd->bd", %1, %2) -> tensor<2x4xf32> +// will become +// %2 = tensorrt.einsum("abc,cd->bd", %0, %2) -> tensor<2x4xf32> +class EinsumMergeDown1Axis : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::EinsumOp op, + PatternRewriter &rewriter) const override { + EinsumEquation equation; + if (failed(equation.parse(op.getEquation()))) + return failure(); + + char nextChar = 'a'; + auto getNextChar = [&]() -> char { + while (nextChar <= 'z') { + char c = nextChar++; + if (equation.equation.find(c) == std::string::npos) + return c; + } + return 0; + }; + + SmallVector newInputs; + bool madeChange = false; + for (size_t i = 0; i < op.getInputs().size(); i++) { + Value input = op.getInputs()[i]; + if (auto collapse = input.getDefiningOp()) { + RankedTensorType inputType = collapse.getInput().getType(); + if (!inputType.hasStaticShape()) + return failure(/* collapse rank op with dynamic shape */); + auto inputShape = inputType.getShape(); + std::string newEquation = ""; + size_t k = 0; + for (size_t j = 0; j < inputShape.size(); j++) { + if (inputShape[j] == 1) { + char c = getNextChar(); + if (c == 0) + return failure(/* no more einsum characters available */); + newEquation += c; + } else { + newEquation += equation.lhsParts[i][k++]; + } + } + assert(k == equation.lhsParts[i].size()); + newInputs.push_back(collapse.getInput()); + equation.lhsParts[i] = newEquation; + madeChange = true; + } else { + newInputs.push_back(input); + } + } + + if (!madeChange) + return failure(); + + std::string newEquation = equation.generateEquation(); + assert(equation.rhs.size() == op.getType().getShape().size()); + rewriter.replaceOpWithNewOp(op, op.getType(), newInputs, + newEquation); + return success(); + } +}; +} // namespace + +namespace { + +// In the case that the output of an einsum has a 1 shaped output, then the +// reshape can be merged with the einsum if there is an input that is also one +// shaped. E.g. +// %1 = tensorrt.einsum("abc,cd->bd", %0 : tensor<1x2x3xf32>, %1) -> +// tensor<2x4xf32> %2 = tensorrt.reshape %1 : tensor<2x4xf32> to +// tensor<1x2x4xf32> +// will become +// %2 = tensorrt.einsum("abc,cd->abd", %0, %1) -> tensor<1x2x4xf32> +class EinsumMergeUp1Axis : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::ExpandRankOp op, + PatternRewriter &rewriter) const override { + if (!op.getType().hasStaticShape()) + return failure(/* only handle static expand rank */); + auto einsum = op.getInput().getDefiningOp(); + if (!einsum) + return failure(); + if (!einsum->hasOneUse()) + return failure(/* einsum used more than once, can't modify */); + + EinsumEquation equation; + if (failed(equation.parse(einsum.getEquation()))) + return failure(); + + llvm::SmallSetVector oneAxisChars; + llvm::SmallSetVector nonOneAxisChars; + for (size_t i = 0; i < einsum.getInputs().size(); i++) { + auto inputShape = + cast(einsum.getInputs()[i].getType()).getShape(); + for (size_t j = 0; j < inputShape.size(); j++) { + if (inputShape[j] == 1) + oneAxisChars.insert(equation.lhsParts[i][j]); + else + nonOneAxisChars.insert(equation.lhsParts[i][j]); + } + } + + // an axis can hvae 1 and non-1 shapes associated in the case that the axis + // is broadcast. In which case, it is not a 1 shaped axis on the output. + oneAxisChars.remove_if([&](char c) { return nonOneAxisChars.contains(c); }); + if (oneAxisChars.empty()) + return failure(/* no one axis inputs found */); + + auto einsumShape = op.getInput().getType().getShape(); + auto outputShape = op.getResult().getType().getShape(); + std::string newRhs = ""; + for (size_t i = 0, j = 0, k = 0; i < outputShape.size(); i++) { + if (outputShape[i] == 1) { + if (k >= oneAxisChars.size()) + return failure(); + newRhs += oneAxisChars[k++]; + } else { + if (j >= equation.rhs.size()) + return failure(); + assert(einsumShape[j] == outputShape[i]); + newRhs += equation.rhs[j++]; + } + } + + std::string newEquation = equation.lhs + "->" + newRhs; + assert(newRhs.size() == op.getType().getShape().size()); + rewriter.replaceOpWithNewOp( + op, op.getType(), einsum.getInputs(), newEquation); + return success(); + } +}; +} // namespace + +namespace { +// In the case of an einsum that is performing a broadcast, increase the rank of +// its inputs so that it can better match the matrix multiply pattern. +class EinsumPushUp1AxisReshape : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::EinsumOp op, + PatternRewriter &rewriter) const override { + EinsumEquation equation; + if (failed(equation.parse(op.getEquation()))) + return failure(); + if (op->getNumOperands() != 2) + return failure(); + + assert(equation.rhs.size() == op.getType().getShape().size()); + + char matrixAxes[2] = {0, 0}; + char multipliedAxis = 0; + + for (size_t i = 0; i < 2; i++) { + for (int j = equation.lhsParts[i].size() - 1; j >= 0; j--) { + char c = equation.lhsParts[i][j]; + if (multipliedAxis == 0 && + equation.lhsParts[1 - i].find(c) != std::string::npos && + equation.rhs.find(c) == std::string::npos) + multipliedAxis = c; + if (matrixAxes[i] == 0 && + equation.lhsParts[1 - i].find(c) == std::string::npos && + equation.rhs[equation.rhs.size() - 2 + i] == c) + matrixAxes[i] = c; + } + } + + RankedTensorType inputType[2] = { + cast(op.getInputs()[0].getType()), + cast(op.getInputs()[1].getType())}; + if (!inputType[0].hasStaticShape() || !inputType[1].hasStaticShape()) + return failure(); + + SmallVector newInputShapes[2] = { + SmallVector{inputType[0].getShape()}, + SmallVector{inputType[1].getShape()}}; + EinsumEquation newEquation = equation; + + for (int i = 0; i < 2; i++) { + for (char c : equation.lhsParts[i]) { + if (c == multipliedAxis || c == matrixAxes[i] || + equation.rhs.find(c) == std::string::npos) + continue; + if (newEquation.lhsParts[1 - i].find(c) == std::string::npos) { + // figure out the best place to insert "c" + // Find the best index to insert 'c' into newEquation.lhsParts[1-i] + // so that all letters to the left of 'c' in equation.lhsParts[i] are + // to the left of 'c' and all letters to the right of 'c' in + // equation.lhsParts[i] are to the right of 'c' + size_t insertIdx = 0; + // Find the leftmost position such that all letters in + // equation.lhsParts[i] before 'c' are to the left of 'c' in + // newEquation.lhsParts[1-i] and all letters after 'c' are to the + // right We do this by finding the first position in + // newEquation.lhsParts[1-i] where a letter that comes after 'c' in + // equation.lhsParts[i] appears. If none, insert at the end. + std::string &target = newEquation.lhsParts[1 - i]; + const std::string &src = newEquation.lhsParts[i]; + size_t cPos = src.find(c); + for (insertIdx = 0; insertIdx <= target.size(); ++insertIdx) { + bool valid = true; + // Check all letters before c in src + for (size_t l = 0; l < cPos; ++l) { + char leftChar = src[l]; + size_t posInTarget = target.find(leftChar); + if (posInTarget != std::string::npos && + posInTarget >= insertIdx) { + valid = false; + break; + } + } + if (!valid) + continue; + // Check all letters after c in src + for (size_t r = cPos + 1; r < src.size(); ++r) { + char rightChar = src[r]; + size_t posInTarget = target.find(rightChar); + if (posInTarget != std::string::npos && posInTarget < insertIdx) { + valid = false; + break; + } + } + if (valid) + break; + } + target.insert(target.begin() + insertIdx, c); + newInputShapes[1 - i].insert( + newInputShapes[1 - i].begin() + insertIdx, 1); + assert(target.size() == newInputShapes[1 - i].size()); + } + } + } + + RankedTensorType newInputTypes[2] = {inputType[0].clone(newInputShapes[0]), + inputType[1].clone(newInputShapes[1])}; + + if (newInputTypes[0] == inputType[0] && newInputTypes[1] == inputType[1]) + return failure(/* nothing changed */); + + assert(newInputShapes[0].size() == newEquation.lhsParts[0].size() && + newInputShapes[1].size() == newEquation.lhsParts[1].size()); + + SmallVector reshapes{ + rewriter.createOrFold(op.getLoc(), newInputTypes[0], + op.getInputs()[0]), + rewriter.createOrFold(op.getLoc(), newInputTypes[1], + op.getInputs()[1])}; + + assert(newEquation.rhs.size() == op.getType().getShape().size()); + rewriter.replaceOpWithNewOp(op, op.getType(), reshapes, + newEquation.generateEquation()); + + return success(); + } +}; +} // namespace + +namespace { +// Push up a reshape through an einum to its inputs +// reshape(einsum(x1, x2, ..)) -> +// einsum(reshape(transpose(x1)), reshape(transpose(x2)), ...) +class PushReshapeUpThroughEinsum + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::ReshapeOp op, + PatternRewriter &rewriter) const override { + if (!op.getResult().getType().hasStaticShape()) + return failure(/* only handle static reshapes */); + + auto einsum = op.getInput().getDefiningOp(); + if (!einsum) + return failure(); + if (!einsum->hasOneUse()) + return failure(); + + EinsumEquation equation; + if (failed(equation.parse(einsum.getEquation()))) + return failure(); + + char nextChar = 'a'; + auto getNextChar = [&]() -> char { + while (nextChar <= 'z') { + char c = nextChar++; + if (equation.equation.find(c) == std::string::npos) + return c; + } + return 0; + }; + + SmallVector> inputShapes; + for (Value input : einsum.getInputs()) { + SmallVector shape( + cast(input.getType()).getShape()); + inputShapes.push_back(shape); + } + + auto reshapeInShape = op.getInput().getType().getShape(); + auto reshapeOutShape = op.getResult().getType().getShape(); + + struct ReshapeInfo { + std::string newAxes; + SmallVector newShape; + SmallVector oldShape; + }; + + bool hasNonTrivalReshape = false; + std::unordered_map inputToReshapedMap; + size_t inputNumElems = 1; + size_t outputNumElems = 1; + std::string inAxes = ""; + std::string outAxes = ""; + std::string prevInAxes = ""; + SmallVector outShape; + SmallVector inShape; + for (size_t i = 0, j = 0; i < reshapeOutShape.size(); i++) { + if (reshapeOutShape[i] == 0) + return failure(/* 0-shape not supported */); + outputNumElems *= reshapeOutShape[i]; + outShape.push_back(reshapeOutShape[i]); + char c = getNextChar(); + if (c == 0) + return failure(/* no more einsum characters available */); + outAxes += c; + while (j < reshapeInShape.size() && inputNumElems < outputNumElems) { + inputNumElems *= reshapeInShape[j]; + inShape.push_back(reshapeInShape[j]); + inAxes += equation.rhs[j++]; + } + if (inputNumElems == outputNumElems) { + if (inAxes.empty()) { + if (!prevInAxes.empty() && reshapeOutShape[i] == 1 && + outAxes.size() == 1) { + auto &p = inputToReshapedMap[prevInAxes]; + p.newAxes.push_back(c); + p.newShape.push_back(1); + if (prevInAxes.size() != p.newAxes.size()) + hasNonTrivalReshape = true; + outAxes = ""; + outShape.clear(); + inShape.clear(); + } + continue; + } + if (inAxes.size() != outAxes.size()) + hasNonTrivalReshape = true; + inputToReshapedMap[inAxes] = ReshapeInfo{ + .newAxes = outAxes, .newShape = outShape, .oldShape = inShape}; + outShape.clear(); + inShape.clear(); + prevInAxes = inAxes; + inAxes = ""; + outAxes = ""; + } + } + if (inputNumElems != outputNumElems || !inAxes.empty() || !outAxes.empty()) + return failure(/* should not happen, unexpected reshape */); + if (!hasNonTrivalReshape) + return failure(/* reshape is only expanding rank */); + + llvm::SmallMapVector charToGroup; + for (auto &[k, v] : inputToReshapedMap) + for (auto c : k) + charToGroup[c] = k; + + // check that all of the inputs are have the right groupping. If this + // doesn't happen then that means that the reshape can not get pushed + // through + for (std::string &eqLhs : equation.lhsParts) { + for (char c : eqLhs) { + auto it = charToGroup.find(c); + if (it == charToGroup.end()) + continue; + for (char c2 : it->second) + if (eqLhs.find(c2) == std::string::npos) + return failure(/* Not able to push reshape through einsum */); + } + } + + EinsumEquation newEquation = equation; + newEquation.rhs = ""; + for (char c : equation.rhs) { + assert(charToGroup.count(c)); + if (charToGroup[c][0] == c) + newEquation.rhs += inputToReshapedMap[charToGroup[c]].newAxes; + } + + // generate a `x` -> `reshape(transpose(x))` if necessary + SmallVector newInputs; + newEquation.lhsParts.clear(); + + LLVM_DEBUG({ + std::stringstream out; + out << "==== Einsum Reshape/Transpose Pushup Debug ====\n"; + for (const auto &entry : charToGroup) { + out << " charToGroup[" << entry.first << "] = " << entry.second + << "\n"; + } + for (const auto &entry : inputToReshapedMap) { + out << " inputToReshapedMap[" << entry.first + << "]: axes = " << entry.second.newAxes << ", shape = ["; + for (size_t si = 0; si < entry.second.newShape.size(); ++si) { + out << entry.second.newShape[si]; + if (si + 1 < entry.second.newShape.size()) + out << ", "; + } + out << "], old shape = ["; + for (size_t si = 0; si < entry.second.oldShape.size(); ++si) { + out << entry.second.oldShape[si]; + if (si + 1 < entry.second.oldShape.size()) + out << ", "; + } + out << "]"; + out << "\n"; + } + DBGS() << out.str(); + }); + + // check that the input shape for all of the inputs match (that there are no + // broadcasts happening on some inputs) + for (auto &[inputAxes, reshapeInfo] : inputToReshapedMap) { + // this is a single axis, so broadcasting is allowed in this case, hence + // do not check + if (inputAxes.size() == 1 && reshapeInfo.newAxes.size() == 1) + continue; + + for (size_t i = 0; i < einsum.getInputs().size(); i++) { + auto inputShape = + cast(einsum.getInputs()[i].getType()).getShape(); + for (size_t j = 0; j < inputAxes.size(); j++) { + size_t pos = equation.lhsParts[i].find(inputAxes[j]); + if (pos != std::string::npos && + inputShape[pos] != reshapeInfo.oldShape[j]) + return failure(/* input shape does not match output shape*/); + } + } + } + + for (size_t i = 0; i < einsum.getInputs().size(); i++) { + Value input = einsum.getInputs()[i]; + auto inputType = cast(input.getType()); + std::string newInputEquation = ""; + SmallVector newInputShape; + SmallVector newInputTranspose; + for (int j = 0; j < inputType.getRank(); j++) { + auto group = charToGroup.find(equation.lhsParts[i][j]); + if (group == charToGroup.end()) { + // this must be going into the multply, so it should just keep this + // letter + newInputEquation += equation.lhsParts[i][j]; + newInputTranspose.push_back(j); + newInputShape.push_back(inputType.getDimSize(j)); + } else { + // then there is some pattern that is getting consumed + if (group->second[0] != equation.lhsParts[i][j]) + continue; // then this isn't the first character, so it should have + // already been processed + // this is the first character in the group. So process all of the + // group + for (char c : group->second) + newInputTranspose.push_back(equation.lhsParts[i].find(c)); + newInputEquation += inputToReshapedMap[group->second].newAxes; + for (int64_t v : inputToReshapedMap[group->second].newShape) { + if (v != 1 && group->second.size() == 1 && + inputType.getDimSize(j) == 1) { + // if the group is of size 1, then it can have different sizes for + // each input due to broadcasting + newInputShape.push_back(1); + } else { + newInputShape.push_back(v); + } + } + } + } + + // Debug print for this input's result + LLVM_DEBUG({ + std::stringstream out; + out << "Input #" << i << " orig eq: " << equation.lhsParts[i] + << " new eq: " << newInputEquation << "\n"; + out << " newInputTranspose: ["; + for (size_t ti = 0; ti < newInputTranspose.size(); ++ti) { + out << newInputTranspose[ti]; + if (ti + 1 < newInputTranspose.size()) + out << ", "; + } + out << "]\n"; + out << " newInputShape: ["; + for (size_t si = 0; si < newInputShape.size(); ++si) { + out << newInputShape[si]; + if (si + 1 < newInputShape.size()) + out << ", "; + } + out << "]\n"; + out << " oldShape: ["; + for (size_t si = 0; si < inputType.getShape().size(); ++si) { + out << inputType.getShape()[si]; + if (si + 1 < inputType.getShape().size()) + out << ", "; + } + out << "]\n"; + DBGS() << out.str() << "\n"; + }); + + auto newTranspose = rewriter.createOrFold( + op.getLoc(), input, + AffineMap::getPermutationMap(newInputTranspose, + op.getLoc().getContext())); + auto newReshape = rewriter.createOrFold( + op.getLoc(), inputType.clone(newInputShape), newTranspose); + + newInputs.push_back(newReshape); + newEquation.lhsParts.push_back(newInputEquation); + } + std::string newEquationStr = newEquation.generateEquation(); + + LLVM_DEBUG({ + DBGS() << newEquationStr << "\n" + << "===============================================\n"; + }); + + auto newEinsum = rewriter.create( + einsum.getLoc(), op.getType(), newInputs, newEquationStr); + assert(newEquation.rhs.size() == newEinsum.getType().getShape().size()); + assert(op.getType() == newEinsum.getType()); + rewriter.replaceOp(op, newEinsum.getResult()); + + return success(); + } +}; +} // namespace + +namespace { + +// if there are mutliple axes that are getting multiplied together in an einsum, +// push up a reshape so that there is only a single axis. This will help with +// the conversion from einsum to matrix multiply. For example, +// %0 = tensorrt.einsum {equation = "acd,bcd->ab"} ins(%arg0, %arg1) +// will become +// %0 = tensorrt.reshape %arg0 +// %1 = tensorrt.reshape %arg1 +// %2 = tensorrt.einsum {equation = "ac,bc->ab"} ins(%0, %1) +class EinsumPushUpMultipleMulitipliedAxes + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::EinsumOp op, + PatternRewriter &rewriter) const override { + EinsumEquation equation; + if (failed(equation.parse(op.getEquation()))) + return failure(); + + std::string multipliedAxes = ""; + for (char c : equation.lhsParts[0]) + if (equation.rhs.find(c) == std::string::npos) + multipliedAxes += c; + if (multipliedAxes.size() <= 1) + return failure(/* pattern does not match */); + for (size_t i = 0; i < equation.lhsParts.size(); i++) { + if (equation.lhsParts[i].find(multipliedAxes) == std::string::npos) + return failure(/* pattern does not match */); + if (!cast(op.getInputs()[i].getType()).hasStaticShape()) + return failure(); + } + char nextChar = 'a'; + while (nextChar <= 'z') { + if (equation.equation.find(nextChar) == std::string::npos) + break; + nextChar++; + } + if (nextChar > 'z') + return failure(/* No more characters available */); + + SmallVector newInputs; + EinsumEquation newEquation = equation; + for (size_t i = 0; i < equation.lhsParts.size(); i++) { + auto inputType = cast(op.getInputs()[i].getType()); + SmallVector newInputShape; + std::string newInputEquation = ""; + size_t j = 0; + while (j < equation.lhsParts[i].size() && + multipliedAxes.find(equation.lhsParts[i][j]) == + std::string::npos) { + newInputShape.push_back(inputType.getDimSize(j)); + newInputEquation += equation.lhsParts[i][j]; + j++; + } + + int64_t combinedInputShape = 1; + while (j < equation.lhsParts[i].size() && + multipliedAxes.find(equation.lhsParts[i][j]) != std::string::npos) + combinedInputShape *= inputType.getDimSize(j++); + newInputShape.push_back(combinedInputShape); + newInputEquation += nextChar; + while (j < equation.lhsParts[i].size()) { + newInputShape.push_back(inputType.getDimSize(j)); + newInputEquation += equation.lhsParts[i][j]; + j++; + } + + newEquation.lhsParts[i] = newInputEquation; + auto reshape = rewriter.createOrFold( + op.getLoc(), inputType.clone(newInputShape), op.getInputs()[i]); + newInputs.push_back(reshape); + } + + assert(newEquation.rhs.size() == op.getType().getShape().size()); + rewriter.replaceOpWithNewOp( + op, op.getType(), newInputs, newEquation.generateEquation()); + return success(); + } +}; +} // namespace + +namespace { +static uint64_t estimateShuffleCost(Value input) { + // This is a heuristic. One may wish to update this in the future depending + // on their use case. This heuristic currently attempts to put shuffles + // "together" which allows two shuffles to be merged, and put shuffles on + // constant values which allows for them to be merged with the constant. + + Operation *op = input.getDefiningOp(); + bool foundShuffle = false; + bool canMergeUp = true; + for (int i = 0; op && i < 10; i++) { + if (isa(op)) + return 0; // This has found a constant. The constant can be + // reshaped/rearranged as necessary to absorb the shuffle. + // Hence, mark this as having 0 cost. + if (canMergeUp && + isa( + op)) + foundShuffle = true; + if (!isa( + op)) + canMergeUp = false; + if (op->getNumOperands() != 1) + break; + op = op->getOperand(0).getDefiningOp(); + } + + if (foundShuffle) + return 100; // should be able to merge this op up into another shuffle. So + // it gets a lower cost + + // if there is no constant or no existing shuffle, then there is nothing to + // merge with, so we are going to mark this as "high cost" + return 1000; +} +} // namespace + +namespace { +// Push a reshape down through an einsum +// einsum(reshape(x), y) -> transpose(reshape(einsum(x, reshape(transpose(y)), +// ...)) +class PushReshapeDownThroughEinsum + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::EinsumOp op, + PatternRewriter &rewriter) const override { + // this needs to some "heuristic" to determine if a reshape should + // get pushed down as reshapes might need to get added to other inputs to + // make the shapes work + bool hasReshapeInput = false; + Location reshapeLoc = op.getLoc(); + for (auto input : op.getInputs()) { + if (!cast(input.getType()).hasStaticShape()) { + return failure(/* dynamic input not supported */); + } + if (auto reshape = input.getDefiningOp()) { + if (!reshape.getInput().getType().hasStaticShape()) + return failure(/* dynamic reshape input not supported */); + hasReshapeInput = true; + reshapeLoc = reshape.getLoc(); + } + } + if (!hasReshapeInput) + return failure(); + + EinsumEquation equation; + if (failed(equation.parse(op.getEquation()))) + return failure(); + + char nextChar = 'a'; + auto getNextChar = [&]() -> char { + while (nextChar <= 'z') { + char c = nextChar++; + if (equation.equation.find(c) == std::string::npos) + return c; + } + return 0; + }; + + uint64_t currentEstimatedCost = 0; + + struct ReshapeInfo { + SmallVector inputShape; + SmallVector outputShape; + std::string newEinsumStr; + }; + std::unordered_map inputToReshapedMap; + for (size_t i = 0; i < op.getInputs().size(); i++) { + auto input = op.getInputs()[i]; + RankedTensorType einsumInputType = + cast(input.getType()); // reshape output type + if (auto reshape = input.getDefiningOp()) { + currentEstimatedCost += estimateShuffleCost(input); + size_t inputNumElems = 1; + size_t outputNumElems = 1; + SmallVector inputShape; + SmallVector outputShape; + std::string outputEinsumStr = ""; + RankedTensorType reshapeInputType = reshape.getInput().getType(); + for (int j = 0, k = 0; j < einsumInputType.getRank(); j++) { + if (einsumInputType.getDimSize(j) <= 1) { + // if 0-shape, then means the tensor is empty. Annoying edge case + // that not going to handle + // TODO: if 1-shape, then need additional logic to handle this + return failure(/* 0 or 1 dim not supported */); + } + outputNumElems *= einsumInputType.getDimSize(j); + outputShape.push_back(einsumInputType.getDimSize(j)); + outputEinsumStr += equation.lhsParts[i][j]; + while (k < reshapeInputType.getRank() && + inputNumElems < outputNumElems) { + if (reshapeInputType.getDimSize(k) == 1) + return failure(/* 1 dim not supported */); + inputNumElems *= reshapeInputType.getDimSize(k); + inputShape.push_back(reshapeInputType.getDimSize(k++)); + } + if (inputNumElems == outputNumElems) { + auto it = inputToReshapedMap.find(outputEinsumStr); + if (it != inputToReshapedMap.end()) { + if (it->second.inputShape != inputShape || + it->second.outputShape != outputShape) + return failure( + /* a single axis has multiple inconsistent reshapes */); + } else { + if (outputShape != inputShape) { + std::string newEinsumStr = ""; + for (size_t l = 0; l < inputShape.size(); l++) { + char c = getNextChar(); + if (c == 0) + return failure(/* no more characters available */); + newEinsumStr += c; + } + assert(outputEinsumStr.size() == outputShape.size()); + inputToReshapedMap[outputEinsumStr] = + ReshapeInfo{.inputShape = inputShape, + .outputShape = outputShape, + .newEinsumStr = newEinsumStr}; + } else { + // do not register this as there is no change in the shape + // so if something else requires a change, then it will get + // registered for this symbol instead + assert(outputShape.size() == 1); + } + } + inputShape.clear(); + outputShape.clear(); + outputEinsumStr = ""; + } + } + assert(inputNumElems == outputNumElems); + } + } + + llvm::SmallMapVector charToGroup; + for (auto &[k, v] : inputToReshapedMap) { + for (char c : k) { + auto it = charToGroup.find(c); + if (it == charToGroup.end()) + charToGroup[c] = k; + else + return failure( + /* a single axis has multiple inconsistent reshapes */); + } + } + + for (std::string &part : equation.lhsParts) { + for (char c : part) { + auto group = charToGroup.find(c); + if (group == charToGroup.end()) + continue; + for (char c2 : group->second) { + if (part.find(c2) == std::string::npos) + return failure( + /* Missing dimensions that need to be reshaped together */); + } + } + } + + for (char c : equation.rhs) { + auto group = charToGroup.find(c); + if (group == charToGroup.end()) + continue; + for (char c2 : group->second) { + if (equation.rhs.find(c2) == std::string::npos) + return failure( + /* Missing dimensions that need to be reshaped together */); + } + } + + size_t newEstimatedCost = 0; + + for (size_t i = 0; i < op.getInputs().size(); i++) { + Value input = op.getInputs()[i]; + RankedTensorType inputType = cast(input.getType()); + SmallVector newInputShape; + for (int j = 0; j < inputType.getRank(); j++) { + char c = equation.lhsParts[i][j]; + auto it = charToGroup.find(c); + if (it == charToGroup.end()) { + newInputShape.push_back(inputType.getDimSize(j)); + } else { + if (it->second[0] != c) + continue; // this will be processed on the first letter + auto group = inputToReshapedMap.find(it->second); + for (size_t k = 0; k < group->second.inputShape.size(); k++) { + newInputShape.push_back(group->second.inputShape[k]); + } + } + } + Value reshapeIn = input; + while (auto reshape = reshapeIn.getDefiningOp()) { + reshapeIn = reshape.getInput(); + } + SmallVector reshapeInShape{ + cast(reshapeIn.getType()).getShape()}; + if (reshapeInShape != newInputShape) + newEstimatedCost += estimateShuffleCost(reshapeIn); + } + + if (newEstimatedCost >= currentEstimatedCost) + return failure(/* new cost is not better than current cost */); + + // done matching against the pattern. Going to start modifying the MLIR at + // this point + + SmallVector newInputs; + EinsumEquation newEquation; + for (size_t i = 0; i < op.getInputs().size(); i++) { + Value input = op.getInputs()[i]; + RankedTensorType inputType = cast(input.getType()); + SmallVector newInputShape; + SmallVector newInputTranspose; + std::string newEinsumStr = ""; + for (int j = 0; j < inputType.getRank(); j++) { + char c = equation.lhsParts[i][j]; + auto it = charToGroup.find(c); + if (it == charToGroup.end()) { + newInputShape.push_back(inputType.getDimSize(j)); + newInputTranspose.push_back(j); + newEinsumStr += c; + } else { + if (it->second[0] != c) + continue; // this will be processed on the first letter + auto group = inputToReshapedMap.find(it->second); + newEinsumStr += group->second.newEinsumStr; + for (size_t k = 0; k < group->second.inputShape.size(); k++) { + newInputShape.push_back(group->second.inputShape[k]); + } + for (char c2 : group->first) { + size_t pos = equation.lhsParts[i].find(c2); + assert(pos != std::string::npos); + newInputTranspose.push_back(pos); + } + } + } + + Value reshapeIn = rewriter.createOrFold( + op.getLoc(), input, + AffineMap::getPermutationMap(newInputTranspose, op.getContext())); + while (auto definingOp = reshapeIn.getDefiningOp()) { + // two sequential reshapes just results in the shape of the last + // reshape. There are canonicalization patterns that do this as well + // but do it here so that the reshape op that was an input is no longer + // used. + reshapeIn = definingOp.getInput(); + } + auto reshape = rewriter.createOrFold( + op.getLoc(), inputType.clone(newInputShape), reshapeIn); + + newInputs.push_back(reshape); + newEquation.lhsParts.push_back(newEinsumStr); + } + + RankedTensorType outputType = op.getType(); + SmallVector einsumOutputShape; + SmallVector afterEinsumReshape; + SmallVector afterReshapeTranspose; + + for (int j = 0; j < outputType.getRank(); j++) { + char c = equation.rhs[j]; + auto it = charToGroup.find(c); + if (it == charToGroup.end()) { + einsumOutputShape.push_back(outputType.getDimSize(j)); + afterEinsumReshape.push_back(outputType.getDimSize(j)); + afterReshapeTranspose.push_back(j); + newEquation.rhs += c; + } else { + if (it->second[0] != c) + continue; + auto group = inputToReshapedMap.find(it->second); + newEquation.rhs += group->second.newEinsumStr; + for (size_t k = 0; k < group->second.inputShape.size(); k++) { + // the output shape of the einsum is the input shape of the reshape + // now as the reshape will appear after the einsum + einsumOutputShape.push_back(group->second.inputShape[k]); + } + for (size_t k = 0; k < group->second.outputShape.size(); k++) + afterEinsumReshape.push_back(group->second.outputShape[k]); + for (char c2 : it->second) { + size_t pos = equation.rhs.find(c2); + assert(pos != std::string::npos); + afterReshapeTranspose.push_back(pos); + } + } + } + + std::string newEinsumEquation = newEquation.generateEquation(); + + auto newEinsum = rewriter.create( + op.getLoc(), outputType.clone(einsumOutputShape), newInputs, + newEinsumEquation); + assert(newEquation.rhs.size() == newEinsum.getType().getShape().size()); + + auto newReshape = rewriter.createOrFold( + reshapeLoc, outputType.clone(afterEinsumReshape), + newEinsum.getResult()); + + Value newOut = rewriter.createOrFold( + reshapeLoc, newReshape, + AffineMap::getPermutationMap(afterReshapeTranspose, op.getContext())); + + assert(op.getType() == newOut.getType()); + rewriter.replaceOp(op, newOut); + + return success(); + } +}; +} // namespace + +namespace { +// reshape(transpose(x)) -> transpose(reshape(x)) +// NOTE: there are more cases that could be handled here +class MoveReshapeBeforeTranspose + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::ReshapeOp op, + PatternRewriter &rewriter) const override { + auto transpose = op.getInput().getDefiningOp(); + if (!transpose) + return failure(); + + RankedTensorType transposeInputType = transpose.getInput().getType(); + RankedTensorType reshapeInputType = + op.getInput().getType(); // transpose output type + RankedTensorType reshapeOutputType = op.getType(); + if (!reshapeInputType.hasStaticShape() || + !reshapeOutputType.hasStaticShape() || + !transposeInputType.hasStaticShape()) + return failure(); + + SmallVector transposePerm; + for (int i = 0; i < reshapeInputType.getRank(); i++) + transposePerm.push_back(i); + if (!transpose.getPermutation().isPermutation()) + return failure(/* Transpose is not a permutation */); + + transposePerm = transpose.getPermutation().compose(transposePerm); + + struct ReshapeGroup { + SmallVector transposeInAxes; + SmallVector transposeOutAxes; + SmallVector reshapeOut; + int64_t startOutputIdx; + }; + SmallVector reshapeGroups; + + SmallVector transposeInAxes; + SmallVector transposeOutAxes; + SmallVector groupReshapeOut; + size_t inputNumElems = 1; + size_t outputNumElems = 1; + int j = 0; + for (int i = 0; i < reshapeInputType.getRank(); i++) { + inputNumElems *= reshapeInputType.getDimSize(i); + if (!transposeInAxes.empty() && + transposeInAxes.back() + 1 != transposePerm[i]) + return failure(/* the transpose and the reshape are not commutative */); + transposeInAxes.push_back(transposePerm[i]); + while (j < reshapeOutputType.getRank() && + inputNumElems > outputNumElems) { + outputNumElems *= reshapeOutputType.getDimSize(j); + groupReshapeOut.push_back(reshapeOutputType.getDimSize(j)); + transposeOutAxes.push_back(j++); + } + if (inputNumElems == outputNumElems) { + reshapeGroups.push_back(ReshapeGroup{ + .transposeInAxes = transposeInAxes, + .transposeOutAxes = transposeOutAxes, + .reshapeOut = groupReshapeOut, + .startOutputIdx = -1, // set later + }); + transposeInAxes.clear(); + transposeOutAxes.clear(); + groupReshapeOut.clear(); + } + } + assert(inputNumElems == outputNumElems); + while (j < reshapeOutputType.getRank()) { + outputNumElems *= reshapeOutputType.getDimSize(j); + groupReshapeOut.push_back(reshapeOutputType.getDimSize(j)); + transposeOutAxes.push_back(j++); + } + assert(inputNumElems == outputNumElems); + assert(transposeInAxes.empty()); + if (!transposeOutAxes.empty() || !groupReshapeOut.empty()) + reshapeGroups.push_back(ReshapeGroup{ + .transposeInAxes = transposeInAxes, + .transposeOutAxes = transposeOutAxes, + .reshapeOut = groupReshapeOut, + .startOutputIdx = -1, // set later + }); + + SmallVector newTranspose; + SmallVector newReshape; + + std::sort(reshapeGroups.begin(), reshapeGroups.end(), [](auto &a, auto &b) { + if (a.transposeInAxes.empty()) + return false; + if (b.transposeInAxes.empty()) + return true; + return a.transposeInAxes[0] < b.transposeInAxes[0]; + }); + + for (auto &group : reshapeGroups) { + group.startOutputIdx = newReshape.size(); + for (int64_t i : group.reshapeOut) + newReshape.push_back(i); + } + + std::sort(reshapeGroups.begin(), reshapeGroups.end(), [](auto &a, auto &b) { + if (a.transposeOutAxes.empty()) + return false; + if (b.transposeOutAxes.empty()) + return true; + return a.transposeOutAxes[0] < b.transposeOutAxes[0]; + }); + + LLVM_DEBUG({ + std::stringstream out; + out << "Reshape Groups:\n"; + for (size_t idx = 0; idx < reshapeGroups.size(); ++idx) { + const auto &group = reshapeGroups[idx]; + out << " Group " << idx << ":\n"; + out << " transposeInAxes: ["; + for (size_t i = 0; i < group.transposeInAxes.size(); ++i) { + out << group.transposeInAxes[i]; + if (i + 1 < group.transposeInAxes.size()) + out << ", "; + } + out << "]\n"; + out << " transposeOutAxes: ["; + for (size_t i = 0; i < group.transposeOutAxes.size(); ++i) { + out << group.transposeOutAxes[i]; + if (i + 1 < group.transposeOutAxes.size()) + out << ", "; + } + out << "]\n"; + out << " reshapeOut: ["; + for (size_t i = 0; i < group.reshapeOut.size(); ++i) { + out << group.reshapeOut[i]; + if (i + 1 < group.reshapeOut.size()) + out << ", "; + } + out << "]\n"; + out << " startOutputIdx: " << group.startOutputIdx << "\n"; + } + DBGS() << out.str(); + }); + + for (auto &group : reshapeGroups) + for (size_t i = 0; i < group.reshapeOut.size(); i++) + newTranspose.push_back(group.startOutputIdx + i); + + Value newReshapeOp = rewriter.createOrFold( + op.getLoc(), reshapeInputType.clone(newReshape), transpose.getInput()); + Value newTransposeOp = rewriter.createOrFold( + transpose.getLoc(), newReshapeOp, + AffineMap::getPermutationMap(newTranspose, op.getContext())); + + assert(op.getType() == newTransposeOp.getType()); + rewriter.replaceOp(op, newTransposeOp); + + return success(); + } +}; +} // namespace + +namespace { +// transpose(reshape(x)) -> reshape(transpose(x)) +// NOTE: there are more cases that could be handled here +class MoveTransposeBeforeReshape + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::TransposeOp op, + PatternRewriter &rewriter) const override { + auto reshape = op.getInput().getDefiningOp(); + if (!reshape) + return failure(); + + RankedTensorType reshapeInputType = reshape.getInput().getType(); + RankedTensorType reshapeOutputType = reshape.getType(); + RankedTensorType transposeOutputType = op.getType(); + if (!reshapeInputType.hasStaticShape() || + !reshapeOutputType.hasStaticShape() || + !transposeOutputType.hasStaticShape()) + return failure(); + + SmallVector transposePerm; + for (int i = 0; i < reshapeOutputType.getRank(); i++) { + transposePerm.push_back(i); + } + transposePerm = + inversePermutation(op.getPermutation()).compose(transposePerm); + + struct ReshapeGroup { + SmallVector inputAxes; + SmallVector outputAxes; + SmallVector reshapeOut; + }; + SmallVector reshapeGroups; + + SmallVector inputAxes; + SmallVector outputAxes; + SmallVector groupReshapeOut; + size_t inputNumElems = 1; + size_t outputNumElems = 1; + int j = 0; + for (int i = 0; i < reshapeInputType.getRank(); i++) { + inputNumElems *= reshapeInputType.getDimSize(i); + inputAxes.push_back(i); + while (j < reshapeOutputType.getRank() && + inputNumElems > outputNumElems) { + outputNumElems *= reshapeOutputType.getDimSize(j); + groupReshapeOut.push_back(reshapeOutputType.getDimSize(j)); + if (!outputAxes.empty() && outputAxes.back() + 1 != transposePerm[j]) + return failure( + /* the transpose and the reshape are not commutative */); + outputAxes.push_back(transposePerm[j++]); + } + if (inputNumElems == outputNumElems) { + reshapeGroups.push_back(ReshapeGroup{ + .inputAxes = inputAxes, + .outputAxes = outputAxes, + .reshapeOut = groupReshapeOut, + }); + inputAxes.clear(); + outputAxes.clear(); + groupReshapeOut.clear(); + } + } + assert(inputNumElems == outputNumElems); + while (j < reshapeOutputType.getRank()) { + outputNumElems *= reshapeOutputType.getDimSize(j); + groupReshapeOut.push_back(reshapeOutputType.getDimSize(j)); + outputAxes.push_back(transposePerm[j++]); + } + + assert(inputNumElems == outputNumElems); + assert(inputAxes.empty()); + if (!outputAxes.empty() || !groupReshapeOut.empty()) + reshapeGroups.push_back(ReshapeGroup{ + .inputAxes = inputAxes, + .outputAxes = outputAxes, + .reshapeOut = groupReshapeOut, + }); + + SmallVector newTranspose; + SmallVector newReshape; + + std::sort(reshapeGroups.begin(), reshapeGroups.end(), [](auto &a, auto &b) { + if (a.outputAxes.empty()) + return false; + if (b.outputAxes.empty()) + return true; + return a.outputAxes[0] < b.outputAxes[0]; + }); + + // Debug print of reshapeGroups + LLVM_DEBUG({ + std::stringstream out; + out << "reshapeGroups:\n"; + for (size_t idx = 0; idx < reshapeGroups.size(); ++idx) { + const auto &group = reshapeGroups[idx]; + out << " Group " << idx << ":\n"; + out << " inputAxes: ["; + for (size_t i = 0; i < group.inputAxes.size(); ++i) { + out << group.inputAxes[i]; + if (i + 1 < group.inputAxes.size()) + out << ", "; + } + out << "]\n"; + out << " outputAxes: ["; + for (size_t i = 0; i < group.outputAxes.size(); ++i) { + out << group.outputAxes[i]; + if (i + 1 < group.outputAxes.size()) + out << ", "; + } + out << "]\n"; + out << " reshapeOut: ["; + for (size_t i = 0; i < group.reshapeOut.size(); ++i) { + out << group.reshapeOut[i]; + if (i + 1 < group.reshapeOut.size()) + out << ", "; + } + out << "]\n"; + } + DBGS() << out.str(); + }); + + for (auto &group : reshapeGroups) { + for (int64_t i : group.inputAxes) + newTranspose.push_back(i); + for (int64_t i : group.reshapeOut) + newReshape.push_back(i); + } + + Value newTransposeOp; + if (newTranspose.empty()) { + newTransposeOp = reshape.getInput(); // this can happen in the case of a + // scalar tensor type + } else { + newTransposeOp = rewriter.createOrFold( + op.getLoc(), reshape.getInput(), + AffineMap::getPermutationMap(newTranspose, op.getContext())); + } + Value newReshapeOp = rewriter.createOrFold( + reshape.getLoc(), reshapeInputType.clone(newReshape), newTransposeOp); + + assert(op.getType() == newReshapeOp.getType()); + rewriter.replaceOp(op, newReshapeOp); + return success(); + } +}; +} // namespace + +namespace { +// activation(reshape(x)) -> reshape(activation(x)) +class PushDownReshapeActivationRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::ActivationOp op, + PatternRewriter &rewriter) const override { + auto producer = op.getInput().getDefiningOp(); + if (!producer) + return failure(); + + auto activationOp = rewriter.create( + op.getLoc(), producer.getInput(), op.getActivationType(), + op.getAlphaAttr(), op.getBetaAttr()); + auto reshapeOp = rewriter.createOrFold( + producer.getLoc(), op.getType(), activationOp.getResult(), + producer.getShape()); + assert(op.getType() == reshapeOp.getType()); + rewriter.replaceOp(op, reshapeOp); + return success(); + } +}; +} // namespace + +namespace { +// unary(reshape(x)) -> reshape(unary(x)) +class PushDownReshapeUnaryRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::UnaryOp op, + PatternRewriter &rewriter) const override { + auto producer = op.getInput().getDefiningOp(); + if (!producer) + return failure(); + + auto unaryOp = rewriter.create( + op.getLoc(), producer.getInput(), op.getUnaryOperationAttr()); + auto reshapeOp = rewriter.createOrFold( + producer.getLoc(), op.getType(), unaryOp.getResult(), + producer.getShape()); + assert(op.getType() == reshapeOp.getType()); + rewriter.replaceOp(op, reshapeOp); + return success(); + } +}; +} // namespace + +namespace { +// identity(reshape(x)) -> reshape(identity(x)) +class PushDownReshapeIdentityRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::IdentityOp op, + PatternRewriter &rewriter) const override { + auto producer = op.getInput().getDefiningOp(); + if (!producer) + return failure(); + + RankedTensorType newIdentityType = + producer.getInput().getType().clone(op.getType().getElementType()); + Value newIdentityResult = rewriter.create( + op.getLoc(), newIdentityType, producer.getInput()); + auto reshapeOp = rewriter.createOrFold( + producer.getLoc(), op.getType(), newIdentityResult, + producer.getShape()); + assert(op.getType() == reshapeOp.getType()); + rewriter.replaceOp(op, reshapeOp); + return success(); + } +}; +} // namespace + +namespace { +// reshape(unary_OpType(x)) -> unary_OpType(reshape(x)) +template +class PushUpReshapeUnary : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::ReshapeOp op, + PatternRewriter &rewriter) const override { + auto producer = op.getInput().getDefiningOp(); + if (!producer) + return failure(); + + Type reshapeType = + op.getType().clone(producer.getInput().getType().getElementType()); + + Value newReshapeResult = rewriter.create( + op.getLoc(), reshapeType, producer.getInput(), op.getShape()); + auto newOp = + rewriter.createOrFold(producer.getLoc(), op.getType(), + newReshapeResult, producer->getAttrs()); + assert(op.getType() == newOp.getType()); + rewriter.replaceOp(op, newOp); + return success(); + } +}; +} // namespace + +namespace { +// op(dequantize(quantize(x))) -> dequantize(quantize(op(x))) +template +class PushUpOpQuantizeDequantize : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpType op, + PatternRewriter &rewriter) const override { + auto dequantizeOp = + op.getInput().template getDefiningOp(); + if (!dequantizeOp) + return failure(); + Value scale = dequantizeOp.getScale(); + RankedTensorType scaleType = cast(scale.getType()); + if (!(scaleType.getRank() == 0 || + (scaleType.getRank() == 1 && scaleType.getDimSize(0) == 1)) || + dequantizeOp.getAxis().has_value()) + return failure(); + auto quantizeOp = + dequantizeOp.getInput().template getDefiningOp(); + if (!quantizeOp) + return failure(); + if (quantizeOp.getScale() != scale || quantizeOp.getAxis().has_value()) + return failure(); + + auto input = quantizeOp.getInput(); + auto pushedOp = rewriter.create( + op.getLoc(), op.getResult().getType(), input, op->getAttrs()); + RankedTensorType newQuantizedType = pushedOp.getType().clone( + quantizeOp.getResult().getType().getElementType()); + auto newQuantizeOp = rewriter.create( + quantizeOp.getLoc(), newQuantizedType, pushedOp, scale, + quantizeOp.getAxisAttr()); + auto newDequantizeOp = rewriter.create( + dequantizeOp.getLoc(), op.getResult().getType(), + newQuantizeOp.getResult(), scale, dequantizeOp.getAxisAttr()); + assert(op.getType() == newDequantizeOp.getType()); + rewriter.replaceOp(op, newDequantizeOp.getResult()); + return success(); + } +}; +} // namespace + +namespace { +// dequantize(quantize(op(x))) -> op(dequantize(quantize(x))) +template +class PushDownOpQuantizeDequantize + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::DequantizeOp dequantizeOp, + PatternRewriter &rewriter) const override { + Value scale = dequantizeOp.getScale(); + auto scaleType = cast(scale.getType()); + if (!(scaleType.getRank() == 0 || + (scaleType.getRank() == 1 && scaleType.getDimSize(0) == 1)) || + dequantizeOp.getAxis().has_value()) + return failure(); + auto quantizeOp = + dequantizeOp.getInput().getDefiningOp(); + if (!quantizeOp) + return failure(); + if (quantizeOp.getScale() != scale || quantizeOp.getAxis().has_value()) + return failure(); + + auto op = quantizeOp.getInput().getDefiningOp(); + if (!op) + return failure(); + + auto input = op.getInput(); + RankedTensorType newQuantizedType = input.getType().clone( + quantizeOp.getResult().getType().getElementType()); + auto newQuantizeOp = rewriter.create( + quantizeOp.getLoc(), newQuantizedType, input, scale, + quantizeOp.getAxisAttr()); + RankedTensorType newDequantizedType = newQuantizedType.clone( + dequantizeOp.getResult().getType().getElementType()); + auto newDequantizeOp = rewriter.create( + dequantizeOp.getLoc(), newDequantizedType, newQuantizeOp.getResult(), + scale, dequantizeOp.getAxisAttr()); + auto newOp = + rewriter.create(op.getLoc(), dequantizeOp.getResult().getType(), + newDequantizeOp.getResult(), op->getAttrs()); + assert(dequantizeOp.getType() == newOp.getType()); + rewriter.replaceOp(dequantizeOp, newOp.getResult()); + return success(); + } +}; +} // namespace + +namespace { +// Convert einsum("bij,bjk->bik", %0, %1) -> matrix_multiply(%0, %1) +// by matching different einsum patterns that are supported by the +// matrix_multiply op +class EinsumToMatrixMultiply : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::EinsumOp op, + PatternRewriter &rewriter) const override { + if (op.getInputs().size() != 2) + return failure(); + + EinsumEquation equation; + if (failed(equation.parse(op.getEquation()))) + return failure(); + + char matrixAxis[2] = {0, 0}; + char multipliedAxis = 0; + std::string batchAxes = ""; + + Value inputs[2] = {op.getInputs()[0], op.getInputs()[1]}; + + for (char c : equation.lhsParts[0]) { + if (equation.lhsParts[1].find(c) == std::string::npos) { + if (matrixAxis[0] != 0) + return failure(/* einsum does not match matrix multiply format */); + matrixAxis[0] = c; + } + if (equation.rhs.find(c) == std::string::npos) { + if (multipliedAxis != 0) + return failure(/* einsum does not match matrix multipliy format */); + multipliedAxis = c; + } + } + for (char c : equation.lhsParts[1]) { + if (equation.lhsParts[0].find(c) == std::string::npos) { + if (matrixAxis[1] != 0) + return failure(/* einsum does not match matrix multiply format */); + matrixAxis[1] = c; + } + if (equation.rhs.find(c) == std::string::npos) { + if (multipliedAxis != 0 && multipliedAxis != c) + return failure(/* einsum does not match matrix multiply format */); + multipliedAxis = c; + } + } + + for (size_t i = 0; i < equation.rhs.size(); i++) { + if (equation.lhsParts[0][i] == equation.rhs[i] && + equation.lhsParts[1][i] == equation.rhs[i]) + batchAxes += equation.rhs[i]; + else + break; + } + + if (multipliedAxis == 0) + return failure(); + + if (matrixAxis[0] != 0 && matrixAxis[1] != 0 && + equation.rhs.find(matrixAxis[0]) > equation.rhs.find(matrixAxis[1])) { + // the order of the arguments need to get swapped as the order for a + // matrix multiply requires the first matrix axis appears first + std::swap(equation.lhsParts[0], equation.lhsParts[1]); + std::swap(matrixAxis[0], matrixAxis[1]); + std::swap(inputs[0], inputs[1]); + } + + MatrixOperation opType[2]; + for (int i = 0; i < 2; i++) { + if (matrixAxis[i] == 0) { + if (equation.lhsParts[i] == batchAxes + multipliedAxis) + opType[i] = MatrixOperation::kVECTOR; + else + return failure(/* einsum does not match matrix multiply format */); + } else { + if (equation.lhsParts[i] == + (batchAxes + matrixAxis[i]) + multipliedAxis) + opType[i] = MatrixOperation::kNONE; + else if (equation.lhsParts[i] == + (batchAxes + multipliedAxis) + matrixAxis[i]) + opType[i] = MatrixOperation::kTRANSPOSE; + else + return failure(/* einsum does not match matrix multiply format */); + } + } + + switch (opType[1]) { + case MatrixOperation::kTRANSPOSE: + opType[1] = MatrixOperation::kNONE; + break; + case MatrixOperation::kNONE: + opType[1] = MatrixOperation::kTRANSPOSE; + break; + default:; + } + + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), inputs[0], inputs[1], opType[0], + opType[1]); + + return success(); + } +}; +} // namespace + +namespace { + +// Push down reshape through elementwise op +// elementwise(reshape(x), y) -> reshape(elementwise(x, reshape(y))) +class PushDownReshapeElementwise + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::ElementWiseOp op, + PatternRewriter &rewriter) const override { + bool hasReshapeInput = false; + uint64_t currentEstimatedCost = 0; + Location reshapeLoc = op.getLoc(); + for (Value input : op.getOperands()) { + if (!cast(input.getType()).hasStaticShape()) { + return failure(); + } + if (auto reshape = input.getDefiningOp()) { + if (!reshape.getInput().getType().hasStaticShape()) + return failure(); + hasReshapeInput = true; + reshapeLoc = reshape.getLoc(); + currentEstimatedCost += estimateShuffleCost(input); + } + } + if (!hasReshapeInput) + return failure(); + + if (op.getInput1().getType().getShape() != + op.getInput2().getType().getShape()) + return failure(); + + uint64_t newCost; + auto reshape1 = op.getInput1().getDefiningOp(); + auto reshape2 = op.getInput2().getDefiningOp(); + bool useLhsShape = true; + if (reshape1 && reshape2 && + reshape1.getInput().getType().getShape() == + reshape2.getInput().getType().getShape()) { + newCost = 0; // should always do it + } else { + uint64_t cost1 = estimateShuffleCost(op.getInput1()); + uint64_t cost2 = estimateShuffleCost(op.getInput2()); + if (cost1 < cost2) { + useLhsShape = + false; // want to put the reshape on the rhs as its cost is lower + newCost = cost1; + } else { + useLhsShape = + true; // want to put the reshape on the lhs as its cost is lower + newCost = cost2; + } + } + + if (newCost >= currentEstimatedCost) { + return failure(); + } + + Value newLhs = op.getInput1(); + Value newRhs = op.getInput2(); + while (auto reshape = newLhs.getDefiningOp()) + newLhs = reshape.getInput(); + while (auto reshape = newRhs.getDefiningOp()) + newRhs = reshape.getInput(); + + auto newShape = useLhsShape ? reshape1.getInput().getType().getShape() + : reshape2.getInput().getType().getShape(); + + RankedTensorType newLhsType = op.getInput1().getType().clone(newShape); + RankedTensorType newRhsType = op.getInput2().getType().clone(newShape); + + newLhs = rewriter.createOrFold(op.getLoc(), newLhsType, + newLhs); + newRhs = rewriter.createOrFold(op.getLoc(), newRhsType, + newRhs); + + RankedTensorType elementwiseType = op.getResult().getType().clone(newShape); + auto newElementwiseOp = rewriter.create( + op.getLoc(), elementwiseType, newLhs, newRhs, + op.getElementwiseOperation()); + auto newReshapeOp = rewriter.createOrFold( + reshapeLoc, op.getType(), newElementwiseOp.getResult()); + assert(op.getType() == newReshapeOp.getType()); + rewriter.replaceOp(op, newReshapeOp); + return success(); + } +}; +} // namespace + +namespace { +// Push up reshape through elementwise op +// reshape(elementwise(x, y)) -> elementwise(reshape(x), reshape(y)) +class PushUpReshapeElementwise : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::ReshapeOp op, + PatternRewriter &rewriter) const override { + auto elementwiseOp = op.getInput().getDefiningOp(); + if (!elementwiseOp) + return failure(); + + RankedTensorType type = op.getType(); + if (!type.hasStaticShape()) + return failure(); + + if (elementwiseOp.getInput1().getType().getShape() != + elementwiseOp.getInput2().getType().getShape()) + return failure(); + + // heuristic to check if should apply + Operation *lhsParent = elementwiseOp.getInput1().getDefiningOp(); + Operation *rhsParent = elementwiseOp.getInput2().getDefiningOp(); + bool isLhsParentReshapeOrTransposeOrConstant = + lhsParent && isa(lhsParent); + bool isRhsParentReshapeOrTransposeOrConstant = + rhsParent && isa(rhsParent); + if (!isLhsParentReshapeOrTransposeOrConstant && + !isRhsParentReshapeOrTransposeOrConstant) + return failure(); + + RankedTensorType newLhsType = + type.clone(elementwiseOp.getInput1().getType().getElementType()); + RankedTensorType newRhsType = + type.clone(elementwiseOp.getInput2().getType().getElementType()); + auto newLhs = rewriter.createOrFold( + op.getLoc(), newLhsType, elementwiseOp.getInput1()); + auto newRhs = rewriter.createOrFold( + op.getLoc(), newRhsType, elementwiseOp.getInput2()); + + auto newElementwiseOp = rewriter.create( + elementwiseOp.getLoc(), op.getResult().getType(), newLhs, newRhs, + elementwiseOp.getElementwiseOperation()); + assert(op.getType() == newElementwiseOp.getType()); + rewriter.replaceOp(op, newElementwiseOp.getResult()); + + return success(); + } +}; +} // namespace + +namespace { +// This pattern matches matrix multiply operations whose arguments are produced +// by a transpose that swaps only the last two dimensions. In such cases, it +// absorbs the transpose into the matrix multiply operator by toggling the +// internal transpose flag on the corresponding input, eliminating unnecessary +// explicit transpose operations. +class MatrixMultiplyTransposedArguments + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::MatrixMultiplyOp op, + PatternRewriter &rewriter) const override { + bool didChange = false; + const auto replaceArg = + [&](Value arg, + MatrixOperation operation) -> std::tuple { + if (operation == MatrixOperation::kVECTOR) { + return std::make_tuple(arg, operation); + } + auto transpose = arg.getDefiningOp(); + if (transpose && shouldFuseTranspose(transpose, op)) { + AffineMap perm = transpose.getPermutation(); + // Check if perm swaps its last two axes while keeping everything else + // the same + auto permVec = llvm::to_vector(perm.getResults()); + int64_t rank = permVec.size(); + if (rank < 2) + return std::make_tuple(arg, operation); + bool swapsLastTwo = true; + for (int64_t i = 0; i < rank - 2; ++i) { + auto expr = dyn_cast(permVec[i]); + if (!expr || expr.getPosition() != i) { + swapsLastTwo = false; + break; + } + } + if (swapsLastTwo) { + auto expr1 = dyn_cast(permVec[rank - 2]); + auto expr2 = dyn_cast(permVec[rank - 1]); + if (!(expr1 && expr2 && expr1.getPosition() == rank - 1 && + expr2.getPosition() == rank - 2)) + swapsLastTwo = false; + } + if (swapsLastTwo) { + didChange = true; + return std::make_tuple(transpose.getInput(), + operation == MatrixOperation::kTRANSPOSE + ? MatrixOperation::kNONE + : MatrixOperation::kTRANSPOSE); + } + return std::make_tuple(arg, operation); + } else { + return std::make_tuple(arg, operation); + } + }; + + auto [newLhs, newLhsOp] = replaceArg(op.getInput0(), op.getOp0()); + auto [newRhs, newRhsOp] = replaceArg(op.getInput1(), op.getOp1()); + + if (didChange) { + rewriter.replaceOpWithNewOp( + op, op.getType(), newLhs, newRhs, newLhsOp, newRhsOp); + return success(); + } else { + return failure(); + } + } +}; +} // namespace + +namespace { +// push up transpose through softmax +// softmax(transpose(x)) -> transpose(softmax(x)) +class PushUpTransposeSoftmax : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::TransposeOp op, + PatternRewriter &rewriter) const override { + auto softmax = op.getInput().getDefiningOp(); + if (!softmax) + return failure(); + unsigned axis = softmax.getAxis(); + unsigned newAxis = + inversePermutation(op.getPermutation()).getDimPosition(axis); + auto newTranspose = rewriter.create( + op.getLoc(), softmax.getInput(), op.getPermutation()); + auto newSoftmax = rewriter.create( + softmax.getLoc(), newTranspose, newAxis); + assert(op.getType() == newSoftmax.getType()); + rewriter.replaceOp(op, newSoftmax.getResult()); + return success(); + } +}; +} // namespace + +namespace { +// push down transpose through softmax +// transpose(softmax(x)) -> softmax(transpose(x)) +class PushDownTransposeSoftmax : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::SoftMaxOp op, + PatternRewriter &rewriter) const override { + auto transpose = op.getInput().getDefiningOp(); + if (!transpose) + return failure(); + unsigned axis = op.getAxis(); + unsigned newAxis = transpose.getPermutation().getDimPosition(axis); + auto newSoftmax = rewriter.create( + op.getLoc(), transpose.getInput(), newAxis); + auto newTranspose = rewriter.create( + transpose.getLoc(), newSoftmax, transpose.getPermutation()); + assert(op.getType() == newTranspose.getType()); + rewriter.replaceOp(op, newTranspose.getResult()); + return success(); + } +}; +} // namespace + +namespace { +// push up reshape through softmax +// softmax(reshape(x)) -> reshape(softmax(x)) +class PushUpReshapeSoftmax : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::ReshapeOp op, + PatternRewriter &rewriter) const override { + if (!op.getType().hasStaticShape() || + !op.getInput().getType().hasStaticShape()) + return failure(); + auto softmax = op.getInput().getDefiningOp(); + if (!softmax) + return failure(); + int axis = softmax.getAxis(); + int newAxis = -1; + size_t numInputElements = 1; + size_t numOutputElements = 1; + auto inputType = op.getInput().getType(); + auto outputType = op.getType(); + for (int i = 0, j = 0; i < inputType.getRank(); i++) { + numInputElements *= inputType.getDimSize(i); + while (numOutputElements < numInputElements && j < outputType.getRank()) + numOutputElements *= outputType.getDimSize(j++); + if (i == axis) { + if (numInputElements != numOutputElements || + inputType.getDimSize(i) != outputType.getDimSize(j - 1)) { + return failure(/* the reshape impacts the elements that are getting softmaxed */); + } else { + newAxis = j - 1; + break; + } + } + } + if (newAxis == -1) + return failure(); + auto newReshape = rewriter.create( + op.getLoc(), outputType, softmax.getInput()); + auto newSoftmax = rewriter.create( + softmax.getLoc(), newReshape.getResult(), newAxis); + assert(op.getType() == newSoftmax.getType()); + rewriter.replaceOp(op, newSoftmax.getResult()); + return success(); + } +}; +} // namespace + +namespace { +// push down reshape through softmax +// reshape(softmax(x)) -> softmax(reshape(x)) +class PushDownReshapeSoftmax : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::SoftMaxOp op, + PatternRewriter &rewriter) const override { + auto reshapeOp = op.getInput().getDefiningOp(); + if (!reshapeOp) + return failure(); + int axis = op.getAxis(); + int newAxis = -1; + size_t numInputElements = 1; + size_t numOutputElements = 1; + auto inputType = reshapeOp.getInput().getType(); + auto outputType = reshapeOp.getType(); + for (int i = 0, j = 0; i < outputType.getRank(); i++) { + numOutputElements *= outputType.getDimSize(i); + while (numInputElements < numOutputElements && j < inputType.getRank()) + numInputElements *= inputType.getDimSize(j++); + if (i == axis) { + if (numInputElements != numOutputElements || + inputType.getDimSize(j - 1) != outputType.getDimSize(i)) { + return failure(); + } else { + newAxis = j - 1; + break; + } + } + } + if (newAxis == -1) + return failure(); + auto newSoftmax = rewriter.create( + op.getLoc(), reshapeOp.getInput(), newAxis); + auto newReshape = rewriter.create( + reshapeOp.getLoc(), outputType, newSoftmax.getResult()); + assert(op.getType() == newReshape.getType()); + rewriter.replaceOp(op, newReshape.getResult()); + return success(); + } +}; +} // namespace + +namespace { + +// If there is a transpose that is shuffling an axis that is a 1, then that +// transpose could instead ba a reshape A reshape is preferred over a transpose +// as it should not correspond with rearranging the tensor's memory +class SimpleTransposeToReshape + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensorrt::TransposeOp op, + PatternRewriter &rewriter) const override { + if (!op.getInput().getType().hasStaticShape()) + return failure(); + auto transposeInputType = op.getInput().getType(); + SmallVector transposePerm; + int nonOneCount = 0; + for (int i = 0; i < transposeInputType.getRank(); i++) { + transposePerm.push_back(nonOneCount); + if (transposeInputType.getDimSize(i) != 1) + nonOneCount++; + } + if (!op.getPermutation().isPermutation()) + return failure(/* Transpose is not a permutation */); + + transposePerm = op.getPermutation().compose(transposePerm); + for (int i = 1; i < transposePerm.size(); i++) + if (transposePerm[i - 1] > transposePerm[i]) + return failure(/* Pattern failed to match*/); + + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getInput()); + return success(); + } +}; + +} // namespace + +namespace { +class TransposeReshapeEliminationPass + : public tensorrt::impl::TransposeReshapeEliminationPassBase< + TransposeReshapeEliminationPass> { +public: + using Base::Base; + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + Operation *op = getOperation(); + + // 1), convert ops to a "simpler" form using einsums and reshapes and + // transposes + { + RewritePatternSet patterns(ctx); + patterns.insert, + RankChangeToReshape>(ctx); + ReshapeOp::getCanonicalizationPatternsSameOp(patterns, ctx); + TransposeOp::getCanonicalizationPatterns(patterns, ctx); + if (failed(applyPatternsGreedily(op, std::move(patterns)))) { + emitError(op->getLoc()) + << "failed to apply simplification patterns in " << getArgument(); + return signalPassFailure(); + } + } + + // 1.1), eliminate 1-axis einsums as these are reshapes that can be pushed + // around further + { + RewritePatternSet patterns(ctx); + patterns.insert(ctx); + if (failed(applyPatternsGreedily(op, std::move(patterns)))) { + emitError(op->getLoc()) + << "failed to apply simplification patterns in " << getArgument(); + return signalPassFailure(); + } + } + + // 2) we try to eliminate transpose operations by "pushing down" the + // transpose operations. This involves performing rewrites of the form + // "op(transpose(y))->transpose(op(y))". Often, this will eliminate most + // transpose operations in CNN networks produced by frameworks that use NHWC + // conventions (e.g. Tensorflow and often JAX/Flax models). + { + RewritePatternSet patterns(ctx); + patterns.insert< + PushdownTransposeEwise, TransposeConstantFold, + PushdownTransposeIdentity, PushDownTransposeActivationRewriter, + PushDownTransposeUnary, PushDownTransposeToEinsum, + MoveReshapeBeforeTranspose, PushDownReshapeActivationRewriter, + PushDownReshapeUnaryRewriter, PushDownReshapeIdentityRewriter, + PushDownOpQuantizeDequantize, + PushDownOpQuantizeDequantize, + PushReshapeDownThroughEinsum, PushDownReshapeElementwise, + PushDownTransposeSoftmax, PushDownReshapeSoftmax, + SimpleTransposeToReshape>(ctx, PatternBenefit(1)); + patterns.insert(ctx, PatternBenefit(0)); + TransposeOp::getCanonicalizationPatterns(patterns, ctx); + ExpandRankOp::getCanonicalizationPatterns(patterns, ctx); + ReshapeOp::getCanonicalizationPatternsSameOp(patterns, ctx); + if (failed(applyPatternsGreedily(op, std::move(patterns)))) { + emitError(op->getLoc()) + << "failed to apply pushdown patterns in " << getArgument(); + return signalPassFailure(); + } + } + + // 3) we try to eliminate transpose operations by "pushing up" (commute + // in the reverse direction). This can possible eliminate additional + // transpose ops. + { + RewritePatternSet patterns(ctx); + patterns.insert< + TransposeConstantFold, PushUpTransposeUnary, + PushUpTransposeUnary, PushUpTransposeUnary, + PushUpTransposeElementwise, PushUpTransposeToEinsum, + MoveTransposeBeforeReshape, PushUpReshapeUnary, + PushUpReshapeUnary, PushUpReshapeUnary, + PushUpOpQuantizeDequantize, + PushUpOpQuantizeDequantize, + PushReshapeUpThroughEinsum, PushUpReshapeElementwise, + PushUpTransposeSoftmax, PushUpReshapeSoftmax, + SimpleTransposeToReshape>(ctx, PatternBenefit(2)); + patterns.insert(ctx, PatternBenefit(1)); + patterns.insert(ctx, PatternBenefit(0)); + TransposeOp::getCanonicalizationPatterns(patterns, ctx); + ExpandRankOp::getCanonicalizationPatterns(patterns, ctx); + ReshapeOp::getCanonicalizationPatternsSameOp(patterns, ctx); + if (failed(applyPatternsGreedily(op, std::move(patterns)))) { + emitError(op->getLoc()) + << "failed to apply pushup patterns in " << getArgument(); + return signalPassFailure(); + } + } + + // 4) convert einsums back to matrix multiplies + // (Unsure if this is necessary as TensorRT seems to generate the same + // matrix mulitiply kernels) + { + RewritePatternSet patterns(ctx); + TransposeOp::getCanonicalizationPatterns(patterns, ctx); + ExpandRankOp::getCanonicalizationPatterns(patterns, ctx); + ReshapeOp::getCanonicalizationPatterns( + patterns, ctx); // convert back to expand rank and collapse rank ops + patterns.insert(ctx, PatternBenefit(1)); + patterns.insert< + MatrixMultiplyTransposedArguments, EinsumPushUp1AxisReshape, + EinsumPushUpMultipleMulitipliedAxes, SimpleTransposeToReshape>(ctx); + if (failed(applyPatternsGreedily(op, std::move(patterns)))) { + emitError(op->getLoc()) + << "failed to apply convert back to matrix multiply pattern " + << getArgument(); + return signalPassFailure(); + } + } + + // 4.1) if there are any remaining einsums, merge the transposes back into + // the einsum + { + RewritePatternSet patterns(ctx); + TransposeOp::getCanonicalizationPatterns(patterns, ctx); + ExpandRankOp::getCanonicalizationPatterns(patterns, ctx); + ReshapeOp::getCanonicalizationPatterns( + patterns, ctx); // convert back to expand rank and collapse rank ops + patterns.insert(ctx, PatternBenefit(1)); + patterns + .insert( + ctx); + if (failed(applyPatternsGreedily(op, std::move(patterns)))) { + emitError(op->getLoc()) + << "failed to apply merge stragler transposes to einsum " + << getArgument(); + return signalPassFailure(); + } + } + } +}; +} // namespace diff --git a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/reshape-elimination.mlir b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/reshape-elimination.mlir index 09d909f3e..f5e9f9ca0 100644 --- a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/reshape-elimination.mlir +++ b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/reshape-elimination.mlir @@ -1,4 +1,4 @@ -// RUN: tensorrt-opt %s -split-input-file -tensorrt-reshape-elimination | FileCheck %s +// RUN: tensorrt-opt %s -split-input-file -tensorrt-transpose-reshape-elimination | FileCheck %s func.func @matmul_eliminate_reshape_lhs(%arg0: tensor<1x2x3x4xf16>, %arg1: tensor<4x2xf16>) -> tensor<1x2x3x2xf16>{ %0 = tensorrt.reshape %arg0 : tensor<1x2x3x4xf16> to tensor<6x4xf16> @@ -10,7 +10,7 @@ func.func @matmul_eliminate_reshape_lhs(%arg0: tensor<1x2x3x4xf16>, %arg1: tenso // CHECK-LABEL: @matmul_eliminate_reshape_lhs // CHECK-SAME: (%[[arg0:.+]]: {{.*}}, %[[arg1:.+]]: {{.*}}) -// CHECK: %[[v0:.+]] = tensorrt.expand_rank %[[arg1]] +// CHECK-NEXT: %[[v0:.+]] = tensorrt.expand_rank %[[arg1]] : tensor<4x2xf16> to tensor<1x1x4x2xf16> // CHECK-NEXT: %[[v1:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[arg0]], %[[v0]] : {{.*}}) // CHECK-NEXT: return %[[v1]] @@ -26,7 +26,7 @@ func.func @matmul_eliminate_reshape_lhs_2(%arg0: tensor<1x2x3x4x5x6xf16>, %arg1: // CHECK-LABEL: @matmul_eliminate_reshape_lhs_2 // CHECK-SAME: (%[[arg0:.+]]: {{.*}}, %[[arg1:.+]]: {{.*}}) -// CHECK: %[[v0:.+]] = tensorrt.expand_rank %[[arg1]] +// CHECK-NEXT: %[[v0:.+]] = tensorrt.expand_rank %[[arg1]] : tensor<1x2x6x8xf16> to tensor<1x2x1x1x6x8xf16> // CHECK-NEXT: %[[v1:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[arg0]], %[[v0]] : {{.*}}) // CHECK-NEXT: return %[[v1]] @@ -42,12 +42,12 @@ func.func @matmul_eliminate_reshape_lhs_3(%arg0: tensor<2x2x3x4xf16>, %arg1: ten // CHECK-LABEL: @matmul_eliminate_reshape_lhs_3 // CHECK-SAME: (%[[arg0:.+]]: {{.*}}, %[[arg1:.+]]: {{.*}}) -// CHECK: %[[v0:.+]] = tensorrt.expand_rank %[[arg1]] +// CHECK-NEXT: %[[v0:.+]] = tensorrt.expand_rank %arg1 : tensor<2x4x5xf16> to tensor<2x1x4x5xf16> // CHECK-NEXT: %[[v1:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[arg0]], %[[v0]] : {{.*}}) // CHECK-NEXT: return %[[v1]] // ----- -func.func @matmul_eliminate_reshape_lhs_negative(%arg0: tensor<10x20x30x40x50xf16>, %arg1: tensor<10x600x50x30xf16>) -> tensor<10x20x30x40x30xf16>{ +func.func @matmul_eliminate_reshape_lhs_4(%arg0: tensor<10x20x30x40x50xf16>, %arg1: tensor<10x600x50x30xf16>) -> tensor<10x20x30x40x30xf16>{ %0 = tensorrt.reshape %arg0 : tensor<10x20x30x40x50xf16> to tensor<10x600x40x50xf16> %1 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation, op1 = #tensorrt.matrix_operation} ins(%0, %arg1 : tensor<10x600x40x50xf16>, tensor<10x600x50x30xf16>) -> tensor<10x600x40x30xf16> @@ -55,16 +55,14 @@ func.func @matmul_eliminate_reshape_lhs_negative(%arg0: tensor<10x20x30x40x50xf1 return %2: tensor<10x20x30x40x30xf16> } -// CHECK-LABEL: @matmul_eliminate_reshape_lhs_negative +// CHECK-LABEL: @matmul_eliminate_reshape_lhs_4 // CHECK-SAME: (%[[arg0:.+]]: {{.*}}, %[[arg1:.+]]: {{.*}}) -// CHECK: %[[v0:.+]] = tensorrt.reshape %[[arg0]] -// CHECK-NEXT: %[[v1:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[v0]], %[[arg1]] : {{.*}}) -// CHECK-NEXT: %[[v2:.+]] = tensorrt.reshape %[[v1]] -// CHECK-NEXT: return %[[v2]] - +// CHECK: %[[v0:.+]] = tensorrt.reshape %[[arg1]] +// CHECK-NEXT: %[[v1:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[arg0]], %[[v0]] : {{.*}}) +// CHECK-NEXT: return %[[v1]] // ----- -func.func @matmul_eliminate_reshape_lhs_negative_dynamic(%arg0: tensor<10x?x30x40x50xf16>, %arg1: tensor<10x600x50x30xf16>) -> tensor<10x20x30x40x30xf16>{ +func.func @matmul_eliminate_reshape_lhs_5_dynamic(%arg0: tensor<10x?x30x40x50xf16>, %arg1: tensor<10x600x50x30xf16>) -> tensor<10x20x30x40x30xf16>{ %0 = tensorrt.reshape %arg0 : tensor<10x?x30x40x50xf16> to tensor<10x600x40x50xf16> %1 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation, op1 = #tensorrt.matrix_operation} ins(%0, %arg1 : tensor<10x600x40x50xf16>, tensor<10x600x50x30xf16>) -> tensor<10x600x40x30xf16> @@ -72,16 +70,16 @@ func.func @matmul_eliminate_reshape_lhs_negative_dynamic(%arg0: tensor<10x?x30x4 return %2: tensor<10x20x30x40x30xf16> } -// CHECK-LABEL: @matmul_eliminate_reshape_lhs_negative_dynamic +// CHECK-LABEL: @matmul_eliminate_reshape_lhs_5_dynamic // CHECK-SAME: (%[[arg0:.+]]: {{.*}}, %[[arg1:.+]]: {{.*}}) // CHECK: %[[v0:.+]] = tensorrt.reshape %[[arg0]] -// CHECK-NEXT: %[[v1:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[v0]], %[[arg1]] : {{.*}}) -// CHECK-NEXT: %[[v2:.+]] = tensorrt.reshape %[[v1]] +// CHECK: %[[v1:.+]] = tensorrt.reshape %[[arg1]] +// CHECK-NEXT: %[[v2:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[v0]], %[[v1]] : {{.*}}) // CHECK-NEXT: return %[[v2]] // ----- -func.func @matmul_eliminate_reshape_lhs_negative_2(%arg0: tensor<1x2x3x4xf16>, %arg1: tensor<4x6xf16>) -> tensor<6x2x3xf16>{ +func.func @matmul_push_reshape_lhs_6(%arg0: tensor<1x2x3x4xf16>, %arg1: tensor<4x6xf16>) -> tensor<6x2x3xf16>{ %0 = tensorrt.reshape %arg0 : tensor<1x2x3x4xf16> to tensor<6x4xf16> %1 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation, op1 = #tensorrt.matrix_operation} ins(%0, %arg1 : tensor<6x4xf16>, tensor<4x6xf16>) -> tensor<6x6xf16> @@ -89,12 +87,14 @@ func.func @matmul_eliminate_reshape_lhs_negative_2(%arg0: tensor<1x2x3x4xf16>, % return %2: tensor<6x2x3xf16> } -// CHECK-LABEL: @matmul_eliminate_reshape_lhs_negative_2 +// CHECK-LABEL: @matmul_push_reshape_lhs_6 // CHECK-SAME: (%[[arg0:.+]]: {{.*}}, %[[arg1:.+]]: {{.*}}) -// CHECK: %[[v0:.+]] = tensorrt.reshape %[[arg0]] -// CHECK-NEXT: %[[v1:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[v0]], %[[arg1]] : {{.*}}) -// CHECK-NEXT: %[[v2:.+]] = tensorrt.reshape %[[v1]] -// CHECK-NEXT: return %[[v2]] +// CHECK-DAG: %[[v0:.+]] = tensorrt.reshape %[[arg1]] +// CHECK-DAG: %[[v1:.+]] = tensorrt.transpose {{{.*}}} %[[v0]] +// CHECK-DAG: %[[v2:.+]] = tensorrt.reshape %[[arg0]] +// CHECK-DAG: %[[v3:.+]] = tensorrt.expand_rank %[[v1]] +// CHECK-NEXT: %[[v4:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[v2]], %[[v3]] : {{.*}}) +// CHECK-NEXT: return %[[v4]] // ----- func.func @matmul_simplify_reshape_rhs(%arg0: tensor<10x20x30x40xf16>, %arg1: tensor<200x60x30xf16>) -> tensor<10x20x60x40xf16>{ @@ -124,11 +124,11 @@ func.func @matmul_simplify_reshape_rhs_2(%arg0: tensor<1x2x3x4x5x6xf16>, %arg1: // CHECK-LABEL: @matmul_simplify_reshape_rhs_2 // CHECK-SAME: (%[[arg0:.+]]: {{.*}}, %[[arg1:.+]]: {{.*}}) // CHECK: %[[v0:.+]] = tensorrt.reshape %[[arg1]] -// CHECK-NEXT: %[[v1:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[v0]], %[[arg0]] : {{.*}}) -// CHECK-NEXT: return %[[v1]] +// CHECK-NEXT: %[[v2:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[v0]], %[[arg0]] : {{.*}}) +// CHECK-NEXT: return %[[v2]] // ----- -func.func @matmul_simplify_reshape_rhs_negative(%arg0: tensor<1x2x3x4xf16>, %arg1: tensor<6x6xf16>) -> tensor<1x2x3x4xf16>{ +func.func @matmul_simplify_reshape_rhs_3(%arg0: tensor<1x2x3x4xf16>, %arg1: tensor<6x6xf16>) -> tensor<1x2x3x4xf16>{ %0 = tensorrt.reshape %arg0 : tensor<1x2x3x4xf16> to tensor<6x4xf16> %1 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation, op1 = #tensorrt.matrix_operation} ins(%arg1, %0 : tensor<6x6xf16>, tensor<6x4xf16>) -> tensor<6x4xf16> @@ -136,16 +136,16 @@ func.func @matmul_simplify_reshape_rhs_negative(%arg0: tensor<1x2x3x4xf16>, %arg return %2: tensor<1x2x3x4xf16> } -// CHECK-LABEL: @matmul_simplify_reshape_rhs_negative +// CHECK: @matmul_simplify_reshape_rhs_3 // CHECK-SAME: (%[[arg0:.+]]: {{.*}}, %[[arg1:.+]]: {{.*}}) -// CHECK: %[[v0:.+]] = tensorrt.reshape %[[arg0]] -// CHECK-NEXT: %[[v1:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[arg1]], %[[v0]] : {{.*}}) -// CHECK-NEXT: %[[v2:.+]] = tensorrt.reshape %[[v1]] +// CHECK-DAG: %[[v0:.+]] = tensorrt.reshape %[[arg0]] +// CHECK-DAG: %[[v1:.+]] = tensorrt.reshape %[[arg1]] +// CHECK-NEXT: %[[v2:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[v1]], %[[v0]] : {{.*}}) // CHECK-NEXT: return %[[v2]] // ----- -func.func @matmul_simplify_reshape_rhs_negative_dynamic(%arg0: tensor, %arg1: tensor<1x2x12x6x5xf16>) -> tensor<1x2x3x4x6x6xf16>{ +func.func @matmul_simplify_reshape_rhs_4_dynamic(%arg0: tensor, %arg1: tensor<1x2x12x6x5xf16>) -> tensor<1x2x3x4x6x6xf16>{ %0 = tensorrt.reshape %arg0 : tensor to tensor<1x2x12x5x6xf16> %1 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation, op1 = #tensorrt.matrix_operation} ins(%arg1, %0 : tensor<1x2x12x6x5xf16>, tensor<1x2x12x5x6xf16>) -> tensor<1x2x12x6x6xf16> @@ -153,11 +153,11 @@ func.func @matmul_simplify_reshape_rhs_negative_dynamic(%arg0: tensor } -// CHECK-LABEL: @matmul_simplify_reshape_rhs_negative_dynamic +// CHECK-LABEL: @matmul_simplify_reshape_rhs_4_dynamic // CHECK-SAME: (%[[arg0:.+]]: {{.*}}, %[[arg1:.+]]: {{.*}}) -// CHECK: %[[v0:.+]] = tensorrt.reshape %[[arg0]] -// CHECK-NEXT: %[[v1:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[arg1]], %[[v0]] : {{.*}}) -// CHECK-NEXT: %[[v2:.+]] = tensorrt.reshape %[[v1]] +// CHECK-DAG: %[[v0:.+]] = tensorrt.reshape %[[arg0]] +// CHECK-DAG: %[[v1:.+]] = tensorrt.reshape %[[arg1]] +// CHECK-NEXT: %[[v2:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[v1]], %[[v0]] : {{.*}}) // CHECK-NEXT: return %[[v2]] // ----- diff --git a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-elimination.mlir b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-elimination.mlir index 56a7cf015..5137bf902 100644 --- a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-elimination.mlir +++ b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-elimination.mlir @@ -1,4 +1,4 @@ -// RUN: tensorrt-opt %s -split-input-file -tensorrt-transpose-elimination | FileCheck %s +// RUN: tensorrt-opt %s -split-input-file -tensorrt-transpose-reshape-elimination | FileCheck %s func.func @transpose_const_fold() -> tensor<2x2xi32> { %const = tensorrt.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> @@ -92,7 +92,7 @@ func.func @transpose_pushdown_switch(%arg0: tensor<2x2xf32>, %arg1: tensor<1x2xf // CHECK: #[[$map:.+]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK-LABEL: @transpose_pushdown_switch // CHECK-SAME: (%[[arg0:.+]]: tensor<2x2xf32>, %[[arg1:.+]]: tensor<1x2xf32>) -> tensor<2x2xf32> -// CHECK: %[[v0:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[arg1]] : tensor<1x2xf32> to tensor<2x1xf32> +// CHECK: %[[v0:.+]] = tensorrt.reshape %[[arg1]] : tensor<1x2xf32> to tensor<2x1xf32> // CHECK: %[[v1:.+]] = tensorrt.element_wise (%[[arg0]], %[[v0]] : tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32> // CHECK: %[[v2:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[v1]] : tensor<2x2xf32> to tensor<2x2xf32> // CHECK: return %[[v2]] @@ -224,10 +224,8 @@ func.func @push_up_transpose_elementwise_lhs(%arg0: tensor<1x197x1x64xf32>) -> t // CHECK-LABEL: @push_up_transpose_elementwise_lhs // CHECK-SAME: (%[[arg0:.+]]: {{.*}}) // CHECK-NEXT: %[[cst_f32:.+]] = tensorrt.constant -// CHECK-NEXT: %[[v0:.+]] = tensorrt.reshape %[[cst_f32]] // CHECK-NEXT: %[[v1:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[arg0]] -// CHECK-NEXT: %[[v2:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[v0]] -// CHECK-NEXT: %[[v3:.+]] = tensorrt.element_wise (%[[v2]], %[[v1]] : {{.*}}) +// CHECK-NEXT: %[[v3:.+]] = tensorrt.element_wise (%[[cst_f32]], %[[v1]] : {{.*}}) // CHECK-NEXT: return %[[v3]] // ----- @@ -261,10 +259,8 @@ func.func @push_up_transpose_elementwise_rhs(%arg0: tensor<1x197x1x64xf32>) -> t // CHECK-LABEL: @push_up_transpose_elementwise_rhs // CHECK-SAME: (%[[arg0:.+]]: {{.*}}) // CHECK-NEXT: %[[cst_f32:.+]] = tensorrt.constant -// CHECK-NEXT: %[[v0:.+]] = tensorrt.reshape %[[cst_f32]] // CHECK-NEXT: %[[v1:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[arg0]] -// CHECK-NEXT: %[[v2:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[v0]] -// CHECK-NEXT: %[[v3:.+]] = tensorrt.element_wise (%[[v1]], %[[v2]] : {{.*}}) +// CHECK-NEXT: %[[v3:.+]] = tensorrt.element_wise (%[[v1]], %[[cst_f32]] : {{.*}}) // CHECK-NEXT: return %[[v3]] // ----- @@ -416,9 +412,8 @@ func.func @push_up_transpose_elementwise_reshape_reshape_neg(%arg0: tensor<3152x // CHECK-LABEL: @push_up_transpose_elementwise_reshape_reshape_neg // CHECK-SAME: (%[[arg0:.+]]: {{.*}}) // CHECK-NEXT: %[[cst_f32:.+]] = tensorrt.constant -// CHECK-NEXT: %[[v0:.+]] = tensorrt.reshape %[[cst_f32]] // CHECK-NEXT: %[[v1:.+]] = tensorrt.reshape %[[arg0]] -// CHECK-NEXT: %[[v2:.+]] = tensorrt.element_wise (%[[v1]], %[[v0]] : {{.*}}) +// CHECK-NEXT: %[[v2:.+]] = tensorrt.element_wise (%[[v1]], %[[cst_f32]] : {{.*}}) // CHECK-NEXT: %[[v3:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[v2]] // CHECK-NEXT: return %[[v3]] @@ -458,4 +453,22 @@ func.func @push_up_transpose_elementwise_reshape_transpose_neg(%arg0: tensor<10x // CHECK-NEXT: %[[v1:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[arg1]] // CHECK-NEXT: %[[v2:.+]] = tensorrt.element_wise (%[[v1]], %[[v0]] : {{.*}}) // CHECK-NEXT: %[[v3:.+]] = tensorrt.transpose {permutation = #[[$map1]]} %[[v2]] -// CHECK-NEXT: return %[[v3]] \ No newline at end of file +// CHECK-NEXT: return %[[v3]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4, d2)> +func.func @element_wise_with_two_constants() -> tensor<1x8x20x35x192xf32> { + %cst_f32 = tensorrt.constant dense_resource<__elided__> : tensor<1x8x192x20x35xf32> + %cst_f32_0 = tensorrt.constant dense_resource<__elided__> : tensor<1x8x20x35x192xf32> + %1 = tensorrt.transpose {permutation = #map} %cst_f32 : tensor<1x8x192x20x35xf32> to tensor<1x8x20x35x192xf32> + %2 = tensorrt.element_wise (%1, %cst_f32_0 : tensor<1x8x20x35x192xf32>, tensor<1x8x20x35x192xf32>) -> tensor<1x8x20x35x192xf32> + return %2 : tensor<1x8x20x35x192xf32> +} + +// CHECK: @element_wise_with_two_constants() +// CHECK: %[[const0:.+]] = tensorrt.constant dense_resource<__elided__> : tensor<1x8x192x20x35xf32> +// CHECK: %[[const1:.+]] = tensorrt.constant dense_resource<__elided__> : tensor<1x8x20x35x192xf32> +// CHECK: %[[v0:.+]] = tensorrt.transpose {permutation = #map} %[[const0]] +// CHECK: %[[v1:.+]] = tensorrt.element_wise (%[[v0]], %[[const1]] +// CHECK: return %[[v1]] \ No newline at end of file diff --git a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-reshape-elimination.mlir b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-reshape-elimination.mlir new file mode 100644 index 000000000..624aa7f55 --- /dev/null +++ b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-reshape-elimination.mlir @@ -0,0 +1,368 @@ +// RUN: tensorrt-opt %s -split-input-file -tensorrt-transpose-reshape-elimination | FileCheck %s + + +// CHECK: transpose_merge_with_matmul +// CHECK: %[[out1:.+]] = tensorrt.matrix_multiply +// CHECK: return %[[out1]] +func.func @transpose_merge_with_matmul(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x4x5xf32>) -> tensor<2x3x1x5xf32> { + %1 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation, op1 = #tensorrt.matrix_operation} ins(%arg0, %arg1 : tensor<1x2x3x4xf32>,tensor<1x2x4x5xf32>) -> tensor<1x2x3x5xf32> + %2 = tensorrt.shuffle {first_transpose = array, reshape = array, second_transpose = array, zero_is_placeholder = false} ins(%1 : tensor<1x2x3x5xf32>) -> tensor<2x3x1x5xf32> + return %2 : tensor<2x3x1x5xf32> +} + +// ----- + +// CHECK: @reshape_push_up_through_matmul(%[[arg0:.+]]: tensor<16x1024x1024xbf16>, %[[arg1:.+]]: tensor<1x1024x1024xbf16>) +// CHECK: %[[out:.+]] = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation, op1 = #tensorrt.matrix_operation} ins(%[[arg0]], %[[arg1]] : tensor<16x1024x1024xbf16>, tensor<1x1024x1024xbf16>) +// CHECK: return %[[out]] +func.func @reshape_push_up_through_matmul(%arg0: tensor<16x1024x1024xbf16>, %arg1: tensor<1x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> { + %6 = tensorrt.shuffle {first_transpose = array, reshape = array, second_transpose = array, zero_is_placeholder = false} ins(%arg0 : tensor<16x1024x1024xbf16>) -> tensor<1x16384x1024xbf16> + %7 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation, op1 = #tensorrt.matrix_operation} ins(%6, %arg1 : tensor<1x16384x1024xbf16>, tensor<1x1024x1024xbf16>) -> tensor<1x16384x1024xbf16> + %8 = tensorrt.shuffle {first_transpose = array, reshape = array, second_transpose = array, zero_is_placeholder = false} ins(%7 : tensor<1x16384x1024xbf16>) -> tensor<16x1024x1024xbf16> + return %8 : tensor<16x1024x1024xbf16> +} + +// ----- + +// CHECK: func.func @reshape_push_to_constant(%[[arg0:.+]]: tensor<4x30xf32>) +// CHECK: %[[const:.+]] = tensorrt.constant dense_resource<__elided__> +// CHECK: %[[out:.+]] = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation, op1 = #tensorrt.matrix_operation} ins(%[[arg0]], %[[const]] : {{.*}}) +// CHECK: return %[[out]] +func.func @reshape_push_to_constant(%arg0: tensor<4x30xf32>) -> tensor<4x50xf32> { + %const = tensorrt.constant dense_resource<__elided__> : tensor<5x6x50xf32> + %1 = tensorrt.reshape %arg0 : tensor<4x30xf32> to tensor<4x5x6xf32> + %out = tensorrt.einsum {equation = "abc,bcd->ad"} ins(%1, %const : tensor<4x5x6xf32>, tensor<5x6x50xf32>) -> tensor<4x50xf32> + return %out : tensor<4x50xf32> +} + +// ----- + +// CHECK: @reshape_transpose_push(%[[arg0:.+]]: tensor<2x3x4x5xf32>) +// CHECK: return %[[arg0]] +func.func @reshape_transpose_push(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>{ + %1 = tensorrt.reshape %arg0 : tensor<2x3x4x5xf32> to tensor<6x4x5xf32> + %2 = tensorrt.transpose { permutation = affine_map<(d0, d1, d2) -> (d0, d2, d1)> } %1 : tensor<6x4x5xf32> to tensor<6x5x4xf32> + %3 = tensorrt.reshape %2 : tensor<6x5x4xf32> to tensor<2x3x5x4xf32> + %4 = tensorrt.transpose { permutation = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)> } %3 : tensor<2x3x5x4xf32> to tensor<2x3x4x5xf32> + return %4 : tensor<2x3x4x5xf32> +} + +// ----- + +// CHECK: @reshape_transpose_cant_push(%[[arg0:.+]]: tensor<6x4x5xf32>) +// CHECK: %[[V1:.+]] = tensorrt.reshape %[[arg0]] +// CHECK: %[[V2:.+]] = tensorrt.transpose {{.*}} %[[V1]] +// CHECK: %[[V3:.+]] = tensorrt.reshape %[[V2]] +// CHECK: return %[[V3]] +func.func @reshape_transpose_cant_push(%arg0: tensor<6x4x5xf32>) -> tensor<6x4x5xf32>{ + %1 = tensorrt.reshape %arg0 : tensor<6x4x5xf32> to tensor<2x3x4x5xf32> + %2 = tensorrt.transpose { permutation = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)> } %1 : tensor<2x3x4x5xf32> to tensor<2x5x4x3xf32> + %3 = tensorrt.reshape %2 : tensor<2x5x4x3xf32> to tensor<6x4x5xf32> + return %3 : tensor<6x4x5xf32> +} + +// ----- + +// CHECK: @unary_push_reshape(%[[arg0:.+]]: tensor<10x3xf32>) +// CHECK: %[[out:.+]] = tensorrt.unary {{.*}} %[[arg0]] +// CHECK: return %[[out]] +func.func @unary_push_reshape(%arg0: tensor<10x3xf32>) -> tensor<10x3xf32> { + %1 = tensorrt.reshape %arg0 : tensor<10x3xf32> to tensor<5x6xf32> + %2 = tensorrt.unary { unaryOperation = #tensorrt.unary_operation } %1 : tensor<5x6xf32> + %3 = tensorrt.reshape %2 : tensor<5x6xf32> to tensor<10x3xf32> + return %3 : tensor<10x3xf32> +} + +// ----- + +// CHECK: @identity_push_reshape(%[[arg0:.+]]: tensor<10x3xf32>) +// CHECK: %[[out:.+]] = tensorrt.identity %[[arg0]] +// CHECK: return %[[out]] +func.func @identity_push_reshape(%arg0: tensor<10x3xf32>) -> tensor<10x3xf16> { + %1 = tensorrt.reshape %arg0 : tensor<10x3xf32> to tensor<5x6xf32> + %2 = tensorrt.identity %1 : tensor<5x6xf32> to tensor<5x6xf16> + %3 = tensorrt.reshape %2 : tensor<5x6xf16> to tensor<10x3xf16> + return %3 : tensor<10x3xf16> +} + +// ----- + +// CHECK: @activation_push_reshape(%[[arg0:.+]]: tensor<10x3xf32>) +// CHECK: %[[out:.+]] = tensorrt.activation {{.*}} %[[arg0]] +// CHECK: return %[[out]] +func.func @activation_push_reshape(%arg0: tensor<10x3xf32>) -> tensor<10x3xf32> { + %1 = tensorrt.reshape %arg0 : tensor<10x3xf32> to tensor<5x6xf32> + %2 = tensorrt.activation { activationType = #tensorrt.activation_type} %1 : tensor<5x6xf32> + %3 = tensorrt.reshape %2 : tensor<5x6xf32> to tensor<10x3xf32> + return %3 : tensor<10x3xf32> +} + +// ----- + +// CHECK: transpose_quantize_dequantize_push(%[[arg0:.+]]: tensor<10x5xf32>, %[[scale:.+]]: tensor) +// CHECK: %[[V0:.+]] = tensorrt.quantize in(%[[arg0]] : tensor<10x5xf32>) scale(%[[scale]] : tensor) -> tensor<10x5xi8> +// CHECK: %[[V1:.+]] = tensorrt.dequantize in(%[[V0]] : tensor<10x5xi8>) scale(%[[scale]] : tensor) -> tensor<10x5xf32> +// CHECK: return %[[V1]] +func.func @transpose_quantize_dequantize_push(%arg0: tensor<10x5xf32>, %scale: tensor) -> tensor<10x5xf32> { + %1 = tensorrt.transpose {permutation = affine_map<(d0, d1)->(d1, d0)>} %arg0 : tensor<10x5xf32> to tensor<5x10xf32> + %2 = tensorrt.quantize in(%1 : tensor<5x10xf32>) scale(%scale : tensor) -> tensor<5x10xi8> + %3 = tensorrt.dequantize in(%2 : tensor<5x10xi8>) scale(%scale : tensor) -> tensor<5x10xf32> + %4 = tensorrt.transpose {permutation = affine_map<(d0, d1)->(d1, d0)>} %3 : tensor<5x10xf32> to tensor<10x5xf32> + return %4 : tensor<10x5xf32> +} + +// ----- + +// CHECK: reshape_quantize_dequantize_push(%[[arg0:.+]]: tensor<10x5xf32>, %[[scale:.+]]: tensor) +// CHECK: %[[V0:.+]] = tensorrt.quantize in(%[[arg0]] : tensor<10x5xf32>) scale(%[[scale]] : tensor) -> tensor<10x5xi8> +// CHECK: %[[V1:.+]] = tensorrt.dequantize in(%[[V0]] : tensor<10x5xi8>) scale(%[[scale]] : tensor) -> tensor<10x5xf32> +// CHECK: return %[[V1]] +func.func @reshape_quantize_dequantize_push(%arg0: tensor<10x5xf32>, %scale: tensor) -> tensor<10x5xf32> { + %1 = tensorrt.reshape %arg0 : tensor<10x5xf32> to tensor<5x10xf32> + %2 = tensorrt.quantize in(%1 : tensor<5x10xf32>) scale(%scale : tensor) -> tensor<5x10xi8> + %3 = tensorrt.dequantize in(%2 : tensor<5x10xi8>) scale(%scale : tensor) -> tensor<5x10xf32> + %4 = tensorrt.reshape %3 : tensor<5x10xf32> to tensor<10x5xf32> + return %4 : tensor<10x5xf32> +} + +// ----- + +// CHECK: @matrix_multiply_keep(%[[arg0:.+]]: tensor<1x2x3x4xf32>, %[[arg1:.+]]: tensor<1x2x4x5xf32>) +// CHECK: tensorrt.matrix_multiply +func.func @matrix_multiply_keep(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x4x5xf32>) -> tensor<1x2x3x5xf32> { + %1 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation, op1 = #tensorrt.matrix_operation} ins(%arg0, %arg1 : tensor<1x2x3x4xf32>,tensor<1x2x4x5xf32>) -> tensor<1x2x3x5xf32> + return %1 : tensor<1x2x3x5xf32> +} + +// ----- + +// CHECK: @elementwise_push_down_reshape(%[[arg0:.+]]: tensor<1x2x3x4xf32>) +// CHECK: %[[ret:.+]] = tensorrt.element_wise (%[[arg0]], %[[const:.+]]) +// CHECK: return %[[ret]] +func.func @elementwise_push_down_reshape(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { + %const = tensorrt.constant dense_resource<__elided__> : tensor<2xf32> + %const_1 = tensorrt.expand_rank %const : tensor<2xf32> to tensor<1x2xf32> + %const_2 = tensorrt.broadcast %const_1 broadcast_dims<0, 1> : tensor<1x2xf32> to tensor<12x2xf32> + %1 = tensorrt.reshape %arg0 : tensor<1x2x3x4xf32> to tensor<12x2xf32> + %2 = tensorrt.element_wise (%1, %const_2 : tensor<12x2xf32>, tensor<12x2xf32>) -> tensor<12x2xf32> + %3 = tensorrt.reshape %2 : tensor<12x2xf32> to tensor<1x2x3x4xf32> + return %3 : tensor<1x2x3x4xf32> +} + +// ----- + +// CHECK: @reshape_with_one(%[[arg0:.+]]: tensor<2x3x4x5xf32>) +// CHECK: %[[const:.+]] = tensorrt.constant dense_resource<__elided__> +// CHECK: %[[V0:.+]] = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation, op1 = #tensorrt.matrix_operation} ins(%[[arg0]], %[[const]] : {{.*}}) -> tensor<2x3x4x6xf32> +// CHECK: return %[[V0]] +func.func @reshape_with_one(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x6xf32> { + %const = tensorrt.constant dense_resource<__elided__> : tensor<1x1x2x3x5x6xf32> + %1 = tensorrt.reshape %arg0 : tensor<2x3x4x5xf32> to tensor<1x1x2x3x4x5xf32> + %2 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation, op1 = #tensorrt.matrix_operation} ins(%1, %const : tensor<1x1x2x3x4x5xf32>, tensor<1x1x2x3x5x6xf32>) -> tensor<1x1x2x3x4x6xf32> + %3 = tensorrt.reshape %2 : tensor<1x1x2x3x4x6xf32> to tensor<2x3x4x6xf32> + return %3 : tensor<2x3x4x6xf32> +} + +// ----- + +// CHECK: @elementwise_reshape(%[[arg0:.+]]: tensor<12x3x3xf32>, %[[arg1:.+]]: tensor<12xf32>) +// CHECK: %[[v0:.+]] = tensorrt.expand_rank %[[arg1]] : tensor<12xf32> to tensor<12x1x1xf32> +// CHECK: %[[v1:.+]] = tensorrt.element_wise (%[[arg0]], %[[v0]] : tensor<12x3x3xf32>, tensor<12x1x1xf32>) -> tensor<12x3x3xf32> +// CHECK: %[[v2:.+]] = tensorrt.transpose {permutation = #map} %[[v1]] : tensor<12x3x3xf32> to tensor<12x3x3xf32> +// CHECK: return %[[v2]] +#map = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +func.func @elementwise_reshape(%arg0: tensor<12x3x3xf32>, %arg1: tensor<12xf32>) -> tensor<12x3x3xf32> { + %0 = tensorrt.transpose {permutation = #map} %arg0 : tensor<12x3x3xf32> to tensor<12x3x3xf32> + %1 = tensorrt.expand_rank %arg1 : tensor<12xf32> to tensor<12x1x1xf32> + %2 = tensorrt.element_wise (%0, %1 : tensor<12x3x3xf32>, tensor<12x1x1xf32>) -> tensor<12x3x3xf32> + return %2 : tensor<12x3x3xf32> +} + +// ----- + +// CHECK: @matmul_argument_swap(%[[arg0:.+]]: tensor<1x2x4x3x561xf32>, %[[arg1:.+]]: tensor<1x2x4x3x3xf32>) -> tensor<1x2x4x561x3xf32> +// CHECK-DAG: %[[v0:.+]] = tensorrt.[[op1:.+]] %[[arg0]] : tensor[[shape1:.+]] +// CHECK-DAG: %[[v1:.+]] = tensorrt.[[op2:.+]] %[[v0]] : tensor[[shape2:.+]] +// CHECK-DAG: %[[v2:.+]] = tensorrt.[[op3:.+]] %[[arg1]] : tensor[[shape3:.+]] +// CHECK-DAG: %[[v3:.+]] = tensorrt.[[op4:.+]] %[[v2]] : tensor[[shape4:.+]] +// CHECK: %[[v4:.+]] = tensorrt.matrix_multiply [[params:.+]] ins(%[[v3]], %[[v1]] : {{.*}}) +// CHECK: return %[[v4]] +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4, d3)> +func.func @matmul_argument_swap(%arg0: tensor<1x2x4x3x561xf32>, %arg1: tensor<1x2x4x3x3xf32>) -> tensor<1x2x4x561x3xf32> { + %0 = tensorrt.reshape %arg1 : tensor<1x2x4x3x3xf32> to tensor<8x3x3xf32> + %1 = tensorrt.reshape %arg0 : tensor<1x2x4x3x561xf32> to tensor<8x3x561xf32> + %2 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation, op1 = #tensorrt.matrix_operation} ins(%0, %1 : tensor<8x3x3xf32>, tensor<8x3x561xf32>) -> tensor<8x3x561xf32> + %3 = tensorrt.reshape %2 : tensor<8x3x561xf32> to tensor<1x2x4x3x561xf32> + %4 = tensorrt.transpose {permutation = #map} %3 : tensor<1x2x4x3x561xf32> to tensor<1x2x4x561x3xf32> + return %4 : tensor<1x2x4x561x3xf32> +} + +// ----- + +// CHECK: @transpose_reshape_reorder(%[[arg0:.+]]: tensor<12x256x8x8x16x8xf32>) +// CHECK: %[[v0:.+]] = tensorrt.transpose {permutation = #map} %[[arg0]] : tensor<12x256x8x8x16x8xf32> to tensor<12x8x8x16x8x256xf32> +// CHECK: %[[v1:.+]] = tensorrt.reshape %[[v0]] : tensor<12x8x8x16x8x256xf32> to tensor<12x64x128x256xf32> +// CHECK: return %[[v1]] +#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> +func.func @transpose_reshape_reorder(%arg0: tensor<12x256x8x8x16x8xf32>) -> tensor<12x64x128x256xf32> { + %0 = tensorrt.reshape %arg0 : tensor<12x256x8x8x16x8xf32> to tensor<12x256x64x128xf32> + %1 = tensorrt.transpose {permutation = #map} %0 : tensor<12x256x64x128xf32> to tensor<12x64x128x256xf32> + return %1 : tensor<12x64x128x256xf32> +} + +// ----- + +// CHECK: affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)> +// CHECK: @transpose_softmax(%[[arg0:.+]]: tensor<2x3x4x5xf32>) +// CHECK: %[[v0:.+]] = tensorrt.transpose {permutation = #map} %[[arg0]] : tensor<2x3x4x5xf32> to tensor<2x5x3x4xf32> +// CHECK: %[[v1:.+]] = tensorrt.softmax {axis = 2 : i64} %[[v0]] : tensor<2x5x3x4xf32> +// CHECK: return %[[v1]] +#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> +func.func @transpose_softmax(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x5x3x4xf32> { + %0 = tensorrt.transpose {permutation = #map} %arg0 : tensor<2x3x4x5xf32> to tensor<2x4x5x3xf32> + %1 = tensorrt.softmax {axis = 3 : i64} %0 : tensor<2x4x5x3xf32> + %2 = tensorrt.transpose {permutation = #map} %1 : tensor<2x4x5x3xf32> to tensor<2x5x3x4xf32> + return %2 : tensor<2x5x3x4xf32> +} + +// ----- + +// CHECK: @reshape_softmax(%[[arg0:.+]]: tensor<24x5x6xf32>) +// CHECK: %[[v0:.+]] = tensorrt.softmax {axis = 1 : i64} %[[arg0]] : tensor<24x5x6xf32> +// CHECK: return %[[v0]] +func.func @reshape_softmax(%arg0: tensor<24x5x6xf32>) -> tensor<24x5x6xf32> { + %0 = tensorrt.reshape %arg0 : tensor<24x5x6xf32> to tensor<2x3x4x5x6xf32> + %1 = tensorrt.softmax{axis = 3 : i64} %0 : tensor<2x3x4x5x6xf32> + %2 = tensorrt.reshape %1 : tensor<2x3x4x5x6xf32> to tensor<24x5x6xf32> + return %2 : tensor<24x5x6xf32> +} + +// ----- + +// CHECK: @reshape_softmax_cant_push(%[[arg0:.+]]: tensor<2x3x4x5x6xf32>) +// CHECK: %[[v0:.+]] = tensorrt.reshape %[[arg0]] : tensor<2x3x4x5x6xf32> to tensor<24x10x3xf32> +// CHECK: %[[v1:.+]] = tensorrt.softmax {axis = 1 : i64} %[[v0]] : tensor<24x10x3xf32> +// CHECK: %[[v2:.+]] = tensorrt.reshape %[[v1]] : tensor<24x10x3xf32> to tensor<2x3x4x5x6xf32> +// CHECK: return %[[v2]] +func.func @reshape_softmax_cant_push(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x5x6xf32> { + %0 = tensorrt.reshape %arg0 : tensor<2x3x4x5x6xf32> to tensor<24x10x3xf32> + %1 = tensorrt.softmax {axis = 1 : i64} %0 : tensor<24x10x3xf32> + %2 = tensorrt.reshape %1 : tensor<24x10x3xf32> to tensor<2x3x4x5x6xf32> + return %2 : tensor<2x3x4x5x6xf32> +} + +// ----- + +// CHECK: @reshape_transpose_reorder_ones_dim(%[[arg0:.+]]: tensor<2x1x1x1x1xf32>, %[[arg1:.+]]: tensor<1x2x3x3xf32>) +// CHECK: %[[v0:.+]] = tensorrt.collapse_rank %[[arg0]] : tensor<2x1x1x1x1xf32> to tensor<2x1x1x1xf32> +// CHECK: %[[v1:.+]] = tensorrt.deconvolution [[parmas:.+]] in(%[[arg1]] : tensor<1x2x3x3xf32>) kernelWeights(%[[v0]] : tensor<2x1x1x1xf32>) -> tensor<1x2x3x5xf32> +// CHECK: return %[[v1]] +func.func @reshape_transpose_reorder_ones_dim(%arg0: tensor<2x1x1x1x1xf32>, %arg1: tensor<1x2x3x3xf32>) -> tensor<1x2x3x5xf32> { + %2 = tensorrt.transpose {permutation = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d1, d3, d4)>} %arg0 : tensor<2x1x1x1x1xf32> to tensor<2x1x1x1x1xf32> + %3 = tensorrt.reshape %2 : tensor<2x1x1x1x1xf32> to tensor<2x1x1x1xf32> + %4 = tensorrt.deconvolution {dilation = array, num_groups = 2 : ui32, post_padding = array, pre_padding = array, stride = array} in(%arg1 : tensor<1x2x3x3xf32>) kernelWeights(%3 : tensor<2x1x1x1xf32>) -> tensor<1x2x3x5xf32> + return %4 : tensor<1x2x3x5xf32> +} + +// ----- + +// CHECK: @push_down_transpose_einsum(%[[arg0:.+]]: {{.*}}, %[[arg1:.+]]: {{.*}}) +// CHECK: %[[const0:.+]] = tensorrt.constant dense<1.000000e+00> : {{.*}} +// CHECK: %[[v0:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[arg0]], %[[arg1]] : {{.*}}) +// CHECK: %[[v1:.+]] = tensorrt.reshape %[[v0]] +// CHECK: %[[v2:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[v1]], %[[const0]] : {{.*}}) +// CHECK: return %[[v2]] +func.func @push_down_transpose_einsum(%arg0: tensor<1x6x1500x64xf32>, %arg1: tensor<1x6x1500x1500xf32>) -> tensor<1x1500x384xf32> { + %cst_f32 = tensorrt.constant dense<1.000000e+00> : tensor<384x384xf32> + %0 = tensorrt.reshape %arg0 : tensor<1x6x1500x64xf32> to tensor<6x1500x64xf32> + %1 = tensorrt.reshape %arg1 : tensor<1x6x1500x1500xf32> to tensor<6x1500x1500xf32> + %2 = tensorrt.einsum {equation = "bcd,bec->ebd"} ins(%0, %1 : tensor<6x1500x64xf32>, tensor<6x1500x1500xf32>) -> tensor<1500x6x64xf32> + %3 = tensorrt.reshape %2 : tensor<1500x6x64xf32> to tensor<1x1500x6x64xf32> + %4 = tensorrt.reshape %2 : tensor<1500x6x64xf32> to tensor<1500x384xf32> + %cst_f32_0 = tensorrt.constant dense<1.000000e+00> : tensor<384x6x64xf32> + %5 = tensorrt.einsum {equation = "bde,cde->bc"} ins(%2, %cst_f32_0 : tensor<1500x6x64xf32>, tensor<384x6x64xf32>) -> tensor<1500x384xf32> + %6 = tensorrt.reshape %5 : tensor<1500x384xf32> to tensor<1x1500x384xf32> + return %6 : tensor<1x1500x384xf32> +} + +// ----- + +// CHECK: @multihead_attention +// CHECK: %[[v0:.+]] = tensorrt.matrix_multiply +// CHECK: %[[v1:.+]] = tensorrt.element_wise (%[[v0]], %[[const0:.+]] : {{.*}}) +// CHECK: %[[v2:.+]] = tensorrt.element_wise (%[[v1]], %[[const1:.+]] : {{.*}}) +// CHECK: %[[v3:.+]] = tensorrt.softmax {axis = [[axis:.+]] : i64} %[[v2]] +// CHECK: %[[v4:.+]] = tensorrt.matrix_multiply {{.*}} ins(%[[v3]], %[[values:.+]] : {{.*}}) +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map5 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +func.func @multihead_attention(%arg0: tensor<566x48x64xf32>, %arg1: tensor<566x48x64xf32>, %arg2: tensor<566x48x64xf32>) -> tensor<566x48x64xf32> { + %cst_f32_683 = tensorrt.constant dense<1.000000e+00> : tensor<1x1x1xf32> + %cst_f32_704 = tensorrt.constant dense<0.000000e+00> : tensor<1x1x1xf32> + %312 = tensorrt.transpose {permutation = #map3} %arg2 : tensor<566x48x64xf32> to tensor<48x566x64xf32> + %314 = tensorrt.transpose {permutation = #map3} %arg0 : tensor<566x48x64xf32> to tensor<48x566x64xf32> + %315 = tensorrt.transpose {permutation = #map5} %arg1 : tensor<566x48x64xf32> to tensor<48x64x566xf32> + %316 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation, op1 = #tensorrt.matrix_operation} ins(%314, %315 : tensor<48x566x64xf32>, tensor<48x64x566xf32>) -> tensor<48x566x566xf32> + %317 = tensorrt.element_wise (%316, %cst_f32_683 : tensor<48x566x566xf32>, tensor<1x1x1xf32>) -> tensor<48x566x566xf32> + %318 = tensorrt.element_wise (%317, %cst_f32_704 : tensor<48x566x566xf32>, tensor<1x1x1xf32>) -> tensor<48x566x566xf32> + %319 = tensorrt.softmax {axis = 2 : i64} %318 : tensor<48x566x566xf32> + %320 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation, op1 = #tensorrt.matrix_operation} ins(%319, %312 : tensor<48x566x566xf32>, tensor<48x566x64xf32>) -> tensor<48x566x64xf32> + %321 = tensorrt.transpose {permutation = #map3} %320 : tensor<48x566x64xf32> to tensor<566x48x64xf32> + return %321 : tensor<566x48x64xf32> +} + +// ----- + +// CHECK: @transpose_on_scalar(%[[arg0:.+]]: tensor<4488x4x48xf32>, %[[arg1:.+]]: tensor) +// CHECK: %[[v0:.+]] = tensorrt.expand_rank %[[arg1]] : tensor to tensor<1x1x1xf32> +// CHECK: %[[v1:.+]] = tensorrt.element_wise (%[[arg0]], %[[v0]] : tensor<4488x4x48xf32>, tensor<1x1x1xf32>) -> tensor<4488x4x48xf32> +// CHECK: %[[v2:.+]] = tensorrt.transpose {permutation = #map} %[[v1]] : tensor<4488x4x48xf32> to tensor<4x4488x48xf32> +// CHECK: return %[[v2]] +#map = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +func.func @transpose_on_scalar(%arg0: tensor<4488x4x48xf32>, %arg1: tensor) -> tensor<4x4488x48xf32> { + %0 = tensorrt.transpose {permutation = #map} %arg0 : tensor<4488x4x48xf32> to tensor<4x4488x48xf32> + %1 = tensorrt.expand_rank %arg1 : tensor to tensor<1x1x1xf32> + %2 = tensorrt.element_wise (%0, %1 : tensor<4x4488x48xf32>, tensor<1x1x1xf32>) -> tensor<4x4488x48xf32> + return %2 : tensor<4x4488x48xf32> +} + +// ----- + +// CHECK: @einsum_multiply_two_axis(%[[arg0:.+]]: tensor<10x11x12xf32>, %[[arg1:.+]]: tensor<13x11x12xf32>) +// CHECK-DAG: %[[v0:.+]] = tensorrt.reshape %[[arg0]] : tensor<10x11x12xf32> to tensor<10x132xf32> +// CHECK-DAG: %[[v1:.+]] = tensorrt.reshape %[[arg1]] : tensor<13x11x12xf32> to tensor<13x132xf32> +// CHECK: %[[v2:.+]] = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation, op1 = #tensorrt.matrix_operation} ins(%[[v0]], %[[v1]] : tensor<10x132xf32>, tensor<13x132xf32>) -> tensor<10x13xf32> +// CHECK: return %[[v2]] +func.func @einsum_multiply_two_axis(%arg0: tensor<10x11x12xf32>, %arg1: tensor<13x11x12xf32>) -> tensor<10x13xf32> { + %0 = tensorrt.einsum {equation = "acd,bcd->ab"} ins(%arg0, %arg1: tensor<10x11x12xf32>, tensor<13x11x12xf32>) -> tensor<10x13xf32> + return %0 : tensor<10x13xf32> +} + +// ----- + +// CHECK: @can_not_push_reshape_through_einsum(%[[arg0:.+]]: tensor<2x20x12x64xf32>, %[[arg1:.+]]: tensor<2x12x20x1xf32>) +// CHECK: %[[v0:.+]] = tensorrt.einsum {{{.*}}} ins(%[[arg0]], %[[arg1]] : {{.*}}) +// CHECK: %[[v1:.+]] = tensorrt.reshape %[[v0]] : tensor<2x12x64xf32> to tensor<2x1x768xf32> +// CHECK: return %[[v1]] +func.func @can_not_push_reshape_through_einsum(%arg0: tensor<2x20x12x64xf32>, %arg1: tensor<2x12x20x1xf32>) -> tensor<2x1x768xf32>{ + %0 = tensorrt.einsum {equation = "acbd,abcd->abd"} ins(%arg0, %arg1 : tensor<2x20x12x64xf32>, tensor<2x12x20x1xf32>) -> tensor<2x12x64xf32> + %1 = tensorrt.reshape %0 : tensor<2x12x64xf32> to tensor<2x1x768xf32> + return %1 : tensor<2x1x768xf32> +} + +// ----- + +// CHECK: @push_reshape_broadcast(%[[arg0:.+]]: tensor<6x64x448xf32>, %[[arg1:.+]]: tensor<6x1x448xf32>) +// CHECK: %[[const:.+]] = tensorrt.constant dense_resource<__elided__> : tensor<1x1x384x384xf32> +// CHECK-DAG: %[[v0:.+]] = tensorrt.expand_rank %[[arg0]] : tensor<6x64x448xf32> to tensor<1x1x6x64x448xf32> +// CHECK-DAG: %[[v1:.+]] = tensorrt.expand_rank %[[arg1]] : tensor<6x1x448xf32> to tensor<1x1x6x1x448xf32> +// CHECK: %[[v2:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[v0]], %[[v1]] : {{.*}}) +// CHECK: %[[v3:.+]] = tensorrt.reshape %[[v2]] : tensor<1x1x6x64xf32> to tensor<1x1x384xf32> +// CHECK: %[[v4:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[v3]], %[[const]] : {{.*}}) -> tensor<1x1x384xf32> +// CHECK: return %[[v4]] +func.func @push_reshape_broadcast(%arg0: tensor<6x64x448xf32>, %arg1: tensor<6x1x448xf32>) -> tensor<1x1x384xf32> { + %const = tensorrt.constant dense_resource<__elided__> : tensor<384x6x64xf32> + %1 = tensorrt.einsum {equation = "bdc,bdc->bd"} ins(%arg0, %arg1 : tensor<6x64x448xf32>, tensor<6x1x448xf32>) -> tensor<6x64xf32> + %2 = tensorrt.einsum {equation = "bd,ebd->e"} ins(%1, %const : tensor<6x64xf32>, tensor<384x6x64xf32>) -> tensor<384xf32> + %3 = tensorrt.reshape %2 : tensor<384xf32> to tensor<1x1x384xf32> + return %3 : tensor<1x1x384xf32> +} \ No newline at end of file