Skip to content

Conversation

@matthewfl
Copy link
Contributor

@matthewfl matthewfl commented Aug 1, 2025

The TransposeReshapeElimination pass is designed to subsume the existing Transpose and Reshape Elimination passes. This pass adds a lot of new rewrite rules which enable pushing the transposes and reshapes around so that they can be combined and then eliminated. The rules from the existing TransposeElimination are copied into the TransposeReshapeElimination.cpp file. The rules from the ReshapeElimination pass should be subsumed by the rules added to the TransposeReshapeElimination.

The motivation for this pass is that there are some cases where shuffles can get inserted around matrix multiplications and element wise ops which break various fusions inside of TensorRT.

The process is as follows:

  1. "canonicalize" the network into simpler ops
    • shuffle(x) -> reshape(transpose(reshape(x)))
    • matrix_multiply(x, y) -> einsum("ij,jk->ik", x, y)
    • expand_rank(x) -> reshape(x)
    • collapse_rank(x) -> reshape(x)
  2. Push down reshape and transpose ops as much as possible. Merging and eliminating when possible
    • einsum(transpose(x), ...) -> einsum(x, ...) Merge transpose into einsum
    • einsum(...) -> transpose(einsum(...)) Pull transpose out of einsum (to try to match matrix multiply pattern)
    • einsum(reshape(x), y, ...) -> transpose(reshape(einsum(x, reshape(transpose(y)), ...)) Push reshape down. Possibly add reshape and transposes to other inputs as needed. Conditioned on heuristic checking if "better"
    • EXISTING unary(transpose(x)) -> transpose(unary(x))
    • EXISTING activation(transpose(x)) -> transpose(activation(x))
    • EXSITING identity_op(transpose(x)) -> transpose(identity_op(x))
    • activation(reshape(x)) -> reshape(activation(x))
    • unary(reshape(x)) -> reshape(unary(x))
    • identity_op(reshape(x)) -> reshape(identity_op(x))
    • reshape(transpose(x)) -> transpose(reshape(x)) if possible put reshape before transpose
    • qdq(transpose(x)) -> transpose(qdq(x)) if the scale is 0-dim
    • qdq(reshape(x)) -> reshape(qdq(x)) if the scale is 0-dim
    • EXISTING reshape(reshape(x)) -> reshape(x)
    • EXISTING transpose(transpose(x)) -> transpose(x)
    • EXISTING reshape(x) -> x if reshape is identity
    • EXISTING transpose(x) -> x if transpose is identity
    • elementwise(reshape(a), b) -> reshape(elementwise(a, reshape(b))) conditioned on heuristic
    • EXISTING elementwise(transpose(a), b) -> transpose(elementwise(a, transpose(b)))
  3. Push up reshape and transpose ops as much as possible. Merging and eliminating when possible
    • transpose(einsum(...)) -> einsum(...). Merge transpose into einsum
    • einsum(...) -> einsum(transpose(x), ...). Pull transposes out of einsum (to try to match matrix multiply pattern)
    • reshape(einsum(...)) -> einsum(reshape(transpose(x)), ...) Push reshapes up through einsum. Adding transposes as needed
    • EXISTING transpose(activation(x)) -> activation(transpose(x))
    • EXISTING transpose(unary(x)) -> unary(transpose(x))
    • EXISTING transpose(identity_op(x)) -> identity_op(transpose(x))
    • reshape(activation(x)) -> activation(reshape(x))
    • reshape(unary(x)) -> unary(reshape(x))
    • reshape(identity_op(x)) -> identity_op(reshape(x))
    • EXISTING reshape(reshape(x)) -> reshape(x)
    • EXISTING transpose(transpose(x)) -> transpose(x)
    • EXISTING reshape(x) -> x if reshape is identity
    • EXISTING transpose(x) -> x if transpose is identity
    • transpose(reshape(x)) -> reshape(transpose(x)) if possible put transpose before reshape
    • transpose(qdq(x)) -> qdq(transpose(x)) if the scale is 0-dim
    • reshape(qdq(x)) -> qdq(reshape(x)) if the scale is 0-dim
    • reshape(elementwise(a, b)) -> elementwise(reshape(a), reshape(b))
    • EXISTING transpose(elementwise(a, b)) -> elementwise(transpose(a), transpose(b))
  4. Convert back to matrix multiplication form to assist with TRT's pattern matching
    • einsum(x, y) -> matrix_multiply(x, y) if einsum matches a matrix multiply pattern
    • matrix_multiply(transpose(x), y) -> matrix_multiply(x, y) merge transpose if possible
  5. Final clean ups, additional merging of transpose/reshapes into leftover einsums
    • einsum(x, y) -> matrix_multiply(x, y) if einsum matches a matrix multiply pattern
    • matrix_multiply(transpose(x), y) -> matrix_multiply(x, y) merge transpose if possible
    • transpose(einsum(...)) -> einsum(...)
    • einsum(tranpose(x), ...) -> einsum(...)
    • einsum(collapse_rank(x), ...) -> einsum(...)
    • expand_rank(einsum(...)) -> einsum(...)

NOTE: The overarching goal of this PR is to improve the pattern matching inside of TensorRT (and therefore the quality of kernel's that TensorRT can generate, and fusion that TensorRT will generate). I have some empirical evidence that the mlir that is generated seems to be be an improvement, however I am still not 100% sure what is the best way to generate mlir in some of these edge cases when it comes to getting the fastest model out of TensorRT.

@matthewfl
Copy link
Contributor Author

@christopherbate @shelkesagar29 This is the big PR to eliminate shuffles from the program. Ideally, this should include deleting the ReshapeElimination and TransposeElimination and replacing them with the new pass (including updating their lit tests when necessary to match the new pattern). Not sure what strategy you would like to take to get this merged upstream.

@matthewfl matthewfl changed the title Transpose Reshape Elimination pass [mlir-tensorrt] Transpose Reshape Elimination pass Aug 1, 2025
@christopherbate
Copy link
Collaborator

If the pass can subsume the existing two passes, that's fine. I'd say keep the existing test files in-place and just update the RUN command and FileCheck directives -- this way we can see what changed with respect to tests. A follow on commit can merge the test files if required.

return success();
}
};
} // namespace
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: everything above line 436 of TransposeReshapeElimination.cpp came from the existing TransposeElimination.cpp file

