Skip to content

Commit

Permalink
quick refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
vighneshiyer committed Aug 30, 2023
1 parent 6f3eae2 commit a31a002
Show file tree
Hide file tree
Showing 12 changed files with 212 additions and 187 deletions.
166 changes: 31 additions & 135 deletions src/simcommand/package.scala
Original file line number Diff line number Diff line change
@@ -1,117 +1,16 @@
import chisel3.{Data}

import scala.collection.mutable.ArrayBuffer
import sourcecode.{Line, FileName, Enclosing}
import sourcecode.{Enclosing, FileName, Line}
import simcommand.runtime.Primitives._
import simcommand.runtime.Bridges._
import simcommand.runtime.{Bridges, Config, Imperative, Result}

package object simcommand {

// Below is inspired by the trampolined continuation in TailCalls.scala (in the Scala stdlib)
/**
* This class represents an RTL simulation command and its return value
* @tparam R Type of the command's return value
*/
sealed abstract class Command[R](implicit val line: Line, val filename: FileName, val enclosing: Enclosing) {
final def map[R2](f: R => R2): Command[R2] = {
flatMap(r => Return(f(r)))
}

final def flatMap[R2](f: R => Command[R2]): Command[R2] = {
this match {
case Return(retval) => f(retval)
case c: Command[R] => Cont(c, f)
}
}

// tailRec provides an efficent tail recursion primitive. If the provided
// function takes in arguments and returns Command[Left(newArguments)], it
// flatMaps again with the new arguments. If given a
// Command[Right(result)], it will instead lift the internal result.
final def tailRec[R2](f: R => Command[Either[R, R2]]): Command[R2] = {
this.flatMap(Rec(_, f))
}

// tailRecM is the functional version of tailRecM. This version is
// also safe, but tends to be significantly slower.
// Similar implementation as cats.Free: https://github.com/typelevel/cats/pull/1041/files#diff-7349edfd077f9612f7181fe1f8caca63ac667c847ce83b53dceae4d08040fd55
final def tailRecM[R2](f: R => Command[Either[R, R2]]): Command[R2] = {
this.flatMap(f).flatMap {
// recursion here is lazy so the stack won't blow up
case Left(_) => tailRecM(f)
case Right(result) => lift(result)
}
}

def debugInfo: String = {
this.getClass.getSimpleName + "(" + filename.value + ":" + line.value + ")"
}
}

sealed trait Interactable[I] {
def set(value: I): Unit
def get(): I
def compare(value: I): Command[Boolean]
}

implicit class Chisel3Interactor[I <: Data](value: I) extends Interactable[I] {
val tester = chiseltest.testableData(value)
override def set(p: I): Unit = tester.poke(p)
override def get(): I = tester.peek()
// TODO: Implement a better comparator for chisel3 datatypes
override def compare(v: I): Command[Boolean] = peek(this).map(_.litValue == v.litValue)
}

private case class PrimitiveInteractor[I](var value: I) extends Interactable[I] {
override def set(p: I): Unit = value = p
override def get(): I = value
override def compare(v: I): Command[Boolean] = value match {
case x: Data => peek(this).map(_.asInstanceOf[Data].litValue == v.asInstanceOf[Data].litValue)
case _ => peek(this).map(_ == v)
}
override def hashCode(): Int = System.identityHashCode(this)
}

trait Steppable {
def step(cycles: Int): Unit
}
private case class Chisel3Clock(clock: chisel3.Clock) extends Steppable {
def step(cycles: Int): Unit = chiseltest.testableClock(clock).step(cycles)
}
case class FakeClock() extends Steppable {
def step(cycles: Int): Unit = {}
}

// Command sum type
//// DUT interaction
private[simcommand] case class Poke[I](signal: Interactable[I], value: I)(implicit line: Line, filename: FileName, enclosing: Enclosing) extends Command[Unit]
private[simcommand] case class Peek[I](signal: Interactable[I])(implicit line: Line, filename: FileName, enclosing: Enclosing) extends Command[I]

//// Simulator synchronization points
private[simcommand] case class Step(cycles: Int)(implicit line: Line, filename: FileName, enclosing: Enclosing) extends Command[Unit]

//// End of a command sequence / Pure value
private[simcommand] case class Return[R](retval: R)(implicit line: Line, filename: FileName, enclosing: Enclosing) extends Command[R]

//// Continuation
private[simcommand] case class Cont[R1, R2](a: Command[R1], f: R1 => Command[R2])(implicit line: Line, filename: FileName, enclosing: Enclosing) extends Command[R2]
private[simcommand] case class Rec[R1, R2](st: R1, f: R1 => Command[Either[R1, R2]])(implicit line: Line, filename: FileName, enclosing: Enclosing) extends Command[R2]

//// fork/join synchronization
private[simcommand] case class ThreadHandle[R](id: Int)
private[simcommand] case class Fork[R](c: Command[R], name: Option[String], order: Int)(implicit line: Line, filename: FileName, enclosing: Enclosing) extends Command[ThreadHandle[R]] {
def makeThreadHandle(id: Int): ThreadHandle[R] = ThreadHandle[R](id)
}
private[simcommand] case class Join[R](threadHandle: ThreadHandle[R])(implicit line: Line, filename: FileName, enclosing: Enclosing) extends Command[R]
private[simcommand] case class Kill[R](threadHandle: ThreadHandle[R])(implicit line: Line, filename: FileName, enclosing: Enclosing) extends Command[R]

// Inter-thread communication channels
private[simcommand] case class ChannelHandle[T](id: Int)
private[simcommand] case class MakeChannel[T](size: Int)(implicit line: Line, filename: FileName, enclosing: Enclosing) extends Command[ChannelHandle[T]]
private[simcommand] case class Put[T](chan: ChannelHandle[T], data: T)(implicit line: Line, filename: FileName, enclosing: Enclosing) extends Command[Unit]
private[simcommand] case class GetBlocking[T](chan: ChannelHandle[T])(implicit line: Line, filename: FileName, enclosing: Enclosing) extends Command[T]
private[simcommand] case class NonEmpty[T](chan: ChannelHandle[T])(implicit line: Line, filename: FileName, enclosing: Enclosing) extends Command[Boolean]
type Command[R] = simcommand.runtime.Primitives.Command[R]
type Interactable[I] = simcommand.runtime.Bridges.Interactable[I]
type ChannelHandle[T] = simcommand.runtime.Primitives.ChannelHandle[T]
val FakeClock: Bridges.FakeClock.type = simcommand.runtime.Bridges.FakeClock

// Public API

def unsafeRun[R](cmd: Command[R], clock: chisel3.Clock, cfg: Config = Config()): Result[R] = {
unsafeRun(cmd, Chisel3Clock(clock), cfg)
}
Expand All @@ -120,52 +19,52 @@ package object simcommand {
Imperative.unsafeRun(cmd, clock, cfg)
}

def poke[I](signal: Interactable[I], value: I)(implicit line: Line, filename: FileName, enclosing: Enclosing): Command[Unit] = {
Poke(signal, value)
def poke[I](signal: Interactable[I], value: I)(implicit line: Line, fileName: FileName, enclosing: Enclosing): Command[Unit] = {
Poke(signal, value)(SourceInfo(line, fileName, enclosing))
}

def peek[I](signal: Interactable[I])(implicit line: Line, filename: FileName, enclosing: Enclosing): Command[I] = {
Peek(signal)
def peek[I](signal: Interactable[I])(implicit line: Line, fileName: FileName, enclosing: Enclosing): Command[I] = {
Peek(signal)(SourceInfo(line, fileName, enclosing))
}

def step(cycles: Int)(implicit line: Line, filename: FileName, enclosing: Enclosing): Command[Unit] = {
Step(cycles)
def step(cycles: Int)(implicit line: Line, fileName: FileName, enclosing: Enclosing): Command[Unit] = {
Step(cycles)(SourceInfo(line, fileName, enclosing))
}

def lift[R](value: R)(implicit line: Line, filename: FileName, enclosing: Enclosing): Command[R] = {
Return(value)
def lift[R](value: R)(implicit line: Line, fileName: FileName, enclosing: Enclosing): Command[R] = {
Return(value)(SourceInfo(line, fileName, enclosing))
}

def noop()(implicit line: Line, filename: FileName, enclosing: Enclosing): Command[Unit] = {
def noop()(implicit line: Line, fileName: FileName, enclosing: Enclosing): Command[Unit] = {
lift(())
}

def fork[R](cmd: Command[R])(implicit line: Line, filename: FileName, enclosing: Enclosing): Command[ThreadHandle[R]] = {
Fork(cmd, None, order = 0)
def fork[R](cmd: Command[R])(implicit line: Line, fileName: FileName, enclosing: Enclosing): Command[ThreadHandle[R]] = {
Fork(cmd, None, order = 0)(SourceInfo(line, fileName, enclosing))
}

def fork[R](cmd: Command[R], name: String, order: Int = 0)(implicit line: Line, filename: FileName, enclosing: Enclosing): Command[ThreadHandle[R]] = {
Fork(cmd, Some(name), order)
def fork[R](cmd: Command[R], name: String, order: Int = 0)(implicit line: Line, fileName: FileName, enclosing: Enclosing): Command[ThreadHandle[R]] = {
Fork(cmd, Some(name), order)(SourceInfo(line, fileName, enclosing))
}

def join[R](handle: ThreadHandle[R])(implicit line: Line, filename: FileName, enclosing: Enclosing): Command[R] = {
Join(handle)
def join[R](handle: ThreadHandle[R])(implicit line: Line, fileName: FileName, enclosing: Enclosing): Command[R] = {
Join(handle)(SourceInfo(line, fileName, enclosing))
}

def makeChannel[R](size: Int)(implicit line: Line, filename: FileName, enclosing: Enclosing): Command[ChannelHandle[R]] = {
MakeChannel(size)
def makeChannel[R](size: Int)(implicit line: Line, fileName: FileName, enclosing: Enclosing): Command[ChannelHandle[R]] = {
MakeChannel(size)(SourceInfo(line, fileName, enclosing))
}

def put[R](chan: ChannelHandle[R], data: R)(implicit line: Line, filename: FileName, enclosing: Enclosing): Command[Unit] = {
Put(chan, data)
def put[R](chan: ChannelHandle[R], data: R)(implicit line: Line, fileName: FileName, enclosing: Enclosing): Command[Unit] = {
Put(chan, data)(SourceInfo(line, fileName, enclosing))
}

def getBlocking[R](chan: ChannelHandle[R])(implicit line: Line, filename: FileName, enclosing: Enclosing): Command[R] = {
GetBlocking(chan)
def getBlocking[R](chan: ChannelHandle[R])(implicit line: Line, fileName: FileName, enclosing: Enclosing): Command[R] = {
GetBlocking(chan)(SourceInfo(line, fileName, enclosing))
}

def nonEmpty[R](chan: ChannelHandle[R])(implicit line: Line, filename: FileName, enclosing: Enclosing): Command[Boolean] = {
NonEmpty(chan)
def nonEmpty[R](chan: ChannelHandle[R])(implicit line: Line, fileName: FileName, enclosing: Enclosing): Command[Boolean] = {
NonEmpty(chan)(SourceInfo(line, fileName, enclosing))
}

def binding[I](value: I): Interactable[I] = {
Expand Down Expand Up @@ -307,7 +206,4 @@ package object simcommand {
cycles
).map(_.forall(x => x))
}

class CombinatorialDependencyException(name: Option[String], order: Int, cmd: Command[_])
extends Exception("Detected combinatorial loop in thread '" + name + "' with order " + order + " caused by command " + cmd.debugInfo)
}
43 changes: 43 additions & 0 deletions src/simcommand/runtime/Bridges.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package simcommand.runtime

import simcommand.runtime.Primitives._
import chisel3.Data

// Bridges between the pure Scala Command primitives and RTL simulator I/Os and clocks
object Bridges {
sealed trait Interactable[I] {
def set(value: I): Unit
def get(): I
def compare(value: I): Command[Boolean]
}

implicit class Chisel3Interactor[I <: Data](value: I) extends Interactable[I] {
val tester = chiseltest.testableData(value)
override def set(p: I): Unit = tester.poke(p)
override def get(): I = tester.peek()
// TODO: Implement a better comparator for chisel3 datatypes, this only works for non-aggregate data types
override def compare(v: I): Command[Boolean] = Peek(this)(SourceInfo.getSourceInfo).map(_.litValue == v.litValue)
}

case class PrimitiveInteractor[I](var value: I) extends Interactable[I] {
override def set(p: I): Unit = value = p
override def get(): I = value
override def compare(v: I): Command[Boolean] = value match {
case x: Data => Peek(this)(SourceInfo.getSourceInfo).map(_.asInstanceOf[Data].litValue == v.asInstanceOf[Data].litValue)
case _ => Peek(this)(SourceInfo.getSourceInfo).map(_ == v)
}
override def hashCode(): Int = System.identityHashCode(this)
}

trait Steppable {
def step(cycles: Int): Unit
}

case class Chisel3Clock(clock: chisel3.Clock) extends Steppable {
def step(cycles: Int): Unit = chiseltest.testableClock(clock).step(cycles)
}

case class FakeClock() extends Steppable {
def step(cycles: Int): Unit = {}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
package simcommand
package simcommand.runtime

import simcommand.runtime.Bridges._
import simcommand.runtime.Primitives._

import scala.collection.mutable

class CombinatorialDependencyException(name: Option[String], order: Int, cmd: Command[_])
extends Exception("Detected combinatorial loop in thread '" + name + "' with order " + order + " caused by command " + cmd.sourceInfoString)

case class Config(
// Whether or not to print debug output
print: Boolean = false,
Expand Down
108 changes: 108 additions & 0 deletions src/simcommand/runtime/Primitives.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package simcommand.runtime

import simcommand.runtime.Bridges._
import sourcecode.{Enclosing, FileName, Line}

object Primitives {
case class SourceInfo(line: Line, fileName: FileName, enclosing: Enclosing)
object SourceInfo {
def getSourceInfo(implicit line: Line, fileName: FileName, enclosing: Enclosing): SourceInfo = {
SourceInfo(line, fileName, enclosing)
}
}

/**
* This class represents an RTL simulation command and its return value
*
* @tparam R Type of the command's return value
*/
sealed abstract class Command[R](implicit sourceInfo: SourceInfo) {
val name: String

final def map[R2](f: R => R2): Command[R2] = {
flatMap(r => Return(f(r)))
}

final def flatMap[R2](f: R => Command[R2]): Command[R2] = {
this match {
case Return(retval) => f(retval)
case c: Command[R] => Cont(c, f)
}
}

// tailRec provides an efficient tail recursion primitive. If the provided
// function takes in arguments and returns Command[Left(newArguments)], it
// flatMaps again with the new arguments. If given a
// Command[Right(result)], it will instead lift the internal result.
final def tailRec[R2](f: R => Command[Either[R, R2]]): Command[R2] = {
this.flatMap(Rec(_, f))
}

// tailRecM is the functional version of tailRecM. This version is
// also safe, but tends to be significantly slower.
// Similar implementation as cats.Free: https://github.com/typelevel/cats/pull/1041/files#diff-7349edfd077f9612f7181fe1f8caca63ac667c847ce83b53dceae4d08040fd55
final def tailRecM[R2](f: R => Command[Either[R, R2]]): Command[R2] = {
this.flatMap(f).flatMap {
// recursion here is lazy so the stack won't blow up
case Left(_) => tailRecM(f)
case Right(result) => Return(result)
}
}

def sourceInfoString: String = {
name + "(" + sourceInfo.fileName.value + ":" + sourceInfo.line.value + ")"
}
}

// Lift a pure value into the Command monad
private[simcommand] case class Return[R](retval: R)(implicit sourceInfo: SourceInfo) extends Command[R] {
override val name: String = "return"
}

// Continuations, recursion primitives
private[simcommand] case class Cont[R1, R2](a: Command[R1], f: R1 => Command[R2])(implicit sourceInfo: SourceInfo) extends Command[R2] {
override val name: String = "cont"
}
private[simcommand] case class Rec[R1, R2](st: R1, f: R1 => Command[Either[R1, R2]])(implicit sourceInfo: SourceInfo) extends Command[R2] {
override val name: String = "rec"
}

// Basic DUT interaction
private[simcommand] case class Poke[I](signal: Interactable[I], value: I)(implicit sourceInfo: SourceInfo) extends Command[Unit] {
override val name: String = "poke"
}
private[simcommand] case class Peek[I](signal: Interactable[I])(implicit sourceInfo: SourceInfo) extends Command[I] {
override val name: String = "peek"
}
private[simcommand] case class Step(cycles: Int)(implicit sourceInfo: SourceInfo) extends Command[Unit] {
override val name: String = "step"
}

// Fork/Join simulation threading
private[simcommand] case class ThreadHandle[R](id: Int)
private[simcommand] case class Fork[R](c: Command[R], threadName: Option[String], order: Int)(implicit sourceInfo: SourceInfo) extends Command[ThreadHandle[R]] {
def makeThreadHandle(id: Int): ThreadHandle[R] = ThreadHandle[R](id)
override val name: String = "fork"
}
private[simcommand] case class Join[R](threadHandle: ThreadHandle[R])(implicit sourceInfo: SourceInfo) extends Command[R] {
override val name: String = "join"
}
private[simcommand] case class Kill[R](threadHandle: ThreadHandle[R])(implicit sourceInfo: SourceInfo) extends Command[R] {
override val name: String = "kill"
}

// Inter-thread communication via channels
private[simcommand] case class ChannelHandle[T](id: Int)
private[simcommand] case class MakeChannel[T](size: Int)(implicit sourceInfo: SourceInfo) extends Command[ChannelHandle[T]] {
override val name: String = "makeChannel"
}
private[simcommand] case class Put[T](chan: ChannelHandle[T], data: T)(implicit sourceInfo: SourceInfo) extends Command[Unit] {
override val name: String = "put"
}
private[simcommand] case class GetBlocking[T](chan: ChannelHandle[T])(implicit sourceInfo: SourceInfo) extends Command[T] {
override val name: String = "getBlocking"
}
private[simcommand] case class NonEmpty[T](chan: ChannelHandle[T])(implicit sourceInfo: SourceInfo) extends Command[Boolean] {
override val name: String = "nonEmpty"
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package simcommand
package simcommand.vips

import chisel3._
import chisel3.util.DecoupledIO
import simcommand._

class DecoupledCommands[T <: Data](io: DecoupledIO[T]) {
class Decoupled[T <: Data](io: DecoupledIO[T]) {
def enqueue(data: T): Command[Unit] = for {
_ <- poke(io.bits, data)
_ <- poke(io.valid, true.B)
Expand Down
Loading

0 comments on commit a31a002

Please sign in to comment.