2828#include " mlir/IR/Matchers.h"
2929#include " mlir/Pass/Pass.h"
3030#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
31+ #include " llvm/ADT/StringExtras.h"
3132#include " llvm/Support/Debug.h"
3233#include < numeric>
3334
@@ -567,37 +568,31 @@ class RankChangeToReshape : public OpRewritePattern<OpType> {
567568
568569namespace {
569570struct EinsumEquation {
570- std::string equation;
571- SmallVector<std::string > lhsParts;
572- std::string lhs;
573- std::string rhs;
571+ llvm::StringRef equation;
572+ SmallVector<llvm::SmallString< 128 > > lhsParts;
573+ llvm::SmallString< 128 > lhs;
574+ llvm::SmallString< 128 > rhs;
574575
575576 LogicalResult parse (llvm::StringRef einsumEquation) {
576- std::string e{einsumEquation};
577- return parse (e);
578- }
579-
580- LogicalResult parse (const std::string &einsumEquation) {
581577 size_t pos = einsumEquation.find (" ->" );
582578 if (pos == std::string::npos)
583579 return failure ();
584580 equation = einsumEquation;
585- lhs = einsumEquation.substr (0 , pos);
586- rhs = einsumEquation.substr (pos + 2 );
587- std::istringstream lhsStream (lhs);
588- std::string currentPart;
589- while (std::getline (lhsStream, currentPart, ' ,' )) {
590- lhsParts.push_back (currentPart);
591- }
581+ lhs = equation.substr (0 , pos);
582+ rhs = equation.substr (pos + 2 );
583+ SmallVector<llvm::StringRef> parts;
584+ llvm::SplitString (lhs, parts, " ," );
585+ for (llvm::StringRef part : parts)
586+ lhsParts.push_back (part); // cast from StringRef to SmallString
592587 return success ();
593588 }
594589
595- std::string generateEquation () const {
596- std::string ret = lhsParts[0 ];
590+ StringRef generateEquation () const {
591+ llvm::SmallString< 128 > ret = lhsParts[0 ];
597592 for (size_t i = 1 ; i < lhsParts.size (); i++) {
598- ret += " ," + lhsParts[i];
593+ ret. append ({ " ," , lhsParts[i]}) ;
599594 }
600- ret += " ->" + rhs;
595+ ret. append ({ " ->" , rhs}) ;
601596 return ret;
602597 }
603598};
@@ -653,7 +648,7 @@ class PushDownTransposeToEinsum : public OpRewritePattern<tensorrt::EinsumOp> {
653648 if (!hasTransposeInput)
654649 return failure ();
655650
656- std::string newEinsumEquation = einsumEquation.generateEquation ();
651+ StringRef newEinsumEquation = einsumEquation.generateEquation ();
657652
658653 rewriter.replaceOpWithNewOp <tensorrt::EinsumOp>(op, op.getType (), newInputs,
659654 newEinsumEquation);
@@ -696,7 +691,7 @@ class PushUpTransposeToEinsum : public OpRewritePattern<tensorrt::TransposeOp> {
696691 for (size_t i = 0 ; i < einsumRhs.size (); i++)
697692 einsumEquation.rhs += (char )einsumRhs[i];
698693
699- std::string newEinsumEquation = einsumEquation.generateEquation ();
694+ StringRef newEinsumEquation = einsumEquation.generateEquation ();
700695
701696 auto newEinsum = rewriter.create <tensorrt::EinsumOp>(
702697 op.getLoc (), op.getType (), einsum.getInputs (), newEinsumEquation);
@@ -735,7 +730,7 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
735730 std::sort (outputAxes.begin (), outputAxes.end (),
736731 [&](const std::pair<char , int64_t > &a,
737732 const std::pair<char , int64_t > &b) {
738- for (std::string &eqLhs : equation.lhsParts ) {
733+ for (auto &eqLhs : equation.lhsParts ) {
739734 if (eqLhs.find (a.first ) != std::string::npos) {
740735 if (eqLhs.find (b.first ) != std::string::npos) {
741736 return eqLhs.find (a.first ) < eqLhs.find (b.first );
@@ -751,7 +746,7 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
751746
752747 SmallVector<int64_t > newEinsumShape;
753748 SmallVector<int64_t > outputPerm;
754- std::string newEinsumRhs = " " ;
749+ SmallString< 128 > newEinsumRhs{ " " } ;
755750 for (auto &[c, i] : outputAxes) {
756751 newEinsumRhs += c;
757752 newEinsumShape.push_back (op.getType ().getDimSize (i));
@@ -760,7 +755,7 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
760755 if (newEinsumRhs == equation.rhs )
761756 return failure (); // no change
762757
763- std::string newEinsumEquation = equation.generateEquation ();
758+ StringRef newEinsumEquation = equation.generateEquation ();
764759
765760 auto newEinsum = rewriter.create <tensorrt::EinsumOp>(
766761 op.getLoc (), op.getType ().clone (newEinsumShape), op.getInputs (),
@@ -852,7 +847,7 @@ class EinsumPushUpTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
852847 if (!didChange)
853848 return failure ();
854849
855- std::string newEquation = equation.generateEquation ();
850+ StringRef newEquation = equation.generateEquation ();
856851 rewriter.replaceOpWithNewOp <tensorrt::EinsumOp>(op, op.getType (), newInputs,
857852 newEquation);
858853 return success ();
@@ -928,7 +923,7 @@ class EinsumEliminate1Axis : public OpRewritePattern<tensorrt::EinsumOp> {
928923 newOutputShape.push_back (outputType.getDimSize (i));
929924 }
930925 }
931- std::string newEquation = newEinsumEquation.generateEquation ();
926+ StringRef newEquation = newEinsumEquation.generateEquation ();
932927
933928 if (changeOutput) {
934929 auto newEinsum = rewriter.create <tensorrt::EinsumOp>(
@@ -1004,7 +999,7 @@ class EinsumMergeDown1Axis : public OpRewritePattern<tensorrt::EinsumOp> {
1004999 if (!madeChange)
10051000 return failure ();
10061001
1007- std::string newEquation = equation.generateEquation ();
1002+ StringRef newEquation = equation.generateEquation ();
10081003
10091004 rewriter.replaceOpWithNewOp <tensorrt::EinsumOp>(op, op.getType (), newInputs,
10101005 newEquation);
@@ -1064,7 +1059,7 @@ class EinsumMergeUp1Axis : public OpRewritePattern<tensorrt::ExpandRankOp> {
10641059 }
10651060 }
10661061
1067- std::string newEquation = equation.lhs + " ->" + newRhs;
1062+ SmallString< 128 > newEquation{ equation.lhs , " ->" , newRhs} ;
10681063 rewriter.replaceOpWithNewOp <tensorrt::EinsumOp>(
10691064 op, op.getType (), einsum.getInputs (), newEquation);
10701065 return success ();
@@ -1166,7 +1161,7 @@ class PushReshapeUpThroughEinsum
11661161 // check that all of the inputs are have the right groupping. If this
11671162 // doesn't happen then that means that the reshape can not get pushed
11681163 // through
1169- for (std::string &eqLhs : equation.lhsParts ) {
1164+ for (auto &eqLhs : equation.lhsParts ) {
11701165 for (char c : eqLhs) {
11711166 auto it = charToGroup.find (c);
11721167 if (it == charToGroup.end ())
@@ -1193,7 +1188,7 @@ class PushReshapeUpThroughEinsum
11931188 for (size_t i = 0 ; i < einsum.getInputs ().size (); i++) {
11941189 Value input = einsum.getInputs ()[i];
11951190 auto inputType = cast<RankedTensorType>(input.getType ());
1196- std::string newInputEquation = " " ;
1191+ SmallString< 128 > newInputEquation{ " " } ;
11971192 SmallVector<int64_t > newInputShape;
11981193 SmallVector<int64_t > newInputTranspose;
11991194 for (int j = 0 ; j < inputType.getRank (); j++) {
@@ -1229,7 +1224,7 @@ class PushReshapeUpThroughEinsum
12291224 newEquation.lhsParts .push_back (newInputEquation);
12301225 }
12311226
1232- std::string newEquationStr = newEquation.generateEquation ();
1227+ StringRef newEquationStr = newEquation.generateEquation ();
12331228
12341229 if (has1OutputShape) {
12351230 SmallVector<int64_t > newShape;
@@ -1419,13 +1414,13 @@ class PushReshapeDownThroughEinsum
14191414 }
14201415 }
14211416
1422- for (std::string &part : equation.lhsParts ) {
1417+ for (auto &part : equation.lhsParts ) {
14231418 for (char c : part) {
14241419 auto group = charToGroup.find (c);
14251420 if (group == charToGroup.end ())
14261421 continue ;
14271422 for (char c2 : group->second ) {
1428- if (part.find (c2) == std::string ::npos)
1423+ if (part.find (c2) == StringRef ::npos)
14291424 return failure (
14301425 /* Missing dimensions that need to be reshaped together */ );
14311426 }
@@ -1437,7 +1432,7 @@ class PushReshapeDownThroughEinsum
14371432 if (group == charToGroup.end ())
14381433 continue ;
14391434 for (char c2 : group->second ) {
1440- if (equation.rhs .find (c2) == std::string ::npos)
1435+ if (equation.rhs .find (c2) == StringRef ::npos)
14411436 return failure (
14421437 /* Missing dimensions that need to be reshaped together */ );
14431438 }
@@ -1486,7 +1481,7 @@ class PushReshapeDownThroughEinsum
14861481 RankedTensorType inputType = cast<RankedTensorType>(input.getType ());
14871482 SmallVector<int64_t > newInputShape;
14881483 SmallVector<int64_t > newInputTranspose;
1489- std::string newEinsumStr = " " ;
1484+ SmallString< 128 > newEinsumStr{ " " } ;
14901485 for (int j = 0 ; j < inputType.getRank (); j++) {
14911486 char c = equation.lhsParts [i][j];
14921487 auto it = charToGroup.find (c);
@@ -1504,7 +1499,7 @@ class PushReshapeDownThroughEinsum
15041499 }
15051500 for (char c2 : group->first ) {
15061501 size_t pos = equation.lhsParts [i].find (c2);
1507- assert (pos != std::string ::npos);
1502+ assert (pos != StringRef ::npos);
15081503 newInputTranspose.push_back (pos);
15091504 }
15101505 }
@@ -1555,13 +1550,13 @@ class PushReshapeDownThroughEinsum
15551550 }
15561551 for (char c2 : it->second ) {
15571552 size_t pos = equation.rhs .find (c2);
1558- assert (pos != std::string ::npos);
1553+ assert (pos != StringRef ::npos);
15591554 afterReshapeTranspose.push_back (pos);
15601555 }
15611556 }
15621557 }
15631558
1564- std::string newEinsumEquation = newEquation.generateEquation ();
1559+ StringRef newEinsumEquation = newEquation.generateEquation ();
15651560
15661561 auto newEinsum = rewriter.create <tensorrt::EinsumOp>(
15671562 op.getLoc (), outputType.clone (einsumOutputShape), newInputs,
0 commit comments