Skip to content

Commit d6211a3

Browse files
committed
Add calls CSE pass
1 parent d622fe6 commit d6211a3

File tree

10 files changed

+1773
-1787
lines changed

10 files changed

+1773
-1787
lines changed

include/public/marco/Dialect/BaseModelica/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,11 @@ def CallCSEPass
318318
"mlir::bmodelica::BaseModelicaDialect"
319319
];
320320

321+
let statistics = [
322+
Statistic<"newCSEVariables", "new-cse-variables", "How many CSE variables have been created">,
323+
Statistic<"replacedCalls", "replaced-calls", "How many calls were replaced by a CSE variable usage">
324+
];
325+
321326
let constructor = "mlir::bmodelica::createCallCSEPass()";
322327
}
323328

lib/Dialect/BaseModelica/Transforms/CallCSE.cpp

Lines changed: 194 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,67 +9,230 @@ namespace mlir::bmodelica {
99
using namespace ::mlir::bmodelica;
1010

1111
namespace {
12-
class CallCSEPass : public impl::CallCSEPassBase<CallCSEPass> {
12+
class CallCSEPass final : public impl::CallCSEPassBase<CallCSEPass> {
1313
public:
1414
using CallCSEPassBase<CallCSEPass>::CallCSEPassBase;
1515

1616
void runOnOperation() override;
1717

1818
private:
1919
mlir::LogicalResult processModelOp(ModelOp modelOp);
20+
21+
/// Replace all calls in the equivalence group with gets to a generated
22+
/// variable. The variable will be driven by an equation derived from the
23+
/// first call in the group.
24+
///
25+
/// One variable and driver equation will be emitted per result,
26+
/// if the call is to a function with multiple result values.
27+
void emitCse(llvm::SmallVectorImpl<CallOp> &equivalenceGroup, ModelOp modelOp,
28+
DynamicOp dynamicOp, mlir::SymbolTable &symbolTable,
29+
mlir::RewriterBase &rewriter);
2030
};
21-
} // namespace
2231

23-
void CallCSEPass::runOnOperation() {
24-
llvm::SmallVector<ModelOp, 1> modelOps;
32+
/// Get all call operations in the model.
33+
void collectCallOps(ModelOp modelOp, llvm::SmallVectorImpl<CallOp> &callOps) {
34+
llvm::SmallVector<EquationInstanceOp> dynamicEquationOps;
35+
modelOp.collectMainEquations(dynamicEquationOps);
2536

26-
walkClasses(getOperation(), [&](mlir::Operation *op) {
27-
if (auto modelOp = mlir::dyn_cast<ModelOp>(op)) {
28-
modelOps.push_back(modelOp);
37+
llvm::DenseSet<EquationTemplateOp> visitedTemplateOps;
38+
for (EquationInstanceOp equationOp : dynamicEquationOps) {
39+
EquationTemplateOp templateOp = equationOp.getTemplate();
40+
if (!templateOp.getInductionVariables().empty() ||
41+
visitedTemplateOps.contains(templateOp)) {
42+
continue;
2943
}
30-
});
31-
32-
if (mlir::failed(mlir::failableParallelForEach(
33-
&getContext(), modelOps, [&](mlir::Operation *op) {
34-
return processModelOp(mlir::cast<ModelOp>(op));
35-
}))) {
36-
return signalPassFailure();
44+
visitedTemplateOps.insert(templateOp);
45+
templateOp->walk([&](CallOp callOp) { callOps.push_back(callOp); });
3746
}
3847
}
3948

40-
mlir::LogicalResult CallCSEPass::processModelOp(ModelOp modelOp) {
41-
mlir::IRRewriter rewriter(modelOp);
42-
llvm::SmallVector<EquationInstanceOp> initialEquationOps;
43-
llvm::SmallVector<EquationInstanceOp> dynamicEquationOps;
49+
/// Partition the list of call operations into groups given by
50+
/// EquationExpressionOpInterface::isEquivalent
51+
void buildCallEquivalenceGroups(
52+
llvm::SmallVectorImpl<CallOp> &callOps,
53+
llvm::SmallVectorImpl<llvm::SmallVector<CallOp>> &callEquivalenceGroups) {
54+
mlir::SymbolTableCollection symbolTableCollection;
55+
llvm::SmallVector<llvm::SmallVector<CallOp>> tmpCallEquivalenceGroups;
56+
57+
for (CallOp callOp : callOps) {
58+
auto callExpression =
59+
mlir::cast<EquationExpressionOpInterface>(callOp.getOperation());
60+
61+
llvm::SmallVector<CallOp> *equivalenceGroup = find_if(
62+
tmpCallEquivalenceGroups, [&](llvm::SmallVector<CallOp> &group) {
63+
assert(!group.empty() && "groups should never be empty");
64+
return callExpression.isEquivalent(group.front(),
65+
symbolTableCollection);
66+
});
67+
68+
if (equivalenceGroup != tmpCallEquivalenceGroups.end()) {
69+
// Add equivalent call to existing group
70+
equivalenceGroup->push_back(callOp);
71+
} else {
72+
// Create new equivalence group
73+
tmpCallEquivalenceGroups.push_back({callOp});
74+
}
75+
}
4476

45-
modelOp.collectInitialEquations(initialEquationOps);
46-
modelOp.collectMainEquations(dynamicEquationOps);
77+
for (llvm::SmallVector<CallOp> &group : tmpCallEquivalenceGroups) {
78+
if (group.size() > 1) {
79+
callEquivalenceGroups.push_back(std::move(group));
80+
}
81+
}
82+
}
4783

48-
llvm::DenseSet<EquationTemplateOp> templateOps;
84+
/// Clone `op` and its def-use chain, returning the cloned version of `op`.
85+
mlir::Operation *cloneDefUseChain(mlir::Operation *op,
86+
mlir::RewriterBase &rewriter) {
87+
llvm::SmallVector<mlir::Operation *> toClone;
88+
llvm::SmallVector<mlir::Operation *> worklist({op});
89+
90+
// DFS through the def-use chain of `op`
91+
while (!worklist.empty()) {
92+
mlir::Operation *current = worklist.back();
93+
worklist.pop_back();
94+
toClone.push_back(current);
95+
for (mlir::Value operand : current->getOperands()) {
96+
if (mlir::Operation *defOp = operand.getDefiningOp()) {
97+
worklist.push_back(defOp);
98+
}
99+
}
100+
// Find the dependencies on operations not defined within the regions of
101+
// `current`. No need to do this if it is isolated from above.
102+
if (!current->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>()) {
103+
// Find all uses of values defined outside `current`.
104+
current->walk([&](mlir::Operation *childOp) {
105+
// Walk includes current, so skip it.
106+
if (childOp == current) {
107+
return;
108+
}
109+
for (mlir::Value operand : childOp->getOperands()) {
110+
// If an operand is defined in the same scope as `current`,
111+
// i.e. the equation template scope, add it to the worklist.
112+
mlir::Operation *definingOp = operand.getDefiningOp();
113+
if (definingOp && definingOp->getBlock() == current->getBlock()) {
114+
worklist.push_back(definingOp);
115+
}
116+
}
117+
});
118+
}
119+
}
49120

50-
for (EquationInstanceOp equationOp : initialEquationOps) {
51-
templateOps.insert(equationOp.getTemplate());
121+
mlir::IRMapping mapping;
122+
mlir::Operation *root = nullptr;
123+
for (mlir::Operation *opToClone : llvm::reverse(toClone)) {
124+
// Skip repeated dependencies on the same operation
125+
if (mapping.contains(opToClone)) {
126+
continue;
127+
}
128+
root = rewriter.clone(*opToClone, mapping);
52129
}
130+
return root;
131+
}
53132

54-
for (EquationInstanceOp equationOp : dynamicEquationOps) {
55-
templateOps.insert(equationOp.getTemplate());
133+
void CallCSEPass::emitCse(llvm::SmallVectorImpl<CallOp> &equivalenceGroup,
134+
ModelOp modelOp, DynamicOp dynamicOp,
135+
mlir::SymbolTable &symbolTable,
136+
mlir::RewriterBase &rewriter) {
137+
assert(!equivalenceGroup.empty() && "equivalenceGroup cannot be empty");
138+
CallOp representative = equivalenceGroup.front();
139+
const mlir::Location loc = representative.getLoc();
140+
141+
// Emit one variable per function result
142+
llvm::SmallVector<VariableOp> cseVariables;
143+
for (auto result : llvm::enumerate(representative.getResults())) {
144+
rewriter.setInsertionPointToStart(modelOp.getBody());
145+
// Emit cse variable
146+
auto cseVariable = rewriter.create<VariableOp>(
147+
loc, "_cse", VariableType::wrap(result.value().getType()));
148+
symbolTable.insert(cseVariable);
149+
cseVariables.push_back(cseVariable);
150+
151+
// Emit driver equation
152+
rewriter.setInsertionPoint(dynamicOp);
153+
auto equationTemplateOp = rewriter.create<EquationTemplateOp>(loc);
154+
rewriter.setInsertionPointToStart(equationTemplateOp.createBody(0));
155+
auto lhsOp = rewriter.create<EquationSideOp>(
156+
loc, rewriter.create<VariableGetOp>(loc, cseVariable)->getResults());
157+
auto rhsOp = rewriter.create<EquationSideOp>(
158+
loc,
159+
cloneDefUseChain(representative, rewriter)->getResult(result.index()));
160+
rewriter.create<EquationSidesOp>(loc, lhsOp, rhsOp);
161+
162+
// Add driver equation to dynamic operation
163+
rewriter.setInsertionPointToEnd(dynamicOp.getBody());
164+
rewriter.create<EquationInstanceOp>(rewriter.getUnknownLoc(),
165+
equationTemplateOp);
166+
}
167+
168+
// Replace calls with get(s) to CSE variable(s)
169+
for (auto &callOp : equivalenceGroup) {
170+
rewriter.setInsertionPoint(callOp);
171+
172+
llvm::SmallVector<mlir::Value> results;
173+
for (VariableOp cseVariable : cseVariables) {
174+
results.push_back(
175+
rewriter.create<VariableGetOp>(loc, cseVariable).getResult());
176+
}
177+
rewriter.replaceOp(callOp, results);
56178
}
57179

180+
this->replacedCalls += equivalenceGroup.size();
181+
++this->newCSEVariables;
182+
}
183+
184+
mlir::LogicalResult CallCSEPass::processModelOp(ModelOp modelOp) {
185+
mlir::IRRewriter rewriter(modelOp);
186+
mlir::SymbolTable symbolTable(modelOp);
187+
58188
llvm::SmallVector<CallOp> callOps;
189+
collectCallOps(modelOp, callOps);
190+
191+
llvm::SmallVector<llvm::SmallVector<CallOp>> callEquivalenceGroups;
192+
buildCallEquivalenceGroups(callOps, callEquivalenceGroups);
59193

60-
for (EquationTemplateOp templateOp : templateOps) {
61-
templateOp->walk([&](CallOp callOp) {
62-
callOps.push_back(callOp);
63-
});
194+
if (callEquivalenceGroups.empty()) {
195+
return mlir::success();
64196
}
65197

66-
198+
rewriter.setInsertionPointToEnd(modelOp.getBody());
199+
DynamicOp dynamicOp = rewriter.create<DynamicOp>(rewriter.getUnknownLoc());
200+
rewriter.createBlock(&dynamicOp.getRegion());
201+
202+
for (llvm::SmallVector<CallOp> &equivalenceGroup : callEquivalenceGroups) {
203+
// Only emit CSEs that will lead to an equivalent, or lower amount of calls
204+
if (equivalenceGroup.size() >= equivalenceGroup.front().getNumResults()) {
205+
emitCse(equivalenceGroup, modelOp, dynamicOp, symbolTable, rewriter);
206+
}
207+
}
208+
209+
if (dynamicOp.getBody()->empty()) {
210+
rewriter.eraseOp(dynamicOp);
211+
}
67212

68213
return mlir::success();
69214
}
215+
} // namespace
216+
217+
void CallCSEPass::runOnOperation() {
218+
llvm::SmallVector<ModelOp, 1> modelOps;
219+
220+
walkClasses(getOperation(), [&](mlir::Operation *op) {
221+
if (auto modelOp = mlir::dyn_cast<ModelOp>(op)) {
222+
modelOps.push_back(modelOp);
223+
}
224+
});
225+
226+
if (mlir::failed(mlir::failableParallelForEach(
227+
&getContext(), modelOps, [&](mlir::Operation *op) {
228+
return processModelOp(mlir::cast<ModelOp>(op));
229+
}))) {
230+
return signalPassFailure();
231+
}
232+
}
70233

71234
namespace mlir::bmodelica {
72235
std::unique_ptr<mlir::Pass> createCallCSEPass() {
73236
return std::make_unique<CallCSEPass>();
74237
}
75-
} // namespace mlir::bmodelica
238+
} // namespace mlir::bmodelica

0 commit comments

Comments
 (0)