From 881ad9e4e12c1efc69466cf061b25dc06b59fe9c Mon Sep 17 00:00:00 2001 From: Jaeho Choi Date: Thu, 15 Feb 2024 09:08:30 +0000 Subject: [PATCH] Refactor code to assign random values in Generate.scala --- src/main/scala/fhetest/Phase/Generate.scala | 157 ++++++++++++-------- 1 file changed, 92 insertions(+), 65 deletions(-) diff --git a/src/main/scala/fhetest/Phase/Generate.scala b/src/main/scala/fhetest/Phase/Generate.scala index b056515..3d5e0a1 100644 --- a/src/main/scala/fhetest/Phase/Generate.scala +++ b/src/main/scala/fhetest/Phase/Generate.scala @@ -48,38 +48,38 @@ case class Generate(encType: ENC_TYPE) { print("\n") print(toString(template) + "\n") print("=" * 80 + "\n") - for _ <- 0 until 5 do { - val concretized = concretizeTemplate(template) - print(toString(concretized) + "\n") - print("-" * 80 + "\n") - - val ast: Goal = buildTemplate(concretized) - val result = Interp(ast, 32768, 65537) - print("CLEAR" + " : " + result + "\n") - for { - backend <- backends - } { - withBackendTempDir( - backend, - { workspaceDir => - given DirName = workspaceDir - val mulDepth = getMulDepth(concretized) - // Default RingDim, PlainModulus with MulDepth - val encParams = EncParams(32768, mulDepth, 65537) - Print( - ast, - symbolTable, - encType, - backend, - encParamsOpt = Some(encParams), - ) - val result = Execute(backend) - print(backend.toString + " : " + result) - }, - ) - } - print("-" * 80 + "\n") + // for _ <- 0 until 5 do { + val concretized = concretizeTemplate(template) + print(toString(concretized) + "\n") + print("-" * 80 + "\n") + + val ast: Goal = buildTemplate(concretized) + val result = Interp(ast, 32768, 65537) + print("CLEAR" + " : " + result + "\n") + for { + backend <- backends + } { + withBackendTempDir( + backend, + { workspaceDir => + given DirName = workspaceDir + val mulDepth = getMulDepth(concretized) + // Default RingDim, PlainModulus with MulDepth + val encParams = EncParams(32768, mulDepth, 65537) + Print( + ast, + symbolTable, + encType, + backend, + encParamsOpt = Some(encParams), + ) + val result = Execute(backend) + print(backend.toString + " : " + result) + }, + ) } + print("-" * 80 + "\n") + // } } // TODO: current length = 100, it can be changed to larger value @@ -88,49 +88,77 @@ case class Generate(encType: ENC_TYPE) { case ENC_TYPE.ENC_DOUBLE => assignRandValues(template, 100, 2.0) case _ => throw new Exception("encType is not set") - def boilerplate(): (Goal, SymbolTable, _) = - val baseStream = new ByteArrayInputStream(baseStr.getBytes("UTF-8")) - Parse(baseStream) - - // TODO: current print only 10 values, it can be changed to larger value - def createNewBaseTemplate(): Goal = boilerplate()._1 + // TODO: Currently, T2 DSL does not support negative numbers + def assignRandValues(template: Templete, len: Int, bound: Int): Templete = + val l = Random.between(1, len) + val vxs = List.fill(l)(Random.between(0, bound)) + // refactor + val t = assignValues(template, vxs) + t.flatMap(s => + s match + case Add(_, _) | Sub(_, _) | Mul(_, _) => + List( + AssignVec( + "y", + List.fill(Random.between(1, len))(Random.between(0, bound)), + ), + s, + ) + case Rot(_, _) => + // TODO: Currently, c is bounded by 10. It can be changed to larger value + List( + Assign("c", Random.between(0, 10)), + s, + ) + case _ => List(s), + ) - def assignIntValue(template: Templete, vx: Int, vy: Int): Templete = - val assignments = List(Assign("x", vx), Assign("y", vy)) - return assignments ++ template + def assignRandValues(template: Templete, len: Int, bound: Double): Templete = + val l = Random.between(1, len) + val vxs = List.fill(l)(Random.between(0, bound)) + // refactor + val t = assignValues(template, vxs) + t.flatMap(s => + s match + case Add(_, _) | Sub(_, _) | Mul(_, _) => + List( + AssignVec( + "y", + List.fill(Random.between(1, len))(Random.between(0, bound)), + ), + s, + ) + case Rot(_, _) => + List( + Assign("c", Random.between(0, 10)), + s, + ) + case _ => List(s), + ) // vxs = [1, 2, 3], vys = [4, 5, 6] => x = { 1, 2, 3 }; y = { 4, 5, 6 }; def assignValues( template: Templete, vxs: List[Int | Double], - vys: List[Int | Double], - vc: Int, ): Templete = - val assignments = - List(AssignVec("x", vxs), AssignVec("y", vys), Assign("c", vc)) - return assignments ++ template + val assignment = AssignVec("x", vxs) + return assignment :: template - def assignRandValue(template: Templete, bound: Int): Templete = - val vx = Random.between(0, bound) - val vy = Random.between(0, bound) - return assignIntValue(template, vx, vy) + def boilerplate(): (Goal, SymbolTable, _) = + val baseStream = new ByteArrayInputStream(baseStr.getBytes("UTF-8")) + Parse(baseStream) - // TODO: Currently, T2 DSL does not support negative numbers - def assignRandValues(template: Templete, len: Int, bound: Int): Templete = - val l = Random.between(1, len) - val vxs = List.fill(l)(Random.between(0, bound)) - val vys = List.fill(l)(Random.between(0, bound)) - // TODO: Currently, c is bounded by 10. It can be changed to larger value - val vc = Random.between(0, 10) - return assignValues(template, vxs, vys, vc) + // TODO: current print only 10 values, it can be changed to larger value + def createNewBaseTemplate(): Goal = boilerplate()._1 - def assignRandValues(template: Templete, len: Int, bound: Double): Templete = - val l = Random.between(1, len) - val vxs = List.fill(l)(Random.between(0, bound)) - val vys = List.fill(l)(Random.between(0, bound)) - // TODO: Currently, c is bounded by 10. It can be changed to larger value - val vc = Random.between(0, 10) - return assignValues(template, vxs, vys, vc) + // def assignIntValue(template: Templete, vx: Int, vy: Int): Templete = + // val assignments = List(Assign("x", vx), Assign("y", vy)) + // return assignments ++ template + + // def assignRandValue(template: Templete, bound: Int): Templete = + // val vx = Random.between(0, bound) + // val vy = Random.between(0, bound) + // return assignIntValue(template, vx, vy) def parseStmt(stmtStr: String): Statement = val input_stream: InputStream = new ByteArrayInputStream( @@ -156,7 +184,6 @@ case class Generate(encType: ENC_TYPE) { case Sub(l, r) => "x -= y;" case Mul(l, r) => "x *= y;" case Rot(l, r) => "rotate_left(x, c);" - def concretize(s: Stmt) = parseStmt(toString(s)) type Templete = List[Stmt] def toString(t: Templete): String = t.map(toString).mkString("")