Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3875,6 +3875,11 @@ def TensorRT_ReshapeOp : TensorRT_Op<"reshape",
let extraClassDeclaration = [{
/// Returns true if created op is valid for TensorRT major version.
bool isValidForTensorRTVersion(int64_t trtMajorVersion);

/// Get canonicalization patterns which rewrite as ReshapeOp and
/// do NOT include rewrites which do not get to a different kind of Op
/// (e.g. ExpandRankOp, CollapseRankOp).
static void getCanonicalizationPatternsSameOp(RewritePatternSet &results, MLIRContext *context);
}] # baseClassDeclaration;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,84 +176,87 @@ def LegalizeInt8Pass : Pass<"tensorrt-legalize-int8", "func::FuncOp"> {
}

//===----------------------------------------------------------------------===//
// TransposeEliminationPass
// TransposeReshapeEliminationPass
//===----------------------------------------------------------------------===//
def TransposeEliminationPass : Pass<"tensorrt-transpose-elimination"> {
let summary = "try to eliminate tensorrt.transpose operations";
def TransposeReshapeEliminationPass : Pass<"tensorrt-transpose-reshape-elimination"> {
let summary = "try to eliminate tensorrt.transpose, tensorrt.reshape, and tensorrt.shuffle operations";

let description = [{

It is well-known that excessive number of transpose ops (either
"tensorrt.transpose" or "tensorrt.shuffle" operations with identity reshape)
can cause performance issues with TensorRT. This commonly occurs when the
input source being converted represents convolutions in "NHWC" format vs.
TensorRT's preferred "NCHW" format. In the conversion of these types of
It is well-known that excessive number of transpose or reshapes ops (either
"tensorrt.transpose", "tensorrt.reshape" or "tensorrt.shuffle")
can cause performance issues with TensorRT. For example, this commonly occurs
when the input source being converted represents convolutions in "NHWC" format
vs. TensorRT's preferred "NCHW" format. In the conversion of these types of
convolutions, a number of transpose operations must be inserted. These
transpose operations can prevent fusions. For example, a transpose operation
between a convolution and a pointwise addition can prevent convolution-bias
fusion.

This pass tries to eliminate transpose operations by applying the following
patterns in a greedy manner:

1) rotating `tensorrt.transpose` "forward" certain computational operations,
especially `tensorrt.element_wise` ops. This means that the transpose will
be applied to the result of the elementwise operation as well as the other
branch of the operation. To avoid an infinite ping-pong application of this
pattern certain heuristics are applied to determine whether or not this is
beneficial. For example:

```
func.func @transpose_pushdown_switch(%arg0: tensor<2x2xf32>, %arg1: tensor<1x2xf32>)
-> tensor<2x2xf32> {
%1 = tensorrt.transpose {
permutation = affine_map<(d0, d1)->(d1, d0)>
} %arg0 : tensor<2x2xf32> to tensor<2x2xf32>
%2 = tensorrt.element_wise <kSUM> (
%1, %arg1: tensor<2x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
return %2 : tensor<2x2xf32>
}
```

becomes

```
func.func @transpose_pushdown_switch(%arg0: tensor<2x2xf32>,
%arg1: tensor<1x2xf32>) -> tensor<2x2xf32> {
%0 = tensorrt.transpose {permutation = #map}
%arg1 : tensor<1x2xf32> to tensor<2x1xf32>
%1 = tensorrt.element_wise <kSUM>
(%arg0, %0 : tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
%2 = tensorrt.transpose {permutation = #map} %1 : tensor<2x2xf32> to tensor<2x2xf32>
return %2 : tensor<2x2xf32>
}
```

In this case, moving the transpose to the other branch results in lower
memory cost on the inputs, but higher total memory cost (because a transpose
on the result is also added). However, we always prefer to push transpose
operations as far forward as possible in this transformation.

2) Const-folding transpose operations. Often, it is undesirable to let
weights be transposed at runtime. Instead, weights should be pre-transformed
to put them into a form that is suitable for TensorRT convolutions.
Therefore, we apply global transpose-const folding. This can be quite
expensive for large weights but is important to reduce runtime transpose
costs.

fusion. Fusions can also be blocked if a reshape is placed between a
matrix multiplication and an activation.

The pass tries to eliminate transposes and reshapes by pushing the transposes
and reshapes around to combine them into identity transposes and reshapes which
can be eliminated from the program. To accomplish this, this pass uses several
rewrite rules that push transposes and reshapes from the output of an Operation
to the operation's input, and vice-versa. The rewrites currently included in this
pass handle common cases, though currently does not handle every possible scenario---one
may wish to extend this pass in the future as needed.

The process is as follows:
1) Normalize transpose, reshape, shuffle and matrix multiply into a common set of ops.
- shuffle(x) -> reshape(transpose(reshape(x)))
- matrix_multiply(x, y) -> einsum("ij,jk->ik", x, y)
- expand_rank(x) -> reshape(x)
- collapse_rank(x) -> reshape(x)
2) Push down reshape and transpose, eliminating when possible. E.g. op(transpose(x)) -> transpose(op(x))
- einsum(transpose(x), ...) -> einsum(x, ...)
- einsum(...) -> transpose(einsum(...)) Pull transposes out of einsum (to try to match matrix multiply pattern)
- einsum(reshape(x), y, ...) -> transpose(reshape(einsum(x, reshape(transpose(y)), ...))
- unary(transpose(x)) -> transpose(unary(x))
- activation(transpose(x)) -> transpose(activation(x))
- identity_op(transpose(x)) -> transpose(identity_op(x))
- activation(reshape(x)) -> reshape(activation(x))
- unary(reshape(x)) -> reshape(unary(x))
- identity_op(reshape(x)) -> reshape(identity_op(x))
- reshape(transpose(x)) -> transpose(reshape(x)) if possible put reshape before transpose
- dequantize(quantize(transpose(x))) -> transpose(dequantize(quantize((x))) if the scale is 0-dim
- dequantize(quantize(reshape(x))) -> reshape(dequantize(quantize(x))) if the scale is 0-dim
- reshape(reshape(x)) -> reshape(x)
- transpose(transpose(x)) -> transpose(x)
- reshape(x) -> x if reshape is identity
- transpose(x) -> x if transpose is identity
- elementwise(reshape(a), b) -> reshape(elementwise(a, reshape(b))) conditioned on heuristic
- elementwise(transpose(a), b) -> transpose(elementwise(a, transpose(b)))
3) Push up reshape and transpose, eliminating when possible. E.g. transpose(op(x)) -> op(transpose(x))
- transpose(einsum(...)) -> einsum(...)
- einsum(...) -> einsum(transpose(x), ...) Pull transposes out of einsum (to try to match matrix multiply pattern)
- reshape(einsum(...)) -> einsum(reshape(transpose(x)), ...)
- transpose(activation(x)) -> activation(transpose(x))
- transpose(unary(x)) -> unary(transpose(x))
- transpose(identity_op(x)) -> identity_op(transpose(x))
- reshape(activation(x)) -> activation(reshape(x))
- reshape(unary(x)) -> unary(reshape(x))
- reshape(identity_op(x)) -> identity_op(reshape(x))
- reshape(reshape(x)) -> reshape(x)
- transpose(transpose(x)) -> transpose(x)
- reshape(x) -> x if reshape is identity
- transpose(x) -> x if transpose is identity
- transpose(reshape(x)) -> reshape(transpose(x)) if possible put transpose before reshape
- transpose(dequantize(quantize(x))) -> dequantize(quantize(transpose(x))) if the scale is 0-dim
- reshape(dequantize(quantize(x))) -> dequantize(quantize(reshape(x))) if the scale is 0-dim
- reshape(elementwise(a, b)) -> elementwise(reshape(a), reshape(b))
- transpose(elementwise(a, b)) -> elementwise(transpose(a), transpose(b))
4) Final clean up. Fuse leftover transposes and reshapes with other ops.
- einsum("ij,jk->ik", x, y) -> matrix_multiply(x, y) if einsum matches a matrix multiply pattern
- matrix_multiply(transpose(x), y) -> matrix_multiply(x, y) merge transpose if possible
- transpose(einsum(...)) -> einsum(...)
- einsum(transpose(x), ...) -> einsum(...)
- einsum(collapse_rank(x), ...) -> einsum(...)
- expand_rank(einsum(...)) -> einsum(...)

