diff --git a/include/public/marco/Dialect/BaseModelica/IR/BaseModelicaOpInterfaces.td b/include/public/marco/Dialect/BaseModelica/IR/BaseModelicaOpInterfaces.td index e1dd8474d..0cf0d17fb 100644 --- a/include/public/marco/Dialect/BaseModelica/IR/BaseModelicaOpInterfaces.td +++ b/include/public/marco/Dialect/BaseModelica/IR/BaseModelicaOpInterfaces.td @@ -90,6 +90,14 @@ def EquationExpressionOpInterface "void", "printExpression", (ins "::llvm::raw_ostream&":$os, "const ::llvm::DenseMap<::mlir::Value, int64_t>&":$inductions)>, + InterfaceMethod< + "Check if two expressions are equivalent", + "bool", "isEquivalent", + (ins "mlir::Operation*":$other, + "mlir::SymbolTableCollection&":$symbolTableCollection), "", [{ + // Safely assume that the two expressions are different. + return false; + }]>, InterfaceMethod< "Get the number of elements.", "uint64_t", "getNumOfExpressionElements", diff --git a/include/public/marco/Dialect/BaseModelica/Transforms/CallCSE.h b/include/public/marco/Dialect/BaseModelica/Transforms/CallCSE.h new file mode 100644 index 000000000..08b0abb93 --- /dev/null +++ b/include/public/marco/Dialect/BaseModelica/Transforms/CallCSE.h @@ -0,0 +1,13 @@ +#ifndef MARCO_DIALECT_BASEMODELICA_TRANSFORMS_CALLCSE_H +#define MARCO_DIALECT_BASEMODELICA_TRANSFORMS_CALLCSE_H + +#include "mlir/Pass/Pass.h" + +namespace mlir::bmodelica { +#define GEN_PASS_DECL_CALLCSEPASS +#include "marco/Dialect/BaseModelica/Transforms/Passes.h.inc" + +std::unique_ptr createCallCSEPass(); +} // namespace mlir::bmodelica + +#endif // MARCO_DIALECT_BASEMODELICA_TRANSFORMS_CALLCSE_H diff --git a/include/public/marco/Dialect/BaseModelica/Transforms/Passes.h b/include/public/marco/Dialect/BaseModelica/Transforms/Passes.h index 0e086bc80..00f21e87b 100644 --- a/include/public/marco/Dialect/BaseModelica/Transforms/Passes.h +++ b/include/public/marco/Dialect/BaseModelica/Transforms/Passes.h @@ -4,6 +4,7 @@ #include "marco/Dialect/BaseModelica/Transforms/AccessReplacementTest.h" #include "marco/Dialect/BaseModelica/Transforms/AutomaticDifferentiation.h" #include "marco/Dialect/BaseModelica/Transforms/BindingEquationConversion.h" +#include "marco/Dialect/BaseModelica/Transforms/CallCSE.h" #include "marco/Dialect/BaseModelica/Transforms/DerivativeChainRule.h" #include "marco/Dialect/BaseModelica/Transforms/DerivativesMaterialization.h" #include "marco/Dialect/BaseModelica/Transforms/EquationAccessSplit.h" diff --git a/include/public/marco/Dialect/BaseModelica/Transforms/Passes.td b/include/public/marco/Dialect/BaseModelica/Transforms/Passes.td index 83d8591d4..9e0b1ddc3 100644 --- a/include/public/marco/Dialect/BaseModelica/Transforms/Passes.td +++ b/include/public/marco/Dialect/BaseModelica/Transforms/Passes.td @@ -305,6 +305,27 @@ def EquationFunctionLoopHoistingPass let constructor = "mlir::bmodelica::createEquationFunctionLoopHoistingPass()"; } +def CallCSEPass + : Pass<"call-cse", "mlir::ModuleOp"> +{ + let summary = "Move equal function calls to dedicated equation."; + + let description = [{ + Move equal function calls to dedicated equation. + }]; + + let dependentDialects = [ + "mlir::bmodelica::BaseModelicaDialect" + ]; + + let statistics = [ + Statistic<"newCSEVariables", "new-cse-variables", "How many CSE variables have been created">, + Statistic<"replacedCalls", "replaced-calls", "How many calls were replaced by a CSE variable usage"> + ]; + + let constructor = "mlir::bmodelica::createCallCSEPass()"; +} + def ReadOnlyVariablesPropagationPass : Pass<"propagate-read-only-variables", "mlir::ModuleOp"> { diff --git a/lib/Dialect/BaseModelica/Transforms/CMakeLists.txt b/lib/Dialect/BaseModelica/Transforms/CMakeLists.txt index 7439c74fd..395dced5c 100644 --- a/lib/Dialect/BaseModelica/Transforms/CMakeLists.txt +++ b/lib/Dialect/BaseModelica/Transforms/CMakeLists.txt @@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRBaseModelicaTransforms AllocationOpInterfaceImpl.cpp BindingEquationConversion.cpp BufferizableOpInterfaceImpl.cpp + CallCSE.cpp ConstantMaterializableTypeInterfaceImpl.cpp DerivableOpInterfaceImpl.cpp DerivableTypeInterfaceImpl.cpp diff --git a/lib/Dialect/BaseModelica/Transforms/CallCSE.cpp b/lib/Dialect/BaseModelica/Transforms/CallCSE.cpp new file mode 100644 index 000000000..4f3cc1526 --- /dev/null +++ b/lib/Dialect/BaseModelica/Transforms/CallCSE.cpp @@ -0,0 +1,238 @@ +#include "marco/Dialect/BaseModelica/Transforms/CallCSE.h" +#include "marco/Dialect/BaseModelica/IR/BaseModelica.h" + +namespace mlir::bmodelica { +#define GEN_PASS_DEF_CALLCSEPASS +#include "marco/Dialect/BaseModelica/Transforms/Passes.h.inc" +} // namespace mlir::bmodelica + +using namespace ::mlir::bmodelica; + +namespace { +class CallCSEPass final : public impl::CallCSEPassBase { +public: + using CallCSEPassBase::CallCSEPassBase; + + void runOnOperation() override; + +private: + mlir::LogicalResult processModelOp(ModelOp modelOp); + + /// Replace all calls in the equivalence group with gets to a generated + /// variable. The variable will be driven by an equation derived from the + /// first call in the group. + /// + /// One variable and driver equation will be emitted per result, + /// if the call is to a function with multiple result values. + void emitCse(llvm::SmallVectorImpl &equivalenceGroup, ModelOp modelOp, + DynamicOp dynamicOp, mlir::SymbolTable &symbolTable, + mlir::RewriterBase &rewriter); +}; + +/// Get all call operations in the model. +void collectCallOps(ModelOp modelOp, llvm::SmallVectorImpl &callOps) { + llvm::SmallVector dynamicEquationOps; + modelOp.collectMainEquations(dynamicEquationOps); + + llvm::DenseSet visitedTemplateOps; + for (EquationInstanceOp equationOp : dynamicEquationOps) { + EquationTemplateOp templateOp = equationOp.getTemplate(); + if (!templateOp.getInductionVariables().empty() || + visitedTemplateOps.contains(templateOp)) { + continue; + } + visitedTemplateOps.insert(templateOp); + templateOp->walk([&](CallOp callOp) { callOps.push_back(callOp); }); + } +} + +/// Partition the list of call operations into groups given by +/// EquationExpressionOpInterface::isEquivalent +void buildCallEquivalenceGroups( + llvm::SmallVectorImpl &callOps, + llvm::SmallVectorImpl> &callEquivalenceGroups) { + mlir::SymbolTableCollection symbolTableCollection; + llvm::SmallVector> tmpCallEquivalenceGroups; + + for (CallOp callOp : callOps) { + auto callExpression = + mlir::cast(callOp.getOperation()); + + llvm::SmallVector *equivalenceGroup = find_if( + tmpCallEquivalenceGroups, [&](llvm::SmallVector &group) { + assert(!group.empty() && "groups should never be empty"); + return callExpression.isEquivalent(group.front(), + symbolTableCollection); + }); + + if (equivalenceGroup != tmpCallEquivalenceGroups.end()) { + // Add equivalent call to existing group + equivalenceGroup->push_back(callOp); + } else { + // Create new equivalence group + tmpCallEquivalenceGroups.push_back({callOp}); + } + } + + for (llvm::SmallVector &group : tmpCallEquivalenceGroups) { + if (group.size() > 1) { + callEquivalenceGroups.push_back(std::move(group)); + } + } +} + +/// Clone `op` and its def-use chain, returning the cloned version of `op`. +mlir::Operation *cloneDefUseChain(mlir::Operation *op, + mlir::RewriterBase &rewriter) { + llvm::SmallVector toClone; + llvm::SmallVector worklist({op}); + + // DFS through the def-use chain of `op` + while (!worklist.empty()) { + mlir::Operation *current = worklist.back(); + worklist.pop_back(); + toClone.push_back(current); + for (mlir::Value operand : current->getOperands()) { + if (mlir::Operation *defOp = operand.getDefiningOp()) { + worklist.push_back(defOp); + } + } + // Find the dependencies on operations not defined within the regions of + // `current`. No need to do this if it is isolated from above. + if (!current->hasTrait()) { + // Find all uses of values defined outside `current`. + current->walk([&](mlir::Operation *childOp) { + // Walk includes current, so skip it. + if (childOp == current) { + return; + } + for (mlir::Value operand : childOp->getOperands()) { + // If an operand is defined in the same scope as `current`, + // i.e. the equation template scope, add it to the worklist. + mlir::Operation *definingOp = operand.getDefiningOp(); + if (definingOp && definingOp->getBlock() == current->getBlock()) { + worklist.push_back(definingOp); + } + } + }); + } + } + + mlir::IRMapping mapping; + mlir::Operation *root = nullptr; + for (mlir::Operation *opToClone : llvm::reverse(toClone)) { + // Skip repeated dependencies on the same operation + if (mapping.contains(opToClone)) { + continue; + } + root = rewriter.clone(*opToClone, mapping); + } + return root; +} + +void CallCSEPass::emitCse(llvm::SmallVectorImpl &equivalenceGroup, + ModelOp modelOp, DynamicOp dynamicOp, + mlir::SymbolTable &symbolTable, + mlir::RewriterBase &rewriter) { + assert(!equivalenceGroup.empty() && "equivalenceGroup cannot be empty"); + CallOp representative = equivalenceGroup.front(); + const mlir::Location loc = representative.getLoc(); + + // Emit one variable per function result + llvm::SmallVector cseVariables; + for (auto result : llvm::enumerate(representative.getResults())) { + rewriter.setInsertionPointToStart(modelOp.getBody()); + // Emit cse variable + auto cseVariable = rewriter.create( + loc, "_cse", VariableType::wrap(result.value().getType())); + symbolTable.insert(cseVariable); + cseVariables.push_back(cseVariable); + + // Emit driver equation + rewriter.setInsertionPoint(dynamicOp); + auto equationTemplateOp = rewriter.create(loc); + rewriter.setInsertionPointToStart(equationTemplateOp.createBody(0)); + auto lhsOp = rewriter.create( + loc, rewriter.create(loc, cseVariable)->getResults()); + auto rhsOp = rewriter.create( + loc, + cloneDefUseChain(representative, rewriter)->getResult(result.index())); + rewriter.create(loc, lhsOp, rhsOp); + + // Add driver equation to dynamic operation + rewriter.setInsertionPointToEnd(dynamicOp.getBody()); + rewriter.create(rewriter.getUnknownLoc(), + equationTemplateOp); + } + + // Replace calls with get(s) to CSE variable(s) + for (auto &callOp : equivalenceGroup) { + rewriter.setInsertionPoint(callOp); + + llvm::SmallVector results; + for (VariableOp cseVariable : cseVariables) { + results.push_back( + rewriter.create(loc, cseVariable).getResult()); + } + rewriter.replaceOp(callOp, results); + } + + this->replacedCalls += equivalenceGroup.size(); + ++this->newCSEVariables; +} + +mlir::LogicalResult CallCSEPass::processModelOp(ModelOp modelOp) { + mlir::IRRewriter rewriter(modelOp); + mlir::SymbolTable symbolTable(modelOp); + + llvm::SmallVector callOps; + collectCallOps(modelOp, callOps); + + llvm::SmallVector> callEquivalenceGroups; + buildCallEquivalenceGroups(callOps, callEquivalenceGroups); + + if (callEquivalenceGroups.empty()) { + return mlir::success(); + } + + rewriter.setInsertionPointToEnd(modelOp.getBody()); + DynamicOp dynamicOp = rewriter.create(rewriter.getUnknownLoc()); + rewriter.createBlock(&dynamicOp.getRegion()); + + for (llvm::SmallVector &equivalenceGroup : callEquivalenceGroups) { + // Only emit CSEs that will lead to an equivalent, or lower amount of calls + if (equivalenceGroup.size() >= equivalenceGroup.front().getNumResults()) { + emitCse(equivalenceGroup, modelOp, dynamicOp, symbolTable, rewriter); + } + } + + if (dynamicOp.getBody()->empty()) { + rewriter.eraseOp(dynamicOp); + } + + return mlir::success(); +} +} // namespace + +void CallCSEPass::runOnOperation() { + llvm::SmallVector modelOps; + + walkClasses(getOperation(), [&](mlir::Operation *op) { + if (auto modelOp = mlir::dyn_cast(op)) { + modelOps.push_back(modelOp); + } + }); + + if (mlir::failed(mlir::failableParallelForEach( + &getContext(), modelOps, [&](mlir::Operation *op) { + return processModelOp(mlir::cast(op)); + }))) { + return signalPassFailure(); + } +} + +namespace mlir::bmodelica { +std::unique_ptr createCallCSEPass() { + return std::make_unique(); +} +} // namespace mlir::bmodelica \ No newline at end of file diff --git a/lib/Dialect/BaseModelica/Transforms/EquationExpressionOpInterfaceImpl.cpp b/lib/Dialect/BaseModelica/Transforms/EquationExpressionOpInterfaceImpl.cpp index 6cbfe205d..1cae914a9 100644 --- a/lib/Dialect/BaseModelica/Transforms/EquationExpressionOpInterfaceImpl.cpp +++ b/lib/Dialect/BaseModelica/Transforms/EquationExpressionOpInterfaceImpl.cpp @@ -3,13 +3,13 @@ using namespace ::mlir::bmodelica; -static void -printExpression(llvm::raw_ostream &os, mlir::Value value, - const llvm::DenseMap &inductions) { +namespace { +void printExpression(llvm::raw_ostream &os, const mlir::Value value, + const llvm::DenseMap &inductions) { mlir::Operation *op = value.getDefiningOp(); if (!op) { - if (auto inductionsIt = inductions.find(value); + if (const auto inductionsIt = inductions.find(value); inductionsIt != inductions.end()) { os << "{ind " << inductionsIt->getSecond() << "}"; } else { @@ -28,10 +28,452 @@ printExpression(llvm::raw_ostream &os, mlir::Value value, expressionOp.printExpression(os, inductions); } +template +void printBinaryExpression( + mlir::Operation *op, llvm::raw_ostream &os, + const llvm::DenseMap &inductions, + const llvm::StringLiteral opSymbol) { + auto castedOp = mlir::cast(op); + + os << "("; + ::printExpression(os, castedOp.getLhs(), inductions); + os << " " << opSymbol << " "; + ::printExpression(os, castedOp.getRhs(), inductions); + os << ")"; +} + +template +void printUnaryExpression( + mlir::Operation *op, llvm::raw_ostream &os, + const llvm::DenseMap &inductions, + const llvm::StringLiteral opName) { + + auto castedOp = mlir::cast(op); + + os << opName << "("; + ::printExpression(os, castedOp.getOperand(), inductions); + os << ")"; +} + +bool areExpressionOperandsEquivalent( + mlir::ValueRange firstOperands, mlir::ValueRange secondOperands, + mlir::SymbolTableCollection &symbolTableCollection) { + if (firstOperands.size() != secondOperands.size()) { + return false; + } + + for (auto [firstOperand, secondOperand] : + llvm::zip(firstOperands, secondOperands)) { + auto firstExp = firstOperand.getDefiningOp(); + auto secondExp = + secondOperand.getDefiningOp(); + + if (!firstExp || !secondExp) { + return false; + } + + if (!firstExp.isEquivalent(secondExp, symbolTableCollection)) { + return false; + } + } + + return true; +} + +bool areEquationExpressionsEquivalent( + mlir::Operation *firstOp, mlir::Operation *secondOp, + mlir::SymbolTableCollection &symbolTableCollection) { + if (firstOp->getResultTypes() != secondOp->getResultTypes()) { + return false; + } + + return areExpressionOperandsEquivalent( + firstOp->getOperands(), secondOp->getOperands(), symbolTableCollection); +} + +template +bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) { + auto otherCasted = mlir::dyn_cast(other); + + if (!otherCasted) { + return false; + } + + return areEquationExpressionsEquivalent(op, otherCasted.getOperation(), + symbolTableCollection); +} +} // namespace + namespace { +struct RangeOpInterface + : EquationExpressionOpInterface::ExternalModel { + void printExpression( + mlir::Operation *op, llvm::raw_ostream &os, + const llvm::DenseMap &inductions) const { + os << "range("; + llvm::interleaveComma(op->getOperands(), os, + [&](const mlir::Value operand) { + ::printExpression(os, operand, inductions); + }); + os << ")"; + } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } +}; + +struct ReductionOpInterface + : EquationExpressionOpInterface::ExternalModel { + void printExpression( + mlir::Operation *op, llvm::raw_ostream &os, + const llvm::DenseMap &inductions) const { + auto castedOp = mlir::cast(op); + + // Add the inductions to the inductions map. + llvm::DenseMap expandedInductions(inductions); + const auto inductionValues = castedOp.getInductions(); + + for (mlir::Value inductionValue : inductionValues) { + const auto id = static_cast(expandedInductions.size()); + expandedInductions[inductionValue] = id; + } + + // Print the operation. + os << castedOp.getAction(); + os << "("; + + auto terminator = mlir::cast(castedOp.getBody()->getTerminator()); + + llvm::interleaveComma(terminator.getValues(), os, + [&](const mlir::Value exp) { + ::printExpression(os, exp, expandedInductions); + }); + + os << " for "; + const auto iterables = castedOp.getIterables(); + + for (size_t i = 0, e = inductionValues.size(); i < e; ++i) { + if (i != 0) { + os << ", "; + } + + ::printExpression(os, inductionValues[i], expandedInductions); + } + + os << " in "; + + llvm::interleaveComma(iterables, os, [&](const mlir::Value exp) { + ::printExpression(os, exp, expandedInductions); + }); + + os << ")"; + } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + auto casted = mlir::cast(op); + auto otherCasted = mlir::dyn_cast(other); + + if (!otherCasted) { + return false; + } + + if (casted.getAction() != otherCasted.getAction()) { + return false; + } + + if (!areEquationExpressionsEquivalent(op, otherCasted.getOperation(), + symbolTableCollection)) { + return false; + } + + auto yieldOp = mlir::cast(casted.getBody()->getTerminator()); + auto otherYieldOp = + mlir::cast(otherCasted.getBody()->getTerminator()); + return areExpressionOperandsEquivalent( + yieldOp.getValues(), otherYieldOp.getValues(), symbolTableCollection); + } + + uint64_t getNumOfExpressionElements(mlir::Operation *op) const { + auto castedOp = mlir::cast(op); + + auto terminator = mlir::cast(castedOp.getBody()->getTerminator()); + + return terminator.getValues().size(); + } + + mlir::Value getExpressionElement(mlir::Operation *op, + const uint64_t element) const { + auto castedOp = mlir::cast(op); + + auto terminator = mlir::cast(castedOp.getBody()->getTerminator()); + + return terminator.getValues()[element]; + } + + llvm::SmallVector + getAdditionalInductions(mlir::Operation *op) const { + auto castedOp = mlir::cast(op); + llvm::SmallVector result; + const auto inductions = castedOp.getInductions(); + result.append(inductions.begin(), inductions.end()); + return result; + } + + mlir::LogicalResult + mapAdditionalInductions(mlir::Operation *op, + AdditionalInductions &additionalInductions) const { + auto castedOp = mlir::cast(op); + + IndexSet indices; + llvm::SmallVector> inductionsMap; + + for (const auto &[induction, iterable] : + zip(castedOp.getInductions(), castedOp.getIterables())) { + auto constantOp = iterable.getDefiningOp(); + + if (!constantOp) { + return mlir::failure(); + } + + auto iterableAttr = constantOp.getValue(); + + if (auto rangeAttr = iterableAttr.dyn_cast()) { + assert(rangeAttr.getStep() == 1); + + const auto lowerBound = + static_cast(rangeAttr.getLowerBound()); + + const auto upperBound = + static_cast(rangeAttr.getUpperBound()); + + Range range(lowerBound, upperBound + 1); + indices = indices.append(IndexSet(MultidimensionalRange(range))); + + auto currentDimension = static_cast(indices.rank() - 1); + inductionsMap.emplace_back(induction, currentDimension); + + continue; + } + + if (auto rangeAttr = iterableAttr.dyn_cast()) { + assert(rangeAttr.getStep().convertToDouble() == 1); + + const auto lowerBound = static_cast( + rangeAttr.getLowerBound().convertToDouble()); + + const auto upperBound = static_cast( + rangeAttr.getUpperBound().convertToDouble()); + + Range range(lowerBound, upperBound); + indices = indices.append(IndexSet(MultidimensionalRange(range))); + + auto currentDimension = static_cast(indices.rank() - 1); + inductionsMap.emplace_back(induction, currentDimension); + + continue; + } + + return mlir::failure(); + } + + const uint64_t iterationSpace = + additionalInductions.addIterationSpace(std::move(indices)); + + for (size_t i = 0, e = inductionsMap.size(); i < e; ++i) { + additionalInductions.addInductionVariable( + inductionsMap[i].first, iterationSpace, inductionsMap[i].second); + } + + return mlir::success(); + } +}; + +struct CallOpInterface + : EquationExpressionOpInterface::ExternalModel { + void printExpression( + mlir::Operation *op, llvm::raw_ostream &os, + const llvm::DenseMap &inductions) const { + auto casted = mlir::cast(op); + os << casted.getCallee() << "("; + + llvm::interleaveComma(casted.getArgs(), os, [&](const mlir::Value exp) { + ::printExpression(os, exp, inductions); + }); + + os << ")"; + } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + auto casted = mlir::cast(op); + auto otherCasted = mlir::dyn_cast(other); + + if (!otherCasted) { + return false; + } + + if (casted.getCallee() != otherCasted.getCallee()) { + return false; + } + + if (casted->getResultTypes() != otherCasted->getResultTypes()) { + return false; + } + + const auto argNames = casted.getArgNames(); + const auto otherArgNames = otherCasted.getArgNames(); + + llvm::StringMap argNamesPos; + llvm::StringMap otherArgNamesPos; + + if (argNames) { + getArgNamesPos(*argNames, argNamesPos); + } + + if (otherArgNames) { + getArgNamesPos(*otherArgNames, otherArgNamesPos); + } + + if (argNames && otherArgNames) { + if (!haveSameArgNames(argNamesPos, otherArgNamesPos)) { + return false; + } + + for (const auto &entry : argNamesPos) { + const mlir::Value arg = casted.getArgs()[entry.getValue()]; + const mlir::Value otherArg = + otherCasted.getArgs()[otherArgNamesPos[entry.getKey()]]; + + if (!areExpressionOperandsEquivalent(arg, otherArg, + symbolTableCollection)) { + return false; + } + } + } else if (argNames) { + if (failed(getArgNamesPos(otherCasted, symbolTableCollection, + otherArgNamesPos))) { + return false; + } + + if (!compareNamedUnnamedArgs(casted.getArgs(), argNamesPos, + otherCasted.getArgs(), otherArgNamesPos, + symbolTableCollection)) { + return false; + } + } else if (otherArgNames) { + if (failed(getArgNamesPos(casted, symbolTableCollection, argNamesPos))) { + return false; + } + + if (!compareNamedUnnamedArgs(otherCasted.getArgs(), otherArgNamesPos, + casted.getArgs(), argNamesPos, + symbolTableCollection)) { + return false; + } + } else { + if (!areExpressionOperandsEquivalent( + casted.getArgs(), otherCasted.getArgs(), symbolTableCollection)) { + return false; + } + } + + return true; + } + + static void getArgNamesPos(mlir::ArrayAttr argNames, + llvm::StringMap &pos) { + for (auto argName : llvm::enumerate(argNames)) { + const auto name = + argName.value().cast().getValue(); + pos[name] = argName.index(); + } + } + + static mlir::LogicalResult + getArgNamesPos(CallOp callOp, + mlir::SymbolTableCollection &symbolTableCollection, + llvm::StringMap &pos) { + auto otherFunctionOp = mlir::dyn_cast(callOp.getFunction( + callOp->getParentOfType(), symbolTableCollection)); + + if (!otherFunctionOp) { + return mlir::failure(); + } + + size_t variablePos = 0; + + for (VariableOp variableOp : otherFunctionOp.getVariables()) { + if (variableOp.isInput()) { + pos[variableOp.getSymName()] = variablePos++; + } + } + + return mlir::success(); + } + + bool containsArgNames(const llvm::StringMap &parent, + const llvm::StringMap &child) const { + return all_of(child, [&](const auto &entry) { + return parent.contains(entry.getKey()); + }); + } + + bool haveSameArgNames(const llvm::StringMap &first, + const llvm::StringMap &second) const { + return containsArgNames(first, second) && containsArgNames(second, first); + } + + bool compareNamedUnnamedArgs( + const mlir::ValueRange namedArgs, + const llvm::StringMap &namedArgsPos, mlir::ValueRange unnamedArgs, + const llvm::StringMap &unnamedArgsPos, + mlir::SymbolTableCollection &symbolTableCollection) const { + if (namedArgs.size() != unnamedArgs.size()) { + return false; + } + + llvm::DenseMap inverseUnnamedArgsPos; + + for (const auto &entry : unnamedArgsPos) { + inverseUnnamedArgsPos[entry.getValue()] = entry.getKey().str(); + } + + for (auto unnamedArg : llvm::enumerate(unnamedArgs)) { + auto inverseUnnamedArgPosIt = + inverseUnnamedArgsPos.find(unnamedArg.index()); + + if (inverseUnnamedArgPosIt == inverseUnnamedArgsPos.end()) { + return false; + } + + auto namedArgsPosIt = + namedArgsPos.find(inverseUnnamedArgPosIt->getSecond()); + + if (namedArgsPosIt == namedArgsPos.end()) { + return false; + } + + assert(namedArgsPosIt->getValue() < namedArgs.size()); + mlir::Value namedArg = namedArgs[namedArgsPosIt->getValue()]; + + if (!areExpressionOperandsEquivalent(namedArg, unnamedArg.value(), + symbolTableCollection)) { + return false; + } + } + + return true; + } +}; + struct EquationSidesOpInterface - : public EquationExpressionOpInterface::ExternalModel< - EquationSidesOpInterface, EquationSidesOp> { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -39,22 +481,22 @@ struct EquationSidesOpInterface os << "{"; - llvm::interleaveComma(castedOp.getLhsValues(), os, [&](mlir::Value exp) { - ::printExpression(os, exp, inductions); - }); + llvm::interleaveComma( + castedOp.getLhsValues(), os, + [&](const mlir::Value exp) { ::printExpression(os, exp, inductions); }); os << "} = {"; - llvm::interleaveComma(castedOp.getRhsValues(), os, [&](mlir::Value exp) { - ::printExpression(os, exp, inductions); - }); + llvm::interleaveComma( + castedOp.getRhsValues(), os, + [&](const mlir::Value exp) { ::printExpression(os, exp, inductions); }); os << "}"; } }; struct TensorFromElementsOpInterface - : public EquationExpressionOpInterface::ExternalModel< + : EquationExpressionOpInterface::ExternalModel< TensorFromElementsOpInterface, TensorFromElementsOp> { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, @@ -63,24 +505,30 @@ struct TensorFromElementsOpInterface os << "{"; - llvm::interleaveComma(castedOp.getValues(), os, [&](mlir::Value exp) { + llvm::interleaveComma(castedOp.getValues(), os, [&](const mlir::Value exp) { ::printExpression(os, exp, inductions); }); os << "}"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, + symbolTableCollection); + } }; struct TensorBroadcastOpInterface - : public EquationExpressionOpInterface::ExternalModel< - TensorBroadcastOpInterface, TensorBroadcastOp> { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { auto castedOp = mlir::cast(op); os << "{"; - mlir::TensorType tensorType = castedOp.getResult().getType(); + const auto tensorType = castedOp.getResult().getType(); for (int64_t i = 0, e = tensorType.getNumElements(); i < e; ++i) { if (i != 0) { @@ -92,11 +540,16 @@ struct TensorBroadcastOpInterface os << "}"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; struct TensorViewOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -107,11 +560,16 @@ struct TensorViewOpInterface llvm::interleaveComma( castedOp.getSubscriptions(), os, - [&](mlir::Value exp) { ::printExpression(os, exp, inductions); }); + [&](const mlir::Value exp) { ::printExpression(os, exp, inductions); }); os << "]"; } + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } + mlir::LogicalResult getEquationAccesses( mlir::Operation *op, llvm::SmallVectorImpl &accesses, mlir::SymbolTableCollection &symbolTable, @@ -121,10 +579,10 @@ struct TensorViewOpInterface &dimensionAccesses, EquationPath path) const { auto castedOp = mlir::cast(op); - auto indices = castedOp.getSubscriptions(); + const auto indices = castedOp.getSubscriptions(); for (size_t i = 0, e = indices.size(); i < e; ++i) { - mlir::Value index = indices[e - 1 - i]; + const mlir::Value index = indices[e - 1 - i]; auto dimensionAccess = getDimensionAccess(explicitInductionsPositionMap, additionalInductions, index); @@ -150,8 +608,8 @@ struct TensorViewOpInterface }; struct TensorExtractOpInterface - : public EquationExpressionOpInterface::ExternalModel< - TensorExtractOpInterface, TensorExtractOp> { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -160,13 +618,18 @@ struct TensorExtractOpInterface ::printExpression(os, castedOp.getTensor(), inductions); os << "["; - llvm::interleaveComma(castedOp.getIndices(), os, [&](mlir::Value exp) { - ::printExpression(os, exp, inductions); - }); + llvm::interleaveComma( + castedOp.getIndices(), os, + [&](const mlir::Value exp) { ::printExpression(os, exp, inductions); }); os << "]"; } + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } + mlir::LogicalResult getEquationAccesses( mlir::Operation *op, llvm::SmallVectorImpl &accesses, mlir::SymbolTableCollection &symbolTable, @@ -176,10 +639,10 @@ struct TensorExtractOpInterface &dimensionAccesses, EquationPath path) const { auto castedOp = mlir::cast(op); - auto indices = castedOp.getIndices(); + const auto indices = castedOp.getIndices(); for (size_t i = 0, e = indices.size(); i < e; ++i) { - mlir::Value index = indices[e - 1 - i]; + const mlir::Value index = indices[e - 1 - i]; auto dimensionAccess = getDimensionAccess(explicitInductionsPositionMap, additionalInductions, index); @@ -205,8 +668,8 @@ struct TensorExtractOpInterface }; struct ArrayFromElementsOpInterface - : public EquationExpressionOpInterface::ExternalModel< - ArrayFromElementsOpInterface, ArrayFromElementsOp> { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -214,17 +677,23 @@ struct ArrayFromElementsOpInterface os << "{"; - llvm::interleaveComma(castedOp.getValues(), os, [&](mlir::Value exp) { + llvm::interleaveComma(castedOp.getValues(), os, [&](const mlir::Value exp) { ::printExpression(os, exp, inductions); }); os << "}"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, + symbolTableCollection); + } }; struct ArrayBroadcastOpInterface - : public EquationExpressionOpInterface::ExternalModel< - ArrayBroadcastOpInterface, ArrayBroadcastOp> { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -243,11 +712,16 @@ struct ArrayBroadcastOpInterface os << "}"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; struct ArrayCastOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -255,6 +729,11 @@ struct ArrayCastOpInterface ::printExpression(os, castedOp.getOperand(), inductions); } + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } + mlir::LogicalResult getEquationAccesses( mlir::Operation *op, llvm::SmallVectorImpl &accesses, mlir::SymbolTableCollection &symbolTable, @@ -264,8 +743,8 @@ struct ArrayCastOpInterface &dimensionAccesses, EquationPath path) const { auto castedOp = mlir::cast(op); - mlir::Value source = castedOp.getSource(); - auto childOp = source.getDefiningOp(); + const mlir::Value source = castedOp.getSource(); + const auto childOp = source.getDefiningOp(); if (!childOp) { return mlir::success(); @@ -277,7 +756,7 @@ struct ArrayCastOpInterface return mlir::failure(); } - if (mlir::failed(expressionInt.getEquationAccesses( + if (failed(expressionInt.getEquationAccesses( accesses, symbolTable, explicitInductionsPositionMap, additionalInductions, dimensionAccesses, path + 0))) { return mlir::failure(); @@ -288,8 +767,7 @@ struct ArrayCastOpInterface }; struct DimOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -302,10 +780,15 @@ struct DimOpInterface os << ")"; } + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } + uint64_t getNumOfExpressionElements(mlir::Operation *op) const { return 1; } mlir::Value getExpressionElement(mlir::Operation *op, - uint64_t position) const { + const uint64_t position) const { auto castedOp = mlir::cast(op); assert(position == 0); return castedOp.getDimension(); @@ -313,8 +796,8 @@ struct DimOpInterface }; struct SubscriptionOpInterface - : public EquationExpressionOpInterface::ExternalModel< - SubscriptionOpInterface, SubscriptionOp> { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -323,13 +806,18 @@ struct SubscriptionOpInterface ::printExpression(os, castedOp.getSource(), inductions); os << "["; - llvm::interleaveComma(castedOp.getIndices(), os, [&](mlir::Value exp) { - ::printExpression(os, exp, inductions); - }); + llvm::interleaveComma( + castedOp.getIndices(), os, + [&](const mlir::Value exp) { ::printExpression(os, exp, inductions); }); os << "]"; } + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } + mlir::LogicalResult getEquationAccesses( mlir::Operation *op, llvm::SmallVectorImpl &accesses, mlir::SymbolTableCollection &symbolTable, @@ -339,10 +827,10 @@ struct SubscriptionOpInterface &dimensionAccesses, EquationPath path) const { auto castedOp = mlir::cast(op); - auto indices = castedOp.getIndices(); + const auto indices = castedOp.getIndices(); for (size_t i = 0, e = indices.size(); i < e; ++i) { - mlir::Value index = indices[e - 1 - i]; + const mlir::Value index = indices[e - 1 - i]; auto dimensionAccess = getDimensionAccess(explicitInductionsPositionMap, additionalInductions, index); @@ -368,8 +856,7 @@ struct SubscriptionOpInterface }; struct LoadOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -378,13 +865,18 @@ struct LoadOpInterface ::printExpression(os, castedOp.getArray(), inductions); os << "["; - llvm::interleaveComma(castedOp.getIndices(), os, [&](mlir::Value exp) { - ::printExpression(os, exp, inductions); - }); + llvm::interleaveComma( + castedOp.getIndices(), os, + [&](const mlir::Value exp) { ::printExpression(os, exp, inductions); }); os << "]"; } + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } + mlir::LogicalResult getEquationAccesses( mlir::Operation *op, llvm::SmallVectorImpl &accesses, mlir::SymbolTableCollection &symbolTable, @@ -394,10 +886,10 @@ struct LoadOpInterface &dimensionAccesses, EquationPath path) const { auto castedOp = mlir::cast(op); - auto indices = castedOp.getIndices(); + const auto indices = castedOp.getIndices(); for (size_t i = 0, e = indices.size(); i < e; ++i) { - mlir::Value index = indices[e - 1 - i]; + const mlir::Value index = indices[e - 1 - i]; auto dimensionAccess = getDimensionAccess(explicitInductionsPositionMap, additionalInductions, index); @@ -423,8 +915,8 @@ struct LoadOpInterface }; struct VariableGetOpInterface - : public EquationExpressionOpInterface::ExternalModel< - VariableGetOpInterface, VariableGetOp> { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -432,6 +924,22 @@ struct VariableGetOpInterface os << castedOp.getVariable(); } + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + auto casted = mlir::cast(op); + auto otherCasted = mlir::dyn_cast(other); + + if (!otherCasted) { + return false; + } + + if (casted.getVariable() != otherCasted.getVariable()) { + return false; + } + + return areEquationExpressionsEquivalent(op, other, symbolTableCollection); + } + mlir::LogicalResult getEquationAccesses( mlir::Operation *op, llvm::SmallVectorImpl &accesses, mlir::SymbolTableCollection &symbolTable, @@ -450,17 +958,17 @@ struct VariableGetOpInterface } // Finalize the accesses. - auto numOfInductions = + const auto numOfInductions = static_cast(explicitInductionsPositionMap.size()); - if (auto tensorType = castedOp.getType().dyn_cast(); + if (const auto tensorType = castedOp.getType().dyn_cast(); tensorType && tensorType.getRank() > static_cast(reverted.size())) { // Access to each scalar variable. for (int64_t i = static_cast(reverted.size()), rank = tensorType.getRank(); i < rank; ++i) { - int64_t dimension = tensorType.getDimSize(i); + const int64_t dimension = tensorType.getDimSize(i); assert(dimension != mlir::ShapedType::kDynamic); reverted.push_back(std::make_unique( @@ -478,8 +986,8 @@ struct VariableGetOpInterface }; struct GlobalVariableGetOpInterface - : public EquationExpressionOpInterface::ExternalModel< - GlobalVariableGetOpInterface, GlobalVariableGetOp> { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -487,6 +995,22 @@ struct GlobalVariableGetOpInterface os << castedOp.getVariable(); } + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + auto casted = mlir::cast(op); + auto otherCasted = mlir::dyn_cast(other); + + if (!otherCasted) { + return false; + } + + if (casted.getVariable() != otherCasted.getVariable()) { + return false; + } + + return areEquationExpressionsEquivalent(op, other, symbolTableCollection); + } + mlir::LogicalResult getEquationAccesses( mlir::Operation *op, llvm::SmallVectorImpl &accesses, mlir::SymbolTableCollection &symbolTable, @@ -500,45 +1024,63 @@ struct GlobalVariableGetOpInterface }; struct ConstantOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { auto castedOp = mlir::cast(op); - if (auto boolAttr = castedOp.getValue().dyn_cast()) { + if (const auto boolAttr = castedOp.getValue().dyn_cast()) { os << (boolAttr.getValue() ? "true" : "false"); return; } - if (auto integerAttr = castedOp.getValue().dyn_cast()) { + if (const auto integerAttr = castedOp.getValue().dyn_cast()) { os << integerAttr.getValue(); return; } - if (auto realAttr = castedOp.getValue().dyn_cast()) { + if (const auto realAttr = castedOp.getValue().dyn_cast()) { os << realAttr.getValue().convertToDouble(); return; } - if (auto integerAttr = castedOp.getValue().dyn_cast()) { + if (const auto integerAttr = + castedOp.getValue().dyn_cast()) { os << integerAttr.getValue(); return; } - if (auto floatAttr = castedOp.getValue().dyn_cast()) { + if (const auto floatAttr = + castedOp.getValue().dyn_cast()) { os << floatAttr.getValueAsDouble(); return; } castedOp.getValue().print(os, true); } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + auto casted = mlir::cast(op); + auto otherCasted = mlir::dyn_cast(other); + + if (!otherCasted) { + return false; + } + + if (casted.getValue() != otherCasted.getValue()) { + return false; + } + + return areEquationExpressionsEquivalent(op, other, symbolTableCollection); + } }; struct NegateOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -548,139 +1090,127 @@ struct NegateOpInterface ::printExpression(os, castedOp.getOperand(), inductions); os << ")"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; struct AddOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printBinaryExpression(op, os, inductions, "+"); + } - os << "("; - ::printExpression(os, castedOp.getLhs(), inductions); - os << " + "; - ::printExpression(os, castedOp.getRhs(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct AddEWOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printBinaryExpression(op, os, inductions, ".+"); + } - os << "("; - ::printExpression(os, castedOp.getLhs(), inductions); - os << " .+ "; - ::printExpression(os, castedOp.getRhs(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct SubOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printBinaryExpression(op, os, inductions, "-"); + } - os << "("; - ::printExpression(os, castedOp.getLhs(), inductions); - os << " - "; - ::printExpression(os, castedOp.getRhs(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct SubEWOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printBinaryExpression(op, os, inductions, ".-"); + } - os << "("; - ::printExpression(os, castedOp.getLhs(), inductions); - os << " .- "; - ::printExpression(os, castedOp.getRhs(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct MulOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printBinaryExpression(op, os, inductions, "*"); + } - os << "("; - ::printExpression(os, castedOp.getLhs(), inductions); - os << " * "; - ::printExpression(os, castedOp.getRhs(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct MulEWOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printBinaryExpression(op, os, inductions, ".*"); + } - os << "("; - ::printExpression(os, castedOp.getLhs(), inductions); - os << " .* "; - ::printExpression(os, castedOp.getRhs(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct DivOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printBinaryExpression(op, os, inductions, "/"); + } - os << "("; - ::printExpression(os, castedOp.getLhs(), inductions); - os << " / "; - ::printExpression(os, castedOp.getRhs(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct DivEWOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printBinaryExpression(op, os, inductions, "./"); + } - os << "("; - ::printExpression(os, castedOp.getLhs(), inductions); - os << " ./ "; - ::printExpression(os, castedOp.getRhs(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct PowOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -692,11 +1222,15 @@ struct PowOpInterface ::printExpression(os, castedOp.getExponent(), inductions); os << ")"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; struct PowEWOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -708,149 +1242,142 @@ struct PowEWOpInterface ::printExpression(os, castedOp.getExponent(), inductions); os << ")"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; struct EqOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printBinaryExpression(op, os, inductions, "=="); + } - os << "("; - ::printExpression(os, castedOp.getLhs(), inductions); - os << " == "; - ::printExpression(os, castedOp.getRhs(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct NotEqOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printBinaryExpression(op, os, inductions, "!="); + } - os << "("; - ::printExpression(os, castedOp.getLhs(), inductions); - os << " != "; - ::printExpression(os, castedOp.getRhs(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct GtOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printBinaryExpression(op, os, inductions, ">"); + } - os << "("; - ::printExpression(os, castedOp.getLhs(), inductions); - os << " > "; - ::printExpression(os, castedOp.getRhs(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct GteOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printBinaryExpression(op, os, inductions, ">="); + } - os << "("; - ::printExpression(os, castedOp.getLhs(), inductions); - os << " >= "; - ::printExpression(os, castedOp.getRhs(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct LtOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printBinaryExpression(op, os, inductions, "<"); + } - os << "("; - ::printExpression(os, castedOp.getLhs(), inductions); - os << " < "; - ::printExpression(os, castedOp.getRhs(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct LteOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printBinaryExpression(op, os, inductions, "<="); + } - os << "("; - ::printExpression(os, castedOp.getLhs(), inductions); - os << " <= "; - ::printExpression(os, castedOp.getRhs(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct NotOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "!"); + } - os << "!("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct AndOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printBinaryExpression(op, os, inductions, "&&"); + } - os << "("; - ::printExpression(os, castedOp.getLhs(), inductions); - os << " && "; - ::printExpression(os, castedOp.getRhs(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct OrOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printBinaryExpression(op, os, inductions, "||"); + } - os << "("; - ::printExpression(os, castedOp.getLhs(), inductions); - os << " || "; - ::printExpression(os, castedOp.getRhs(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct SelectOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -859,79 +1386,83 @@ struct SelectOpInterface ::printExpression(os, castedOp.getCondition(), inductions); os << " ? ("; - llvm::interleaveComma(castedOp.getTrueValues(), os, [&](mlir::Value exp) { - ::printExpression(os, exp, inductions); - }); + llvm::interleaveComma( + castedOp.getTrueValues(), os, + [&](const mlir::Value exp) { ::printExpression(os, exp, inductions); }); os << ") : ("; - llvm::interleaveComma(castedOp.getFalseValues(), os, [&](mlir::Value exp) { - ::printExpression(os, exp, inductions); - }); + llvm::interleaveComma( + castedOp.getFalseValues(), os, + [&](const mlir::Value exp) { ::printExpression(os, exp, inductions); }); os << ")"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; struct AbsOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "abs"); + } - os << "abs("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct AcosOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "acos"); + } - os << "acos("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct AsinOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "asin"); + } - os << "asin("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct AtanOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "atan"); + } - os << "atan("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct Atan2OpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -943,67 +1474,73 @@ struct Atan2OpInterface ::printExpression(os, castedOp.getX(), inductions); os << ")"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; struct CeilOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "ceil"); + } - os << "ceil("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct CosOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "cos"); + } - os << "cos("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct CoshOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "cosh"); + } - os << "cosh("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct DiagonalOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "diagonal"); + } - os << "diagonal("; - ::printExpression(os, castedOp.getValues(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct DivTruncOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -1015,81 +1552,88 @@ struct DivTruncOpInterface ::printExpression(os, castedOp.getY(), inductions); os << ")"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; struct ExpOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "exp"); + } - os << "exp("; - ::printExpression(os, castedOp.getExponent(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct FillOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "fill"); + } - os << "fill("; - ::printExpression(os, castedOp.getValue(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct FloorOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "floor"); + } - os << "floor("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct IdentityOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "identity"); + } - os << "identity("; - ::printExpression(os, castedOp.getSize(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct IntegerOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "integer"); + } - os << "integer("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct LinspaceOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -1103,39 +1647,43 @@ struct LinspaceOpInterface ::printExpression(os, castedOp.getAmount(), inductions); os << ")"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; struct LogOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "log"); + } - os << "log("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct Log10OpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "log10"); + } - os << "log10("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct MaxOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -1144,18 +1692,22 @@ struct MaxOpInterface os << "max("; ::printExpression(os, castedOp.getFirst(), inductions); - if (mlir::Value second = castedOp.getSecond()) { + if (const mlir::Value second = castedOp.getSecond()) { os << ", "; ::printExpression(os, second, inductions); } os << ")"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; struct MinOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -1164,18 +1716,22 @@ struct MinOpInterface os << "min("; ::printExpression(os, castedOp.getFirst(), inductions); - if (mlir::Value second = castedOp.getSecond()) { + if (const mlir::Value second = castedOp.getSecond()) { os << ", "; ::printExpression(os, second, inductions); } os << ")"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; struct ModOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -1187,11 +1743,15 @@ struct ModOpInterface ::printExpression(os, castedOp.getY(), inductions); os << ")"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; struct NDimsOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -1201,11 +1761,15 @@ struct NDimsOpInterface ::printExpression(os, castedOp.getArray(), inductions); os << ")"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; struct OnesOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -1213,17 +1777,22 @@ struct OnesOpInterface os << "ones("; - llvm::interleaveComma(castedOp.getSizes(), os, [&](mlir::Value exp) { + llvm::interleaveComma(castedOp.getSizes(), os, [&](const mlir::Value exp) { ::printExpression(os, exp, inductions); }); os << ")"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; struct ProductOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -1233,11 +1802,15 @@ struct ProductOpInterface ::printExpression(os, castedOp.getArray(), inductions); os << ")"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; struct RemOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -1249,53 +1822,57 @@ struct RemOpInterface ::printExpression(os, castedOp.getY(), inductions); os << ")"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; struct SignOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "sign"); + } - os << "sign("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct SinOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "sin"); + } - os << "sin("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct SinhOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "sinh"); + } - os << "sinh("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct SizeOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -1304,102 +1881,108 @@ struct SizeOpInterface os << "size("; ::printExpression(os, castedOp.getArray(), inductions); - if (mlir::Value dimension = castedOp.getDimension()) { + if (const mlir::Value dimension = castedOp.getDimension()) { os << ", "; ::printExpression(os, dimension, inductions); } os << ")"; } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; struct SqrtOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "sqrt"); + } - os << "sqrt("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct SumOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "sum"); + } - os << "sum("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct SymmetricOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "symmetric"); + } - os << "symmetric("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct TanOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "tan"); + } - os << "tan("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct TanhOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "tanh"); + } - os << "tanh("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct TransposeOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "transpose"); + } - os << "transpose("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct ZerosOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { @@ -1407,308 +1990,162 @@ struct ZerosOpInterface os << "zeros("; - llvm::interleaveComma(castedOp.getSizes(), os, [&](mlir::Value exp) { + llvm::interleaveComma(castedOp.getSizes(), os, [&](const mlir::Value exp) { ::printExpression(os, exp, inductions); }); os << ")"; } -}; - -struct ReductionOpInterface - : public EquationExpressionOpInterface::ExternalModel { - void printExpression( - mlir::Operation *op, llvm::raw_ostream &os, - const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); - - // Add the inductions to the inductions map. - llvm::DenseMap expandedInductions(inductions); - auto inductionValues = castedOp.getInductions(); - - for (mlir::Value inductionValue : inductionValues) { - auto id = static_cast(expandedInductions.size()); - expandedInductions[inductionValue] = id; - } - - // Print the operation. - os << castedOp.getAction(); - os << "("; - - auto terminator = mlir::cast(castedOp.getBody()->getTerminator()); - - llvm::interleaveComma(terminator.getValues(), os, [&](mlir::Value exp) { - ::printExpression(os, exp, expandedInductions); - }); - - os << " for "; - auto iterables = castedOp.getIterables(); - - for (size_t i = 0, e = inductionValues.size(); i < e; ++i) { - if (i != 0) { - os << ", "; - } - - ::printExpression(os, inductionValues[i], expandedInductions); - } - - os << " in "; - - llvm::interleaveComma(iterables, os, [&](mlir::Value exp) { - ::printExpression(os, exp, expandedInductions); - }); - - os << ")"; - } - - uint64_t getNumOfExpressionElements(mlir::Operation *op) const { - auto castedOp = mlir::cast(op); - - auto terminator = mlir::cast(castedOp.getBody()->getTerminator()); - - return terminator.getValues().size(); - } - - mlir::Value getExpressionElement(mlir::Operation *op, - uint64_t element) const { - auto castedOp = mlir::cast(op); - - auto terminator = mlir::cast(castedOp.getBody()->getTerminator()); - - return terminator.getValues()[element]; - } - - llvm::SmallVector - getAdditionalInductions(mlir::Operation *op) const { - auto castedOp = mlir::cast(op); - llvm::SmallVector result; - auto inductions = castedOp.getInductions(); - result.append(inductions.begin(), inductions.end()); - return result; - } - - mlir::LogicalResult - mapAdditionalInductions(mlir::Operation *op, - AdditionalInductions &additionalInductions) const { - auto castedOp = mlir::cast(op); - - IndexSet indices; - llvm::SmallVector> inductionsMap; - - for (const auto &[induction, iterable] : - llvm::zip(castedOp.getInductions(), castedOp.getIterables())) { - auto constantOp = iterable.getDefiningOp(); - - if (!constantOp) { - return mlir::failure(); - } - - auto iterableAttr = constantOp.getValue(); - - if (auto rangeAttr = iterableAttr.dyn_cast()) { - assert(rangeAttr.getStep() == 1); - - auto lowerBound = - static_cast(rangeAttr.getLowerBound()); - - auto upperBound = - static_cast(rangeAttr.getUpperBound()); - - Range range(lowerBound, upperBound + 1); - indices = indices.append(IndexSet(MultidimensionalRange(range))); - - auto currentDimension = static_cast(indices.rank() - 1); - inductionsMap.emplace_back(induction, currentDimension); - - continue; - } - - if (auto rangeAttr = iterableAttr.dyn_cast()) { - assert(rangeAttr.getStep().convertToDouble() == 1); - - auto lowerBound = static_cast( - rangeAttr.getLowerBound().convertToDouble()); - - auto upperBound = static_cast( - rangeAttr.getUpperBound().convertToDouble()); - - Range range(lowerBound, upperBound); - indices = indices.append(IndexSet(MultidimensionalRange(range))); - - auto currentDimension = static_cast(indices.rank() - 1); - inductionsMap.emplace_back(induction, currentDimension); - - continue; - } - - return mlir::failure(); - } - - uint64_t iterationSpace = - additionalInductions.addIterationSpace(std::move(indices)); - - for (size_t i = 0, e = inductionsMap.size(); i < e; ++i) { - additionalInductions.addInductionVariable( - inductionsMap[i].first, iterationSpace, inductionsMap[i].second); - } - return mlir::success(); + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct DerOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { - auto castedOp = mlir::cast(op); + ::printUnaryExpression(op, os, inductions, "der"); + } - os << "der("; - ::printExpression(os, castedOp.getOperand(), inductions); - os << ")"; + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); } }; struct TimeOpInterface - : public EquationExpressionOpInterface::ExternalModel { - void printExpression( - mlir::Operation *op, llvm::raw_ostream &os, - const llvm::DenseMap &inductions) const { + : EquationExpressionOpInterface::ExternalModel { + void printExpression(mlir::Operation *, llvm::raw_ostream &os, + const llvm::DenseMap &) const { os << "time"; } -}; - -struct CallOpInterface - : public EquationExpressionOpInterface::ExternalModel { - void printExpression( - mlir::Operation *op, llvm::raw_ostream &os, - const llvm::DenseMap &inductions) const { - auto casted = mlir::cast(op); - os << casted.getCallee() << "("; - - llvm::interleaveComma(casted.getArgs(), os, [&](mlir::Value exp) { - ::printExpression(os, exp, inductions); - }); - os << ")"; + bool isEquivalent(mlir::Operation *, mlir::Operation *other, + mlir::SymbolTableCollection &) const { + return mlir::isa(other); } }; struct CastOpInterface - : public EquationExpressionOpInterface::ExternalModel { + : EquationExpressionOpInterface::ExternalModel { void printExpression( mlir::Operation *op, llvm::raw_ostream &os, const llvm::DenseMap &inductions) const { auto casted = mlir::cast(op); ::printExpression(os, casted.getValue(), inductions); } + + bool isEquivalent(mlir::Operation *op, mlir::Operation *other, + mlir::SymbolTableCollection &symbolTableCollection) const { + return ::isEquivalent(op, other, symbolTableCollection); + } }; } // namespace namespace mlir::bmodelica { void registerEquationExpressionOpInterfaceExternalModels( - mlir::DialectRegistry ®istry) { - registry.addExtension(+[](mlir::MLIRContext *context, + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, BaseModelicaDialect *dialect) { // clang-format off // Equation root. - EquationSidesOp::attachInterface<::EquationSidesOpInterface>(*context); + EquationSidesOp::attachInterface(*context); // Tensor operations. - TensorFromElementsOp::attachInterface<::TensorFromElementsOpInterface>(*context); - TensorBroadcastOp::attachInterface<::TensorBroadcastOpInterface>(*context); - TensorViewOp::attachInterface<::TensorViewOpInterface>(*context); - TensorExtractOp::attachInterface<::TensorExtractOpInterface>(*context); + TensorFromElementsOp::attachInterface(*context); + TensorBroadcastOp::attachInterface(*context); + TensorViewOp::attachInterface(*context); + TensorExtractOp::attachInterface(*context); // Array operations. - ArrayFromElementsOp::attachInterface<::ArrayFromElementsOpInterface>(*context); - ArrayBroadcastOp::attachInterface<::ArrayBroadcastOpInterface>(*context); - ArrayCastOp::attachInterface<::ArrayCastOpInterface>(*context); - DimOp::attachInterface<::DimOpInterface>(*context); - SubscriptionOp::attachInterface<::SubscriptionOpInterface>(*context); - LoadOp::attachInterface<::LoadOpInterface>(*context); + ArrayFromElementsOp::attachInterface(*context); + ArrayBroadcastOp::attachInterface(*context); + ArrayCastOp::attachInterface(*context); + DimOp::attachInterface(*context); + SubscriptionOp::attachInterface(*context); + LoadOp::attachInterface(*context); // Variable operations. - VariableGetOp::attachInterface<::VariableGetOpInterface>(*context); - GlobalVariableGetOp::attachInterface<::GlobalVariableGetOpInterface>(*context); + VariableGetOp::attachInterface(*context); + GlobalVariableGetOp::attachInterface(*context); // Math operations. - ConstantOp::attachInterface<::ConstantOpInterface>(*context); - NegateOp::attachInterface<::NegateOpInterface>(*context); - AddOp::attachInterface<::AddOpInterface>(*context); - AddEWOp::attachInterface<::AddEWOpInterface>(*context); - SubOp::attachInterface<::SubOpInterface>(*context); - SubEWOp::attachInterface<::SubEWOpInterface>(*context); - MulOp::attachInterface<::MulOpInterface>(*context); - MulEWOp::attachInterface<::MulEWOpInterface>(*context); - DivOp::attachInterface<::DivOpInterface>(*context); - DivEWOp::attachInterface<::DivEWOpInterface>(*context); - PowOp::attachInterface<::PowOpInterface>(*context); - PowEWOp::attachInterface<::PowEWOpInterface>(*context); + ConstantOp::attachInterface(*context); + NegateOp::attachInterface(*context); + AddOp::attachInterface(*context); + AddEWOp::attachInterface(*context); + SubOp::attachInterface(*context); + SubEWOp::attachInterface(*context); + MulOp::attachInterface(*context); + MulEWOp::attachInterface(*context); + DivOp::attachInterface(*context); + DivEWOp::attachInterface(*context); + PowOp::attachInterface(*context); + PowEWOp::attachInterface(*context); // Comparison operations. - EqOp::attachInterface<::EqOpInterface>(*context); - NotEqOp::attachInterface<::NotEqOpInterface>(*context); - GtOp::attachInterface<::GtOpInterface>(*context); - GteOp::attachInterface<::GteOpInterface>(*context); - LtOp::attachInterface<::LtOpInterface>(*context); - LteOp::attachInterface<::LteOpInterface>(*context); + EqOp::attachInterface(*context); + NotEqOp::attachInterface(*context); + GtOp::attachInterface(*context); + GteOp::attachInterface(*context); + LtOp::attachInterface(*context); + LteOp::attachInterface(*context); // Logic operations. - NotOp::attachInterface<::NotOpInterface>(*context); - AndOp::attachInterface<::AndOpInterface>(*context); - OrOp::attachInterface<::OrOpInterface>(*context); - SelectOp::attachInterface<::SelectOpInterface>(*context); + NotOp::attachInterface(*context); + AndOp::attachInterface(*context); + OrOp::attachInterface(*context); + SelectOp::attachInterface(*context); // Built-in operations - AbsOp::attachInterface<::AbsOpInterface>(*context); - AcosOp::attachInterface<::AcosOpInterface>(*context); - AsinOp::attachInterface<::AsinOpInterface>(*context); - AtanOp::attachInterface<::AtanOpInterface>(*context); - Atan2Op::attachInterface<::Atan2OpInterface>(*context); - CeilOp::attachInterface<::CeilOpInterface>(*context); - CosOp::attachInterface<::CosOpInterface>(*context); - CoshOp::attachInterface<::CoshOpInterface>(*context); - DiagonalOp::attachInterface<::DiagonalOpInterface>(*context); - DivTruncOp::attachInterface<::DivTruncOpInterface>(*context); - ExpOp::attachInterface<::ExpOpInterface>(*context); - FillOp::attachInterface<::FillOpInterface>(*context); - FloorOp::attachInterface<::FloorOpInterface>(*context); - IdentityOp::attachInterface<::IdentityOpInterface>(*context); - IntegerOp::attachInterface<::IntegerOpInterface>(*context); - LinspaceOp::attachInterface<::LinspaceOpInterface>(*context); - LogOp::attachInterface<::LogOpInterface>(*context); - Log10Op::attachInterface<::Log10OpInterface>(*context); - MaxOp::attachInterface<::MaxOpInterface>(*context); - MinOp::attachInterface<::MinOpInterface>(*context); - ModOp::attachInterface<::ModOpInterface>(*context); - NDimsOp::attachInterface<::NDimsOpInterface>(*context); - OnesOp::attachInterface<::OnesOpInterface>(*context); - ProductOp::attachInterface<::ProductOpInterface>(*context); - RemOp::attachInterface<::RemOpInterface>(*context); - SignOp::attachInterface<::SignOpInterface>(*context); - SinOp::attachInterface<::SinOpInterface>(*context); - SinhOp::attachInterface<::SinhOpInterface>(*context); - SizeOp::attachInterface<::SizeOpInterface>(*context); - SqrtOp::attachInterface<::SqrtOpInterface>(*context); - SumOp::attachInterface<::SumOpInterface>(*context); - SymmetricOp::attachInterface<::SymmetricOpInterface>(*context); - TanOp::attachInterface<::TanOpInterface>(*context); - TanhOp::attachInterface<::TanhOpInterface>(*context); - TransposeOp::attachInterface<::TransposeOpInterface>(*context); - ZerosOp::attachInterface<::ZerosOpInterface>(*context); + AbsOp::attachInterface(*context); + AcosOp::attachInterface(*context); + AsinOp::attachInterface(*context); + AtanOp::attachInterface(*context); + Atan2Op::attachInterface(*context); + CeilOp::attachInterface(*context); + CosOp::attachInterface(*context); + CoshOp::attachInterface(*context); + DiagonalOp::attachInterface(*context); + DivTruncOp::attachInterface(*context); + ExpOp::attachInterface(*context); + FillOp::attachInterface(*context); + FloorOp::attachInterface(*context); + IdentityOp::attachInterface(*context); + IntegerOp::attachInterface(*context); + LinspaceOp::attachInterface(*context); + LogOp::attachInterface(*context); + Log10Op::attachInterface(*context); + MaxOp::attachInterface(*context); + MinOp::attachInterface(*context); + ModOp::attachInterface(*context); + NDimsOp::attachInterface(*context); + OnesOp::attachInterface(*context); + ProductOp::attachInterface(*context); + RemOp::attachInterface(*context); + SignOp::attachInterface(*context); + SinOp::attachInterface(*context); + SinhOp::attachInterface(*context); + SizeOp::attachInterface(*context); + SqrtOp::attachInterface(*context); + SumOp::attachInterface(*context); + SymmetricOp::attachInterface(*context); + TanOp::attachInterface(*context); + TanhOp::attachInterface(*context); + TransposeOp::attachInterface(*context); + ZerosOp::attachInterface(*context); // Various operations. - ReductionOp::attachInterface<::ReductionOpInterface>(*context); - DerOp::attachInterface<::DerOpInterface>(*context); + ReductionOp::attachInterface(*context); + RangeOp::attachInterface(*context); + DerOp::attachInterface(*context); TimeOp::attachInterface(*context); CallOp::attachInterface<::CallOpInterface>(*context); - CastOp::attachInterface<::CastOpInterface>(*context); + CastOp::attachInterface(*context); // clang-format on }); } diff --git a/lib/Frontend/FrontendActions.cpp b/lib/Frontend/FrontendActions.cpp index 0cd989b5f..cb55ddb72 100644 --- a/lib/Frontend/FrontendActions.cpp +++ b/lib/Frontend/FrontendActions.cpp @@ -808,6 +808,9 @@ void CodeGenAction::buildMLIRLoweringPipeline(mlir::PassManager &pm) { // Lift the equations. pm.addPass(mlir::bmodelica::createEquationTemplatesCreationPass()); + // Eliminate repeated function calls + pm.addPass(mlir::bmodelica::createCallCSEPass()); + // Materialize the derivatives. pm.addPass(mlir::bmodelica::createDerivativesMaterializationPass()); diff --git a/test/Dialect/BaseModelica/Transforms/CallCSE/basic-cse.mlir b/test/Dialect/BaseModelica/Transforms/CallCSE/basic-cse.mlir new file mode 100644 index 000000000..b6b7fe826 --- /dev/null +++ b/test/Dialect/BaseModelica/Transforms/CallCSE/basic-cse.mlir @@ -0,0 +1,69 @@ +// RUN: modelica-opt %s --call-cse | FileCheck %s + +// CHECK-LABEL: @Test +module @Test { + bmodelica.function @foo { + bmodelica.variable @x : !bmodelica.variable + bmodelica.variable @y : !bmodelica.variable + + bmodelica.algorithm { + %0 = bmodelica.variable_get @x : f64 + bmodelica.variable_set @y, %0 : f64 + } + } + + // CHECK-LABEL: @M + bmodelica.model @M { + // CHECK-NEXT: bmodelica.variable @[[CSE:.*]] : !bmodelica.variable + // CHECK-NEXT: bmodelica.variable @x + // CHECK-NEXT: bmodelica.variable @y + bmodelica.variable @x : !bmodelica.variable + bmodelica.variable @y : !bmodelica.variable + + %t0 = bmodelica.equation_template inductions = [] { + %0 = bmodelica.variable_get @x : f64 + %lhs = bmodelica.equation_side %0 : tuple + %1 = bmodelica.constant 1.0 : f64 + // CHECK: %[[RES0:.*]] = bmodelica.variable_get @x + // CHECK-NEXT: %[[LHS:.*]] = bmodelica.equation_side %[[RES0]] + + // CHECK: %[[RES1:.*]] = bmodelica.variable_get @[[CSE]] + // CHECK-NEXT: %[[RHS:.*]] = bmodelica.equation_side %[[RES1]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS]], %[[RHS]] + %2 = bmodelica.call @foo(%1) : (f64) -> f64 + %rhs = bmodelica.equation_side %2 : tuple + bmodelica.equation_sides %lhs, %rhs : tuple, tuple + } + + %t1 = bmodelica.equation_template inductions = [] { + %0 = bmodelica.variable_get @y : f64 + %lhs = bmodelica.equation_side %0 : tuple + %1 = bmodelica.constant 1.0 : f64 + // CHECK: %[[RES0:.*]] = bmodelica.variable_get @y + // CHECK-NEXT: %[[LHS:.*]] = bmodelica.equation_side %[[RES0]] + + // CHECK: %[[RES1:.*]] = bmodelica.variable_get @[[CSE]] + // CHECK-NEXT: %[[RHS:.*]] = bmodelica.equation_side %[[RES1]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS]], %[[RHS]] + %2 = bmodelica.call @foo(%1) : (f64) -> f64 + %rhs = bmodelica.equation_side %2 : tuple + bmodelica.equation_sides %lhs, %rhs : tuple, tuple + } + + bmodelica.dynamic { + bmodelica.equation_instance %t0 : !bmodelica.equation + bmodelica.equation_instance %t1 : !bmodelica.equation + } + + // CHECK: %[[TEMPLATE:.*]] = bmodelica.equation_template inductions = [] + // CHECK-NEXT: %[[RES2:.*]] = bmodelica.variable_get @[[CSE]] + // CHECK-NEXT: %[[LHS2:.*]] = bmodelica.equation_side %[[RES2]] + // CHECK-NEXT: %[[RES3:.*]] = bmodelica.constant 1 + // CHECK-NEXT: %[[RES4:.*]] = bmodelica.call @foo(%[[RES3]]) + // CHECK-NEXT: %[[RHS:.*]] = bmodelica.equation_side %[[RES4]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS2]], %[[RHS]] + + // CHECK: bmodelica.dynamic + // CHECK-NEXT: bmodelica.equation_instance %[[TEMPLATE]] + } +} \ No newline at end of file diff --git a/test/Dialect/BaseModelica/Transforms/CallCSE/complex-expression.mlir b/test/Dialect/BaseModelica/Transforms/CallCSE/complex-expression.mlir new file mode 100644 index 000000000..b1d9a7df1 --- /dev/null +++ b/test/Dialect/BaseModelica/Transforms/CallCSE/complex-expression.mlir @@ -0,0 +1,84 @@ +// RUN: modelica-opt %s --call-cse | FileCheck %s + +// CHECK-LABEL: @Test +module @Test { + bmodelica.function @foo { + bmodelica.variable @x : !bmodelica.variable + bmodelica.variable @y : !bmodelica.variable + + bmodelica.algorithm { + %0 = bmodelica.variable_get @x : f64 + bmodelica.variable_set @y, %0 : f64 + } + } + + // CHECK-LABEL: @M + bmodelica.model @M { + // CHECK-NEXT: bmodelica.variable @[[CSE:.*]] : !bmodelica.variable + // CHECK-NEXT: bmodelica.variable @x + // CHECK-NEXT: bmodelica.variable @y + bmodelica.variable @x : !bmodelica.variable + bmodelica.variable @y : !bmodelica.variable + + %t0 = bmodelica.equation_template inductions = [] { + %0 = bmodelica.variable_get @x : f64 + %lhs = bmodelica.equation_side %0 : tuple + %c1 = bmodelica.constant 1.0 : f64 + %c2 = bmodelica.constant 2.0 : f64 + %c3 = bmodelica.constant 3.0 : f64 + %sub = bmodelica.sub %c2, %c1 : (f64, f64) -> f64 + %cos = bmodelica.cos %sub : f64 -> f64 + %pow = bmodelica.pow %cos, %c3 : (f64, f64) -> f64 + // CHECK: %[[RES0:.*]] = bmodelica.variable_get @x + // CHECK-NEXT: %[[LHS:.*]] = bmodelica.equation_side %[[RES0]] + + // CHECK: %[[RES1:.*]] = bmodelica.variable_get @[[CSE]] + // CHECK-NEXT: %[[RHS:.*]] = bmodelica.equation_side %[[RES1]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS]], %[[RHS]] + %2 = bmodelica.call @foo(%pow) : (f64) -> f64 + %rhs = bmodelica.equation_side %2 : tuple + bmodelica.equation_sides %lhs, %rhs : tuple, tuple + } + + %t1 = bmodelica.equation_template inductions = [] { + %0 = bmodelica.variable_get @y : f64 + %lhs = bmodelica.equation_side %0 : tuple + %c1 = bmodelica.constant 1.0 : f64 + %c2 = bmodelica.constant 2.0 : f64 + %c3 = bmodelica.constant 3.0 : f64 + %sub = bmodelica.sub %c2, %c1 : (f64, f64) -> f64 + %cos = bmodelica.cos %sub : f64 -> f64 + %pow = bmodelica.pow %cos, %c3 : (f64, f64) -> f64 + // CHECK: %[[RES0:.*]] = bmodelica.variable_get @y + // CHECK-NEXT: %[[LHS:.*]] = bmodelica.equation_side %[[RES0]] + + // CHECK: %[[RES1:.*]] = bmodelica.variable_get @[[CSE]] + // CHECK-NEXT: %[[RHS:.*]] = bmodelica.equation_side %[[RES1]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS]], %[[RHS]] + %2 = bmodelica.call @foo(%pow) : (f64) -> f64 + %rhs = bmodelica.equation_side %2 : tuple + bmodelica.equation_sides %lhs, %rhs : tuple, tuple + } + + bmodelica.dynamic { + bmodelica.equation_instance %t0 : !bmodelica.equation + bmodelica.equation_instance %t1 : !bmodelica.equation + } + + // CHECK: %[[TEMPLATE:.*]] = bmodelica.equation_template inductions = [] + // CHECK-NEXT: %[[RES2:.*]] = bmodelica.variable_get @[[CSE]] + // CHECK-NEXT: %[[LHS2:.*]] = bmodelica.equation_side %[[RES2]] + // CHECK-DAG: %[[C1:.*]] = bmodelica.constant 1 + // CHECK-DAG: %[[C2:.*]] = bmodelica.constant 2 + // CHECK-DAG: %[[C3:.*]] = bmodelica.constant 3 + // CHECK-DAG: %[[SUB:.*]] = bmodelica.sub %[[C2]], %[[C1]] + // CHECK-DAG: %[[COS:.*]] = bmodelica.cos %[[SUB]] + // CHECK-DAG: %[[POW:.*]] = bmodelica.pow %[[COS]], %[[C3]] + // CHECK-NEXT: %[[RES4:.*]] = bmodelica.call @foo(%[[POW]]) + // CHECK-NEXT: %[[RHS:.*]] = bmodelica.equation_side %[[RES4]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS2]], %[[RHS]] + + // CHECK: bmodelica.dynamic + // CHECK-NEXT: bmodelica.equation_instance %[[TEMPLATE]] + } +} \ No newline at end of file diff --git a/test/Dialect/BaseModelica/Transforms/CallCSE/idempotent-configurations.mlir b/test/Dialect/BaseModelica/Transforms/CallCSE/idempotent-configurations.mlir new file mode 100644 index 000000000..e88b960bb --- /dev/null +++ b/test/Dialect/BaseModelica/Transforms/CallCSE/idempotent-configurations.mlir @@ -0,0 +1,95 @@ +// RUN: modelica-opt %s --split-input-file --call-cse | FileCheck %s + +// CHECK-LABEL: @SingleCall +module @SingleCall { + bmodelica.function @foo { + bmodelica.variable @x : !bmodelica.variable + bmodelica.variable @y : !bmodelica.variable + + bmodelica.algorithm { + %0 = bmodelica.variable_get @x : f64 + bmodelica.variable_set @y, %0 : f64 + } + } + + bmodelica.model @M { + bmodelica.variable @x : !bmodelica.variable + + // CHECK: bmodelica.equation_template + // CHECK-NEXT: %[[R0:.*]] = bmodelica.variable_get @x + // CHECK-NEXT: %[[LHS:.*]] = bmodelica.equation_side %[[R0]] + // CHECK-NEXT: %[[R1:.*]] = bmodelica.constant 1 + // CHECK-NEXT: %[[R2:.*]] = bmodelica.call @foo(%[[R1]]) + // CHECK-NEXT: %[[RHS:.*]] = bmodelica.equation_side %[[R2]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS]], %[[RHS]] + %t0 = bmodelica.equation_template inductions = [] { + %0 = bmodelica.variable_get @x : f64 + %lhs = bmodelica.equation_side %0 : tuple + %1 = bmodelica.constant 1.0 : f64 + %2 = bmodelica.call @foo(%1) : (f64) -> f64 + %rhs = bmodelica.equation_side %2 : tuple + bmodelica.equation_sides %lhs, %rhs : tuple, tuple + } + + bmodelica.dynamic { + bmodelica.equation_instance %t0 : !bmodelica.equation + } + } +} + +// ----- + +// CHECK-LABEL: @InductionVariables +module @InductionVariables { + bmodelica.function @foo { + bmodelica.variable @x : !bmodelica.variable + bmodelica.variable @y : !bmodelica.variable + + bmodelica.algorithm { + %0 = bmodelica.variable_get @x : f64 + bmodelica.variable_set @y, %0 : f64 + } + } + + bmodelica.model @M { + bmodelica.variable @x : !bmodelica.variable + bmodelica.variable @y : !bmodelica.variable + + // CHECK: bmodelica.equation_template + // CHECK-NEXT: %[[R0:.*]] = bmodelica.variable_get @x + // CHECK-NEXT: %[[LHS:.*]] = bmodelica.equation_side %[[R0]] + // CHECK-NEXT: %[[R1:.*]] = bmodelica.constant 1 + // CHECK-NEXT: %[[R2:.*]] = bmodelica.call @foo(%[[R1]]) + // CHECK-NEXT: %[[RHS:.*]] = bmodelica.equation_side %[[R2]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS]], %[[RHS]] + %t0 = bmodelica.equation_template inductions = [%i1] { + %0 = bmodelica.variable_get @x : f64 + %lhs = bmodelica.equation_side %0 : tuple + %1 = bmodelica.constant 1.0 : f64 + %2 = bmodelica.call @foo(%1) : (f64) -> f64 + %rhs = bmodelica.equation_side %2 : tuple + bmodelica.equation_sides %lhs, %rhs : tuple, tuple + } + + // CHECK: bmodelica.equation_template + // CHECK-NEXT: %[[R0:.*]] = bmodelica.variable_get @x + // CHECK-NEXT: %[[LHS:.*]] = bmodelica.equation_side %[[R0]] + // CHECK-NEXT: %[[R1:.*]] = bmodelica.constant 1 + // CHECK-NEXT: %[[R2:.*]] = bmodelica.call @foo(%[[R1]]) + // CHECK-NEXT: %[[RHS:.*]] = bmodelica.equation_side %[[R2]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS]], %[[RHS]] + %t1 = bmodelica.equation_template inductions = [%i1] { + %0 = bmodelica.variable_get @x : f64 + %lhs = bmodelica.equation_side %0 : tuple + %1 = bmodelica.constant 1.0 : f64 + %2 = bmodelica.call @foo(%1) : (f64) -> f64 + %rhs = bmodelica.equation_side %2 : tuple + bmodelica.equation_sides %lhs, %rhs : tuple, tuple + } + + bmodelica.dynamic { + bmodelica.equation_instance %t0 {indices = #modeling} : !bmodelica.equation + bmodelica.equation_instance %t1 {indices = #modeling} : !bmodelica.equation + } + } +} \ No newline at end of file diff --git a/test/Dialect/BaseModelica/Transforms/CallCSE/multiple-calls.mlir b/test/Dialect/BaseModelica/Transforms/CallCSE/multiple-calls.mlir new file mode 100644 index 000000000..ebbb84f0f --- /dev/null +++ b/test/Dialect/BaseModelica/Transforms/CallCSE/multiple-calls.mlir @@ -0,0 +1,119 @@ +// RUN: modelica-opt %s --call-cse | FileCheck %s + +// CHECK-LABEL: @Test +module @Test { + bmodelica.function @foo { + bmodelica.variable @x : !bmodelica.variable + bmodelica.variable @y : !bmodelica.variable + + bmodelica.algorithm { + %0 = bmodelica.variable_get @x : f64 + bmodelica.variable_set @y, %0 : f64 + } + } + + // CHECK-LABEL: @M + bmodelica.model @M { + // CHECK-NEXT: bmodelica.variable @[[CSE1:.*]] : !bmodelica.variable + // CHECK-NEXT: bmodelica.variable @[[CSE0:.*]] : !bmodelica.variable + // CHECK-NEXT: bmodelica.variable @x + // CHECK-NEXT: bmodelica.variable @y + // CHECK-NEXT: bmodelica.variable @z + bmodelica.variable @x : !bmodelica.variable + bmodelica.variable @y : !bmodelica.variable + bmodelica.variable @z : !bmodelica.variable + + %t0 = bmodelica.equation_template inductions = [] { + %0 = bmodelica.variable_get @x : f64 + %lhs = bmodelica.equation_side %0 : tuple + %c1 = bmodelica.constant 1.0 : f64 + %c2 = bmodelica.constant 2.0 : f64 + %c3 = bmodelica.constant 3.0 : f64 + %add = bmodelica.add %c1, %c2 : (f64, f64) -> f64 + %sub = bmodelica.sub %c2, %c3 : (f64, f64) -> f64 + + // CHECK: %[[X:.*]] = bmodelica.variable_get @x + // CHECK-NEXT: %[[LHS:.*]] = bmodelica.equation_side %[[X]] + + // CHECK: %[[C0:.*]] = bmodelica.variable_get @[[CSE0]] + // CHECK-NEXT: %[[C1:.*]] = bmodelica.variable_get @[[CSE1]] + // CHECK-NEXT: %[[RES:.*]] = bmodelica.add %[[C0]], %[[C1]] + // CHECK-NEXT: %[[RHS:.*]] = bmodelica.equation_side %[[RES]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS]], %[[RHS]] + + %2 = bmodelica.call @foo(%add) : (f64) -> f64 + %3 = bmodelica.call @foo(%sub) : (f64) -> f64 + %4 = bmodelica.add %2, %3 : (f64, f64) -> f64 + %rhs = bmodelica.equation_side %4 : tuple + bmodelica.equation_sides %lhs, %rhs : tuple, tuple + } + + %t1 = bmodelica.equation_template inductions = [] { + %0 = bmodelica.variable_get @y : f64 + %lhs = bmodelica.equation_side %0 : tuple + %c1 = bmodelica.constant 1.0 : f64 + %c2 = bmodelica.constant 2.0 : f64 + %add = bmodelica.add %c1, %c2 : (f64, f64) -> f64 + + // CHECK: %[[Y:.*]] = bmodelica.variable_get @y + // CHECK-NEXT: %[[LHS:.*]] = bmodelica.equation_side %[[Y]] + + // CHECK: %[[C0:.*]] = bmodelica.variable_get @[[CSE0]] + // CHECK-NEXT: %[[RHS:.*]] = bmodelica.equation_side %[[C0]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS]], %[[RHS]] + + %2 = bmodelica.call @foo(%add) : (f64) -> f64 + %rhs = bmodelica.equation_side %2 : tuple + bmodelica.equation_sides %lhs, %rhs : tuple, tuple + } + + %t2 = bmodelica.equation_template inductions = [] { + %0 = bmodelica.variable_get @z : f64 + %lhs = bmodelica.equation_side %0 : tuple + %c2 = bmodelica.constant 2.0 : f64 + %c3 = bmodelica.constant 3.0 : f64 + %sub = bmodelica.sub %c2, %c3 : (f64, f64) -> f64 + + // CHECK: %[[Z:.*]] = bmodelica.variable_get @z + // CHECK-NEXT: %[[LHS:.*]] = bmodelica.equation_side %[[Z]] + + // CHECK: %[[C1:.*]] = bmodelica.variable_get @[[CSE1]] + // CHECK-NEXT: %[[RHS:.*]] = bmodelica.equation_side %[[C1]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS]], %[[RHS]] + + %3 = bmodelica.call @foo(%sub) : (f64) -> f64 + %rhs = bmodelica.equation_side %3 : tuple + bmodelica.equation_sides %lhs, %rhs : tuple, tuple + } + + bmodelica.dynamic { + bmodelica.equation_instance %t0 : !bmodelica.equation + bmodelica.equation_instance %t1 : !bmodelica.equation + bmodelica.equation_instance %t2 : !bmodelica.equation + } + + // CHECK: %[[T0:.*]] = bmodelica.equation_template inductions = [] + // CHECK-NEXT: %[[RES0:.*]] = bmodelica.variable_get @[[CSE0]] + // CHECK-NEXT: %[[LHS:.*]] = bmodelica.equation_side %[[RES0]] + // CHECK-DAG: %[[RES1:.*]] = bmodelica.constant 1 + // CHECK-DAG: %[[RES2:.*]] = bmodelica.constant 2 + // CHECK-DAG: %[[RES3:.*]] = bmodelica.add %[[RES1]], %[[RES2]] + // CHECK-NEXT: %[[RES4:.*]] = bmodelica.call @foo(%[[RES3]]) + // CHECK-NEXT: %[[RHS:.*]] = bmodelica.equation_side %[[RES4]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS]], %[[RHS]] + + // CHECK: %[[T1:.*]] = bmodelica.equation_template inductions = [] + // CHECK-NEXT: %[[RES0:.*]] = bmodelica.variable_get @[[CSE1]] + // CHECK-NEXT: %[[LHS:.*]] = bmodelica.equation_side %[[RES0]] + // CHECK-DAG: %[[RES1:.*]] = bmodelica.constant 2 + // CHECK-DAG: %[[RES2:.*]] = bmodelica.constant 3 + // CHECK-DAG: %[[RES3:.*]] = bmodelica.sub %[[RES1]], %[[RES2]] + // CHECK-NEXT: %[[RES4:.*]] = bmodelica.call @foo(%[[RES3]]) + // CHECK-NEXT: %[[RHS:.*]] = bmodelica.equation_side %[[RES4]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS]], %[[RHS]] + + // CHECK: bmodelica.dynamic + // CHECK-DAG: bmodelica.equation_instance %[[T0]] + // CHECK-DAG: bmodelica.equation_instance %[[T1]] + } +} diff --git a/test/Dialect/BaseModelica/Transforms/CallCSE/multiple-return-values.mlir b/test/Dialect/BaseModelica/Transforms/CallCSE/multiple-return-values.mlir new file mode 100644 index 000000000..73ba8ede8 --- /dev/null +++ b/test/Dialect/BaseModelica/Transforms/CallCSE/multiple-return-values.mlir @@ -0,0 +1,82 @@ +//RUN: modelica-opt %s --call-cse | FileCheck %s + +// CHECK-LABEL: @Test +module @Test { + bmodelica.function @foo { + bmodelica.variable @in : !bmodelica.variable + bmodelica.variable @out1 : !bmodelica.variable + bmodelica.variable @out2 : !bmodelica.variable + + bmodelica.algorithm { + %0 = bmodelica.variable_get @in : f64 + %1 = bmodelica.constant 1.0 : f32 + bmodelica.variable_set @out1, %0 : f64 + bmodelica.variable_set @out2, %1 : f32 + } + } + + // CHECK: bmodelica.model + bmodelica.model @M { + // CHECK-NEXT: bmodelica.variable @[[CSE1:.*]] : !bmodelica.variable + // CHECK-NEXT: bmodelica.variable @[[CSE0:.*]] : !bmodelica.variable + // CHECK-NEXT: bmodelica.variable @x + // CHECK-NEXT: bmodelica.variable @y + bmodelica.variable @x : !bmodelica.variable + bmodelica.variable @y : !bmodelica.variable + + %t0 = bmodelica.equation_template inductions = [] { + %0 = bmodelica.variable_get @x : f64 + %lhs = bmodelica.equation_side %0 : tuple + %1 = bmodelica.constant 1.0 : f64 + // CHECK: %[[RES0:.*]] = bmodelica.variable_get @x + // CHECK-NEXT: %[[LHS:.*]] = bmodelica.equation_side %[[RES0]] + + // CHECK: %[[RES1:.*]] = bmodelica.variable_get @[[CSE0]] + // CHECK: %[[RHS:.*]] = bmodelica.equation_side %[[RES1]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS]], %[[RHS]] + %2:2 = bmodelica.call @foo(%1) : (f64) -> (f64, f32) + %rhs = bmodelica.equation_side %2#0 : tuple + bmodelica.equation_sides %lhs, %rhs : tuple, tuple + } + + %t1 = bmodelica.equation_template inductions = [] { + %0 = bmodelica.variable_get @y : f32 + %lhs = bmodelica.equation_side %0 : tuple + %1 = bmodelica.constant 1.0 : f64 + // CHECK: %[[RES0:.*]] = bmodelica.variable_get @y + // CHECK-NEXT: %[[LHS:.*]] = bmodelica.equation_side %[[RES0]] + + // CHECK: %[[RES1:.*]] = bmodelica.variable_get @[[CSE1]] + // CHECK: %[[RHS:.*]] = bmodelica.equation_side %[[RES1]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS]], %[[RHS]] + %2:2 = bmodelica.call @foo(%1) : (f64) -> (f64, f32) + %rhs = bmodelica.equation_side %2#1 : tuple + bmodelica.equation_sides %lhs, %rhs : tuple, tuple + } + + bmodelica.dynamic { + bmodelica.equation_instance %t0 : !bmodelica.equation + bmodelica.equation_instance %t1 : !bmodelica.equation + } + + // CHECK: %[[TEMPLATE1:.*]] = bmodelica.equation_template inductions = [] + // CHECK-NEXT: %[[RES2:.*]] = bmodelica.variable_get @[[CSE0]] + // CHECK-NEXT: %[[LHS2:.*]] = bmodelica.equation_side %[[RES2]] + // CHECK-NEXT: %[[RES3:.*]] = bmodelica.constant 1 + // CHECK-NEXT: %[[RES4:.*]]:2 = bmodelica.call @foo(%[[RES3]]) + // CHECK-NEXT: %[[RHS0:.*]] = bmodelica.equation_side %[[RES4]]#0 + // CHECK-NEXT: bmodelica.equation_sides %[[LHS2]], %[[RHS0]] + + // CHECK: %[[TEMPLATE2:.*]] = bmodelica.equation_template inductions = [] + // CHECK-NEXT: %[[RES2:.*]] = bmodelica.variable_get @[[CSE1]] + // CHECK-NEXT: %[[LHS2:.*]] = bmodelica.equation_side %[[RES2]] + // CHECK-NEXT: %[[RES3:.*]] = bmodelica.constant 1 + // CHECK-NEXT: %[[RES4:.*]]:2 = bmodelica.call @foo(%[[RES3]]) + // CHECK-NEXT: %[[RHS0:.*]] = bmodelica.equation_side %[[RES4]]#1 + // CHECK-NEXT: bmodelica.equation_sides %[[LHS2]], %[[RHS0]] + + // CHECK: bmodelica.dynamic + // CHECK-NEXT: bmodelica.equation_instance %[[TEMPLATE1]] + // CHECK-NEXT: bmodelica.equation_instance %[[TEMPLATE2]] + } +} diff --git a/test/Dialect/BaseModelica/Transforms/CallCSE/op-with-region.mlir b/test/Dialect/BaseModelica/Transforms/CallCSE/op-with-region.mlir new file mode 100644 index 000000000..22fe9f528 --- /dev/null +++ b/test/Dialect/BaseModelica/Transforms/CallCSE/op-with-region.mlir @@ -0,0 +1,101 @@ +// RUN: modelica-opt %s --call-cse | FileCheck %s + +// CHECK-LABEL: @Test +module @Test { + bmodelica.function @foo { + bmodelica.variable @x : !bmodelica.variable + bmodelica.variable @y : !bmodelica.variable + + bmodelica.algorithm { + %0 = bmodelica.variable_get @x : f64 + bmodelica.variable_set @y, %0 : f64 + } + } + + // CHECK-LABEL: @M + bmodelica.model @M { + // CHECK: bmodelica.variable @[[CSE:.*]] : !bmodelica.variable + bmodelica.variable @x : !bmodelica.variable + bmodelica.variable @y : !bmodelica.variable + + %t0 = bmodelica.equation_template inductions = [] { + %0 = bmodelica.variable_get @x : f64 + %lhs = bmodelica.equation_side %0 : tuple + + %1 = bmodelica.constant 2.0 : f64 + + %lower = bmodelica.constant 5 : index + %upper = bmodelica.constant 10 : index + %step = bmodelica.constant 1 : index + %range = bmodelica.range %lower, %upper, %step : (index, index, index) -> !bmodelica + + %red = bmodelica.reduction "add", iterables = [%range], inductions = [] { + bmodelica.yield %1 : f64 + } : (!bmodelica) -> f64 + + // CHECK: %[[RES0:.*]] = bmodelica.variable_get @x + // CHECK-NEXT: %[[LHS:.*]] = bmodelica.equation_side %[[RES0]] + + // CHECK: %[[RES1:.*]] = bmodelica.variable_get @[[CSE]] + // CHECK-NEXT: %[[RHS:.*]] = bmodelica.equation_side %[[RES1]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS]], %[[RHS]] + + %2 = bmodelica.call @foo(%red) : (f64) -> f64 + + %rhs = bmodelica.equation_side %2 : tuple + bmodelica.equation_sides %lhs, %rhs : tuple, tuple + } + + %t1 = bmodelica.equation_template inductions = [] { + %0 = bmodelica.variable_get @y : f64 + %lhs = bmodelica.equation_side %0 : tuple + + %1 = bmodelica.constant 2.0 : f64 + + %lower = bmodelica.constant 5 : index + %upper = bmodelica.constant 10 : index + %step = bmodelica.constant 1 : index + %range = bmodelica.range %lower, %upper, %step : (index, index, index) -> !bmodelica + + %red = bmodelica.reduction "add", iterables = [%range], inductions = [] { + bmodelica.yield %1 : f64 + } : (!bmodelica) -> f64 + + // CHECK: %[[RES0:.*]] = bmodelica.variable_get @y + // CHECK-NEXT: %[[LHS:.*]] = bmodelica.equation_side %[[RES0]] + + // CHECK: %[[RES1:.*]] = bmodelica.variable_get @[[CSE]] + // CHECK-NEXT: %[[RHS:.*]] = bmodelica.equation_side %[[RES1]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS]], %[[RHS]] + + %2 = bmodelica.call @foo(%red) : (f64) -> f64 + + %rhs = bmodelica.equation_side %2 : tuple + bmodelica.equation_sides %lhs, %rhs : tuple, tuple + } + + bmodelica.dynamic { + bmodelica.equation_instance %t0 : !bmodelica.equation + bmodelica.equation_instance %t1 : !bmodelica.equation + } + + // CHECK: %[[TEMPLATE:.*]] = bmodelica.equation_template inductions = [] + // CHECK-NEXT: %[[RES:.*]] = bmodelica.variable_get @[[CSE]] + // CHECK-NEXT: %[[LHS:.*]] = bmodelica.equation_side %[[RES]] + // CHECK-DAG: %[[c:.*]] = bmodelica.constant 2 + // CHECK-DAG: %[[lower:.*]] = bmodelica.constant 5 + // CHECK-DAG: %[[upper:.*]] = bmodelica.constant 10 + // CHECK-DAG: %[[step:.*]] = bmodelica.constant 1 + // CHECK-DAG: %[[range:.*]] = bmodelica.range %[[lower]], %[[upper]], %[[step]] + + // CHECK-NEXT: %[[red:.*]] = bmodelica.reduction "add", iterables = [%[[range]]] + // CHECK-NEXT: bmodelica.yield %[[c]] + + // CHECK: %[[RES1:.*]] = bmodelica.call @foo(%[[red]]) + // CHECK-NEXT: %[[RHS:.*]] = bmodelica.equation_side %[[RES1]] + // CHECK-NEXT: bmodelica.equation_sides %[[LHS]], %[[RHS]] + + // CHECK: bmodelica.dynamic + // CHECK-NEXT: bmodelica.equation_instance %[[TEMPLATE]] + } +}