Skip to content

Commit 6237397

Browse files
Fix calcValueExpr to take ExpandedExprs
1 parent 90c50ca commit 6237397

File tree

1 file changed

+63
-33
lines changed

1 file changed

+63
-33
lines changed

src/optimization/Combiner.cpp

Lines changed: 63 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
#include "Combiner.h"
88

9+
#include "../Expression.h"
910
#include "../InstructionWalker.h"
1011
#include "../analysis/MemoryAnalysis.h"
1112
#include "../intermediate/Helper.h"
1213
#include "../intermediate/operators.h"
13-
#include "../Expression.h"
1414
#include "../periphery/VPM.h"
1515
#include "../spirv/SPIRVHelper.h"
1616
#include "Eliminator.h"
@@ -1221,8 +1221,20 @@ Optional<int> getIntegerFromExpression(const SubExpression& expr)
12211221
return Optional<int>();
12221222
}
12231223

1224-
// signed, value
1225-
using ExpandedExprs = std::vector<std::pair<bool, SubExpression>>;
1224+
// signed, value
1225+
class ExpandedExprs : public std::vector<std::pair<bool, SubExpression>>
1226+
{
1227+
public:
1228+
std::string to_string() const
1229+
{
1230+
std::stringstream ss;
1231+
for(auto& p : *this)
1232+
{
1233+
ss << (p.first ? "+" : "-") << p.second.to_string();
1234+
}
1235+
return ss.str();
1236+
}
1237+
};
12261238

12271239
void expandExpression(const SubExpression& subExpr, ExpandedExprs& expanded)
12281240
{
@@ -1320,7 +1332,8 @@ void expandExpression(const SubExpression& subExpr, ExpandedExprs& expanded)
13201332
}
13211333
else
13221334
{
1323-
expanded.push_back(std::make_pair(true, SubExpression(std::make_shared<Expression>(op, left, right))));
1335+
expanded.push_back(
1336+
std::make_pair(true, SubExpression(std::make_shared<Expression>(op, left, right))));
13241337
}
13251338
}
13261339
else if(op == Expression::FAKEOP_DIV)
@@ -1337,15 +1350,16 @@ void expandExpression(const SubExpression& subExpr, ExpandedExprs& expanded)
13371350
{
13381351
expanded.push_back(std::make_pair(true, subExpr));
13391352
}
1340-
else {
1353+
else
1354+
{
13411355
throw CompilationError(CompilationStep::OPTIMIZER, "Cannot expand expression", subExpr.to_string());
13421356
}
13431357
}
13441358

1345-
SubExpression calcValueExpr(const SubExpression& expr)
1359+
void calcValueExpr(ExpandedExprs& expanded)
13461360
{
1347-
ExpandedExprs expanded;
1348-
expandExpression(expr, expanded);
1361+
// ExpandedExprs expanded;
1362+
// expandExpression(expr, expanded);
13491363

13501364
// for(auto& p : expanded)
13511365
// logging::debug() << (p.first ? "+" : "-") << p.second->to_string() << " ";
@@ -1367,13 +1381,13 @@ SubExpression calcValueExpr(const SubExpression& expr)
13671381
}
13681382
}
13691383

1370-
SubExpression result(INT_ZERO);
1371-
for(auto& p : expanded)
1372-
{
1373-
result = SubExpression(std::make_shared<Expression>(p.first ? OP_ADD : OP_SUB, result, p.second));
1374-
}
1375-
1376-
return result;
1384+
// SubExpression result(INT_ZERO);
1385+
// for(auto& p : expanded)
1386+
// {
1387+
// result = SubExpression(std::make_shared<Expression>(p.first ? OP_ADD : OP_SUB, result, p.second));
1388+
// }
1389+
//
1390+
// return result;
13771391
}
13781392

13791393
SubExpression replaceLocalToExpr(const SubExpression& expr, const Value& local, SubExpression newExpr)
@@ -1478,44 +1492,61 @@ void optimizations::combineDMALoads(const Module& module, Method& method, const
14781492
logging::debug() << pair.first.to_string() << " = " << pair.second.to_string() << logging::endl;
14791493
}
14801494