To avoid an infinite ping-pong application of these patterns, heuristics are
applied to determine when a pattern is beneficial.
}];
}

//===----------------------------------------------------------------------===//
// ReshapeEliminationPass
//===----------------------------------------------------------------------===//
def ReshapeEliminationPass : Pass<"tensorrt-reshape-elimination"> {
let summary = "try to eliminate tensorrt.reshape operations";

let description = [{
Reshape elimination pass captures pattern with un-necessary reshape and
simplifies it by eliminating reshape operations.
}];
}

#endif // MLIR_TENSORRT_DIALECT_TENSORRT_TRANSFORMS_PASSES
11 changes: 6 additions & 5 deletions mlir-tensorrt/tensorrt/lib/TensorRT/IR/EinsumHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,11 @@ static LogicalResult validateInputsSubscript(const IOSubscripts &subscripts,
// match. for example, ('ij,jk->ik', a, b) is valid for a =
// tensor<4x5xf32>, b = tensor<5x6xf32> but invalid for a =
// tensor<4x6xf32>, b = tensor<5x6xf32>
if (allLabelDims.count(label) == 0) {
allLabelDims.insert(std::pair<char, int64_t>(label, dimension));
// Einsum also supports broadcasting
if (allLabelDims.count(label) == 0 || allLabelDims[label] == 1) {
allLabelDims[label] = dimension;
} else {
if (allLabelDims[label] != dimension)
if (allLabelDims[label] != dimension && dimension != 1)
return emitErrorFn(loc, Twine("label `") + Twine(label) +
Twine("` is repeated between inputs but "
"dimensions are not same"));
Expand Down Expand Up @@ -203,8 +204,8 @@ static LogicalResult inferOutputShapeImpl(const IOSubscripts &ioSubscripts,
llvm::zip((ioSubscripts).inputs, inputOperands)) {
for (const auto &[label, dims] :
llvm::zip(subscript, cast<RankedTensorType>(operand).getShape()))
if (inputLabelsDims.count(label) == 0)
inputLabelsDims.insert(std::pair<char, int64_t>(label, dims));
if (inputLabelsDims.count(label) == 0 || inputLabelsDims[label] == 1)
inputLabelsDims[label] = dims;
}

for (const auto &label : (ioSubscripts).outputs) {
Expand Down
7 changes: 7 additions & 0 deletions mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1850,6 +1850,13 @@ void tensorrt::ReshapeOp::getCanonicalizationPatterns(
SimplifyReshapeToRankExpandCollapse>(context);
}

void tensorrt::ReshapeOp::getCanonicalizationPatternsSameOp(
RewritePatternSet &results, MLIRContext *context) {
results.insert<ConstFoldReshapePattern<ReshapeOp>, SimplifyReshapeReshape
// NOT INCLUDED: SimplifyReshapeToRankExpandCollapse
>(context);
}

void tensorrt::ReshapeOp::build(OpBuilder &builder, OperationState &state,
Type result, Value input) {
ReshapeOp::build(builder, state, result, input, Value());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ add_mtrtd_library(MLIRTensorRTTransforms
Passes.cpp
RaiseActivations.cpp
RaiseNormalizations.cpp
ReshapeElimination.cpp
TransposeElimination.cpp
TransposeReshapeElimination.cpp

DEPENDS
MLIRTensorRTTransformsActivationsPdllGen
Expand Down
4 changes: 1 addition & 3 deletions mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ void tensorrt::buildTensorRTModuleSimplificationPipeline(OpPassManager &pm) {
// Try to eliminate as many `tensorrt.broadcast` ops as possible.
pm.addPass(tensorrt::createBroadcastEliminationPass());
addCleanupPasses(pm);
pm.addPass(tensorrt::createTransposeEliminationPass());
addCleanupPasses(pm);
pm.addPass(tensorrt::createReshapeEliminationPass());
pm.addPass(tensorrt::createTransposeReshapeEliminationPass());
addCleanupPasses(pm);
pm.addPass(tensorrt::createRaiseNormalizationsPass());
addCleanupPasses(pm);
Expand Down
Loading