-
Notifications
You must be signed in to change notification settings - Fork 18
[mlir-tensorrt] Transpose Reshape Elimination pass #686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@christopherbate @shelkesagar29 This is the big PR to eliminate shuffles from the program. Ideally, this should include deleting the |
|
If the pass can subsume the existing two passes, that's fine. I'd say keep the existing test files in-place and just update the RUN command and FileCheck directives -- this way we can see what changed with respect to tests. A follow on commit can merge the test files if required. |
| return success(); | ||
| } | ||
| }; | ||
| } // namespace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: everything above line 436 of TransposeReshapeElimination.cpp came from the existing TransposeElimination.cpp file
f173be8 to
b6d32fd
Compare
| @@ -1,4 +1,4 @@ | |||
| // RUN: tensorrt-opt %s -split-input-file -tensorrt-reshape-elimination | FileCheck %s | |||
| // RUN: tensorrt-opt %s -split-input-file -tensorrt-transpose-reshape-elimination | FileCheck %s | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@christopherbate I have updated the lit test for reshape and transpose to use the new transpose-reshape-elimination pass
|
We shouldn't be copy-pasting code and having two versions of the same pattern... either delete the old pattern passes or use a I'm doing manual copy/paste right now to understand what new code was actually added, which is a bit tedious. In the future, if you could have separate commits for pure code movement vs. changes/new code, then that would be great. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding Matul-to-Einsum and Einsum-to-Matmul:
We should avoid the question of which one is better entirely. I really want to avoid making the TRT dialect transforms reason about which one is better since it's an impossible question for encoding into a heuristic. It's not easy or straightforward for us to figure out a general rule regarding where einsum might be better since it depends on the internal details of TensorRT, the TensorRT version, etc, and it is not easily to make robust to future changes.
For that reason, I would recommend to users to make this decision much closer to the frontend. For example, in the stablehlo conversions we explicitly have an option for whether or not to prefer use of tensorrt.einsum or tensorrt.matrix_multiply.
The biggest issue I can see for Einsum vs. MatMul is for issues regarding pattern recognition of special fusions internal to TRT (e.g. MHA) or some either niche feature or optimization that a user may be expecting to see based off of TRT documentation or some other communication. To my knowledge, there's no special pattern that requires einsum, so I would be hesitant to change all matmul to einsum in the main TRT dialect compilation pipeline unless its guarded by an option. There are other users who are using TRT dialect directly at the frontend, and this would be a surprising change for them IMO.
That said, from my personal experience, it's Ok to "prefer einsum" as long as you have a mechanism for side-steping any potential issues related to fusion optimizations that might be ciritical to your workloads. In the case of MHA or other critical fusions for popular models right now, we are adding a tensorrt.attention operation, which should obviate the issue with respect to einsum vs. matmul.
mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/Transforms/Passes.td
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/Transforms/Passes.td
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Show resolved
Hide resolved
|
Thanks for the through review @christopherbate. I will start working to address the comments on the PR
The original motivation for eliminating shuffles/transpose/reshapes was that it broke pattern matching in TensorRT in the first place. See this bug for additional context: https://partners.nvidia.com/Bug/ViewBug/5381960 Based on observations that I have made of TensorRT, it seems like there is no issue using einsum to represent matrix multiplies, in that it generates the same kernels (but I can't be 100% sure of this given that TRT is closed source and I don't have access to its source). |
This is a new pass that is designed to replace the existing Transpose and Reshape Elemination passes. This pass adds a lot of new rewrite rules which enable pushing the transposes and reshapes around so that they can be combined and then eliminated. Signed-off-by: Matthew Francis-Landau <[email protected]> fix transpose lit test Signed-off-by: Matthew Francis-Landau <[email protected]> cp cp cp Signed-off-by: Matthew Francis-Landau <[email protected]> cp cp some improvements for matching matmul with reshapes, and update lit tests Signed-off-by: Matthew Francis-Landau <[email protected]>
353322b to
255995c
Compare
c960e12 to
21f60d1
Compare
|
Hi @christopherbate, I have address the comments you added to the PR and updated the PR. |
This reverts commit 353322b.
21f60d1 to
5486d22
Compare
|
@matthewfl sorry for the delay; I'll re-review tomorrow and check that it passes internal tests |
…se reshape elimination builds on
… on the output for the reshape groups
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently there are two issues because of which e2e models like whisper Jax and GPT are failing.
- In
MoveReshapeBeforeTransposepattern. I have commented there with minimal failing test case. - In
EinsumPushDownTransposepattern. I have commented with failing test and how to fix the issue.
Once these two are fixed, we are good to merge this one.
This is really good work. Thanks
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/Transforms/Passes.td
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/Transforms/Passes.td
Outdated
Show resolved
Hide resolved
mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td
Outdated
Show resolved
Hide resolved
| // CHECK-DAG: %[[v0:.+]]= tensorrt.collapse_rank %[[arg1]] : tensor<1x2x4x3x3xf32> to tensor<2x4x3x3xf32> | ||
| // CHECK-DAG: %[[v1:.+]] = tensorrt.transpose {permutation = #map} %[[arg0]] : tensor<1x2x4x3x561xf32> to tensor<2x4x561x3x1xf32> | ||
| // CHECK-DAG: %[[v2:.+]] = tensorrt.collapse_rank %[[v1]] : tensor<2x4x561x3x1xf32> to tensor<2x4x561x3xf32> | ||
| // CHECK: %[[v3:.+]] = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation<kNONE>, op1 = #tensorrt.matrix_operation<kTRANSPOSE>} ins(%2, %0 : tensor<2x4x561x3xf32>, tensor<2x4x3x3xf32>) -> tensor<2x4x561x3xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't %2 be %[[v2]]. Same for %0
| int64_t cost2 = memoryCost(consumer.getType()) + memoryCost(op1.getType()); | ||
| LLVM_DEBUG(DBGS() << "cost1=" << cost1 << ", cost2=" << cost2 << "\n"); | ||
| if (cost1 == 0 && cost2 == 0) | ||
| return {}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we return one valid op here?
On call site, we are not checking for null op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fixes ping-ponging and a non-terminating infinite loop on the following test
c3b4d85#diff-e275a939421ea167dbf564d5fe9b866e458b4ee87f70c39e08cccbd333d7b44eR474
| // einsum allows is more flexible with the inputs, braodcasting dimensions and | ||
| // transposing. Hence, we can easily implement rewrites that merge transpose | ||
| // into einsum and push reshape through an einsum | ||
| class MatmulToEinsum : public OpRewritePattern<tensorrt::MatrixMultiplyOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As Chris commented before, we don't want to generally convert MatMul to Einsum.
Is einsum created by this pattern is converted back to matmul?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The einsums are converted back to matmul if they can match the matmul pattern after all of the transposes have been eliminated.
mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp
Show resolved
Hide resolved
| for (auto &[c, i] : outputAxes) { | ||
| newEinsumRhs += c; | ||
| newEinsumShape.push_back(op.getType().getDimSize(i)); | ||
| outputPerm.push_back(i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are using original index to create permutation map but we need to use inverse permutation.
For example, bcd, bec -> ebd is canonicalized to bcd , bec -> bde.
Output ebd [0, 1, 2] is now bde [1, 2, 0] i.e. affine_map (0, 1, 2 -> 1, 2, 0).
To come back to the original state, we need inverse map of above but here we are using outputPerm which is not correct.
Below is test case that fails.
func.func @einsum_failure(%arg0: tensor<1x6x1500x64xf32>, %arg1: tensor<1x6x1500x1500xf32>) -> tensor<1x1500x384xf32> {
%cst_f32 = tensorrt.constant dense<1.000000e+00> : tensor<384x384xf32>
%0 = tensorrt.reshape %arg0 : tensor<1x6x1500x64xf32> to tensor<6x1500x64xf32>
%1 = tensorrt.reshape %arg1 : tensor<1x6x1500x1500xf32> to tensor<6x1500x1500xf32>
%2 = tensorrt.einsum {equation = "bcd,bec->ebd"} ins(%0, %1 : tensor<6x1500x64xf32>, tensor<6x1500x1500xf32>) -> tensor<1500x6x64xf32>
%3 = tensorrt.reshape %2 : tensor<1500x6x64xf32> to tensor<1x1500x6x64xf32>
%4 = tensorrt.reshape %2 : tensor<1500x6x64xf32> to tensor<1500x384xf32>
%cst_f32_0 = tensorrt.constant dense<1.000000e+00> : tensor<384x6x64xf32>
%5 = tensorrt.einsum {equation = "bde,cde->bc"} ins(%2, %cst_f32_0 : tensor<1500x6x64xf32>, tensor<384x6x64xf32>) -> tensor<1500x384xf32>
%6 = tensorrt.reshape %5 : tensor<1500x384xf32> to tensor<1x1500x384xf32>
return %6 : tensor<1x1500x384xf32>
}
To resolve this, we can rename outputPerm to forwardPerm.
Instead of using
auto newTranspose = rewriter.create<tensorrt::TransposeOp>(
op.getLoc(), newEinsum.getResult(),
AffineMap::getPermutationMap(outputPerm, op.getLoc().getContext()));
we can use,
auto forwardMap =
AffineMap::getPermutationMap(forwardPerm, op.getLoc().getContext());
auto newTranspose = rewriter.create<tensorrt::TransposeOp>(
op.getLoc(), newEinsum.getResult(), mlir::inversePermutation(forwardMap));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I fixed this also
| // Create an new transpose op from an einsum. Rearrange the output axes to | ||
| // match the ordering of the input axes This should enable converting the einsum | ||
| // back to a matmul einsum(x1, x2, ...) -> transpose(einsum(x1, x2, ...)) | ||
| class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Nit]
Name of the pattern is confusing. It reads like we are pushing transpose above einsum below it.
Probably we can use another name like CanonicalizeEinsumForMatmul.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few patterns that are helping canonicalize einsum to match matmul. There is the EinsumPushDownTranspose, EinsumPushUpTranspose and EinsumPushUpMultipleMulitipliedAxes.
…ng in the reshape
…can be come matrix multiplies
|
@shelkesagar29 Hi Sagar, I pushed the code that makes sure that the multihead attention fusion is able to still work. You should be able to test it on your models now. Let me know if you want me to rebase this on the main branch or squash the commits. I only added new commits so it should be easier to merge with the copy you were testing. |
Hi Matthew, |
…f one of the axis
|
@shelkesagar29 Thanks for the report. I just pushed a fixed for the issue that you reported. |
a1c0d1e to
95f586c
Compare
|
@shelkesagar29 I just pushed a fix for the bug you sent via email. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All tests pass.
We can merge the PR after following are done.
- Can you please check one more time for adhering to LLVM code style? https://llvm.org/docs/CodingStandards.html#don-t-use-braces-on-simple-single-statement-bodies-of-if-else-loop-statements I could find many places where this doesn't hold.
- Address current comments
- Rebase into single commit with canonical and clear message. This is big change.
Thank you so much for time and patience.
|
|
||
| // CHECK: @transpose_reshape_reorder(%[[arg0:.+]]: tensor<12x256x8x8x16x8xf32>) | ||
| // CHECK: %[[v0:.+]] = tensorrt.transpose {permutation = #map} %[[arg0]] : tensor<12x256x8x8x16x8xf32> to tensor<12x8x8x16x8x256xf32> | ||
| // CHECK: %[[v1:.+]] = tensorrt.reshape %0 : tensor<12x8x8x16x8x256xf32> to tensor<12x64x128x256xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@matthewfl can you please fix this?
|
|
||
| std::string generateEquation() const { | ||
| std::string ret = lhsParts[0]; | ||
| for (size_t i = 1; i < lhsParts.size(); i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove braces {}
| return failure(); | ||
|
|
||
| SmallVector<std::pair<char, int64_t>> outputAxes; | ||
| for (size_t i = 0; i < equation.rhs.size(); i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove braces
| LLVM_DEBUG({ | ||
| std::stringstream out; | ||
| out << "outputAxes: ["; | ||
| for (auto x : outputAxes) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove braces
| auto input = cast<TypedValue<RankedTensorType>>(op.getInputs()[i]); | ||
| RankedTensorType inputType = input.getType(); | ||
| SmallVector<std::pair<char, int64_t>> inputAxes; | ||
| for (int j = 0; j < inputType.getRank(); j++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove braces
| } | ||
| }); | ||
| std::string newEquation = ""; | ||
| for (size_t j = 0; j < inputAxes.size(); j++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
| equation.lhsParts[i] = newEquation; | ||
| didChange = true; | ||
| SmallVector<int64_t> perm; | ||
| for (size_t j = 0; j < inputAxes.size(); j++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
| SmallVector<int64_t> outShape; | ||
| SmallVector<int64_t> inShape; | ||
| for (size_t i = 0, j = 0; i < reshapeOutShape.size(); i++) { | ||
| if (reshapeOutShape[i] == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
The
TransposeReshapeEliminationpass is designed to subsume the existing Transpose and Reshape Elimination passes. This pass adds a lot of new rewrite rules which enable pushing the transposes and reshapes around so that they can be combined and then eliminated. The rules from the existingTransposeEliminationare copied into theTransposeReshapeElimination.cppfile. The rules from theReshapeEliminationpass should be subsumed by the rules added to theTransposeReshapeElimination.The motivation for this pass is that there are some cases where shuffles can get inserted around matrix multiplications and element wise ops which break various fusions inside of TensorRT.
The process is as follows:
shuffle(x)->reshape(transpose(reshape(x)))matrix_multiply(x, y)->einsum("ij,jk->ik", x, y)expand_rank(x)->reshape(x)collapse_rank(x)->reshape(x)reshapeandtransposeops as much as possible. Merging and eliminating when possibleeinsum(transpose(x), ...)->einsum(x, ...)Merge transpose into einsumeinsum(...)->transpose(einsum(...))Pull transpose out of einsum (to try to match matrix multiply pattern)einsum(reshape(x), y, ...)->transpose(reshape(einsum(x, reshape(transpose(y)), ...))Push reshape down. Possibly add reshape and transposes to other inputs as needed. Conditioned on heuristic checking if "better"unary(transpose(x))->transpose(unary(x))activation(transpose(x))->transpose(activation(x))identity_op(transpose(x))->transpose(identity_op(x))activation(reshape(x))->reshape(activation(x))unary(reshape(x))->reshape(unary(x))identity_op(reshape(x))->reshape(identity_op(x))reshape(transpose(x))->transpose(reshape(x))if possible put reshape before transposeqdq(transpose(x))->transpose(qdq(x))if the scale is 0-dimqdq(reshape(x))->reshape(qdq(x))if the scale is 0-dimreshape(reshape(x))->reshape(x)transpose(transpose(x))->transpose(x)reshape(x)->xifreshapeis identitytranspose(x)->xiftransposeis identityelementwise(reshape(a), b)->reshape(elementwise(a, reshape(b)))conditioned on heuristicelementwise(transpose(a), b)->transpose(elementwise(a, transpose(b)))reshapeandtransposeops as much as possible. Merging and eliminating when possibletranspose(einsum(...))->einsum(...). Merge transpose into einsumeinsum(...)->einsum(transpose(x), ...). Pull transposes out of einsum (to try to match matrix multiply pattern)reshape(einsum(...))->einsum(reshape(transpose(x)), ...)Push reshapes up through einsum. Adding transposes as neededtranspose(activation(x))->activation(transpose(x))transpose(unary(x))->unary(transpose(x))transpose(identity_op(x))->identity_op(transpose(x))reshape(activation(x))->activation(reshape(x))reshape(unary(x))->unary(reshape(x))reshape(identity_op(x))->identity_op(reshape(x))reshape(reshape(x))->reshape(x)transpose(transpose(x))->transpose(x)reshape(x)->xifreshapeis identitytranspose(x)->xiftransposeis identitytranspose(reshape(x))->reshape(transpose(x))if possible put transpose before reshapetranspose(qdq(x))->qdq(transpose(x))if the scale is 0-dimreshape(qdq(x))->qdq(reshape(x))if the scale is 0-dimreshape(elementwise(a, b))->elementwise(reshape(a), reshape(b))transpose(elementwise(a, b))->elementwise(transpose(a), transpose(b))einsum(x, y)->matrix_multiply(x, y)if einsum matches a matrix multiply patternmatrix_multiply(transpose(x), y)->matrix_multiply(x, y)merge transpose if possibleeinsum(x, y)->matrix_multiply(x, y)if einsum matches a matrix multiply patternmatrix_multiply(transpose(x), y)->matrix_multiply(x, y)merge transpose if possibletranspose(einsum(...))->einsum(...)einsum(tranpose(x), ...)->einsum(...)einsum(collapse_rank(x), ...)->einsum(...)expand_rank(einsum(...))->einsum(...)NOTE: The overarching goal of this PR is to improve the pattern matching inside of TensorRT (and therefore the quality of kernel's that TensorRT can generate, and fusion that TensorRT will generate). I have some empirical evidence that the mlir that is generated seems to be be an improvement, however I am still not 100% sure what is the best way to generate mlir in some of these edge cases when it comes to getting the fastest model out of TensorRT.