@matthewfl matthewfl force-pushed the mfl/shuffle-elimination branch 2 times, most recently from f173be8 to b6d32fd Compare August 7, 2025 20:11
@@ -1,4 +1,4 @@
// RUN: tensorrt-opt %s -split-input-file -tensorrt-reshape-elimination | FileCheck %s
// RUN: tensorrt-opt %s -split-input-file -tensorrt-transpose-reshape-elimination | FileCheck %s
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@christopherbate I have updated the lit test for reshape and transpose to use the new transpose-reshape-elimination pass

@christopherbate
Copy link
Collaborator

We shouldn't be copy-pasting code and having two versions of the same pattern... either delete the old pattern passes or use a void populateXPatterns(RewritePatternSet &) function to reference the existing patterns that were duplicated.

I'm doing manual copy/paste right now to understand what new code was actually added, which is a bit tedious. In the future, if you could have separate commits for pure code movement vs. changes/new code, then that would be great.

Copy link
Collaborator

@christopherbate christopherbate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding Matul-to-Einsum and Einsum-to-Matmul:

We should avoid the question of which one is better entirely. I really want to avoid making the TRT dialect transforms reason about which one is better since it's an impossible question for encoding into a heuristic. It's not easy or straightforward for us to figure out a general rule regarding where einsum might be better since it depends on the internal details of TensorRT, the TensorRT version, etc, and it is not easily to make robust to future changes.

For that reason, I would recommend to users to make this decision much closer to the frontend. For example, in the stablehlo conversions we explicitly have an option for whether or not to prefer use of tensorrt.einsum or tensorrt.matrix_multiply.

The biggest issue I can see for Einsum vs. MatMul is for issues regarding pattern recognition of special fusions internal to TRT (e.g. MHA) or some either niche feature or optimization that a user may be expecting to see based off of TRT documentation or some other communication. To my knowledge, there's no special pattern that requires einsum, so I would be hesitant to change all matmul to einsum in the main TRT dialect compilation pipeline unless its guarded by an option. There are other users who are using TRT dialect directly at the frontend, and this would be a surprising change for them IMO.

