@@ -1423,7 +1423,7 @@ class PushReshapeUpThroughEinsum
14231423
14241424 LLVM_DEBUG ({
14251425 std::stringstream out;
1426- out << " ==== Einsum Reshape/Transpose Pushdown Debug ====\n " ;
1426+ out << " ==== Einsum Reshape/Transpose Pushup Debug ====\n " ;
14271427 for (const auto &entry : charToGroup) {
14281428 out << " charToGroup[" << entry.first << " ] = " << entry.second
14291429 << " \n " ;
@@ -1492,8 +1492,16 @@ class PushReshapeUpThroughEinsum
14921492 for (char c : group->second )
14931493 newInputTranspose.push_back (equation.lhsParts [i].find (c));
14941494 newInputEquation += inputToReshapedMap[group->second ].newAxes ;
1495- for (int64_t v : inputToReshapedMap[group->second ].newShape )
1496- newInputShape.push_back (v);
1495+ for (int64_t v : inputToReshapedMap[group->second ].newShape ) {
1496+ if (v != 1 && group->second .size () == 1 &&
1497+ inputType.getDimSize (j) == 1 ) {
1498+ // if the group is of size 1, then it can have different sizes for
1499+ // each input due to broadcasting
1500+ newInputShape.push_back (1 );
1501+ } else {
1502+ newInputShape.push_back (v);
1503+ }
1504+ }
14971505 }
14981506 }
14991507
@@ -1516,6 +1524,13 @@ class PushReshapeUpThroughEinsum
15161524 out << " , " ;
15171525 }
15181526 out << " ]\n " ;
1527+ out << " oldShape: [" ;
1528+ for (size_t si = 0 ; si < inputType.getShape ().size (); ++si) {
1529+ out << inputType.getShape ()[si];
1530+ if (si + 1 < inputType.getShape ().size ())
1531+ out << " , " ;
1532+ }
1533+ out << " ]\n " ;
15191534 DBGS () << out.str () << " \n " ;
15201535 });
15211536
@@ -1529,11 +1544,13 @@ class PushReshapeUpThroughEinsum
15291544 newInputs.push_back (newReshape);
15301545 newEquation.lhsParts .push_back (newInputEquation);
15311546 }
1532- LLVM_DEBUG (
1533- { DBGS () << " ===============================================\n " ; });
1534-
15351547 std::string newEquationStr = newEquation.generateEquation ();
15361548
1549+ LLVM_DEBUG ({
1550+ DBGS () << newEquationStr << " \n "
1551+ << " ===============================================\n " ;
1552+ });
1553+
15371554 auto newEinsum = rewriter.create <tensorrt::EinsumOp>(
15381555 einsum.getLoc (), op.getType (), newInputs, newEquationStr);
15391556 assert (newEquation.rhs .size () == newEinsum.getType ().getShape ().size ());
@@ -3056,7 +3073,9 @@ class TransposeReshapeEliminationPass
30563073 PushUpReshapeUnary<UnaryOp>, PushUpReshapeUnary<ActivationOp>,
30573074 PushUpOpQuantizeDequantize<tensorrt::TransposeOp>,
30583075 PushUpOpQuantizeDequantize<tensorrt::ReshapeOp>,
3076+
30593077 PushReshapeUpThroughEinsum, PushUpReshapeElementwise,
3078+
30603079 PushUpTransposeSoftmax, PushUpReshapeSoftmax,
30613080 SimpleTransposeToReshape>(ctx, PatternBenefit (2 ));
30623081 patterns.insert <EinsumPushUpTranspose>(ctx, PatternBenefit (1 ));
0 commit comments