From ca796ca67608e31622accab9948765c8b7135462 Mon Sep 17 00:00:00 2001 From: Jaeho Choi Date: Tue, 20 Feb 2024 06:16:49 +0000 Subject: [PATCH] Refactor generator to support multiple generation strategies (#13) --- .../scala/fhetest/Generate/AbsStatement.scala | 24 +++ .../scala/fhetest/Generate/Template.scala | 3 + .../fhetest/Generate/TemplateGenerator.scala | 37 ++++ src/main/scala/fhetest/Generate/Utils.scala | 47 ++++++ src/main/scala/fhetest/Phase/Generate.scala | 158 +++--------------- src/main/scala/fhetest/Utils/Utils.scala | 2 +- 6 files changed, 136 insertions(+), 135 deletions(-) create mode 100644 src/main/scala/fhetest/Generate/AbsStatement.scala create mode 100644 src/main/scala/fhetest/Generate/Template.scala create mode 100644 src/main/scala/fhetest/Generate/TemplateGenerator.scala create mode 100644 src/main/scala/fhetest/Generate/Utils.scala diff --git a/src/main/scala/fhetest/Generate/AbsStatement.scala b/src/main/scala/fhetest/Generate/AbsStatement.scala new file mode 100644 index 0000000..a9800b6 --- /dev/null +++ b/src/main/scala/fhetest/Generate/AbsStatement.scala @@ -0,0 +1,24 @@ +package fhetest.Generate + +trait AbsStmt +case class Var() +case class Assign(l: String, r: (Int | Double)) extends AbsStmt +case class AssignVec(l: String, r: (List[Int] | List[Double])) extends AbsStmt +case class Add(l: Var, r: Var) extends AbsStmt +case class Sub(l: Var, r: Var) extends AbsStmt +case class Mul(l: Var, r: Var) extends AbsStmt +case class Rot(l: Var, r: Var) extends AbsStmt + +val V = Var() + +def formatNumber(n: Int | Double): String = n match { + case i: Int => i.toString + case d: Double => f"$d%f" +} + +def allAbsStmts: LazyList[AbsStmt] = LazyList( + Add(V, V), + Sub(V, V), + Mul(V, V), + Rot(V, V), +) diff --git a/src/main/scala/fhetest/Generate/Template.scala b/src/main/scala/fhetest/Generate/Template.scala new file mode 100644 index 0000000..d5f72ed --- /dev/null +++ b/src/main/scala/fhetest/Generate/Template.scala @@ -0,0 +1,3 @@ +package fhetest.Generate + +type Template = List[AbsStmt] diff --git a/src/main/scala/fhetest/Generate/TemplateGenerator.scala b/src/main/scala/fhetest/Generate/TemplateGenerator.scala new file mode 100644 index 0000000..975b1dc --- /dev/null +++ b/src/main/scala/fhetest/Generate/TemplateGenerator.scala @@ -0,0 +1,37 @@ +package fhetest.Generate + +// Template Generation Strategy +enum Strategy: + case Exhaustive, Random + +extension (s: Strategy) + def getGenerator: TemplateGenerator = s match { + case Strategy.Exhaustive => ExhaustiveGenerator + case Strategy.Random => RandomGenerator + } + +// Template Generator +trait TemplateGenerator { + def generateTemplates(): LazyList[Template] +} + +object ExhaustiveGenerator extends TemplateGenerator { + def generateTemplates(): LazyList[Template] = { + def allTemplatesOfSize(n: Int): LazyList[Template] = n match { + case 1 => allAbsStmts.map(stmt => List(stmt)) + case _ => + for { + stmt <- allAbsStmts + program <- allTemplatesOfSize(n - 1) + } yield stmt :: program + } + LazyList.from(1).flatMap(allTemplatesOfSize) + } + +} + +object RandomGenerator extends TemplateGenerator { + def generateTemplates(): LazyList[Template] = { + ??? + } +} diff --git a/src/main/scala/fhetest/Generate/Utils.scala b/src/main/scala/fhetest/Generate/Utils.scala new file mode 100644 index 0000000..ae24b97 --- /dev/null +++ b/src/main/scala/fhetest/Generate/Utils.scala @@ -0,0 +1,47 @@ +package fhetest.Generate + +import scala.util.Random + +object Utils { + def assignValues(name: String, vxs: (List[Int] | List[Double])): AbsStmt = + AssignVec(name, vxs) + + extension (s: AbsStmt) + def stringify: String = s match + case Assign(l, r) => s"$l = ${formatNumber(r)};" + case AssignVec(l, r) => s"$l = {${r.map(formatNumber).mkString(",")}};" + case Add(l, r) => "x += y;" + case Sub(l, r) => "x -= y;" + case Mul(l, r) => "x *= y;" + case Rot(l, r) => "rotate_left(x, c);" + + extension (t: Template) + def stringify: String = t.map(_.stringify).mkString("") + def getMulDepth: Int = t.count { + case Mul(_, _) => true; case _ => false + } + + def assignRandValues(len: Int, bound: (Int | Double)): Template = { + val lx = Random.between(1, len + 1) + val ly = Random.between(1, len + 1) + val vxs: (List[Int] | List[Double]) = bound match { + case intBound: Int => List.fill(lx)(Random.between(0, intBound)) + case doubleBound: Double => + List.fill(lx)(Random.between(0.0, doubleBound)) + } + val vys: (List[Int] | List[Double]) = bound match { + case intBound: Int => List.fill(ly)(Random.between(0, intBound)) + case doubleBound: Double => + List.fill(ly)(Random.between(0.0, doubleBound)) + } + + val assigned = assignValues("x", vxs) :: t + assigned.flatMap { + case op @ (Add(_, _) | Sub(_, _) | Mul(_, _)) => + assignValues("y", vys) :: op :: Nil + case op @ Rot(_, _) => + Assign("c", Random.between(0, 10)) :: op :: Nil + case s => s :: Nil + } + } +} diff --git a/src/main/scala/fhetest/Phase/Generate.scala b/src/main/scala/fhetest/Phase/Generate.scala index c54dd06..34dcb99 100644 --- a/src/main/scala/fhetest/Phase/Generate.scala +++ b/src/main/scala/fhetest/Phase/Generate.scala @@ -1,6 +1,8 @@ package fhetest.Phase import fhetest.Utils.* +import fhetest.Generate.* +import fhetest.Generate.Utils.* import org.twc.terminator.SymbolTable; import org.twc.terminator.t2dsl_compiler.*; @@ -12,9 +14,11 @@ import java.nio.file.{Files, Paths}; import java.io.*; import javax.print.attribute.EnumSyntax import scala.jdk.CollectionConverters._ -import scala.util.Random -case class Generate(encType: ENC_TYPE) { +case class Generate( + encType: ENC_TYPE, + strategy: Strategy = Strategy.Exhaustive, +) { // TODO : This boilerplate code is really ugly. But I cannot find a better way to do this. val baseStrFront = encType match { case ENC_TYPE.ENC_INT => @@ -28,33 +32,36 @@ case class Generate(encType: ENC_TYPE) { val symbolTable = boilerplate()._2 + val tempGen = strategy.getGenerator + + val allTemplates = tempGen.generateTemplates() + def apply(nOpt: Option[Int]): LazyList[String] = { val templates = nOpt match { - case Some(n) => allTempletes.take(n) - case None => allTempletes + case Some(n) => allTemplates.take(n) + case None => allTemplates } for { template <- templates } yield { val concretized = concretizeTemplate(template) - toStringWithBaseStr(concretized) + stringifyWithBaseStr(concretized) } } // This is for testing purpose def show(backends: List[Backend], n: Int) = for { - template <- allTempletes.take(n) + template <- allTemplates.take(n) } { // TODO: currently concretize 5 times for each template // TODO: store the results into a json file and print it by parsing the json file print("=" * 80 + "\n") print("\n") - print(toString(template) + "\n") + print(s"$template\n") print("=" * 80 + "\n") - // for _ <- 0 until 5 do { val concretized = concretizeTemplate(template) - print(toString(concretized) + "\n") + print(s"$concretized\n") print("-" * 80 + "\n") val ast: Goal = buildTemplate(concretized) @@ -67,7 +74,7 @@ case class Generate(encType: ENC_TYPE) { backend, { workspaceDir => given DirName = workspaceDir - val mulDepth = getMulDepth(concretized) + val mulDepth = concretized.getMulDepth // Default RingDim, PlainModulus with MulDepth val encParams = EncParams(32768, mulDepth, 65537) Print( @@ -87,67 +94,11 @@ case class Generate(encType: ENC_TYPE) { } // TODO: current length = 100, it can be changed to larger value - def concretizeTemplate(template: Templete): Templete = encType match - case ENC_TYPE.ENC_INT => assignRandValues(template, 100, 100) - case ENC_TYPE.ENC_DOUBLE => assignRandValues(template, 100, 100.0) + def concretizeTemplate(template: Template): Template = encType match + case ENC_TYPE.ENC_INT => template.assignRandValues(100, 100) + case ENC_TYPE.ENC_DOUBLE => template.assignRandValues(100, 100.0) case _ => throw new Exception("encType is not set") - // TODO: Currently, T2 DSL does not support negative numbers - def assignRandValues(template: Templete, len: Int, bound: Int): Templete = - val l = Random.between(1, len) - val vxs = List.fill(l)(Random.between(0, bound)) - // refactor - val t = assignValues(template, vxs) - t.flatMap(s => - s match - case Add(_, _) | Sub(_, _) | Mul(_, _) => - List( - AssignVec( - "y", - List.fill(Random.between(1, len))(Random.between(0, bound)), - ), - s, - ) - case Rot(_, _) => - // TODO: Currently, c is bounded by 10. It can be changed to larger value - List( - Assign("c", Random.between(0, 10)), - s, - ) - case _ => List(s), - ) - - def assignRandValues(template: Templete, len: Int, bound: Double): Templete = - val l = Random.between(1, len) - val vxs = List.fill(l)(Random.between(0, bound)) - // refactor - val t = assignValues(template, vxs) - t.flatMap(s => - s match - case Add(_, _) | Sub(_, _) | Mul(_, _) => - List( - AssignVec( - "y", - List.fill(Random.between(1, len))(Random.between(0, bound)), - ), - s, - ) - case Rot(_, _) => - List( - Assign("c", Random.between(0, 10)), - s, - ) - case _ => List(s), - ) - - // vxs = [1, 2, 3], vys = [4, 5, 6] => x = { 1, 2, 3 }; y = { 4, 5, 6 }; - def assignValues( - template: Templete, - vxs: List[Int | Double], - ): Templete = - val assignment = AssignVec("x", vxs) - return assignment :: template - def boilerplate(): (Goal, SymbolTable, _) = val baseStream = new ByteArrayInputStream(baseStr.getBytes("UTF-8")) Parse(baseStream) @@ -155,79 +106,18 @@ case class Generate(encType: ENC_TYPE) { // TODO: current print only 10 values, it can be changed to larger value def createNewBaseTemplate(): Goal = boilerplate()._1 - // def assignIntValue(template: Templete, vx: Int, vy: Int): Templete = - // val assignments = List(Assign("x", vx), Assign("y", vy)) - // return assignments ++ template - - // def assignRandValue(template: Templete, bound: Int): Templete = - // val vx = Random.between(0, bound) - // val vy = Random.between(0, bound) - // return assignIntValue(template, vx, vy) - def parseStmt(stmtStr: String): Statement = val input_stream: InputStream = new ByteArrayInputStream( stmtStr.getBytes("UTF-8"), ) T2DSLParser(input_stream).Statement() - trait Stmt - case class Var() - case class Assign(l: String, r: (Int | Double)) extends Stmt - case class AssignVec(l: String, r: List[Int | Double]) extends Stmt - case class Add(l: Var, r: Var) extends Stmt - case class Sub(l: Var, r: Var) extends Stmt - case class Mul(l: Var, r: Var) extends Stmt - case class Rot(l: Var, r: Var) extends Stmt - - val V = Var() - - def formatNumber(n: Int | Double): String = n match { - case i: Int => i.toString - case d: Double => f"$d%f" - } - - def toString(s: Stmt) = s match - case Assign(l, r) => s"$l = ${formatNumber(r)};" - case AssignVec(l, r) => s"$l = {${r.map(formatNumber).mkString(",")}};" - case Add(l, r) => "x += y;" - case Sub(l, r) => "x -= y;" - case Mul(l, r) => "x *= y;" - case Rot(l, r) => "rotate_left(x, c);" - - type Templete = List[Stmt] - def toString(t: Templete): String = t.map(toString).mkString("") - def toStringWithBaseStr(t: Templete): String = - baseStrFront + toString(t) + baseStrBack - def getMulDepth(t: Templete): Int = t.count { - case Mul(_, _) => true; case _ => false - } - - def buildTemplate(temp: Templete): Goal = - val stmts = temp.map(toString).map(parseStmt) + def stringifyWithBaseStr(t: Template): String = + baseStrFront + t.stringify + baseStrBack + def buildTemplate(temp: Template): Goal = + val stmts = temp.map(_.stringify).map(parseStmt) val base = createNewBaseTemplate() val baseStmts = base.f0.f7.nodes baseStmts.addAll(0, stmts.asJava) return base - -// 가능한 모든 Stmt를 생성하는 함수 - def allStmts: LazyList[Stmt] = LazyList( - Add(V, V), - Sub(V, V), - Mul(V, V), - Rot(V, V), - ) - -// 주어진 길이에 대해 가능한 모든 템플릿을 생성하는 함수 - def allTempletesOfSize(n: Int): LazyList[Templete] = n match { - case 1 => allStmts.map(stmt => List(stmt)) - case _ => - for { - stmt <- allStmts - program <- allTempletesOfSize(n - 1) - } yield stmt :: program - } - -// 모든 길이에 대해 가능한 모든 템플릿을 생성하는 LazyList - val allTempletes: LazyList[Templete] = - LazyList.from(1).flatMap(allTempletesOfSize) } diff --git a/src/main/scala/fhetest/Utils/Utils.scala b/src/main/scala/fhetest/Utils/Utils.scala index 6aa0c6e..27be177 100644 --- a/src/main/scala/fhetest/Utils/Utils.scala +++ b/src/main/scala/fhetest/Utils/Utils.scala @@ -130,7 +130,7 @@ def compare(obtained: String, expected: String): Unit = { .foreach { case (obtained, result) => assert( - Math.abs(obtained - result) < 0.0001, + Math.abs(obtained - result) < 0.001, s"$obtained and $result are not close", ) }