That said, from my personal experience, it's Ok to "prefer einsum" as long as you have a mechanism for side-steping any potential issues related to fusion optimizations that might be ciritical to your workloads. In the case of MHA or other critical fusions for popular models right now, we are adding a tensorrt.attention operation, which should obviate the issue with respect to einsum vs. matmul.

@matthewfl
Copy link
Contributor Author

Thanks for the through review @christopherbate.

I will start working to address the comments on the PR

We should avoid the question of which one is better entirely. I really want to avoid making the TRT dialect transforms reason about which one is better since it's an impossible question for encoding into a heuristic.

The original motivation for eliminating shuffles/transpose/reshapes was that it broke pattern matching in TensorRT in the first place. See this bug for additional context: https://partners.nvidia.com/Bug/ViewBug/5381960
WRT to einsum vs matrix multiply, given the design of the PR is to do matmul -> einsum, followed by optimizations on the einsum representation, and then converting back einsum -> matmul. This means that what the op was originally (einsum or matrix multiply) is not longer available at the time of einsum -> matmul conversion.

Based on observations that I have made of TensorRT, it seems like there is no issue using einsum to represent matrix multiplies, in that it generates the same kernels (but I can't be 100% sure of this given that TRT is closed source and I don't have access to its source).

This is a new pass that is designed to replace the existing Transpose
and Reshape Elemination passes.  This pass adds a lot of new rewrite
rules which enable pushing the transposes and reshapes around so that
they can be combined and then eliminated.

Signed-off-by: Matthew Francis-Landau <[email protected]>

fix transpose lit test

Signed-off-by: Matthew Francis-Landau <[email protected]>

cp

cp

cp

Signed-off-by: Matthew Francis-Landau <[email protected]>

cp

cp

some improvements for matching matmul with reshapes, and update lit tests

Signed-off-by: Matthew Francis-Landau <[email protected]>
@matthewfl matthewfl force-pushed the mfl/shuffle-elimination branch from 353322b to 255995c Compare August 25, 2025 19:27
@matthewfl matthewfl force-pushed the mfl/shuffle-elimination branch 2 times, most recently from c960e12 to 21f60d1 Compare August 25, 2025 19:32
@matthewfl
Copy link
Contributor Author

Hi @christopherbate, I have address the comments you added to the PR and updated the PR.

@matthewfl matthewfl force-pushed the mfl/shuffle-elimination branch from 21f60d1 to 5486d22 Compare August 27, 2025 15:41
@christopherbate
Copy link
Collaborator

@matthewfl sorry for the delay; I'll re-review tomorrow and check that it passes internal tests

Copy link
Collaborator

@shelkesagar29 shelkesagar29 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently there are two issues because of which e2e models like whisper Jax and GPT are failing.

  1. In MoveReshapeBeforeTranspose pattern. I have commented there with minimal failing test case.
  2. In EinsumPushDownTranspose pattern. I have commented with failing test and how to fix the issue.

Once these two are fixed, we are good to merge this one.
This is really good work. Thanks

// CHECK-DAG: %[[v0:.+]]= tensorrt.collapse_rank %[[arg1]] : tensor<1x2x4x3x3xf32> to tensor<2x4x3x3xf32>
// CHECK-DAG: %[[v1:.+]] = tensorrt.transpose {permutation = #map} %[[arg0]] : tensor<1x2x4x3x561xf32> to tensor<2x4x561x3x1xf32>
// CHECK-DAG: %[[v2:.+]] = tensorrt.collapse_rank %[[v1]] : tensor<2x4x561x3x1xf32> to tensor<2x4x561x3xf32>
// CHECK: %[[v3:.+]] = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation<kNONE>, op1 = #tensorrt.matrix_operation<kTRANSPOSE>} ins(%2, %0 : tensor<2x4x561x3xf32>, tensor<2x4x3x3xf32>) -> tensor<2x4x561x3xf32>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't %2 be %[[v2]]. Same for %0

