66
77#include " Combiner.h"
88
9+ #include " ../Expression.h"
910#include " ../InstructionWalker.h"
1011#include " ../analysis/MemoryAnalysis.h"
1112#include " ../intermediate/Helper.h"
1213#include " ../intermediate/operators.h"
13- #include " ../Expression.h"
1414#include " ../periphery/VPM.h"
1515#include " ../spirv/SPIRVHelper.h"
1616#include " Eliminator.h"
@@ -1221,8 +1221,20 @@ Optional<int> getIntegerFromExpression(const SubExpression& expr)
12211221 return Optional<int >();
12221222}
12231223
1224- // signed, value
1225- using ExpandedExprs = std::vector<std::pair<bool , SubExpression>>;
1224+ // signed, value
1225+ class ExpandedExprs : public std ::vector<std::pair<bool , SubExpression>>
1226+ {
1227+ public:
1228+ std::string to_string () const
1229+ {
1230+ std::stringstream ss;
1231+ for (auto & p : *this )
1232+ {
1233+ ss << (p.first ? " +" : " -" ) << p.second .to_string ();
1234+ }
1235+ return ss.str ();
1236+ }
1237+ };
12261238
12271239void expandExpression (const SubExpression& subExpr, ExpandedExprs& expanded)
12281240{
@@ -1320,7 +1332,8 @@ void expandExpression(const SubExpression& subExpr, ExpandedExprs& expanded)
13201332 }
13211333 else
13221334 {
1323- expanded.push_back (std::make_pair (true , SubExpression (std::make_shared<Expression>(op, left, right))));
1335+ expanded.push_back (
1336+ std::make_pair (true , SubExpression (std::make_shared<Expression>(op, left, right))));
13241337 }
13251338 }
13261339 else if (op == Expression::FAKEOP_DIV)
@@ -1337,15 +1350,16 @@ void expandExpression(const SubExpression& subExpr, ExpandedExprs& expanded)
13371350 {
13381351 expanded.push_back (std::make_pair (true , subExpr));
13391352 }
1340- else {
1353+ else
1354+ {
13411355 throw CompilationError (CompilationStep::OPTIMIZER, " Cannot expand expression" , subExpr.to_string ());
13421356 }
13431357}
13441358
1345- SubExpression calcValueExpr (const SubExpression& expr )
1359+ void calcValueExpr (ExpandedExprs& expanded )
13461360{
1347- ExpandedExprs expanded;
1348- expandExpression (expr, expanded);
1361+ // ExpandedExprs expanded;
1362+ // expandExpression(expr, expanded);
13491363
13501364 // for(auto& p : expanded)
13511365 // logging::debug() << (p.first ? "+" : "-") << p.second->to_string() << " ";
@@ -1367,13 +1381,13 @@ SubExpression calcValueExpr(const SubExpression& expr)
13671381 }
13681382 }
13691383
1370- SubExpression result (INT_ZERO);
1371- for (auto & p : expanded)
1372- {
1373- result = SubExpression (std::make_shared<Expression>(p.first ? OP_ADD : OP_SUB, result, p.second ));
1374- }
1375-
1376- return result;
1384+ // SubExpression result(INT_ZERO);
1385+ // for(auto& p : expanded)
1386+ // {
1387+ // result = SubExpression(std::make_shared<Expression>(p.first ? OP_ADD : OP_SUB, result, p.second));
1388+ // }
1389+ //
1390+ // return result;
13771391}
13781392
13791393SubExpression replaceLocalToExpr (const SubExpression& expr, const Value& local, SubExpression newExpr)
@@ -1478,44 +1492,61 @@ void optimizations::combineDMALoads(const Module& module, Method& method, const
14781492 logging::debug () << pair.first .to_string () << " = " << pair.second .to_string () << logging::endl;
14791493 }
14801494
1481- SubExpression diff;
1495+ ExpandedExprs diff;
14821496 bool eqDiff = true ;
14831497 for (size_t i = 1 ; i < addrExprs.size (); i++)
14841498 {
14851499 auto x = addrExprs[i - 1 ].second ;
14861500 auto y = addrExprs[i].second ;
14871501 auto diffExpr = SubExpression (std::make_shared<Expression>(OP_SUB, y, x));
14881502
1489- auto currentDiff = calcValueExpr (diffExpr);
1503+ ExpandedExprs currentDiff;
1504+ expandExpression (diffExpr, currentDiff);
1505+
1506+ calcValueExpr (currentDiff);
1507+
14901508 // Apply calcValueExpr again for integer literals.
1491- currentDiff = calcValueExpr (currentDiff);
1509+ SubExpression currentExpr (INT_ZERO);
1510+ for (auto & p : currentDifft)
1511+ {
1512+ currentExpr =
1513+ SubExpression (std::make_shared<Expression>(p.first ? OP_ADD : OP_SUB, currentExpr, p.second ));
1514+ }
1515+ currentDiff.clear ();
1516+ expandExpression (currentExpr, currentDiff);
1517+ calcValueExpr (currentDiff);
1518+
1519+ // logging::debug() << currentDiff.to_string() << ", " << diff.to_string() << logging::endl;
14921520
1493- if (!diff )
1521+ if (i == 1 )
14941522 {
1495- diff = currentDiff;
1523+ diff = std::move ( currentDiff) ;
14961524 }
1497- if (currentDiff != diff)
1525+ else if (currentDiff != diff)
14981526 {
14991527 eqDiff = false ;
15001528 break ;
15011529 }
15021530 }
15031531
1504- logging::debug () << addrExprs.size () << " loads are " << (eqDiff ? " " : " not " ) << " equal difference "
1505- << logging::endl;
1532+ logging::debug () << addrExprs.size () << " loads are " << (eqDiff ? " " : " not " )
1533+ << " equal difference: " << diff. to_string () << logging::endl;
15061534
15071535 if (eqDiff)
15081536 {
15091537 // The form of diff should be "0 (+/-) expressions...", then remove the value 0 at most right.
1510- ExpandedExprs expanded;
1511- expandExpression (diff, expanded);
1512- if (expanded.size () == 1 )
1538+ // ExpandedExprs expanded;
1539+ // expandExpression(diff, expanded);
1540+ // for (auto& ex : expanded) {
1541+ // logging::debug() << "ex = " << ex.second.to_string() << logging::endl;
1542+ // }
1543+ if (diff.size () == 1 )
15131544 {
1514- diff = expanded [0 ].second ;
1545+ auto diffExpr = diff [0 ].second ;
15151546
1516- // logging::debug() << "diff = " << diff-> to_string() << logging::endl;
1547+ // logging::debug() << "diff = " << diff. to_string() << logging::endl;
15171548
1518- auto term = diff .getConstantExpression ();
1549+ auto term = diffExpr .getConstantExpression ();
15191550 auto mpValue = term.has_value () ? term->getConstantValue () : Optional<Value>{};
15201551 auto mpLiteral = mpValue.has_value () ? mpValue->getLiteralValue () : Optional<Literal>{};
15211552
@@ -1554,12 +1585,11 @@ void optimizations::combineDMALoads(const Module& module, Method& method, const
15541585 elemType.getInMemoryWidth () * DataType::BYTE, vectorLength, false };
15551586
15561587 uint64_t rows = loadInstrs.size ();
1557- VPMArea area (VPMUsage::SCRATCH, 0 , static_cast <uint8_t >(rows));
15581588 auto entries = Value (Literal (static_cast <uint32_t >(rows)), TYPE_INT32);
1559- it =
1560- method.vpm ->insertReadRAM (method, it, addr, VectorType, /* &area */ nullptr ,
1561- true , INT_ZERO, entries, Optional<uint16_t >(memoryPitch));
1589+ it = method.vpm ->insertReadRAM (method, it, addr, VectorType, nullptr , true ,
1590+ INT_ZERO, entries, Optional<uint16_t >(memoryPitch));
15621591
1592+ VPMArea area (VPMUsage::SCRATCH, 0 , static_cast <uint8_t >(rows));
15631593 it = method.vpm ->insertReadVPM (method, it, output, &area, true );
15641594 }
15651595 else
0 commit comments