Skip to content

Commit a1c0d1e

Browse files
author
Matthew Francis-Landau
committed
fix issue with pushing reshape when broadcasting
1 parent 35bc5fd commit a1c0d1e

File tree

2 files changed

+43
-6
lines changed

2 files changed

+43
-6
lines changed

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,7 +1423,7 @@ class PushReshapeUpThroughEinsum
14231423

14241424
LLVM_DEBUG({
14251425
std::stringstream out;
1426-
out << "==== Einsum Reshape/Transpose Pushdown Debug ====\n";
1426+
out << "==== Einsum Reshape/Transpose Pushup Debug ====\n";
14271427
for (const auto &entry : charToGroup) {
14281428
out << " charToGroup[" << entry.first << "] = " << entry.second
14291429
<< "\n";
@@ -1492,8 +1492,16 @@ class PushReshapeUpThroughEinsum
14921492
for (char c : group->second)
14931493
newInputTranspose.push_back(equation.lhsParts[i].find(c));
14941494
newInputEquation += inputToReshapedMap[group->second].newAxes;
1495-
for (int64_t v : inputToReshapedMap[group->second].newShape)
1496-
newInputShape.push_back(v);
1495+
for (int64_t v : inputToReshapedMap[group->second].newShape) {
1496+
if (v != 1 && group->second.size() == 1 &&
1497+
inputType.getDimSize(j) == 1) {
1498+
// if the group is of size 1, then it can have different sizes for
1499+
// each input due to broadcasting
1500+
newInputShape.push_back(1);
1501+
} else {
1502+
newInputShape.push_back(v);
1503+
}
1504+
}
14971505
}
14981506
}
14991507

@@ -1516,6 +1524,13 @@ class PushReshapeUpThroughEinsum
15161524
out << ", ";
15171525
}
15181526
out << "]\n";
1527+
out << " oldShape: [";
1528+
for (size_t si = 0; si < inputType.getShape().size(); ++si) {
1529+
out << inputType.getShape()[si];
1530+
if (si + 1 < inputType.getShape().size())
1531+
out << ", ";
1532+
}
1533+
out << "]\n";
15191534
DBGS() << out.str() << "\n";
15201535
});
15211536

@@ -1529,11 +1544,13 @@ class PushReshapeUpThroughEinsum
15291544
newInputs.push_back(newReshape);
15301545
newEquation.lhsParts.push_back(newInputEquation);
15311546
}
1532-
LLVM_DEBUG(
1533-
{ DBGS() << "===============================================\n"; });
1534-
15351547
std::string newEquationStr = newEquation.generateEquation();
15361548

1549+
LLVM_DEBUG({
1550+
DBGS() << newEquationStr << "\n"
1551+
<< "===============================================\n";
1552+
});
1553+
15371554
auto newEinsum = rewriter.create<tensorrt::EinsumOp>(
15381555
einsum.getLoc(), op.getType(), newInputs, newEquationStr);
15391556
assert(newEquation.rhs.size() == newEinsum.getType().getShape().size());
@@ -3056,7 +3073,9 @@ class TransposeReshapeEliminationPass
30563073
PushUpReshapeUnary<UnaryOp>, PushUpReshapeUnary<ActivationOp>,
30573074
PushUpOpQuantizeDequantize<tensorrt::TransposeOp>,
30583075
PushUpOpQuantizeDequantize<tensorrt::ReshapeOp>,
3076+
30593077
PushReshapeUpThroughEinsum, PushUpReshapeElementwise,
3078+
30603079
PushUpTransposeSoftmax, PushUpReshapeSoftmax,
30613080
SimpleTransposeToReshape>(ctx, PatternBenefit(2));
30623081
patterns.insert<EinsumPushUpTranspose>(ctx, PatternBenefit(1));

mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-reshape-elimination.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,4 +347,22 @@ func.func @can_not_push_reshape_through_einsum(%arg0: tensor<2x20x12x64xf32>, %a
347347
%0 = tensorrt.einsum {equation = "acbd,abcd->abd"} ins(%arg0, %arg1 : tensor<2x20x12x64xf32>, tensor<2x12x20x1xf32>) -> tensor<2x12x64xf32>
348348
%1 = tensorrt.reshape %0 : tensor<2x12x64xf32> to tensor<2x1x768xf32>
349349
return %1 : tensor<2x1x768xf32>
350+
}
351+
352+
// -----
353+
354+
// CHECK: @push_reshape_broadcast(%[[arg0:.+]]: tensor<6x64x448xf32>, %[[arg1:.+]]: tensor<6x1x448xf32>)
355+
// CHECK: %[[const:.+]] = tensorrt.constant dense_resource<__elided__> : tensor<1x1x384x384xf32>
356+
// CHECK-DAG: %[[v0:.+]] = tensorrt.expand_rank %[[arg0]] : tensor<6x64x448xf32> to tensor<1x1x6x64x448xf32>
357+
// CHECK-DAG: %[[v1:.+]] = tensorrt.expand_rank %[[arg1]] : tensor<6x1x448xf32> to tensor<1x1x6x1x448xf32>
358+
// CHECK: %[[v2:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[v0]], %[[v1]] : {{.*}})
359+
// CHECK: %[[v3:.+]] = tensorrt.reshape %[[v2]] : tensor<1x1x6x64xf32> to tensor<1x1x384xf32>
360+
// CHECK: %[[v4:.+]] = tensorrt.matrix_multiply {{{.*}}} ins(%[[v3]], %[[const]] : {{.*}}) -> tensor<1x1x384xf32>
361+
// CHECK: return %[[v4]]
362+
func.func @push_reshape_broadcast(%arg0: tensor<6x64x448xf32>, %arg1: tensor<6x1x448xf32>) -> tensor<1x1x384xf32> {
363+
%const = tensorrt.constant dense_resource<__elided__> : tensor<384x6x64xf32>
364+
%1 = tensorrt.einsum {equation = "bdc,bdc->bd"} ins(%arg0, %arg1 : tensor<6x64x448xf32>, tensor<6x1x448xf32>) -> tensor<6x64xf32>
365+
%2 = tensorrt.einsum {equation = "bd,ebd->e"} ins(%1, %const : tensor<6x64xf32>, tensor<384x6x64xf32>) -> tensor<384xf32>
366+
%3 = tensorrt.reshape %2 : tensor<384xf32> to tensor<1x1x384xf32>
367+
return %3 : tensor<1x1x384xf32>
350368
}

0 commit comments

Comments
 (0)