Skip to content

Commit 9b8fc6e

Browse files
authored
Interpreter: callProcedure with return values (#352)
* make interpreter return value * make default interpret path return value when outparams set * cleanup * remove ReturnTo(DirectCall) continuation * trace.empty
1 parent b1eb086 commit 9b8fc6e

File tree

6 files changed

+219
-65
lines changed

6 files changed

+219
-65
lines changed

src/main/scala/ir/eval/InterpretBasilIR.scala

+153-41
Original file line numberDiff line numberDiff line change
@@ -284,26 +284,42 @@ case object Eval {
284284
case _ => State.setError(Errored(s"not byte $c"))
285285
}
286286
} yield (res)
287+
288+
}
289+
290+
enum InterpretReturn {
291+
case ReturnVal(outs: Map[LocalVar, Literal])
292+
case Void
293+
}
294+
295+
class BASILInterpreter[S](f: Effects[S, InterpreterError])
296+
extends Interpreter[S, InterpretReturn, InterpreterError](f) {
297+
298+
def interpretOne: util.functional.State[S, Next[InterpretReturn], InterpreterError] =
299+
InterpFuns.interpretContinuation(f)
300+
287301
}
288302

289-
class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S, InterpreterError](f) {
303+
object InterpFuns {
290304

291-
def interpretOne: State[S, Boolean, InterpreterError] = {
305+
def interpretContinuation[S, T <: Effects[S, InterpreterError]](
306+
f: T
307+
): State[S, Next[InterpretReturn], InterpreterError] = {
292308
val next = for {
293309
next <- f.getNext
294310
_ <- State.pure(Logger.debug(s"$next"))
295-
r: Boolean <- (next match {
296-
case Intrinsic(tgt) => LibcIntrinsic.intrinsics(tgt)(f).map(_ => true)
297-
case Run(c: Statement) => interpretStatement(f)(c).map(_ => true)
298-
case ReturnTo(c) => interpretReturn(f)(c).map(_ => true)
299-
case Run(c: Jump) => interpretJump(f)(c).map(_ => true)
300-
case Stopped() => State.pure(false)
301-
case ErrorStop(e) => State.pure(false)
311+
r: Next[InterpretReturn] <- (next match {
312+
case Intrinsic(tgt) => LibcIntrinsic.intrinsics(tgt)(f).map(_ => Next.Continue)
313+
case Run(c: Statement) => interpretStatement(f)(c).map(_ => Next.Continue)
314+
case ReturnFrom(c) => evaluateReturn(f)(c).map(v => Next.Stop(InterpretReturn.ReturnVal(v)))
315+
case Run(c: Jump) => interpretJump(f)(c).map(_ => Next.Continue)
316+
case Stopped() => State.pure(Next.Stop(InterpretReturn.Void))
317+
case ErrorStop(e) => State.pure(Next.Stop(InterpretReturn.Void))
302318
})
303319
} yield (r)
304320

305321
next.flatMapE((e: InterpreterError) => {
306-
f.setNext(ErrorStop(e)).map(_ => false)
322+
f.setNext(ErrorStop(e)).map(_ => Next.Stop(InterpretReturn.Void))
307323
})
308324
}
309325

@@ -350,19 +366,22 @@ class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S
350366
}
351367
}
352368

353-
def interpretReturn[S, T <: Effects[S, InterpreterError]](f: T)(s: DirectCall): State[S, Unit, InterpreterError] = {
369+
/**
370+
* Evaluates the formal out params and returns a map totheir values.
371+
*/
372+
def evaluateReturn[S, T <: Effects[S, InterpreterError]](
373+
f: T
374+
)(returnFrom: ProcSig): State[S, Map[LocalVar, Literal], InterpreterError] = {
354375
for {
355376
outs <- State.mapM(
356-
((bindout: (LocalVar, Variable)) => {
377+
((bindout: (LocalVar)) => {
357378
for {
358-
rhs <- Eval.evalLiteral(f)(bindout._1)
359-
} yield (bindout._2, rhs)
379+
rhs <- Eval.evalLiteral(f)(bindout)
380+
} yield (bindout, rhs)
360381
}),
361-
s.outParams
382+
returnFrom.formalOutParam
362383
)
363-
c <- State.sequence(State.pure(()), outs.map(m => f.storeVar(m._1.name, m._1.toBoogie.scope, Scalar(m._2))))
364-
_ <- f.setNext(Run(s.successor))
365-
} yield (c)
384+
} yield (outs.toMap)
366385
}
367386

368387
def interpretStatement[S, T <: Effects[S, InterpreterError]](f: T)(s: Statement): State[S, Unit, InterpreterError] = {
@@ -411,27 +430,28 @@ class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S
411430
} yield (n)
412431
case dc: DirectCall => {
413432
for {
433+
// eval actual
414434
actualParams <- State.mapM(
415435
(p: (LocalVar, Expr)) =>
416436
for {
417437
v <- Eval.evalLiteral(f)(p._2)
418438
} yield (p._1, v),
419439
dc.actualParams
420440
)
421-
_ <- {
422-
if (LibcIntrinsic.intrinsics.contains(dc.target.procName)) {
423-
f.call(dc.target.name, Intrinsic(dc.target.procName), ReturnTo(dc))
424-
} else if (dc.target.entryBlock.isDefined) {
425-
val block = dc.target.entryBlock.get
426-
f.call(dc.target.name, Run(block.statements.headOption.getOrElse(block.jump)), ReturnTo(dc))
427-
} else {
428-
State.setError(EscapedControlFlow(dc))
441+
// call procedure (immediately evaluated)
442+
ret <- callProcedure(f)(dc.target, actualParams)
443+
// assign return values of procedure
444+
outs =
445+
dc.outParams.map { (formal: LocalVar, lhs: Variable) =>
446+
(formal, lhs, ret(formal))
429447
}
430-
}
431448
_ <- State.sequence(
432449
State.pure(()),
433-
actualParams.map(m => f.storeVar(m._1.name, m._1.toBoogie.scope, Scalar(m._2)))
450+
outs.map { (formal, lhs, value) =>
451+
f.storeVar(lhs.name, lhs.toBoogie.scope, Scalar(value))
452+
}
434453
)
454+
_ <- f.setNext(Run(s.successor))
435455
} yield ()
436456
}
437457
case ic: IndirectCall => {
@@ -451,9 +471,6 @@ class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S
451471
case _: NOP => f.setNext(Run(s.successor))
452472
}
453473
}
454-
}
455-
456-
object InterpFuns {
457474

458475
def initRelocTable[S, T <: Effects[S, InterpreterError]](s: T)(ctx: IRContext): State[S, Unit, InterpreterError] = {
459476

@@ -588,13 +605,68 @@ object InterpFuns {
588605
s = State.execute(s, f.call("init_activation", Stopped(), Stopped()))
589606
s = initMemory(s, "mem", p.initialMemory.values)
590607
s = initMemory(s, "stack", p.initialMemory.values)
591-
mainfun.entryBlock.foreach(startBlock =>
592-
s = State.execute(s, f.call(mainfun.name, Run(IRWalk.firstInBlock(startBlock)), Stopped()))
593-
)
594-
// l <- State.sequence(State.pure(()), mainfun.formalInParam.toList.map(i => f.storeVar(i.name, i.toBoogie.scope, Scalar(BitVecLiteral(0, size(i).get)))))
595608
s
596609
}
597610

611+
// case class FunctionCall(target: String, inParams: List[(String, Literal)])
612+
613+
/*
614+
* Calls a procedure that has a return, immediately evaluating the procedure
615+
*
616+
* Call a function, possibly dispatching to an intrinsic if the
617+
* procedure is not resolved, or resolves to a stub.
618+
*/
619+
def callProcedure[S, T <: Effects[S, InterpreterError]](f: T)(
620+
targetProc: ProcSig | Procedure,
621+
actualParams: Iterable[(LocalVar, Literal)]
622+
): State[S, Map[LocalVar, Literal], InterpreterError] = {
623+
624+
val target = targetProc match {
625+
case p: Procedure => Some(p)
626+
case _ => None
627+
}
628+
val proc = targetProc match {
629+
case p: ProcSig => p
630+
case p: Procedure => ProcSig(p.name, p.formalInParam.toList, p.formalOutParam.toList)
631+
}
632+
633+
val call = for {
634+
// evaluate actual parms
635+
v <- {
636+
// perform call and push return stack frame and continuation
637+
638+
val intrinsicName = target.map(_.procName).getOrElse(proc.name)
639+
if (LibcIntrinsic.intrinsics.contains(intrinsicName)) {
640+
f.call(intrinsicName, Intrinsic(intrinsicName), ReturnFrom(proc))
641+
} else if (target.exists(_.entryBlock.isDefined)) {
642+
val block = target.get.entryBlock.get
643+
f.call(target.get.name, Run(IRWalk.firstInBlock(block)), ReturnFrom(proc))
644+
} else {
645+
State.setError(Errored(s"call to empty procedure: ${proc.name} / $target"))
646+
}
647+
}
648+
// set actual params in the callee state
649+
_ <- State.sequence(
650+
State.pure(()),
651+
actualParams.map(m => f.storeVar(m._1.name, m._1.toBoogie.scope, Scalar(m._2)))
652+
)
653+
} yield ()
654+
655+
for {
656+
r <- call
657+
ret <- evalInterpreter(f, interpretContinuation(f))
658+
rv <-
659+
if (ret.isDefined) {
660+
ret.get match {
661+
case InterpretReturn.ReturnVal(v) => State.pure(v)
662+
case v => State.setError(Errored(s"Call didn't return value (got $v) $proc"))
663+
}
664+
} else {
665+
State.setError(Errored(s"Call to pure function should have returned a value, $proc"))
666+
}
667+
} yield (rv)
668+
}
669+
598670
def initBSS[S, T <: Effects[S, InterpreterError]](f: T)(is: S, p: IRContext): S = {
599671
val bss = for {
600672
first <- p.symbols.find(s => s.name == "__bss_start__").map(_.value)
@@ -634,18 +706,50 @@ object InterpFuns {
634706
}
635707
}
636708

709+
def mainDefaultFunctionArguments(
710+
proc: Procedure,
711+
overlay: Map[String, BitVecLiteral] = Map()
712+
): Map[LocalVar, Literal] = {
713+
val SP: BitVecLiteral = BitVecLiteral(0x78000000, 64)
714+
val FP: BitVecLiteral = SP
715+
val LR: BitVecLiteral = BitVecLiteral(BigInt("78000000", 16), 64)
716+
717+
proc.formalInParam.toList.map {
718+
case l: LocalVar if overlay.contains(l.name) => l -> overlay(l.name)
719+
case l: LocalVar if l.name.startsWith("R0") => l -> BitVecLiteral(1, size(l).get)
720+
case l: LocalVar if l.name.startsWith("R31") => l -> SP
721+
case l: LocalVar if l.name.startsWith("R29") => l -> FP
722+
case l: LocalVar if l.name.startsWith("R30") => l -> LR
723+
case l: LocalVar => l -> BitVecLiteral(0, size(l).get)
724+
}.toMap
725+
}
726+
727+
def interpretEvalProcExc(program: IRContext | Program, functionName: String, params: Map[String, Literal]) = {}
728+
637729
/* Intialise from ELF and Interpret program */
638-
def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext, is: S): S = {
730+
def interpretEvalProg[S, T <: Effects[S, InterpreterError]](
731+
f: T
732+
)(p: IRContext, is: S): (S, Either[InterpreterError, Map[LocalVar, Literal]]) = {
639733
val begin = initProgState(f)(p, is)
640-
val interp = BASILInterpreter(f)
641-
interp.run(begin)
734+
val main = p.program.mainProcedure
735+
callProcedure(f)(main, mainDefaultFunctionArguments(main)).f(begin)
642736
}
643737

644738
/* Interpret IR program */
645-
def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: Program, is: S): S = {
739+
def interpretEvalProg[S, T <: Effects[S, InterpreterError]](
740+
f: T
741+
)(p: Program, is: S): (S, Either[InterpreterError, Map[LocalVar, Literal]]) = {
646742
val begin = initialiseProgram(f)(is, p)
647-
val interp = BASILInterpreter(f)
648-
interp.run(begin)
743+
val main = p.mainProcedure
744+
callProcedure(f)(main, mainDefaultFunctionArguments(main)).f(begin)
745+
}
746+
747+
def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext, is: S): S = {
748+
interpretEvalProg(f)(p, is)._1
749+
}
750+
751+
def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: Program, is: S): S = {
752+
interpretEvalProg(f)(p, is)._1
649753
}
650754
}
651755

@@ -656,3 +760,11 @@ def interpret(IRProgram: Program): InterpreterState = {
656760
def interpret(IRProgram: IRContext): InterpreterState = {
657761
InterpFuns.interpretProg(NormalInterpreter)(IRProgram, InterpreterState())
658762
}
763+
764+
def interpretEval(IRProgram: Program): (InterpreterState, Either[InterpreterError, Map[LocalVar, Literal]]) = {
765+
InterpFuns.interpretEvalProg(NormalInterpreter)(IRProgram, InterpreterState())
766+
}
767+
768+
def interpretEval(IRProgram: IRContext): (InterpreterState, Either[InterpreterError, Map[LocalVar, Literal]]) = {
769+
InterpFuns.interpretEvalProg(NormalInterpreter)(IRProgram, InterpreterState())
770+
}

src/main/scala/ir/eval/Interpreter.scala

+49-13
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,27 @@ import scala.collection.mutable
1414
import scala.collection.immutable
1515
import scala.util.control.Breaks.{break, breakable}
1616

17+
/**
18+
* Procedure signature used when returning from procedures and intrinsics.
19+
* This is mainly used to describe the formalOutparams, where the return values of the procedure
20+
* are stored to be read by the return value.
21+
*
22+
* @param name
23+
* The full name of the procedure (i.e. procedure.name if a real procedure, otherwise the name of the intrinsic)
24+
* @param formalInParam
25+
* The list of formal input parameters (corresponding to procedure.formalInParam)
26+
* @param formalOutParam
27+
* The list of formal outpur params (corresponding to procedure.formalOutParam)
28+
*/
29+
case class ProcSig(name: String, formalInParam: List[LocalVar], formalOutParam: List[LocalVar])
30+
1731
/** Interpreter status type, either stopped, run next command or error
1832
*/
1933
sealed trait ExecutionContinuation
2034
case class Stopped() extends ExecutionContinuation /* normal program stop */
2135
case class ErrorStop(error: InterpreterError) extends ExecutionContinuation /* program stop in error state */
2236
case class Run(next: Command) extends ExecutionContinuation /* continue by executing next command */
23-
case class ReturnTo(call: DirectCall) extends ExecutionContinuation /* continue by executing next command */
37+
case class ReturnFrom(target: ProcSig) extends ExecutionContinuation /* return from a call without continuing */
2438
case class Intrinsic(name: String) extends ExecutionContinuation /* a named intrinsic instruction */
2539

2640
sealed trait InterpreterError
@@ -47,6 +61,12 @@ sealed trait MapValue {
4761
def value: Map[BasilValue, BasilValue]
4862
}
4963

64+
def normalTermination(is: ExecutionContinuation) = is match {
65+
case Stopped() => true
66+
case ReturnFrom(_) => true
67+
case _ => false
68+
}
69+
5070
/* We erase the type of basil values and enforce the invariant that
5171
\exists i . \forall v \in value.keys , v.irType = i and
5272
\exists j . \forall v \in value.values, v.irType = j
@@ -356,7 +376,8 @@ object LibcIntrinsic {
356376
"__libc_malloc_impl" -> singleArg("malloc"),
357377
"free" -> singleArg("free"),
358378
"#free" -> singleArg("free"),
359-
"calloc" -> calloc
379+
"calloc" -> calloc,
380+
"strlen" -> singleArg("strlen")
360381
)
361382

362383
}
@@ -630,21 +651,36 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] {
630651
)
631652
}
632653

633-
trait Interpreter[S, E](val f: Effects[S, E]) {
654+
enum Next[+V] {
655+
case Continue
656+
case Stop(value: V)
657+
}
658+
659+
/**
660+
* Force the evaluation of the state monad steps in an explicit iteration to avoid too much buildup
661+
*/
662+
def evalInterpreter[S, V, E](f: Effects[S, E], doStep: State[S, Next[V], E]): State[S, Option[V], E] = {
663+
@tailrec
664+
def runEval(begin: S): (S, Either[E, Option[V]]) = {
665+
val (fs, cont) = doStep.f(begin)
666+
667+
cont match {
668+
case Right(Next.Stop(v)) => (fs, Right(Some(v)))
669+
case Right(Next.Continue) => runEval(fs)
670+
case Left(e) => (fs, Left(e))
671+
}
672+
}
673+
674+
State(begin => runEval(begin))
675+
}
676+
677+
trait Interpreter[S, V, E](val f: Effects[S, E]) {
634678

635679
/*
636680
* Returns value deciding whether to continue.
637681
*/
638-
def interpretOne: State[S, Boolean, E]
682+
def interpretOne: State[S, Next[V], E]
639683

640-
@tailrec
641-
final def run(begin: S): S = {
642-
val (fs, cont) = interpretOne.f(begin)
684+
final def run(begin: S): S = State.execute(begin, (evalInterpreter(f, interpretOne)))
643685

644-
if (cont.contains(true)) then {
645-
run(fs)
646-
} else {
647-
fs
648-
}
649-
}
650686
}

src/main/scala/translating/IRExpToSMT2.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ object BasilIRToSMT2 extends BasilIRExpWithVis[Sexp] {
205205
}
206206

207207
val terms = list(sym("push")) :: BasilIRToSMT2.extractDecls(e)
208-
++ List(assert, list(sym("set-option"), sym(":smt.timeout"), sym("1")), list(sym("check-sat")))
208+
++ List(assert, list(sym("check-sat")))
209209
++ (if (getModel) then
210210
List(list(sym("echo"), sym("\"" + name.getOrElse("") + " :: " + e + "\"")), list(sym("get-model")))
211211
else List())

0 commit comments

Comments
 (0)