Skip to content

Commit 353322b

Browse files
author
Matthew Francis-Landau
committed
change to using llvm::SplitString for splitting the string
1 parent 725a3f1 commit 353322b

File tree

1 file changed

+34
-39
lines changed

1 file changed

+34
-39
lines changed

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "mlir/IR/Matchers.h"
2929
#include "mlir/Pass/Pass.h"
3030
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31+
#include "llvm/ADT/StringExtras.h"
3132
#include "llvm/Support/Debug.h"
3233
#include <numeric>
3334

@@ -567,37 +568,31 @@ class RankChangeToReshape : public OpRewritePattern<OpType> {
567568

568569
namespace {
569570
struct EinsumEquation {
570-
std::string equation;
571-
SmallVector<std::string> lhsParts;
572-
std::string lhs;
573-
std::string rhs;
571+
llvm::StringRef equation;
572+
SmallVector<llvm::SmallString<128>> lhsParts;
573+
llvm::SmallString<128> lhs;
574+
llvm::SmallString<128> rhs;
574575

575576
LogicalResult parse(llvm::StringRef einsumEquation) {
576-
std::string e{einsumEquation};
577-
return parse(e);
578-
}
579-
580-
LogicalResult parse(const std::string &einsumEquation) {
581577
size_t pos = einsumEquation.find("->");
582578
if (pos == std::string::npos)
583579
return failure();
584580
equation = einsumEquation;
585-
lhs = einsumEquation.substr(0, pos);
586-
rhs = einsumEquation.substr(pos + 2);
587-
std::istringstream lhsStream(lhs);
588-
std::string currentPart;
589-
while (std::getline(lhsStream, currentPart, ',')) {
590-
lhsParts.push_back(currentPart);
591-
}
581+
lhs = equation.substr(0, pos);
582+
rhs = equation.substr(pos + 2);
583+
SmallVector<llvm::StringRef> parts;
584+
llvm::SplitString(lhs, parts, ",");
585+
for (llvm::StringRef part : parts)
586+
lhsParts.push_back(part); // cast from StringRef to SmallString
592587
return success();
593588
}
594589

595-
std::string generateEquation() const {
596-
std::string ret = lhsParts[0];
590+
StringRef generateEquation() const {
591+
llvm::SmallString<128> ret = lhsParts[0];
597592
for (size_t i = 1; i < lhsParts.size(); i++) {
598-
ret += "," + lhsParts[i];
593+
ret.append({",", lhsParts[i]});
599594
}
600-
ret += "->" + rhs;
595+
ret.append({"->", rhs});
601596
return ret;
602597
}
603598
};
@@ -653,7 +648,7 @@ class PushDownTransposeToEinsum : public OpRewritePattern<tensorrt::EinsumOp> {
653648
if (!hasTransposeInput)
654649
return failure();
655650

656-
std::string newEinsumEquation = einsumEquation.generateEquation();
651+
StringRef newEinsumEquation = einsumEquation.generateEquation();
657652

658653
rewriter.replaceOpWithNewOp<tensorrt::EinsumOp>(op, op.getType(), newInputs,
659654
newEinsumEquation);
@@ -696,7 +691,7 @@ class PushUpTransposeToEinsum : public OpRewritePattern<tensorrt::TransposeOp> {
696691
for (size_t i = 0; i < einsumRhs.size(); i++)
697692
einsumEquation.rhs += (char)einsumRhs[i];
698693

699-
std::string newEinsumEquation = einsumEquation.generateEquation();
694+
StringRef newEinsumEquation = einsumEquation.generateEquation();
700695

701696
auto newEinsum = rewriter.create<tensorrt::EinsumOp>(
702697
op.getLoc(), op.getType(), einsum.getInputs(), newEinsumEquation);
@@ -735,7 +730,7 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
735730
std::sort(outputAxes.begin(), outputAxes.end(),
736731
[&](const std::pair<char, int64_t> &a,
737732
const std::pair<char, int64_t> &b) {
738-
for (std::string &eqLhs : equation.lhsParts) {
733+
for (auto &eqLhs : equation.lhsParts) {
739734
if (eqLhs.find(a.first) != std::string::npos) {
740735
if (eqLhs.find(b.first) != std::string::npos) {
741736
return eqLhs.find(a.first) < eqLhs.find(b.first);
@@ -751,7 +746,7 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
751746

752747
SmallVector<int64_t> newEinsumShape;
753748
SmallVector<int64_t> outputPerm;
754-
std::string newEinsumRhs = "";
749+
SmallString<128> newEinsumRhs{""};
755750
for (auto &[c, i] : outputAxes) {
756751
newEinsumRhs += c;
757752
newEinsumShape.push_back(op.getType().getDimSize(i));
@@ -760,7 +755,7 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
760755
if (newEinsumRhs == equation.rhs)
761756
return failure(); // no change
762757

763-
std::string newEinsumEquation = equation.generateEquation();
758+
StringRef newEinsumEquation = equation.generateEquation();
764759

765760
auto newEinsum = rewriter.create<tensorrt::EinsumOp>(
766761
op.getLoc(), op.getType().clone(newEinsumShape), op.getInputs(),
@@ -852,7 +847,7 @@ class EinsumPushUpTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
852847
if (!didChange)
853848
return failure();
854849

855-
std::string newEquation = equation.generateEquation();
850+
StringRef newEquation = equation.generateEquation();
856851
rewriter.replaceOpWithNewOp<tensorrt::EinsumOp>(op, op.getType(), newInputs,
857852
newEquation);
858853
return success();
@@ -928,7 +923,7 @@ class EinsumEliminate1Axis : public OpRewritePattern<tensorrt::EinsumOp> {
928923
newOutputShape.push_back(outputType.getDimSize(i));
929924
}
930925
}
931-
std::string newEquation = newEinsumEquation.generateEquation();
926+
StringRef newEquation = newEinsumEquation.generateEquation();
932927

933928
if (changeOutput) {
934929
auto newEinsum = rewriter.create<tensorrt::EinsumOp>(
@@ -1004,7 +999,7 @@ class EinsumMergeDown1Axis : public OpRewritePattern<tensorrt::EinsumOp> {
1004999
if (!madeChange)
10051000
return failure();
10061001

1007-
std::string newEquation = equation.generateEquation();
1002+
StringRef newEquation = equation.generateEquation();
10081003

10091004
rewriter.replaceOpWithNewOp<tensorrt::EinsumOp>(op, op.getType(), newInputs,
10101005
newEquation);
@@ -1064,7 +1059,7 @@ class EinsumMergeUp1Axis : public OpRewritePattern<tensorrt::ExpandRankOp> {
10641059
}
10651060
}
10661061

1067-
std::string newEquation = equation.lhs + "->" + newRhs;
1062+
SmallString<128> newEquation{equation.lhs, "->", newRhs};
10681063
rewriter.replaceOpWithNewOp<tensorrt::EinsumOp>(
10691064
op, op.getType(), einsum.getInputs(), newEquation);
10701065
return success();
@@ -1166,7 +1161,7 @@ class PushReshapeUpThroughEinsum
11661161
// check that all of the inputs are have the right groupping. If this
11671162
// doesn't happen then that means that the reshape can not get pushed
11681163
// through
1169-
for (std::string &eqLhs : equation.lhsParts) {
1164+
for (auto &eqLhs : equation.lhsParts) {
11701165
for (char c : eqLhs) {
11711166
auto it = charToGroup.find(c);
11721167
if (it == charToGroup.end())
@@ -1193,7 +1188,7 @@ class PushReshapeUpThroughEinsum
11931188
for (size_t i = 0; i < einsum.getInputs().size(); i++) {
11941189
Value input = einsum.getInputs()[i];
11951190
auto inputType = cast<RankedTensorType>(input.getType());
1196-
std::string newInputEquation = "";
1191+
SmallString<128> newInputEquation{""};
11971192
SmallVector<int64_t> newInputShape;
11981193
SmallVector<int64_t> newInputTranspose;
11991194
for (int j = 0; j < inputType.getRank(); j++) {
@@ -1229,7 +1224,7 @@ class PushReshapeUpThroughEinsum
12291224
newEquation.lhsParts.push_back(newInputEquation);
12301225
}
12311226

1232-
std::string newEquationStr = newEquation.generateEquation();
1227+
StringRef newEquationStr = newEquation.generateEquation();
12331228

12341229
if (has1OutputShape) {
12351230
SmallVector<int64_t> newShape;
@@ -1419,13 +1414,13 @@ class PushReshapeDownThroughEinsum
14191414
}
14201415
}
14211416

1422-
for (std::string &part : equation.lhsParts) {
1417+
for (auto &part : equation.lhsParts) {
14231418
for (char c : part) {
14241419
auto group = charToGroup.find(c);
14251420
if (group == charToGroup.end())
14261421
continue;
14271422
for (char c2 : group->second) {
1428-
if (part.find(c2) == std::string::npos)
1423+
if (part.find(c2) == StringRef::npos)
14291424
return failure(
14301425
/* Missing dimensions that need to be reshaped together */);
14311426
}
@@ -1437,7 +1432,7 @@ class PushReshapeDownThroughEinsum
14371432
if (group == charToGroup.end())
14381433
continue;
14391434
for (char c2 : group->second) {
1440-
if (equation.rhs.find(c2) == std::string::npos)
1435+
if (equation.rhs.find(c2) == StringRef::npos)
14411436
return failure(
14421437
/* Missing dimensions that need to be reshaped together */);
14431438
}
@@ -1486,7 +1481,7 @@ class PushReshapeDownThroughEinsum
14861481
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
14871482
SmallVector<int64_t> newInputShape;
14881483
SmallVector<int64_t> newInputTranspose;
1489-
std::string newEinsumStr = "";
1484+
SmallString<128> newEinsumStr{""};
14901485
for (int j = 0; j < inputType.getRank(); j++) {
14911486
char c = equation.lhsParts[i][j];
14921487
auto it = charToGroup.find(c);
@@ -1504,7 +1499,7 @@ class PushReshapeDownThroughEinsum
15041499
}
15051500
for (char c2 : group->first) {
15061501
size_t pos = equation.lhsParts[i].find(c2);
1507-
assert(pos != std::string::npos);
1502+
assert(pos != StringRef::npos);
15081503
newInputTranspose.push_back(pos);
15091504
}
15101505
}
@@ -1555,13 +1550,13 @@ class PushReshapeDownThroughEinsum
15551550
}
15561551
for (char c2 : it->second) {
15571552
size_t pos = equation.rhs.find(c2);
1558-
assert(pos != std::string::npos);
1553+
assert(pos != StringRef::npos);
15591554
afterReshapeTranspose.push_back(pos);
15601555
}
15611556
}
15621557
}
15631558

1564-
std::string newEinsumEquation = newEquation.generateEquation();
1559+
StringRef newEinsumEquation = newEquation.generateEquation();
15651560

15661561
auto newEinsum = rewriter.create<tensorrt::EinsumOp>(
15671562
op.getLoc(), outputType.clone(einsumOutputShape), newInputs,

0 commit comments

Comments
 (0)