Skip to content

Commit

Permalink
Refactor generator to support multiple generation strategies (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
Maokami committed Feb 20, 2024
1 parent 121d4c0 commit ca796ca
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 135 deletions.
24 changes: 24 additions & 0 deletions src/main/scala/fhetest/Generate/AbsStatement.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package fhetest.Generate

trait AbsStmt
case class Var()
case class Assign(l: String, r: (Int | Double)) extends AbsStmt
case class AssignVec(l: String, r: (List[Int] | List[Double])) extends AbsStmt
case class Add(l: Var, r: Var) extends AbsStmt
case class Sub(l: Var, r: Var) extends AbsStmt
case class Mul(l: Var, r: Var) extends AbsStmt
case class Rot(l: Var, r: Var) extends AbsStmt

val V = Var()

def formatNumber(n: Int | Double): String = n match {
case i: Int => i.toString
case d: Double => f"$d%f"
}

def allAbsStmts: LazyList[AbsStmt] = LazyList(
Add(V, V),
Sub(V, V),
Mul(V, V),
Rot(V, V),
)
3 changes: 3 additions & 0 deletions src/main/scala/fhetest/Generate/Template.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package fhetest.Generate

type Template = List[AbsStmt]
37 changes: 37 additions & 0 deletions src/main/scala/fhetest/Generate/TemplateGenerator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package fhetest.Generate

// Template Generation Strategy
enum Strategy:
case Exhaustive, Random

extension (s: Strategy)
def getGenerator: TemplateGenerator = s match {
case Strategy.Exhaustive => ExhaustiveGenerator
case Strategy.Random => RandomGenerator
}

// Template Generator
trait TemplateGenerator {
def generateTemplates(): LazyList[Template]
}

object ExhaustiveGenerator extends TemplateGenerator {
def generateTemplates(): LazyList[Template] = {
def allTemplatesOfSize(n: Int): LazyList[Template] = n match {
case 1 => allAbsStmts.map(stmt => List(stmt))
case _ =>
for {
stmt <- allAbsStmts
program <- allTemplatesOfSize(n - 1)
} yield stmt :: program
}
LazyList.from(1).flatMap(allTemplatesOfSize)
}

}

object RandomGenerator extends TemplateGenerator {
def generateTemplates(): LazyList[Template] = {
???
}
}
47 changes: 47 additions & 0 deletions src/main/scala/fhetest/Generate/Utils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package fhetest.Generate

import scala.util.Random

object Utils {
def assignValues(name: String, vxs: (List[Int] | List[Double])): AbsStmt =
AssignVec(name, vxs)

extension (s: AbsStmt)
def stringify: String = s match
case Assign(l, r) => s"$l = ${formatNumber(r)};"
case AssignVec(l, r) => s"$l = {${r.map(formatNumber).mkString(",")}};"
case Add(l, r) => "x += y;"
case Sub(l, r) => "x -= y;"
case Mul(l, r) => "x *= y;"
case Rot(l, r) => "rotate_left(x, c);"

extension (t: Template)
def stringify: String = t.map(_.stringify).mkString("")
def getMulDepth: Int = t.count {
case Mul(_, _) => true; case _ => false
}

def assignRandValues(len: Int, bound: (Int | Double)): Template = {
val lx = Random.between(1, len + 1)
val ly = Random.between(1, len + 1)
val vxs: (List[Int] | List[Double]) = bound match {
case intBound: Int => List.fill(lx)(Random.between(0, intBound))
case doubleBound: Double =>
List.fill(lx)(Random.between(0.0, doubleBound))
}
val vys: (List[Int] | List[Double]) = bound match {
case intBound: Int => List.fill(ly)(Random.between(0, intBound))
case doubleBound: Double =>
List.fill(ly)(Random.between(0.0, doubleBound))
}

val assigned = assignValues("x", vxs) :: t
assigned.flatMap {
case op @ (Add(_, _) | Sub(_, _) | Mul(_, _)) =>
assignValues("y", vys) :: op :: Nil
case op @ Rot(_, _) =>
Assign("c", Random.between(0, 10)) :: op :: Nil
case s => s :: Nil
}
}
}
158 changes: 24 additions & 134 deletions src/main/scala/fhetest/Phase/Generate.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package fhetest.Phase

import fhetest.Utils.*
import fhetest.Generate.*
import fhetest.Generate.Utils.*
import org.twc.terminator.SymbolTable;
import org.twc.terminator.t2dsl_compiler.*;

Expand All @@ -12,9 +14,11 @@ import java.nio.file.{Files, Paths};
import java.io.*;
import javax.print.attribute.EnumSyntax
import scala.jdk.CollectionConverters._
import scala.util.Random

case class Generate(encType: ENC_TYPE) {
case class Generate(
encType: ENC_TYPE,
strategy: Strategy = Strategy.Exhaustive,
) {
// TODO : This boilerplate code is really ugly. But I cannot find a better way to do this.
val baseStrFront = encType match {
case ENC_TYPE.ENC_INT =>
Expand All @@ -28,33 +32,36 @@ case class Generate(encType: ENC_TYPE) {

val symbolTable = boilerplate()._2

val tempGen = strategy.getGenerator

val allTemplates = tempGen.generateTemplates()

def apply(nOpt: Option[Int]): LazyList[String] = {
val templates = nOpt match {
case Some(n) => allTempletes.take(n)
case None => allTempletes
case Some(n) => allTemplates.take(n)
case None => allTemplates
}
for {
template <- templates
} yield {
val concretized = concretizeTemplate(template)
toStringWithBaseStr(concretized)
stringifyWithBaseStr(concretized)
}
}

// This is for testing purpose
def show(backends: List[Backend], n: Int) =
for {
template <- allTempletes.take(n)
template <- allTemplates.take(n)
} {
// TODO: currently concretize 5 times for each template
// TODO: store the results into a json file and print it by parsing the json file
print("=" * 80 + "\n")
print("<Program>\n")
print(toString(template) + "\n")
print(s"$template\n")
print("=" * 80 + "\n")
// for _ <- 0 until 5 do {
val concretized = concretizeTemplate(template)
print(toString(concretized) + "\n")
print(s"$concretized\n")
print("-" * 80 + "\n")

val ast: Goal = buildTemplate(concretized)
Expand All @@ -67,7 +74,7 @@ case class Generate(encType: ENC_TYPE) {
backend,
{ workspaceDir =>
given DirName = workspaceDir
val mulDepth = getMulDepth(concretized)
val mulDepth = concretized.getMulDepth
// Default RingDim, PlainModulus with MulDepth
val encParams = EncParams(32768, mulDepth, 65537)
Print(
Expand All @@ -87,147 +94,30 @@ case class Generate(encType: ENC_TYPE) {
}

// TODO: current length = 100, it can be changed to larger value
def concretizeTemplate(template: Templete): Templete = encType match
case ENC_TYPE.ENC_INT => assignRandValues(template, 100, 100)
case ENC_TYPE.ENC_DOUBLE => assignRandValues(template, 100, 100.0)
def concretizeTemplate(template: Template): Template = encType match
case ENC_TYPE.ENC_INT => template.assignRandValues(100, 100)
case ENC_TYPE.ENC_DOUBLE => template.assignRandValues(100, 100.0)
case _ => throw new Exception("encType is not set")

// TODO: Currently, T2 DSL does not support negative numbers
def assignRandValues(template: Templete, len: Int, bound: Int): Templete =
val l = Random.between(1, len)
val vxs = List.fill(l)(Random.between(0, bound))
// refactor
val t = assignValues(template, vxs)
t.flatMap(s =>
s match
case Add(_, _) | Sub(_, _) | Mul(_, _) =>
List(
AssignVec(
"y",
List.fill(Random.between(1, len))(Random.between(0, bound)),
),
s,
)
case Rot(_, _) =>
// TODO: Currently, c is bounded by 10. It can be changed to larger value
List(
Assign("c", Random.between(0, 10)),
s,
)
case _ => List(s),
)

def assignRandValues(template: Templete, len: Int, bound: Double): Templete =
val l = Random.between(1, len)
val vxs = List.fill(l)(Random.between(0, bound))
// refactor
val t = assignValues(template, vxs)
t.flatMap(s =>
s match
case Add(_, _) | Sub(_, _) | Mul(_, _) =>
List(
AssignVec(
"y",
List.fill(Random.between(1, len))(Random.between(0, bound)),
),
s,
)
case Rot(_, _) =>
List(
Assign("c", Random.between(0, 10)),
s,
)
case _ => List(s),
)

// vxs = [1, 2, 3], vys = [4, 5, 6] => x = { 1, 2, 3 }; y = { 4, 5, 6 };
def assignValues(
template: Templete,
vxs: List[Int | Double],
): Templete =
val assignment = AssignVec("x", vxs)
return assignment :: template

def boilerplate(): (Goal, SymbolTable, _) =
val baseStream = new ByteArrayInputStream(baseStr.getBytes("UTF-8"))
Parse(baseStream)

// TODO: current print only 10 values, it can be changed to larger value
def createNewBaseTemplate(): Goal = boilerplate()._1

// def assignIntValue(template: Templete, vx: Int, vy: Int): Templete =
// val assignments = List(Assign("x", vx), Assign("y", vy))
// return assignments ++ template

// def assignRandValue(template: Templete, bound: Int): Templete =
// val vx = Random.between(0, bound)
// val vy = Random.between(0, bound)
// return assignIntValue(template, vx, vy)

def parseStmt(stmtStr: String): Statement =
val input_stream: InputStream = new ByteArrayInputStream(
stmtStr.getBytes("UTF-8"),
)
T2DSLParser(input_stream).Statement()

trait Stmt
case class Var()
case class Assign(l: String, r: (Int | Double)) extends Stmt
case class AssignVec(l: String, r: List[Int | Double]) extends Stmt
case class Add(l: Var, r: Var) extends Stmt
case class Sub(l: Var, r: Var) extends Stmt
case class Mul(l: Var, r: Var) extends Stmt
case class Rot(l: Var, r: Var) extends Stmt

val V = Var()

def formatNumber(n: Int | Double): String = n match {
case i: Int => i.toString
case d: Double => f"$d%f"
}

def toString(s: Stmt) = s match
case Assign(l, r) => s"$l = ${formatNumber(r)};"
case AssignVec(l, r) => s"$l = {${r.map(formatNumber).mkString(",")}};"
case Add(l, r) => "x += y;"
case Sub(l, r) => "x -= y;"
case Mul(l, r) => "x *= y;"
case Rot(l, r) => "rotate_left(x, c);"

type Templete = List[Stmt]
def toString(t: Templete): String = t.map(toString).mkString("")
def toStringWithBaseStr(t: Templete): String =
baseStrFront + toString(t) + baseStrBack
def getMulDepth(t: Templete): Int = t.count {
case Mul(_, _) => true; case _ => false
}

def buildTemplate(temp: Templete): Goal =
val stmts = temp.map(toString).map(parseStmt)
def stringifyWithBaseStr(t: Template): String =
baseStrFront + t.stringify + baseStrBack
def buildTemplate(temp: Template): Goal =
val stmts = temp.map(_.stringify).map(parseStmt)
val base = createNewBaseTemplate()
val baseStmts = base.f0.f7.nodes
baseStmts.addAll(0, stmts.asJava)
return base

// 가능한 모든 Stmt를 생성하는 함수
def allStmts: LazyList[Stmt] = LazyList(
Add(V, V),
Sub(V, V),
Mul(V, V),
Rot(V, V),
)

// 주어진 길이에 대해 가능한 모든 템플릿을 생성하는 함수
def allTempletesOfSize(n: Int): LazyList[Templete] = n match {
case 1 => allStmts.map(stmt => List(stmt))
case _ =>
for {
stmt <- allStmts
program <- allTempletesOfSize(n - 1)
} yield stmt :: program
}

// 모든 길이에 대해 가능한 모든 템플릿을 생성하는 LazyList
val allTempletes: LazyList[Templete] =
LazyList.from(1).flatMap(allTempletesOfSize)
}
2 changes: 1 addition & 1 deletion src/main/scala/fhetest/Utils/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def compare(obtained: String, expected: String): Unit = {
.foreach {
case (obtained, result) =>
assert(
Math.abs(obtained - result) < 0.0001,
Math.abs(obtained - result) < 0.001,
s"$obtained and $result are not close",
)
}
Expand Down

0 comments on commit ca796ca

Please sign in to comment.