1481-
SubExpression diff;
1495+
ExpandedExprs diff;
14821496
bool eqDiff = true;
14831497
for(size_t i = 1; i < addrExprs.size(); i++)
14841498
{
14851499
auto x = addrExprs[i - 1].second;
14861500
auto y = addrExprs[i].second;
14871501
auto diffExpr = SubExpression(std::make_shared<Expression>(OP_SUB, y, x));
14881502

1489-
auto currentDiff = calcValueExpr(diffExpr);
1503+
ExpandedExprs currentDiff;
1504+
expandExpression(diffExpr, currentDiff);
1505+
1506+
calcValueExpr(currentDiff);
1507+
14901508
// Apply calcValueExpr again for integer literals.
1491-
currentDiff = calcValueExpr(currentDiff);
1509+
SubExpression currentExpr(INT_ZERO);
1510+
for(auto& p : currentDifft)
1511+
{
1512+
currentExpr =
1513+
SubExpression(std::make_shared<Expression>(p.first ? OP_ADD : OP_SUB, currentExpr, p.second));
1514+
}
1515+
currentDiff.clear();
1516+
expandExpression(currentExpr, currentDiff);
1517+
calcValueExpr(currentDiff);
1518+
1519+
// logging::debug() << currentDiff.to_string() << ", " << diff.to_string() << logging::endl;
14921520

1493-
if(!diff)
1521+
if(i == 1)
14941522
{
1495-
diff = currentDiff;
1523+
diff = std::move(currentDiff);
14961524
}
1497-
if(currentDiff != diff)
1525+
else if(currentDiff != diff)
14981526
{
14991527
eqDiff = false;
15001528
break;
15011529
}
15021530
}
15031531

1504-
logging::debug() << addrExprs.size() << " loads are " << (eqDiff ? "" : "not ") << "equal difference"
1505-
<< logging::endl;
1532+
logging::debug() << addrExprs.size() << " loads are " << (eqDiff ? "" : "not ")
1533+
<< "equal difference: " << diff.to_string() << logging::endl;
15061534

15071535
if(eqDiff)
15081536
{
15091537
// The form of diff should be "0 (+/-) expressions...", then remove the value 0 at most right.
1510-
ExpandedExprs expanded;
1511-
expandExpression(diff, expanded);
1512-
if(expanded.size() == 1)
1538+
// ExpandedExprs expanded;
1539+
// expandExpression(diff, expanded);
1540+
// for (auto& ex : expanded) {
1541+
// logging::debug() << "ex = " << ex.second.to_string() << logging::endl;
1542+
// }
1543+
if(diff.size() == 1)
15131544
{
1514-
diff = expanded[0].second;
1545+
auto diffExpr = diff[0].second;
15151546

1516-
// logging::debug() << "diff = " << diff->to_string() << logging::endl;
1547+
// logging::debug() << "diff = " << diff.to_string() << logging::endl;
15171548

1518-
auto term = diff.getConstantExpression();
1549+
auto term = diffExpr.getConstantExpression();
15191550
auto mpValue = term.has_value() ? term->getConstantValue() : Optional<Value>{};
15201551
auto mpLiteral = mpValue.has_value() ? mpValue->getLiteralValue() : Optional<Literal>{};
15211552

@@ -1554,12 +1585,11 @@ void optimizations::combineDMALoads(const Module& module, Method& method, const
15541585
elemType.getInMemoryWidth() * DataType::BYTE, vectorLength, false};
15551586

15561587
uint64_t rows = loadInstrs.size();
1557-
VPMArea area(VPMUsage::SCRATCH, 0, static_cast<uint8_t>(rows));
15581588
auto entries = Value(Literal(static_cast<uint32_t>(rows)), TYPE_INT32);
1559-
it =
1560-
method.vpm->insertReadRAM(method, it, addr, VectorType, /* &area */ nullptr,
1561-
true, INT_ZERO, entries, Optional<uint16_t>(memoryPitch));
1589+
it = method.vpm->insertReadRAM(method, it, addr, VectorType, nullptr, true,
1590+
INT_ZERO, entries, Optional<uint16_t>(memoryPitch));
15621591

1592+
VPMArea area(VPMUsage::SCRATCH, 0, static_cast<uint8_t>(rows));
15631593
it = method.vpm->insertReadVPM(method, it, output, &area, true);
15641594
}
15651595
else

0 commit comments

Comments
 (0)