int64_t cost2 = memoryCost(consumer.getType()) + memoryCost(op1.getType());
LLVM_DEBUG(DBGS() << "cost1=" << cost1 << ", cost2=" << cost2 << "\n");
if (cost1 == 0 && cost2 == 0)
return {};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we return one valid op here?
On call site, we are not checking for null op.

Copy link
Contributor Author

@matthewfl matthewfl Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes ping-ponging and a non-terminating infinite loop on the following test
c3b4d85#diff-e275a939421ea167dbf564d5fe9b866e458b4ee87f70c39e08cccbd333d7b44eR474

// einsum allows is more flexible with the inputs, braodcasting dimensions and
// transposing. Hence, we can easily implement rewrites that merge transpose
// into einsum and push reshape through an einsum
class MatmulToEinsum : public OpRewritePattern<tensorrt::MatrixMultiplyOp> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As Chris commented before, we don't want to generally convert MatMul to Einsum.
Is einsum created by this pattern is converted back to matmul?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The einsums are converted back to matmul if they can match the matmul pattern after all of the transposes have been eliminated.

for (auto &[c, i] : outputAxes) {
newEinsumRhs += c;
newEinsumShape.push_back(op.getType().getDimSize(i));
outputPerm.push_back(i);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are using original index to create permutation map but we need to use inverse permutation.
For example, bcd, bec -> ebd is canonicalized to bcd , bec -> bde.
Output ebd [0, 1, 2] is now bde [1, 2, 0] i.e. affine_map (0, 1, 2 -> 1, 2, 0).
To come back to the original state, we need inverse map of above but here we are using outputPerm which is not correct.
Below is test case that fails.

func.func @einsum_failure(%arg0: tensor<1x6x1500x64xf32>, %arg1: tensor<1x6x1500x1500xf32>) -> tensor<1x1500x384xf32> {
  %cst_f32 = tensorrt.constant dense<1.000000e+00> : tensor<384x384xf32>
  %0 = tensorrt.reshape %arg0 : tensor<1x6x1500x64xf32> to tensor<6x1500x64xf32>
  %1 = tensorrt.reshape %arg1 : tensor<1x6x1500x1500xf32> to tensor<6x1500x1500xf32>
  %2 = tensorrt.einsum {equation = "bcd,bec->ebd"} ins(%0, %1 : tensor<6x1500x64xf32>, tensor<6x1500x1500xf32>) -> tensor<1500x6x64xf32>
  %3 = tensorrt.reshape %2 : tensor<1500x6x64xf32> to tensor<1x1500x6x64xf32>
  %4 = tensorrt.reshape %2 : tensor<1500x6x64xf32> to tensor<1500x384xf32>
  %cst_f32_0 = tensorrt.constant dense<1.000000e+00> : tensor<384x6x64xf32>
  %5 = tensorrt.einsum {equation = "bde,cde->bc"} ins(%2, %cst_f32_0 : tensor<1500x6x64xf32>, tensor<384x6x64xf32>) -> tensor<1500x384xf32>
  %6 = tensorrt.reshape %5 : tensor<1500x384xf32> to tensor<1x1500x384xf32>
  return %6 : tensor<1x1500x384xf32>
}

To resolve this, we can rename outputPerm to forwardPerm.
Instead of using

auto newTranspose = rewriter.create<tensorrt::TransposeOp>(
        op.getLoc(), newEinsum.getResult(),
        AffineMap::getPermutationMap(outputPerm, op.getLoc().getContext()));

we can use,

auto forwardMap =
        AffineMap::getPermutationMap(forwardPerm, op.getLoc().getContext());
auto newTranspose = rewriter.create<tensorrt::TransposeOp>(
        op.getLoc(), newEinsum.getResult(), mlir::inversePermutation(forwardMap));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I fixed this also

// Create an new transpose op from an einsum. Rearrange the output axes to
// match the ordering of the input axes This should enable converting the einsum
// back to a matmul einsum(x1, x2, ...) -> transpose(einsum(x1, x2, ...))
class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nit]
Name of the pattern is confusing. It reads like we are pushing transpose above einsum below it.
Probably we can use another name like CanonicalizeEinsumForMatmul.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few patterns that are helping canonicalize einsum to match matmul. There is the EinsumPushDownTranspose, EinsumPushUpTranspose and EinsumPushUpMultipleMulitipliedAxes.

