From 0cf400f626a623939a0b3cf0844216f89bfa99cd Mon Sep 17 00:00:00 2001 From: Jaeho Choi Date: Tue, 20 Feb 2024 06:50:51 +0000 Subject: [PATCH] Support Plain Ops (#13) --- .../scala/fhetest/Generate/AbsStatement.scala | 6 +++++ src/main/scala/fhetest/Generate/Utils.scala | 25 ++++++++++++++----- src/main/scala/fhetest/Phase/Generate.scala | 6 ++--- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/main/scala/fhetest/Generate/AbsStatement.scala b/src/main/scala/fhetest/Generate/AbsStatement.scala index a9800b6..438a9e1 100644 --- a/src/main/scala/fhetest/Generate/AbsStatement.scala +++ b/src/main/scala/fhetest/Generate/AbsStatement.scala @@ -5,8 +5,11 @@ case class Var() case class Assign(l: String, r: (Int | Double)) extends AbsStmt case class AssignVec(l: String, r: (List[Int] | List[Double])) extends AbsStmt case class Add(l: Var, r: Var) extends AbsStmt +case class AddP(l: Var, r: Var) extends AbsStmt case class Sub(l: Var, r: Var) extends AbsStmt +case class SubP(l: Var, r: Var) extends AbsStmt case class Mul(l: Var, r: Var) extends AbsStmt +case class MulP(l: Var, r: Var) extends AbsStmt case class Rot(l: Var, r: Var) extends AbsStmt val V = Var() @@ -18,7 +21,10 @@ def formatNumber(n: Int | Double): String = n match { def allAbsStmts: LazyList[AbsStmt] = LazyList( Add(V, V), + AddP(V, V), Sub(V, V), + SubP(V, V), Mul(V, V), + MulP(V, V), Rot(V, V), ) diff --git a/src/main/scala/fhetest/Generate/Utils.scala b/src/main/scala/fhetest/Generate/Utils.scala index ae24b97..8020f7c 100644 --- a/src/main/scala/fhetest/Generate/Utils.scala +++ b/src/main/scala/fhetest/Generate/Utils.scala @@ -3,17 +3,22 @@ package fhetest.Generate import scala.util.Random object Utils { - def assignValues(name: String, vxs: (List[Int] | List[Double])): AbsStmt = - AssignVec(name, vxs) + def assignValue(name: String, v: (Int | Double)): AbsStmt = + Assign(name, v) + def assignValues(name: String, vs: (List[Int] | List[Double])): AbsStmt = + AssignVec(name, vs) extension (s: AbsStmt) def stringify: String = s match case Assign(l, r) => s"$l = ${formatNumber(r)};" case AssignVec(l, r) => s"$l = {${r.map(formatNumber).mkString(",")}};" - case Add(l, r) => "x += y;" - case Sub(l, r) => "x -= y;" - case Mul(l, r) => "x *= y;" - case Rot(l, r) => "rotate_left(x, c);" + case Add(_, _) => "x += y;" + case AddP(_, _) => "x += yP;" + case Sub(_, _) => "x -= y;" + case SubP(_, _) => "x -= yP;" + case Mul(_, _) => "x *= y;" + case MulP(_, _) => "x *= yP;" + case Rot(_, _) => "rotate_left(x, c);" extension (t: Template) def stringify: String = t.map(_.stringify).mkString("") @@ -24,6 +29,8 @@ object Utils { def assignRandValues(len: Int, bound: (Int | Double)): Template = { val lx = Random.between(1, len + 1) val ly = Random.between(1, len + 1) + + // Generate Random Values val vxs: (List[Int] | List[Double]) = bound match { case intBound: Int => List.fill(lx)(Random.between(0, intBound)) case doubleBound: Double => @@ -34,11 +41,17 @@ object Utils { case doubleBound: Double => List.fill(ly)(Random.between(0.0, doubleBound)) } + val vyP = bound match { + case intBound: Int => Random.between(0, intBound) + case doubleBound: Double => Random.between(0.0, doubleBound) + } val assigned = assignValues("x", vxs) :: t assigned.flatMap { case op @ (Add(_, _) | Sub(_, _) | Mul(_, _)) => assignValues("y", vys) :: op :: Nil + case op @ (AddP(_, _) | SubP(_, _) | MulP(_, _)) => + assignValue("yP", vyP) :: op :: Nil case op @ Rot(_, _) => Assign("c", Random.between(0, 10)) :: op :: Nil case s => s :: Nil diff --git a/src/main/scala/fhetest/Phase/Generate.scala b/src/main/scala/fhetest/Phase/Generate.scala index 34dcb99..5ac8ff1 100644 --- a/src/main/scala/fhetest/Phase/Generate.scala +++ b/src/main/scala/fhetest/Phase/Generate.scala @@ -22,12 +22,12 @@ case class Generate( // TODO : This boilerplate code is really ugly. But I cannot find a better way to do this. val baseStrFront = encType match { case ENC_TYPE.ENC_INT => - "int main(void) { EncInt x, y; int c; " + "int main(void) { EncInt x, y; int yP; int c; " case ENC_TYPE.ENC_DOUBLE => - "int main(void) { EncDouble x, y; int c; " + "int main(void) { EncDouble x, y; double yP; int c; " case ENC_TYPE.None => throw new Exception("encType is not set") } - val baseStrBack = " print_batched (x, 10); return 0; } " + val baseStrBack = " print_batched (x, 20); return 0; } " val baseStr = baseStrFront + baseStrBack val symbolTable = boilerplate()._2