Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DSL compiler for simple While language #354

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import contrib.scalapblib._

object basil extends RootModule with ScalaModule with antlr.AntlrModule with ScalaPBModule {
def scalaVersion = "3.3.4"
def ammoniteVersion = "3.0.2"

def scalacOptions: T[Seq[String]] = Seq("-deprecation")

Expand Down
43 changes: 31 additions & 12 deletions src/main/scala/ir/dsl/DSL.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package ir.dsl
import ir.*
import translating.PrettyPrinter.*
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.collection.immutable.*
import scala.annotation.targetName

/**
* IR construction DSL
Expand Down Expand Up @@ -91,16 +93,10 @@ def exprEq(l: Expr, r: Expr): Expr = (l, r) match {
case _ => FalseLiteral
}

def bv32(i: Int): BitVecLiteral = BitVecLiteral(i, 32)

def bv64(i: Int): BitVecLiteral = BitVecLiteral(i, 64)

def bv8(i: Int): BitVecLiteral = BitVecLiteral(i, 8)

def bv16(i: Int): BitVecLiteral = BitVecLiteral(i, 16)

def R(i: Int): Register = Register(s"R$i", 64)

def bv_t(i: Int) = BitVecType(i)

case class DelayNameResolve(ident: String) {
def resolveProc(prog: Program): Option[Procedure] = prog.collectFirst {
case b: Procedure if b.name == ident => b
Expand Down Expand Up @@ -215,7 +211,7 @@ def indirectCall(tgt: Variable): EventuallyIndirectCall = EventuallyIndirectCall
case class EventuallyBlock(
label: String,
sl: Iterable[EventuallyStatement],
j: EventuallyJump,
var j: EventuallyJump,
address: Option[BigInt] = None
) {

Expand Down Expand Up @@ -249,8 +245,22 @@ def block(label: String, sl: (NonCallStatement | EventuallyStatement | Eventuall
case g: EventuallyJump => None
}
val jump = sl.collect { case j: EventuallyJump => j }
require(jump.length == 1, s"DSL block '$label' must contain exactly one jump statement")
EventuallyBlock(label, statements, jump.head)
require(jump.length <= 1, s"DSL block '$label' must contain no more than one jump statement")
val rjump = if (jump.isEmpty) then unreachable else jump.head
EventuallyBlock(label, statements, rjump)
}

/**
* Construct a block from a list of statements with a default name.
*/
def stmts(sl: (EventuallyCall | NonCallStatement | EventuallyStatement | EventuallyJump)*): EventuallyBlock = {

val stmts =
if (sl.isEmpty) then List(unreachable)
else if (!sl.last.isInstanceOf[EventuallyJump]) then (sl.toList ++ List(unreachable))
else sl

block(Counter.nlabel("block"), stmts: _*)
}

case class EventuallyProcedure(
Expand Down Expand Up @@ -314,13 +324,22 @@ def proc(label: String, blocks: EventuallyBlock*): EventuallyProcedure = {
EventuallyProcedure(label, SortedMap(), SortedMap(), blocks, blocks.headOption.map(_.label))
}

def proc(
label: String,
in: Iterable[(String, IRType)],
out: Iterable[(String, IRType)],
blocks: Iterable[EventuallyBlock]
): EventuallyProcedure = {
EventuallyProcedure(label, in.to(SortedMap), out.to(SortedMap), blocks.toSeq, blocks.headOption.map(_.label))
}

def proc(
label: String,
in: Iterable[(String, IRType)],
out: Iterable[(String, IRType)],
blocks: EventuallyBlock*
): EventuallyProcedure = {
EventuallyProcedure(label, in.to(SortedMap), out.to(SortedMap), blocks, blocks.headOption.map(_.label))
proc(label, in, out, blocks.toSeq)
}

def mem: SharedMemory = SharedMemory("mem", 64, 8)
Expand Down
180 changes: 180 additions & 0 deletions src/main/scala/ir/dsl/InfixConstructors.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
package ir.dsl
import ir.*
import translating.PrettyPrinter.*
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.collection.immutable.*
import scala.annotation.targetName
import ir.dsl.*

/**
* Expr and statement construction; this defined infix operators which construct
* BASIL IR statements and expressions.
*
* Typically with binops we default to signed ops, and provide named unsigned alternatives.
*/

extension (lvar: Variable)
infix def :=(j: Expr) = LocalAssign(lvar, j)
def :=(j: Int) = lvar.getType match {
case BitVecType(sz) => LocalAssign(lvar, BitVecLiteral(j, sz))
case IntType => LocalAssign(lvar, IntLiteral(j))
case _ => ???
}
def :=(j: Boolean) = lvar.getType match {
case BoolType => LocalAssign(lvar, if j then TrueLiteral else FalseLiteral)
case _ => ???
}

case class call(target: String, actualParams: (String, Expr)*)

extension (lvar: List[(String, Variable)]) infix def :=(j: call) = directCall(lvar, j.target, j.actualParams: _*)
extension (lvar: Seq[(String, Variable)]) infix def :=(j: call) = directCall(lvar, j.target, j.actualParams: _*)

extension (v: Int)
@targetName("ibv64")
def bv64 = BitVecLiteral(v, 64)
@targetName("ibv32")
def bv32 = BitVecLiteral(v, 32)
@targetName("ibv16")
def bv16 = BitVecLiteral(v, 16)
@targetName("ibv8")
def bv8 = BitVecLiteral(v, 8)
@targetName("ibv1")
def bv1 = BitVecLiteral(v, 1)
@targetName("itobv")
def bv(sz: Int) = BitVecLiteral(v, sz)

def bv64 = BitVecType(64)
def bv32 = BitVecType(32)
def bv16 = BitVecType(16)
def bv8 = BitVecType(8)
def bv1 = BitVecType(1)

extension (i: Expr)
infix def ===(j: Expr): Expr = i.getType match {
case IntType => BinaryExpr(IntEQ, i, j)
case b: BitVecType => BinaryExpr(BVEQ, i, j)
case BoolType => BinaryExpr(BoolEQ, i, j)
case m: MapType => ???
}
infix def !==(j: Expr): Expr = i.getType match {
case IntType => BinaryExpr(IntNEQ, i, j)
case b: BitVecType => BinaryExpr(BVNEQ, i, j)
case BoolType => BinaryExpr(BoolNEQ, i, j)
case m: MapType => ???
}
infix def +(j: Expr): Expr = i.getType match {
case IntType => BinaryExpr(IntADD, i, j)
case b: BitVecType => BinaryExpr(BVADD, i, j)
case BoolType => BinaryExpr(BoolOR, i, j)
case m: MapType => ???
}
infix def -(j: Expr): Expr = i.getType match {
case IntType => BinaryExpr(IntSUB, i, j)
case b: BitVecType => BinaryExpr(BVSUB, i, j)
case BoolType => ???
case m: MapType => ???
}
infix def *(j: Expr): Expr = i.getType match {
case IntType => BinaryExpr(IntMUL, i, j)
case b: BitVecType => BinaryExpr(BVMUL, i, j)
case BoolType => BinaryExpr(BoolAND, i, j)
case m: MapType => ???
}
infix def /(j: Expr): Expr = i.getType match {
case IntType => BinaryExpr(IntDIV, i, j)
case b: BitVecType => BinaryExpr(BVSDIV, i, j)
case BoolType => ???
case m: MapType => ???
}
infix def &&(j: Expr): Expr = i.getType match {
case IntType => ???
case b: BitVecType => BinaryExpr(BVAND, i, j)
case BoolType => BinaryExpr(BoolAND, i, j)
case m: MapType => ???
}
infix def ||(j: Expr): Expr = i.getType match {
case IntType => ???
case b: BitVecType => BinaryExpr(BVOR, i, j)
case BoolType => BinaryExpr(BoolOR, i, j)
case m: MapType => ???
}
infix def <<(j: Expr): Expr = i.getType match {
case IntType => ???
case b: BitVecType => BinaryExpr(BVSHL, i, j)
case BoolType => ???
case m: MapType => ???
}
infix def >>(j: Expr): Expr = i.getType match {
case IntType => ???
case b: BitVecType => BinaryExpr(BVASHR, i, j)
case BoolType => ???
case m: MapType => ???
}
infix def >>>(j: Expr): Expr = i.getType match {
case IntType => ???
case b: BitVecType => BinaryExpr(BVLSHR, i, j)
case BoolType => ???
case m: MapType => ???
}
infix def %(j: Expr): Expr = i.getType match {
case IntType => BinaryExpr(IntMOD, i, j)
case b: BitVecType => BinaryExpr(BVSMOD, i, j)
case BoolType => ???
case m: MapType => ???
}
infix def <(j: Expr): Expr = i.getType match {
case IntType => BinaryExpr(IntLT, i, j)
case b: BitVecType => BinaryExpr(BVSLT, i, j)
case BoolType => ???
case m: MapType => ???
}
infix def >(j: Expr): Expr = i.getType match {
case IntType => BinaryExpr(IntGT, i, j)
case b: BitVecType => BinaryExpr(BVSGT, i, j)
case BoolType => ???
case m: MapType => ???
}
infix def <=(j: Expr): Expr = i.getType match {
case IntType => BinaryExpr(IntLE, i, j)
case b: BitVecType => BinaryExpr(BVSLE, i, j)
case BoolType => ???
case m: MapType => ???
}
infix def >=(j: Expr): Expr = i.getType match {
case IntType => BinaryExpr(IntGE, i, j)
case b: BitVecType => BinaryExpr(BVSGE, i, j)
case BoolType => ???
case m: MapType => ???
}
infix def ult(j: Expr): Expr = i.getType match {
case IntType => ???
case b: BitVecType => BinaryExpr(BVULT, i, j)
case BoolType => ???
case m: MapType => ???
}
infix def ugt(j: Expr): Expr = i.getType match {
case IntType => ???
case b: BitVecType => BinaryExpr(BVUGT, i, j)
case BoolType => ???
case m: MapType => ???
}
infix def ule(j: Expr): Expr = i.getType match {
case IntType => ???
case b: BitVecType => BinaryExpr(BVULE, i, j)
case BoolType => ???
case m: MapType => ???
}
infix def uge(j: Expr): Expr = i.getType match {
case IntType => ???
case b: BitVecType => BinaryExpr(BVUGE, i, j)
case BoolType => ???
case m: MapType => ???
}
infix def ++(j: Expr): Expr = i.getType match {
case IntType => ???
case b: BitVecType => BinaryExpr(BVCONCAT, i, j)
case BoolType => ???
case m: MapType => ???
}
Loading