1010#include " ../analysis/MemoryAnalysis.h"
1111#include " ../intermediate/Helper.h"
1212#include " ../intermediate/operators.h"
13- #include " ../optimization/ValueExpr .h"
13+ #include " ../Expression .h"
1414#include " ../periphery/VPM.h"
1515#include " ../spirv/SPIRVHelper.h"
1616#include " Eliminator.h"
@@ -1125,8 +1125,13 @@ InstructionWalker optimizations::combineArithmeticOperations(
11251125 return it;
11261126}
11271127
1128+ SubExpression makeValueBinaryOpFromLocal (Value& left, const OpCode& binOp, Value& right)
1129+ {
1130+ return SubExpression (std::make_shared<Expression>(binOp, SubExpression (left), SubExpression (right)));
1131+ }
1132+
11281133// try to convert shl to mul and return it as ValueExpr
1129- std::shared_ptr<ValueExpr> shlToMul (Value& value, const intermediate::Operation* op)
1134+ SubExpression shlToMul (const Value& value, const intermediate::Operation* op)
11301135{
11311136 auto left = op->getFirstArg ();
11321137 auto right = *op->getSecondArg ();
@@ -1143,29 +1148,24 @@ std::shared_ptr<ValueExpr> shlToMul(Value& value, const intermediate::Operation*
11431148 if (shiftValue > 0 )
11441149 {
11451150 auto right = Value (Literal (1 << shiftValue), TYPE_INT32);
1146- return makeValueBinaryOpFromLocal (left, ValueBinaryOp::BinaryOp::Mul , right);
1151+ return makeValueBinaryOpFromLocal (left, OP_FMUL , right);
11471152 }
11481153 else
11491154 {
1150- return std::make_shared<ValueTerm> (value);
1155+ return SubExpression (value);
11511156 }
11521157}
11531158
1154- std::shared_ptr<ValueExpr> iiToExpr (Value& value, const LocalUser* inst)
1159+ SubExpression iiToExpr (const Value& value, const LocalUser* inst)
11551160{
1156- using BO = ValueBinaryOp::BinaryOp;
1157- BO binOp = BO::Other;
1158-
11591161 // add, sub, shr, shl, asr
11601162 if (auto op = dynamic_cast <const intermediate::Operation*>(inst))
11611163 {
1162- if (op->op == OP_ADD)
1163- {
1164- binOp = BO::Add;
1165- }
1166- else if (op->op == OP_SUB)
1164+ if (op->op == OP_ADD || op->op == OP_SUB)
11671165 {
1168- binOp = BO::Sub;
1166+ auto left = op->getFirstArg ();
1167+ auto right = *op->getSecondArg ();
1168+ return makeValueBinaryOpFromLocal (left, op->op , right);
11691169 }
11701170 else if (op->op == OP_SHL)
11711171 {
@@ -1176,55 +1176,186 @@ std::shared_ptr<ValueExpr> iiToExpr(Value& value, const LocalUser* inst)
11761176 else
11771177 {
11781178 // If op is neither add nor sub, return value as-is.
1179- return std::make_shared<ValueTerm> (value);
1179+ return SubExpression (value);
11801180 }
1181-
1182- auto left = op->getFirstArg ();
1183- auto right = *op->getSecondArg ();
1184- return makeValueBinaryOpFromLocal (left, binOp, right);
11851181 }
11861182 // mul, div
11871183 else if (auto op = dynamic_cast <const intermediate::IntrinsicOperation*>(inst))
11881184 {
1185+ OpCode binOp = OP_NOP;
11891186 if (op->opCode == " mul" )
11901187 {
1191- binOp = BO::Mul ;
1188+ binOp = Expression::FAKEOP_MUL ;
11921189 }
11931190 else if (op->opCode == " div" )
11941191 {
1195- binOp = BO::Div ;
1192+ binOp = Expression::FAKEOP_DIV ;
11961193 }
11971194 else
11981195 {
11991196 // If op is neither add nor sub, return value as-is.
1200- return std::make_shared<ValueTerm> (value);
1197+ return SubExpression (value);
12011198 }
12021199
12031200 auto left = op->getFirstArg ();
12041201 auto right = *op->getSecondArg ();
12051202 return makeValueBinaryOpFromLocal (left, binOp, right);
12061203 }
12071204
1208- return std::make_shared<ValueTerm> (value);
1205+ return SubExpression (value);
12091206}
12101207
1211- std::shared_ptr<ValueExpr> calcValueExpr (std::shared_ptr<ValueExpr> expr)
1208+ Optional< int > getIntegerFromExpression ( const SubExpression& expr)
12121209{
1213- using BO = ValueBinaryOp::BinaryOp;
1210+ if (auto value = expr.checkValue ())
1211+ {
1212+ if (auto lit = value->checkLiteral ())
1213+ {
1214+ return Optional<int >(lit->signedInt ());
1215+ }
1216+ else if (auto imm = value->checkImmediate ())
1217+ {
1218+ return imm->getIntegerValue ();
1219+ }
1220+ }
1221+ return Optional<int >();
1222+ }
1223+
1224+ // signed, value
1225+ using ExpandedExprs = std::vector<std::pair<bool , SubExpression>>;
12141226
1215- ValueExpr::ExpandedExprs expanded;
1216- expr->expand (expanded);
1227+ void expandExpression (const SubExpression& subExpr, ExpandedExprs& expanded)
1228+ {
1229+ if (auto expr = subExpr.checkExpression ())
1230+ {
1231+ ExpandedExprs leftEE, rightEE;
1232+ auto & left = expr->arg0 ;
1233+ auto & right = expr->arg1 ;
1234+ auto & op = expr->code ;
1235+
1236+ expandExpression (left, leftEE);
1237+ expandExpression (right, rightEE);
1238+
1239+ auto getInteger = [](const std::pair<bool , SubExpression>& v) {
1240+ std::function<Optional<int >(const int &)> addSign = [&](const int & num) {
1241+ return make_optional (v.first ? num : -num);
1242+ };
1243+ return getIntegerFromExpression (v.second ) & addSign;
1244+ };
1245+
1246+ auto leftNum = (leftEE.size () == 1 ) ? getInteger (leftEE[0 ]) : Optional<int >();
1247+ auto rightNum = (rightEE.size () == 1 ) ? getInteger (rightEE[0 ]) : Optional<int >();
1248+
1249+ auto append = [](ExpandedExprs& ee1, ExpandedExprs& ee2) { ee1.insert (ee1.end (), ee2.begin (), ee2.end ()); };
1250+
1251+ if (leftNum && rightNum)
1252+ {
1253+ int l = leftNum.value_or (0 );
1254+ int r = rightNum.value_or (0 );
1255+ int num = 0 ;
1256+
1257+ if (op == OP_ADD)
1258+ {
1259+ num = l + r;
1260+ }
1261+ else if (op == OP_SUB)
1262+ {
1263+ num = l - r;
1264+ }
1265+ else if (op == Expression::FAKEOP_MUL)
1266+ {
1267+ num = l * r;
1268+ }
1269+ else if (op == Expression::FAKEOP_DIV)
1270+ {
1271+ num = l / r;
1272+ }
1273+ else
1274+ {
1275+ throw CompilationError (CompilationStep::OPTIMIZER, " Unknown operation" , op.name );
1276+ }
1277+
1278+ // TODO: Care other types
1279+ auto value = Value (Literal (std::abs (num)), TYPE_INT32);
1280+ SubExpression foldedExpr (value);
1281+ expanded.push_back (std::make_pair (true , foldedExpr));
1282+ }
1283+ else
1284+ {
1285+ if (op == OP_ADD)
1286+ {
1287+ append (expanded, leftEE);
1288+ append (expanded, rightEE);
1289+ }
1290+ else if (op == OP_SUB)
1291+ {
1292+ append (expanded, leftEE);
1293+
1294+ for (auto & e : rightEE)
1295+ {
1296+ e.first = !e.first ;
1297+ }
1298+ append (expanded, rightEE);
1299+ }
1300+ else if (op == Expression::FAKEOP_MUL)
1301+ {
1302+ if (leftNum || rightNum)
1303+ {
1304+ int num = 0 ;
1305+ ExpandedExprs* ee = nullptr ;
1306+ if (leftNum)
1307+ {
1308+ num = leftNum.value_or (0 );
1309+ ee = &rightEE;
1310+ }
1311+ else
1312+ {
1313+ num = rightNum.value_or (0 );
1314+ ee = &leftEE;
1315+ }
1316+ for (int i = 0 ; i < num; i++)
1317+ {
1318+ append (expanded, *ee);
1319+ }
1320+ }
1321+ else
1322+ {
1323+ expanded.push_back (std::make_pair (true , SubExpression (std::make_shared<Expression>(op, left, right))));
1324+ }
1325+ }
1326+ else if (op == Expression::FAKEOP_DIV)
1327+ {
1328+ expanded.push_back (std::make_pair (true , SubExpression (std::make_shared<Expression>(op, left, right))));
1329+ }
1330+ else
1331+ {
1332+ throw CompilationError (CompilationStep::OPTIMIZER, " Unknown operation" , op.name );
1333+ }
1334+ }
1335+ }
1336+ else if (auto value = subExpr.checkValue ())
1337+ {
1338+ expanded.push_back (std::make_pair (true , subExpr));
1339+ }
1340+ else {
1341+ throw CompilationError (CompilationStep::OPTIMIZER, " Cannot expand expression" , subExpr.to_string ());
1342+ }
1343+ }
1344+
1345+ SubExpression calcValueExpr (const SubExpression& expr)
1346+ {
1347+ ExpandedExprs expanded;
1348+ expandExpression (expr, expanded);
12171349
12181350 // for(auto& p : expanded)
12191351 // logging::debug() << (p.first ? "+" : "-") << p.second->to_string() << " ";
12201352 // logging::debug() << logging::endl;
12211353
12221354 for (auto p = expanded.begin (); p != expanded.end ();)
12231355 {
1224- auto comp = std::find_if (
1225- expanded.begin (), expanded.end (), [&p](const std::pair<bool , std::shared_ptr<ValueExpr>>& other) {
1226- return p->first != other.first && *p->second == *other.second ;
1227- });
1356+ auto comp = std::find_if (expanded.begin (), expanded.end (), [&p](const std::pair<bool , SubExpression>& other) {
1357+ return p->first != other.first && p->second == other.second ;
1358+ });
12281359 if (comp != expanded.end ())
12291360 {
12301361 expanded.erase (comp);
@@ -1236,18 +1367,24 @@ std::shared_ptr<ValueExpr> calcValueExpr(std::shared_ptr<ValueExpr> expr)
12361367 }
12371368 }
12381369
1239- std::shared_ptr<ValueExpr> result = std::make_shared<ValueTerm> (INT_ZERO);
1370+ SubExpression result (INT_ZERO);
12401371 for (auto & p : expanded)
12411372 {
1242- result = std::make_shared<ValueBinaryOp>(result, p.first ? BO::Add : BO::Sub, p.second );
1373+ result = SubExpression ( std::make_shared<Expression>( p.first ? OP_ADD : OP_SUB, result, p.second ) );
12431374 }
12441375
12451376 return result;
12461377}
12471378
1379+ SubExpression replaceLocalToExpr (const SubExpression& expr, const Value& local, SubExpression newExpr)
1380+ {
1381+ return expr;
1382+ }
1383+
12481384void optimizations::combineDMALoads (const Module& module , Method& method, const Configuration& config)
12491385{
12501386 using namespace std ;
1387+ using namespace VariantNamespace ;
12511388
12521389 const std::regex vloadReg (" vload(2|3|4|8|16)" );
12531390
@@ -1306,7 +1443,7 @@ void optimizations::combineDMALoads(const Module& module, Method& method, const
13061443 logging::debug () << inst->to_string () << logging::endl;
13071444 }
13081445
1309- std::vector<std::pair<Value, std::shared_ptr<ValueExpr> >> addrExprs;
1446+ std::vector<std::pair<Value, SubExpression >> addrExprs;
13101447
13111448 for (auto & addrValue : offsetValues)
13121449 {
@@ -1318,47 +1455,46 @@ void optimizations::combineDMALoads(const Module& module, Method& method, const
13181455 }
13191456 else
13201457 {
1321- addrExprs.push_back (std::make_pair (addrValue, std::make_shared<ValueTerm> (addrValue)));
1458+ addrExprs.push_back (std::make_pair (addrValue, SubExpression (addrValue)));
13221459 }
13231460 }
13241461 else
13251462 {
13261463 // TODO: is it ok?
1327- addrExprs.push_back (std::make_pair (addrValue, std::make_shared<ValueTerm> (addrValue)));
1464+ addrExprs.push_back (std::make_pair (addrValue, SubExpression (addrValue)));
13281465 }
13291466 }
13301467
13311468 for (auto & current : addrExprs)
13321469 {
13331470 for (auto & other : addrExprs)
13341471 {
1335- auto replaced = current.second ->replaceLocal (other.first , other.second );
1336- current.second = replaced;
1472+ current.second = replaceLocalToExpr (current.second , other.first , other.second );
13371473 }
13381474 }
13391475
13401476 for (auto & pair : addrExprs)
13411477 {
1342- logging::debug () << pair.first .to_string () << " = " << pair.second -> to_string () << logging::endl;
1478+ logging::debug () << pair.first .to_string () << " = " << pair.second . to_string () << logging::endl;
13431479 }
13441480
1345- std::shared_ptr<ValueExpr> diff = nullptr ;
1481+ SubExpression diff;
13461482 bool eqDiff = true ;
13471483 for (size_t i = 1 ; i < addrExprs.size (); i++)
13481484 {
13491485 auto x = addrExprs[i - 1 ].second ;
13501486 auto y = addrExprs[i].second ;
1351- auto diffExpr = std::make_shared<ValueBinaryOp>(y, ValueBinaryOp::BinaryOp::Sub , x);
1487+ auto diffExpr = SubExpression ( std::make_shared<Expression>(OP_SUB, y , x) );
13521488
13531489 auto currentDiff = calcValueExpr (diffExpr);
13541490 // Apply calcValueExpr again for integer literals.
13551491 currentDiff = calcValueExpr (currentDiff);
13561492
1357- if (diff == nullptr )
1493+ if (! diff)
13581494 {
13591495 diff = currentDiff;
13601496 }
1361- if (* currentDiff != * diff)
1497+ if (currentDiff != diff)
13621498 {
13631499 eqDiff = false ;
13641500 break ;
@@ -1371,16 +1507,16 @@ void optimizations::combineDMALoads(const Module& module, Method& method, const
13711507 if (eqDiff)
13721508 {
13731509 // The form of diff should be "0 (+/-) expressions...", then remove the value 0 at most right.
1374- ValueExpr:: ExpandedExprs expanded;
1375- diff-> expand ( expanded);
1510+ ExpandedExprs expanded;
1511+ expandExpression (diff, expanded);
13761512 if (expanded.size () == 1 )
13771513 {
13781514 diff = expanded[0 ].second ;
13791515
13801516 // logging::debug() << "diff = " << diff->to_string() << logging::endl;
13811517
1382- auto term = std::dynamic_pointer_cast<ValueTerm>( diff);
1383- auto mpValue = ( term != nullptr ) ? term->value . getConstantValue () : Optional<Value>{};
1518+ auto term = diff. getConstantExpression ( );
1519+ auto mpValue = term. has_value ( ) ? term->getConstantValue () : Optional<Value>{};
13841520 auto mpLiteral = mpValue.has_value () ? mpValue->getLiteralValue () : Optional<Literal>{};
13851521
13861522 if (mpLiteral)
0 commit comments