Skip to content

Commit f173be8

Browse files
author
Matthew Francis-Landau
committed
Transpose Reshape Elimination pass.
This is a new pass that is designed to replace the existing Transpose and Reshape Elemination passes. This pass adds a lot of new rewrite rules which enable pushing the transposes and reshapes around so that they can be combined and then eliminated. Signed-off-by: Matthew Francis-Landau <[email protected]> fix transpose lit test Signed-off-by: Matthew Francis-Landau <[email protected]> cp cp cp Signed-off-by: Matthew Francis-Landau <[email protected]> cp cp some improvements for matching matmul with reshapes, and update lit tests Signed-off-by: Matthew Francis-Landau <[email protected]>
1 parent deb6e79 commit f173be8

File tree

9 files changed

+2789
-48
lines changed

9 files changed

+2789
-48
lines changed

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3790,6 +3790,7 @@ def TensorRT_ReshapeOp : TensorRT_Op<"reshape",
37903790
let extraClassDeclaration = [{
37913791
/// Returns true if created op is valid for TensorRT major version.
37923792
bool isValidForTensorRTVersion(int64_t trtMajorVersion);
3793+
static void getCanonicalizationPatternsSameOp(RewritePatternSet &results, MLIRContext *context);
37933794
}] # baseClassDeclaration;
37943795
}
37953796

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/Transforms/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,4 +256,18 @@ def ReshapeEliminationPass : Pass<"tensorrt-reshape-elimination"> {
256256
}];
257257
}
258258

259+
//===----------------------------------------------------------------------===//
260+
// TransposeReshapeEliminationPass
261+
//===----------------------------------------------------------------------===//
262+
def TransposeReshapeEliminationPass : Pass<"tensorrt-transpose-reshape-elimination"> {
263+
let summary = "try to eliminate tensorrt.transpose, tensorrt.reshape, and tensorrt.shuffle operations";
264+
265+
let description = [{
266+
Push tensorrt.transpose and tensorrt.reshape operations around to attempt to eleminate them
267+
and merge them with other ops such as matrix multiply. The intention is to improve
268+
pattern matching and fusion inside of TensorRT.
269+
}];
270+
}
271+
272+
259273
#endif // MLIR_TENSORRT_DIALECT_TENSORRT_TRANSFORMS_PASSES

mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1850,6 +1850,13 @@ void tensorrt::ReshapeOp::getCanonicalizationPatterns(
18501850
SimplifyReshapeToRankExpandCollapse>(context);
18511851
}
18521852

1853+
void tensorrt::ReshapeOp::getCanonicalizationPatternsSameOp(
1854+
RewritePatternSet &results, MLIRContext *context) {
1855+
results.insert<ConstFoldReshapePattern<ReshapeOp>, SimplifyReshapeReshape
1856+
// NOT INCLUDED: SimplifyReshapeToRankExpandCollapse
1857+
>(context);
1858+
}
1859+
18531860
void tensorrt::ReshapeOp::build(OpBuilder &builder, OperationState &state,
18541861
Type result, Value input) {
18551862
ReshapeOp::build(builder, state, result, input, Value());

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ add_mtrtd_library(MLIRTensorRTTransforms
2121
RaiseNormalizations.cpp
2222
ReshapeElimination.cpp
2323
TransposeElimination.cpp
24+
TransposeReshapeElimination.cpp
2425

2526
DEPENDS
2627
MLIRTensorRTTransformsActivationsPdllGen

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/Passes.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,7 @@ void tensorrt::buildTensorRTModuleSimplificationPipeline(OpPassManager &pm) {
8686
// Try to eliminate as many `tensorrt.broadcast` ops as possible.
8787
pm.addPass(tensorrt::createBroadcastEliminationPass());
8888
addCleanupPasses(pm);
89-
pm.addPass(tensorrt::createTransposeEliminationPass());
90-
addCleanupPasses(pm);
91-
pm.addPass(tensorrt::createReshapeEliminationPass());
89+
pm.addPass(tensorrt::createTransposeReshapeEliminationPass());
9290
addCleanupPasses(pm);
9391
pm.addPass(tensorrt::createRaiseNormalizationsPass());
9492
addCleanupPasses(pm);

0 commit comments

Comments
 (0)