Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interpreter: callProcedure with return values #352

Merged
merged 5 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 153 additions & 41 deletions src/main/scala/ir/eval/InterpretBasilIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -284,26 +284,42 @@ case object Eval {
case _ => State.setError(Errored(s"not byte $c"))
}
} yield (res)

}

enum InterpretReturn {
case ReturnVal(outs: Map[LocalVar, Literal])
case Void
}

class BASILInterpreter[S](f: Effects[S, InterpreterError])
extends Interpreter[S, InterpretReturn, InterpreterError](f) {

def interpretOne: util.functional.State[S, Next[InterpretReturn], InterpreterError] =
InterpFuns.interpretContinuation(f)

}

class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S, InterpreterError](f) {
object InterpFuns {

def interpretOne: State[S, Boolean, InterpreterError] = {
def interpretContinuation[S, T <: Effects[S, InterpreterError]](
f: T
): State[S, Next[InterpretReturn], InterpreterError] = {
val next = for {
next <- f.getNext
_ <- State.pure(Logger.debug(s"$next"))
r: Boolean <- (next match {
case Intrinsic(tgt) => LibcIntrinsic.intrinsics(tgt)(f).map(_ => true)
case Run(c: Statement) => interpretStatement(f)(c).map(_ => true)
case ReturnTo(c) => interpretReturn(f)(c).map(_ => true)
case Run(c: Jump) => interpretJump(f)(c).map(_ => true)
case Stopped() => State.pure(false)
case ErrorStop(e) => State.pure(false)
r: Next[InterpretReturn] <- (next match {
case Intrinsic(tgt) => LibcIntrinsic.intrinsics(tgt)(f).map(_ => Next.Continue)
case Run(c: Statement) => interpretStatement(f)(c).map(_ => Next.Continue)
case ReturnFrom(c) => evaluateReturn(f)(c).map(v => Next.Stop(InterpretReturn.ReturnVal(v)))
case Run(c: Jump) => interpretJump(f)(c).map(_ => Next.Continue)
case Stopped() => State.pure(Next.Stop(InterpretReturn.Void))
case ErrorStop(e) => State.pure(Next.Stop(InterpretReturn.Void))
})
} yield (r)

next.flatMapE((e: InterpreterError) => {
f.setNext(ErrorStop(e)).map(_ => false)
f.setNext(ErrorStop(e)).map(_ => Next.Stop(InterpretReturn.Void))
})
}

Expand Down Expand Up @@ -350,19 +366,22 @@ class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S
}
}

def interpretReturn[S, T <: Effects[S, InterpreterError]](f: T)(s: DirectCall): State[S, Unit, InterpreterError] = {
/**
* Evaluates the formal out params and returns a map totheir values.
*/
def evaluateReturn[S, T <: Effects[S, InterpreterError]](
f: T
)(returnFrom: ProcSig): State[S, Map[LocalVar, Literal], InterpreterError] = {
for {
outs <- State.mapM(
((bindout: (LocalVar, Variable)) => {
((bindout: (LocalVar)) => {
for {
rhs <- Eval.evalLiteral(f)(bindout._1)
} yield (bindout._2, rhs)
rhs <- Eval.evalLiteral(f)(bindout)
} yield (bindout, rhs)
}),
s.outParams
returnFrom.formalOutParam
)
c <- State.sequence(State.pure(()), outs.map(m => f.storeVar(m._1.name, m._1.toBoogie.scope, Scalar(m._2))))
_ <- f.setNext(Run(s.successor))
} yield (c)
} yield (outs.toMap)
}

def interpretStatement[S, T <: Effects[S, InterpreterError]](f: T)(s: Statement): State[S, Unit, InterpreterError] = {
Expand Down Expand Up @@ -411,27 +430,28 @@ class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S
} yield (n)
case dc: DirectCall => {
for {
// eval actual
actualParams <- State.mapM(
(p: (LocalVar, Expr)) =>
for {
v <- Eval.evalLiteral(f)(p._2)
} yield (p._1, v),
dc.actualParams
)
_ <- {
if (LibcIntrinsic.intrinsics.contains(dc.target.procName)) {
f.call(dc.target.name, Intrinsic(dc.target.procName), ReturnTo(dc))
} else if (dc.target.entryBlock.isDefined) {
val block = dc.target.entryBlock.get
f.call(dc.target.name, Run(block.statements.headOption.getOrElse(block.jump)), ReturnTo(dc))
} else {
State.setError(EscapedControlFlow(dc))
// call procedure (immediately evaluated)
ret <- callProcedure(f)(dc.target, actualParams)
// assign return values of procedure
outs =
dc.outParams.map { (formal: LocalVar, lhs: Variable) =>
(formal, lhs, ret(formal))
}
}
_ <- State.sequence(
State.pure(()),
actualParams.map(m => f.storeVar(m._1.name, m._1.toBoogie.scope, Scalar(m._2)))
outs.map { (formal, lhs, value) =>
f.storeVar(lhs.name, lhs.toBoogie.scope, Scalar(value))
}
)
_ <- f.setNext(Run(s.successor))
} yield ()
}
case ic: IndirectCall => {
Expand All @@ -451,9 +471,6 @@ class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S
case _: NOP => f.setNext(Run(s.successor))
}
}
}

object InterpFuns {

def initRelocTable[S, T <: Effects[S, InterpreterError]](s: T)(ctx: IRContext): State[S, Unit, InterpreterError] = {

Expand Down Expand Up @@ -588,13 +605,68 @@ object InterpFuns {
s = State.execute(s, f.call("init_activation", Stopped(), Stopped()))
s = initMemory(s, "mem", p.initialMemory.values)
s = initMemory(s, "stack", p.initialMemory.values)
mainfun.entryBlock.foreach(startBlock =>
s = State.execute(s, f.call(mainfun.name, Run(IRWalk.firstInBlock(startBlock)), Stopped()))
)
// l <- State.sequence(State.pure(()), mainfun.formalInParam.toList.map(i => f.storeVar(i.name, i.toBoogie.scope, Scalar(BitVecLiteral(0, size(i).get)))))
s
}

// case class FunctionCall(target: String, inParams: List[(String, Literal)])

/*
* Calls a procedure that has a return, immediately evaluating the procedure
*
* Call a function, possibly dispatching to an intrinsic if the
* procedure is not resolved, or resolves to a stub.
*/
def callProcedure[S, T <: Effects[S, InterpreterError]](f: T)(
targetProc: ProcSig | Procedure,
actualParams: Iterable[(LocalVar, Literal)]
): State[S, Map[LocalVar, Literal], InterpreterError] = {

val target = targetProc match {
case p: Procedure => Some(p)
case _ => None
}
val proc = targetProc match {
case p: ProcSig => p
case p: Procedure => ProcSig(p.name, p.formalInParam.toList, p.formalOutParam.toList)
}

val call = for {
// evaluate actual parms
v <- {
// perform call and push return stack frame and continuation

val intrinsicName = target.map(_.procName).getOrElse(proc.name)
if (LibcIntrinsic.intrinsics.contains(intrinsicName)) {
f.call(intrinsicName, Intrinsic(intrinsicName), ReturnFrom(proc))
} else if (target.exists(_.entryBlock.isDefined)) {
val block = target.get.entryBlock.get
f.call(target.get.name, Run(IRWalk.firstInBlock(block)), ReturnFrom(proc))
} else {
State.setError(Errored(s"call to empty procedure: ${proc.name} / $target"))
}
}
// set actual params in the callee state
_ <- State.sequence(
State.pure(()),
actualParams.map(m => f.storeVar(m._1.name, m._1.toBoogie.scope, Scalar(m._2)))
)
} yield ()

for {
r <- call
ret <- evalInterpreter(f, interpretContinuation(f))
rv <-
if (ret.isDefined) {
ret.get match {
case InterpretReturn.ReturnVal(v) => State.pure(v)
case v => State.setError(Errored(s"Call didn't return value (got $v) $proc"))
}
} else {
State.setError(Errored(s"Call to pure function should have returned a value, $proc"))
}
} yield (rv)
}

def initBSS[S, T <: Effects[S, InterpreterError]](f: T)(is: S, p: IRContext): S = {
val bss = for {
first <- p.symbols.find(s => s.name == "__bss_start__").map(_.value)
Expand Down Expand Up @@ -634,18 +706,50 @@ object InterpFuns {
}
}

def mainDefaultFunctionArguments(
proc: Procedure,
overlay: Map[String, BitVecLiteral] = Map()
): Map[LocalVar, Literal] = {
val SP: BitVecLiteral = BitVecLiteral(0x78000000, 64)
val FP: BitVecLiteral = SP
val LR: BitVecLiteral = BitVecLiteral(BigInt("78000000", 16), 64)

proc.formalInParam.toList.map {
case l: LocalVar if overlay.contains(l.name) => l -> overlay(l.name)
case l: LocalVar if l.name.startsWith("R0") => l -> BitVecLiteral(1, size(l).get)
case l: LocalVar if l.name.startsWith("R31") => l -> SP
case l: LocalVar if l.name.startsWith("R29") => l -> FP
case l: LocalVar if l.name.startsWith("R30") => l -> LR
case l: LocalVar => l -> BitVecLiteral(0, size(l).get)
}.toMap
}

def interpretEvalProcExc(program: IRContext | Program, functionName: String, params: Map[String, Literal]) = {}

/* Intialise from ELF and Interpret program */
def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext, is: S): S = {
def interpretEvalProg[S, T <: Effects[S, InterpreterError]](
f: T
)(p: IRContext, is: S): (S, Either[InterpreterError, Map[LocalVar, Literal]]) = {
val begin = initProgState(f)(p, is)
val interp = BASILInterpreter(f)
interp.run(begin)
val main = p.program.mainProcedure
callProcedure(f)(main, mainDefaultFunctionArguments(main)).f(begin)
}

/* Interpret IR program */
def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: Program, is: S): S = {
def interpretEvalProg[S, T <: Effects[S, InterpreterError]](
f: T
)(p: Program, is: S): (S, Either[InterpreterError, Map[LocalVar, Literal]]) = {
val begin = initialiseProgram(f)(is, p)
val interp = BASILInterpreter(f)
interp.run(begin)
val main = p.mainProcedure
callProcedure(f)(main, mainDefaultFunctionArguments(main)).f(begin)
}

def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext, is: S): S = {
interpretEvalProg(f)(p, is)._1
}

def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: Program, is: S): S = {
interpretEvalProg(f)(p, is)._1
}
}

Expand All @@ -656,3 +760,11 @@ def interpret(IRProgram: Program): InterpreterState = {
def interpret(IRProgram: IRContext): InterpreterState = {
InterpFuns.interpretProg(NormalInterpreter)(IRProgram, InterpreterState())
}

def interpretEval(IRProgram: Program): (InterpreterState, Either[InterpreterError, Map[LocalVar, Literal]]) = {
InterpFuns.interpretEvalProg(NormalInterpreter)(IRProgram, InterpreterState())
}

def interpretEval(IRProgram: IRContext): (InterpreterState, Either[InterpreterError, Map[LocalVar, Literal]]) = {
InterpFuns.interpretEvalProg(NormalInterpreter)(IRProgram, InterpreterState())
}
62 changes: 49 additions & 13 deletions src/main/scala/ir/eval/Interpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,27 @@ import scala.collection.mutable
import scala.collection.immutable
import scala.util.control.Breaks.{break, breakable}

/**
* Procedure signature used when returning from procedures and intrinsics.
* This is mainly used to describe the formalOutparams, where the return values of the procedure
* are stored to be read by the return value.
*
* @param name
* The full name of the procedure (i.e. procedure.name if a real procedure, otherwise the name of the intrinsic)
* @param formalInParam
* The list of formal input parameters (corresponding to procedure.formalInParam)
* @param formalOutParam
* The list of formal outpur params (corresponding to procedure.formalOutParam)
*/
case class ProcSig(name: String, formalInParam: List[LocalVar], formalOutParam: List[LocalVar])

/** Interpreter status type, either stopped, run next command or error
*/
sealed trait ExecutionContinuation
case class Stopped() extends ExecutionContinuation /* normal program stop */
case class ErrorStop(error: InterpreterError) extends ExecutionContinuation /* program stop in error state */
case class Run(next: Command) extends ExecutionContinuation /* continue by executing next command */
case class ReturnTo(call: DirectCall) extends ExecutionContinuation /* continue by executing next command */
case class ReturnFrom(target: ProcSig) extends ExecutionContinuation /* return from a call without continuing */
case class Intrinsic(name: String) extends ExecutionContinuation /* a named intrinsic instruction */

sealed trait InterpreterError
Expand All @@ -47,6 +61,12 @@ sealed trait MapValue {
def value: Map[BasilValue, BasilValue]
}

def normalTermination(is: ExecutionContinuation) = is match {
case Stopped() => true
case ReturnFrom(_) => true
case _ => false
}

/* We erase the type of basil values and enforce the invariant that
\exists i . \forall v \in value.keys , v.irType = i and
\exists j . \forall v \in value.values, v.irType = j
Expand Down Expand Up @@ -356,7 +376,8 @@ object LibcIntrinsic {
"__libc_malloc_impl" -> singleArg("malloc"),
"free" -> singleArg("free"),
"#free" -> singleArg("free"),
"calloc" -> calloc
"calloc" -> calloc,
"strlen" -> singleArg("strlen")
)

}
Expand Down Expand Up @@ -630,21 +651,36 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] {
)
}

trait Interpreter[S, E](val f: Effects[S, E]) {
enum Next[+V] {
case Continue
case Stop(value: V)
}

/**
* Force the evaluation of the state monad steps in an explicit iteration to avoid too much buildup
*/
def evalInterpreter[S, V, E](f: Effects[S, E], doStep: State[S, Next[V], E]): State[S, Option[V], E] = {
@tailrec
def runEval(begin: S): (S, Either[E, Option[V]]) = {
val (fs, cont) = doStep.f(begin)

cont match {
case Right(Next.Stop(v)) => (fs, Right(Some(v)))
case Right(Next.Continue) => runEval(fs)
case Left(e) => (fs, Left(e))
}
}

State(begin => runEval(begin))
}

trait Interpreter[S, V, E](val f: Effects[S, E]) {

/*
* Returns value deciding whether to continue.
*/
def interpretOne: State[S, Boolean, E]
def interpretOne: State[S, Next[V], E]

@tailrec
final def run(begin: S): S = {
val (fs, cont) = interpretOne.f(begin)
final def run(begin: S): S = State.execute(begin, (evalInterpreter(f, interpretOne)))

if (cont.contains(true)) then {
run(fs)
} else {
fs
}
}
}
2 changes: 1 addition & 1 deletion src/main/scala/translating/IRExpToSMT2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ object BasilIRToSMT2 extends BasilIRExpWithVis[Sexp] {
}

val terms = list(sym("push")) :: BasilIRToSMT2.extractDecls(e)
++ List(assert, list(sym("set-option"), sym(":smt.timeout"), sym("1")), list(sym("check-sat")))
++ List(assert, list(sym("check-sat")))
++ (if (getModel) then
List(list(sym("echo"), sym("\"" + name.getOrElse("") + " :: " + e + "\"")), list(sym("get-model")))
else List())
Expand Down
Loading