@@ -9,67 +9,230 @@ namespace mlir::bmodelica {
99using namespace ::mlir::bmodelica;
1010
1111namespace {
12- class CallCSEPass : public impl ::CallCSEPassBase<CallCSEPass> {
12+ class CallCSEPass final : public impl::CallCSEPassBase<CallCSEPass> {
1313public:
1414 using CallCSEPassBase<CallCSEPass>::CallCSEPassBase;
1515
1616 void runOnOperation () override ;
1717
1818private:
1919 mlir::LogicalResult processModelOp (ModelOp modelOp);
20+
21+ // / Replace all calls in the equivalence group with gets to a generated
22+ // / variable. The variable will be driven by an equation derived from the
23+ // / first call in the group.
24+ // /
25+ // / One variable and driver equation will be emitted per result,
26+ // / if the call is to a function with multiple result values.
27+ void emitCse (llvm::SmallVectorImpl<CallOp> &equivalenceGroup, ModelOp modelOp,
28+ DynamicOp dynamicOp, mlir::SymbolTable &symbolTable,
29+ mlir::RewriterBase &rewriter);
2030};
21- } // namespace
2231
23- void CallCSEPass::runOnOperation () {
24- llvm::SmallVector<ModelOp, 1 > modelOps;
32+ // / Get all call operations in the model.
33+ void collectCallOps (ModelOp modelOp, llvm::SmallVectorImpl<CallOp> &callOps) {
34+ llvm::SmallVector<EquationInstanceOp> dynamicEquationOps;
35+ modelOp.collectMainEquations (dynamicEquationOps);
2536
26- walkClasses (getOperation (), [&](mlir::Operation *op) {
27- if (auto modelOp = mlir::dyn_cast<ModelOp>(op)) {
28- modelOps.push_back (modelOp);
37+ llvm::DenseSet<EquationTemplateOp> visitedTemplateOps;
38+ for (EquationInstanceOp equationOp : dynamicEquationOps) {
39+ EquationTemplateOp templateOp = equationOp.getTemplate ();
40+ if (!templateOp.getInductionVariables ().empty () ||
41+ visitedTemplateOps.contains (templateOp)) {
42+ continue ;
2943 }
30- });
31-
32- if (mlir::failed (mlir::failableParallelForEach (
33- &getContext (), modelOps, [&](mlir::Operation *op) {
34- return processModelOp (mlir::cast<ModelOp>(op));
35- }))) {
36- return signalPassFailure ();
44+ visitedTemplateOps.insert (templateOp);
45+ templateOp->walk ([&](CallOp callOp) { callOps.push_back (callOp); });
3746 }
3847}
3948
40- mlir::LogicalResult CallCSEPass::processModelOp (ModelOp modelOp) {
41- mlir::IRRewriter rewriter (modelOp);
42- llvm::SmallVector<EquationInstanceOp> initialEquationOps;
43- llvm::SmallVector<EquationInstanceOp> dynamicEquationOps;
49+ // / Partition the list of call operations into groups given by
50+ // / EquationExpressionOpInterface::isEquivalent
51+ void buildCallEquivalenceGroups (
52+ llvm::SmallVectorImpl<CallOp> &callOps,
53+ llvm::SmallVectorImpl<llvm::SmallVector<CallOp>> &callEquivalenceGroups) {
54+ mlir::SymbolTableCollection symbolTableCollection;
55+ llvm::SmallVector<llvm::SmallVector<CallOp>> tmpCallEquivalenceGroups;
56+
57+ for (CallOp callOp : callOps) {
58+ auto callExpression =
59+ mlir::cast<EquationExpressionOpInterface>(callOp.getOperation ());
60+
61+ llvm::SmallVector<CallOp> *equivalenceGroup = find_if (
62+ tmpCallEquivalenceGroups, [&](llvm::SmallVector<CallOp> &group) {
63+ assert (!group.empty () && " groups should never be empty" );
64+ return callExpression.isEquivalent (group.front (),
65+ symbolTableCollection);
66+ });
67+
68+ if (equivalenceGroup != tmpCallEquivalenceGroups.end ()) {
69+ // Add equivalent call to existing group
70+ equivalenceGroup->push_back (callOp);
71+ } else {
72+ // Create new equivalence group
73+ tmpCallEquivalenceGroups.push_back ({callOp});
74+ }
75+ }
4476
45- modelOp.collectInitialEquations (initialEquationOps);
46- modelOp.collectMainEquations (dynamicEquationOps);
77+ for (llvm::SmallVector<CallOp> &group : tmpCallEquivalenceGroups) {
78+ if (group.size () > 1 ) {
79+ callEquivalenceGroups.push_back (std::move (group));
80+ }
81+ }
82+ }
4783
48- llvm::DenseSet<EquationTemplateOp> templateOps;
84+ // / Clone `op` and its def-use chain, returning the cloned version of `op`.
85+ mlir::Operation *cloneDefUseChain (mlir::Operation *op,
86+ mlir::RewriterBase &rewriter) {
87+ llvm::SmallVector<mlir::Operation *> toClone;
88+ llvm::SmallVector<mlir::Operation *> worklist ({op});
89+
90+ // DFS through the def-use chain of `op`
91+ while (!worklist.empty ()) {
92+ mlir::Operation *current = worklist.back ();
93+ worklist.pop_back ();
94+ toClone.push_back (current);
95+ for (mlir::Value operand : current->getOperands ()) {
96+ if (mlir::Operation *defOp = operand.getDefiningOp ()) {
97+ worklist.push_back (defOp);
98+ }
99+ }
100+ // Find the dependencies on operations not defined within the regions of
101+ // `current`. No need to do this if it is isolated from above.
102+ if (!current->hasTrait <mlir::OpTrait::IsIsolatedFromAbove>()) {
103+ // Find all uses of values defined outside `current`.
104+ current->walk ([&](mlir::Operation *childOp) {
105+ // Walk includes current, so skip it.
106+ if (childOp == current) {
107+ return ;
108+ }
109+ for (mlir::Value operand : childOp->getOperands ()) {
110+ // If an operand is defined in the same scope as `current`,
111+ // i.e. the equation template scope, add it to the worklist.
112+ mlir::Operation *definingOp = operand.getDefiningOp ();
113+ if (definingOp && definingOp->getBlock () == current->getBlock ()) {
114+ worklist.push_back (definingOp);
115+ }
116+ }
117+ });
118+ }
119+ }
49120
50- for (EquationInstanceOp equationOp : initialEquationOps) {
51- templateOps.insert (equationOp.getTemplate ());
121+ mlir::IRMapping mapping;
122+ mlir::Operation *root = nullptr ;
123+ for (mlir::Operation *opToClone : llvm::reverse (toClone)) {
124+ // Skip repeated dependencies on the same operation
125+ if (mapping.contains (opToClone)) {
126+ continue ;
127+ }
128+ root = rewriter.clone (*opToClone, mapping);
52129 }
130+ return root;
131+ }
53132
54- for (EquationInstanceOp equationOp : dynamicEquationOps) {
55- templateOps.insert (equationOp.getTemplate ());
133+ void CallCSEPass::emitCse (llvm::SmallVectorImpl<CallOp> &equivalenceGroup,
134+ ModelOp modelOp, DynamicOp dynamicOp,
135+ mlir::SymbolTable &symbolTable,
136+ mlir::RewriterBase &rewriter) {
137+ assert (!equivalenceGroup.empty () && " equivalenceGroup cannot be empty" );
138+ CallOp representative = equivalenceGroup.front ();
139+ const mlir::Location loc = representative.getLoc ();
140+
141+ // Emit one variable per function result
142+ llvm::SmallVector<VariableOp> cseVariables;
143+ for (auto result : llvm::enumerate (representative.getResults ())) {
144+ rewriter.setInsertionPointToStart (modelOp.getBody ());
145+ // Emit cse variable
146+ auto cseVariable = rewriter.create <VariableOp>(
147+ loc, " _cse" , VariableType::wrap (result.value ().getType ()));
148+ symbolTable.insert (cseVariable);
149+ cseVariables.push_back (cseVariable);
150+
151+ // Emit driver equation
152+ rewriter.setInsertionPoint (dynamicOp);
153+ auto equationTemplateOp = rewriter.create <EquationTemplateOp>(loc);
154+ rewriter.setInsertionPointToStart (equationTemplateOp.createBody (0 ));
155+ auto lhsOp = rewriter.create <EquationSideOp>(
156+ loc, rewriter.create <VariableGetOp>(loc, cseVariable)->getResults ());
157+ auto rhsOp = rewriter.create <EquationSideOp>(
158+ loc,
159+ cloneDefUseChain (representative, rewriter)->getResult (result.index ()));
160+ rewriter.create <EquationSidesOp>(loc, lhsOp, rhsOp);
161+
162+ // Add driver equation to dynamic operation
163+ rewriter.setInsertionPointToEnd (dynamicOp.getBody ());
164+ rewriter.create <EquationInstanceOp>(rewriter.getUnknownLoc (),
165+ equationTemplateOp);
166+ }
167+
168+ // Replace calls with get(s) to CSE variable(s)
169+ for (auto &callOp : equivalenceGroup) {
170+ rewriter.setInsertionPoint (callOp);
171+
172+ llvm::SmallVector<mlir::Value> results;
173+ for (VariableOp cseVariable : cseVariables) {
174+ results.push_back (
175+ rewriter.create <VariableGetOp>(loc, cseVariable).getResult ());
176+ }
177+ rewriter.replaceOp (callOp, results);
56178 }
57179
180+ this ->replacedCalls += equivalenceGroup.size ();
181+ ++this ->newCSEVariables ;
182+ }
183+
184+ mlir::LogicalResult CallCSEPass::processModelOp (ModelOp modelOp) {
185+ mlir::IRRewriter rewriter (modelOp);
186+ mlir::SymbolTable symbolTable (modelOp);
187+
58188 llvm::SmallVector<CallOp> callOps;
189+ collectCallOps (modelOp, callOps);
190+
191+ llvm::SmallVector<llvm::SmallVector<CallOp>> callEquivalenceGroups;
192+ buildCallEquivalenceGroups (callOps, callEquivalenceGroups);
59193
60- for (EquationTemplateOp templateOp : templateOps) {
61- templateOp->walk ([&](CallOp callOp) {
62- callOps.push_back (callOp);
63- });
194+ if (callEquivalenceGroups.empty ()) {
195+ return mlir::success ();
64196 }
65197
66-
198+ rewriter.setInsertionPointToEnd (modelOp.getBody ());
199+ DynamicOp dynamicOp = rewriter.create <DynamicOp>(rewriter.getUnknownLoc ());
200+ rewriter.createBlock (&dynamicOp.getRegion ());
201+
202+ for (llvm::SmallVector<CallOp> &equivalenceGroup : callEquivalenceGroups) {
203+ // Only emit CSEs that will lead to an equivalent, or lower amount of calls
204+ if (equivalenceGroup.size () >= equivalenceGroup.front ().getNumResults ()) {
205+ emitCse (equivalenceGroup, modelOp, dynamicOp, symbolTable, rewriter);
206+ }
207+ }
208+
209+ if (dynamicOp.getBody ()->empty ()) {
210+ rewriter.eraseOp (dynamicOp);
211+ }
67212
68213 return mlir::success ();
69214}
215+ } // namespace
216+
217+ void CallCSEPass::runOnOperation () {
218+ llvm::SmallVector<ModelOp, 1 > modelOps;
219+
220+ walkClasses (getOperation (), [&](mlir::Operation *op) {
221+ if (auto modelOp = mlir::dyn_cast<ModelOp>(op)) {
222+ modelOps.push_back (modelOp);
223+ }
224+ });
225+
226+ if (mlir::failed (mlir::failableParallelForEach (
227+ &getContext (), modelOps, [&](mlir::Operation *op) {
228+ return processModelOp (mlir::cast<ModelOp>(op));
229+ }))) {
230+ return signalPassFailure ();
231+ }
232+ }
70233
71234namespace mlir ::bmodelica {
72235std::unique_ptr<mlir::Pass> createCallCSEPass () {
73236 return std::make_unique<CallCSEPass>();
74237}
75- } // namespace mlir::bmodelica
238+ } // namespace mlir::bmodelica
0 commit comments