Skip to content

Commit 684b92c

Browse files
authoredMar 10, 2025··
add procedure call inlining (#345)
1 parent 8f8addc commit 684b92c

File tree

5 files changed

+197
-1
lines changed

5 files changed

+197
-1
lines changed
 

‎src/main/scala/ir/Program.scala

+8
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,14 @@ class Block private (
610610
jump.deParent()
611611
}
612612

613+
def createBlockAfter(suffix: String): Block = {
614+
val nb = Block(label + suffix)
615+
parent.addBlock(nb)
616+
val ojump = jump
617+
replaceJump(GoTo(nb))
618+
nb.replaceJump(ojump)
619+
}
620+
613621
def createBlockBetween(b2: Block, label: String = "_goto_"): Block = {
614622
require(nextBlocks.toSet.contains(b2))
615623
val b1 = this

‎src/main/scala/ir/cilvisitor/CILVisitor.scala

+1
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,4 @@ def visit_prog(v: CILVisitor, b: Program): Program = CILVisitorImpl(v).visit_pro
217217
def visit_stmt(v: CILVisitor, e: Statement): List[Statement] = CILVisitorImpl(v).visit_stmt(e)
218218
def visit_jump(v: CILVisitor, e: Jump): Jump = CILVisitorImpl(v).visit_jump(e)
219219
def visit_expr(v: CILVisitor, e: Expr): Expr = CILVisitorImpl(v).visit_expr(e)
220+
def visit_rvar(v: CILVisitor, e: Variable): Variable = CILVisitorImpl(v).visit_rvar(e)
+163
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
package ir.transforms
2+
import ir.*
3+
import util.Logger
4+
import ir.cilvisitor.*
5+
import ir.dsl.IRToDSL.*
6+
import ir.dsl.*
7+
8+
import scala.collection.mutable
9+
10+
object Counter {
11+
var id = 0
12+
def next() = {
13+
id += 1
14+
id
15+
}
16+
}
17+
18+
def memoise[K, V](f: K => V): K => V = {
19+
val rn = mutable.Map[K, V]()
20+
21+
def fun(arg: K): V = {
22+
if (!rn.contains(arg)) {
23+
rn(arg) = f(arg)
24+
}
25+
26+
rn(arg)
27+
}
28+
fun
29+
}
30+
31+
def renameBlock(s: String): String = {
32+
s + "_" + (Counter.next())
33+
}
34+
35+
class VarRenamer(proc: Procedure) extends CILVisitor {
36+
37+
def doRename(v: Variable): Variable = v match {
38+
case l: LocalVar if l.name.endsWith("_in") => {
39+
val name = l.name.stripSuffix("_in")
40+
proc.getFreshSSAVar(name, l.getType)
41+
}
42+
case l: LocalVar if l.index != 0 =>
43+
proc.getFreshSSAVar(l.varName, l.getType)
44+
case _ => v
45+
}
46+
val memoed = memoise(doRename)
47+
48+
def rename(v: Variable) = {
49+
memoed(v)
50+
}
51+
52+
override def vlvar(v: Variable) = {
53+
ChangeTo(rename(v))
54+
}
55+
56+
override def vrvar(v: Variable) = {
57+
ChangeTo(rename(v))
58+
}
59+
60+
}
61+
62+
def convertJumpRenaming(blockName: String => String, varName: CILVisitor, x: Jump): EventuallyJump = x match {
63+
case GoTo(targs, label) => EventuallyGoto(targs.map(t => DelayNameResolve(blockName(t.label))).toList, label)
64+
case Unreachable(label) => EventuallyUnreachable(label)
65+
case Return(label, out) =>
66+
EventuallyReturn(
67+
out.toList.map { case (v: Variable, e: Expr) =>
68+
((v.name, visit_expr(varName, e)))
69+
}.toArray,
70+
label
71+
)
72+
}
73+
74+
def keyToString[T](varRenamer: CILVisitor)(x: (Variable, Expr)): (String, Expr) =
75+
(x(0).name, visit_expr(varRenamer, x(1)))
76+
77+
def convertStatementRenaming(varRenamer: CILVisitor)(x: Statement): EventuallyStatement = x match {
78+
case DirectCall(targ, outs, actuals, label) =>
79+
directCall(
80+
outs.toArray.map(v => (v(0).name, visit_rvar(varRenamer, v(1)))),
81+
targ.name,
82+
actuals.toArray.map(keyToString(varRenamer)): _*
83+
)
84+
case IndirectCall(targ, label) => indirectCall(visit_rvar(varRenamer, targ))
85+
case x: NonCallStatement =>
86+
CloneableStatement(visit_stmt(varRenamer, cloneStatement(x)).head.asInstanceOf[NonCallStatement])
87+
}
88+
89+
def convertBlockRenaming(varRenamer: CILVisitor, blockName: String => String)(x: Block) = {
90+
EventuallyBlock(
91+
blockName(x.label),
92+
x.statements.toArray.map(convertStatementRenaming(varRenamer)),
93+
convertJumpRenaming(blockName, varRenamer, x.jump),
94+
x.address
95+
)
96+
}
97+
98+
/**
99+
* Inline a procedure call to the calling procedure;
100+
*
101+
* - maintains param form
102+
* - maintain dsa form if both procedures are in dsa form
103+
*
104+
* require invariant.SingleCallBlockEnd
105+
*/
106+
def inlineCall(prog: Program, c: DirectCall): Unit = {
107+
require(c.target.entryBlock.isDefined)
108+
require(c.target.returnBlock.isDefined)
109+
110+
val proc = c.parent.parent
111+
val block = c.parent
112+
val target = c.target
113+
// rename ssa variables to maintain DSA in the new procedure
114+
// rename block labels to avoid creating conflicts if inlining the same function multiple times
115+
val varRenamer = VarRenamer(proc)
116+
val blockRenamer: String => String = memoise(renameBlock)
117+
118+
val entry = target.entryBlock.get
119+
val returnBlock = target.returnBlock.get
120+
121+
// the internal blocks to the target
122+
val internalBlocks =
123+
(target.blocks.toSet -- Seq(entry, returnBlock)).map(convertBlockRenaming(varRenamer, blockRenamer))
124+
125+
// clone the target procedure and apply the renaming
126+
val (entryTempBlock, entryResolver) = convertBlockRenaming(varRenamer, blockRenamer)(entry).makeResolver
127+
val eventuallyReturnBlock = convertBlockRenaming(varRenamer, blockRenamer)(returnBlock)
128+
val afterCallBlock = block.createBlockAfter("_inlineret")
129+
proc.addBlock(entryTempBlock)
130+
block.replaceJump(GoTo(entryTempBlock))
131+
132+
// replace return in target with a jump to the aftercall block
133+
val (returnTemp, resolveReturnBlock) = eventuallyReturnBlock.copy(j = unreachable).makeResolver
134+
proc.addBlock(returnTemp)
135+
136+
// resolve internal call blocks
137+
val resolvers = internalBlocks.map(_.makeResolver)
138+
resolvers.foreach { case (block, _) => proc.addBlock(block) }
139+
resolvers.foreach { case (_, resolve) => resolve(prog, proc) }
140+
141+
// remove original call statement
142+
block.statements.remove(c)
143+
// link the inlined blocks to the call block and the aftercall block
144+
entryResolver(prog, proc)
145+
block.replaceJump(GoTo(entryTempBlock))
146+
resolveReturnBlock(prog, proc)
147+
returnTemp.replaceJump(GoTo(afterCallBlock))
148+
149+
// assign the actual parameters in the caller to the renamed formal parameters in the entry block
150+
val targetReturnValues: Map[String, Expr] = eventuallyReturnBlock.j match {
151+
case r: EventuallyReturn => r.params.toMap
152+
case _ => throw Exception("returnblock should have a return statement")
153+
}
154+
val outAssignments = c.outParams.map { case (formal: LocalVar, lvar: Variable) =>
155+
LocalAssign(lvar, targetReturnValues(formal.name))
156+
}
157+
val inAssignments = c.actualParams.map { case (formal: LocalVar, actual: Expr) =>
158+
LocalAssign(visit_rvar(varRenamer, formal), actual)
159+
}
160+
afterCallBlock.statements.prependAll(outAssignments)
161+
entryTempBlock.statements.prependAll(inAssignments)
162+
163+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package ir.transforms
2+
import ir.Program
3+
4+
def inlinePLTLaunchpad(prog: Program) = {
5+
for (p <- prog.procedures) {
6+
7+
val candidate =
8+
(p.blocks.size <= 4)
9+
&& p.calls.size == 1
10+
&& p.calls.forall(_.isExternal.contains(true))
11+
&& p.procName.startsWith("FUN")
12+
&& !p.calls.contains(p)
13+
14+
if (candidate) {
15+
p.incomingCalls().foreach { call =>
16+
inlineCall(prog, call)
17+
}
18+
}
19+
}
20+
21+
applyRPO(prog)
22+
23+
}

‎src/main/scala/util/RunUtils.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -725,9 +725,10 @@ object RunUtils {
725725
Logger.info("[!] Simplify :: DynamicSingleAssignment")
726726
DebugDumpIRLogger.writeToFile(File("il-before-dsa.il"), pp_prog(program))
727727

728-
// transforms.DynamicSingleAssignment.applyTransform(program, liveVars)
729728
transforms.OnePassDSA().applyTransform(program)
730729

730+
transforms.inlinePLTLaunchpad(ctx.program)
731+
731732
transforms.removeEmptyBlocks(program)
732733

733734
AnalysisResultDotLogger.writeToFile(

0 commit comments

Comments
 (0)
Please sign in to comment.