@matthewfl
Copy link
Contributor Author

@shelkesagar29 Hi Sagar, I pushed the code that makes sure that the multihead attention fusion is able to still work. You should be able to test it on your models now.

Let me know if you want me to rebase this on the main branch or squash the commits. I only added new commits so it should be easier to merge with the copy you were testing.

@shelkesagar29
Copy link
Collaborator

@shelkesagar29 Hi Sagar, I pushed the code that makes sure that the multihead attention fusion is able to still work. You should be able to test it on your models now.

Let me know if you want me to rebase this on the main branch or squash the commits. I only added new commits so it should be easier to merge with the copy you were testing.

Hi Matthew,
There are still few issues.
PushReshapeUpThroughEinsum pattern is creating incorrect IR (--debug output shows it).

func.func @repro(%arg0: tensor<2x20x12x64xf32>, %arg1: tensor<2x12x20x1xf32>) -> tensor<2x1x768xf32>{
  %0 = tensorrt.einsum {equation = "acbd,abcd->abd"} ins(%arg0, %arg1 : tensor<2x20x12x64xf32>, tensor<2x12x20x1xf32>) -> tensor<2x12x64xf32>
  %1 = tensorrt.reshape %0 : tensor<2x12x64xf32> to tensor<2x1x768xf32>
  return %1 : tensor<2x1x768xf32>
}

@matthewfl
Copy link
Contributor Author

@shelkesagar29 Thanks for the report. I just pushed a fixed for the issue that you reported.

@matthewfl matthewfl force-pushed the mfl/shuffle-elimination branch from a1c0d1e to 95f586c Compare October 30, 2025 18:51
@matthewfl
Copy link
Contributor Author

@shelkesagar29 I just pushed a fix for the bug you sent via email.

Copy link
Collaborator

@shelkesagar29 shelkesagar29 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All tests pass.
We can merge the PR after following are done.

  1. Can you please check one more time for adhering to LLVM code style? https://llvm.org/docs/CodingStandards.html#don-t-use-braces-on-simple-single-statement-bodies-of-if-else-loop-statements I could find many places where this doesn't hold.
  2. Address current comments
  3. Rebase into single commit with canonical and clear message. This is big change.

Thank you so much for time and patience.


// CHECK: @transpose_reshape_reorder(%[[arg0:.+]]: tensor<12x256x8x8x16x8xf32>)
// CHECK: %[[v0:.+]] = tensorrt.transpose {permutation = #map} %[[arg0]] : tensor<12x256x8x8x16x8xf32> to tensor<12x8x8x16x8x256xf32>
// CHECK: %[[v1:.+]] = tensorrt.reshape %0 : tensor<12x8x8x16x8x256xf32> to tensor<12x64x128x256xf32>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewfl can you please fix this?


std::string generateEquation() const {
std::string ret = lhsParts[0];
for (size_t i = 1; i < lhsParts.size(); i++) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove braces {}

return failure();

SmallVector<std::pair<char, int64_t>> outputAxes;
for (size_t i = 0; i < equation.rhs.size(); i++) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove braces

LLVM_DEBUG({
std::stringstream out;
out << "outputAxes: [";
for (auto x : outputAxes) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove braces

auto input = cast<TypedValue<RankedTensorType>>(op.getInputs()[i]);
RankedTensorType inputType = input.getType();
SmallVector<std::pair<char, int64_t>> inputAxes;
for (int j = 0; j < inputType.getRank(); j++) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove braces

}
});
std::string newEquation = "";
for (size_t j = 0; j < inputAxes.size(); j++) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

equation.lhsParts[i] = newEquation;
didChange = true;
SmallVector<int64_t> perm;
for (size_t j = 0; j < inputAxes.size(); j++) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

SmallVector<int64_t> outShape;
SmallVector<int64_t> inShape;
for (size_t i = 0, j = 0; i < reshapeOutShape.size(); i++) {
if (reshapeOutShape[i] == 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants