Skip to content

Commit

Permalink
Refactor code to assign random values in Generate.scala
Browse files Browse the repository at this point in the history
  • Loading branch information
Maokami committed Feb 15, 2024
1 parent fd90c75 commit 881ad9e
Showing 1 changed file with 92 additions and 65 deletions.
157 changes: 92 additions & 65 deletions src/main/scala/fhetest/Phase/Generate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,38 +48,38 @@ case class Generate(encType: ENC_TYPE) {
print("<Program>\n")
print(toString(template) + "\n")
print("=" * 80 + "\n")
for _ <- 0 until 5 do {
val concretized = concretizeTemplate(template)
print(toString(concretized) + "\n")
print("-" * 80 + "\n")

val ast: Goal = buildTemplate(concretized)
val result = Interp(ast, 32768, 65537)
print("CLEAR" + " : " + result + "\n")
for {
backend <- backends
} {
withBackendTempDir(
backend,
{ workspaceDir =>
given DirName = workspaceDir
val mulDepth = getMulDepth(concretized)
// Default RingDim, PlainModulus with MulDepth
val encParams = EncParams(32768, mulDepth, 65537)
Print(
ast,
symbolTable,
encType,
backend,
encParamsOpt = Some(encParams),
)
val result = Execute(backend)
print(backend.toString + " : " + result)
},
)
}
print("-" * 80 + "\n")
// for _ <- 0 until 5 do {
val concretized = concretizeTemplate(template)
print(toString(concretized) + "\n")
print("-" * 80 + "\n")

val ast: Goal = buildTemplate(concretized)
val result = Interp(ast, 32768, 65537)
print("CLEAR" + " : " + result + "\n")
for {
backend <- backends
} {
withBackendTempDir(
backend,
{ workspaceDir =>
given DirName = workspaceDir
val mulDepth = getMulDepth(concretized)
// Default RingDim, PlainModulus with MulDepth
val encParams = EncParams(32768, mulDepth, 65537)
Print(
ast,
symbolTable,
encType,
backend,
encParamsOpt = Some(encParams),
)
val result = Execute(backend)
print(backend.toString + " : " + result)
},
)
}
print("-" * 80 + "\n")
// }
}

// TODO: current length = 100, it can be changed to larger value
Expand All @@ -88,49 +88,77 @@ case class Generate(encType: ENC_TYPE) {
case ENC_TYPE.ENC_DOUBLE => assignRandValues(template, 100, 2.0)
case _ => throw new Exception("encType is not set")

def boilerplate(): (Goal, SymbolTable, _) =
val baseStream = new ByteArrayInputStream(baseStr.getBytes("UTF-8"))
Parse(baseStream)

// TODO: current print only 10 values, it can be changed to larger value
def createNewBaseTemplate(): Goal = boilerplate()._1
// TODO: Currently, T2 DSL does not support negative numbers
def assignRandValues(template: Templete, len: Int, bound: Int): Templete =
val l = Random.between(1, len)
val vxs = List.fill(l)(Random.between(0, bound))
// refactor
val t = assignValues(template, vxs)
t.flatMap(s =>
s match
case Add(_, _) | Sub(_, _) | Mul(_, _) =>
List(
AssignVec(
"y",
List.fill(Random.between(1, len))(Random.between(0, bound)),
),
s,
)
case Rot(_, _) =>
// TODO: Currently, c is bounded by 10. It can be changed to larger value
List(
Assign("c", Random.between(0, 10)),
s,
)
case _ => List(s),
)

def assignIntValue(template: Templete, vx: Int, vy: Int): Templete =
val assignments = List(Assign("x", vx), Assign("y", vy))
return assignments ++ template
def assignRandValues(template: Templete, len: Int, bound: Double): Templete =
val l = Random.between(1, len)
val vxs = List.fill(l)(Random.between(0, bound))
// refactor
val t = assignValues(template, vxs)
t.flatMap(s =>
s match
case Add(_, _) | Sub(_, _) | Mul(_, _) =>
List(
AssignVec(
"y",
List.fill(Random.between(1, len))(Random.between(0, bound)),
),
s,
)
case Rot(_, _) =>
List(
Assign("c", Random.between(0, 10)),
s,
)
case _ => List(s),
)

// vxs = [1, 2, 3], vys = [4, 5, 6] => x = { 1, 2, 3 }; y = { 4, 5, 6 };
def assignValues(
template: Templete,
vxs: List[Int | Double],
vys: List[Int | Double],
vc: Int,
): Templete =
val assignments =
List(AssignVec("x", vxs), AssignVec("y", vys), Assign("c", vc))
return assignments ++ template
val assignment = AssignVec("x", vxs)
return assignment :: template

def assignRandValue(template: Templete, bound: Int): Templete =
val vx = Random.between(0, bound)
val vy = Random.between(0, bound)
return assignIntValue(template, vx, vy)
def boilerplate(): (Goal, SymbolTable, _) =
val baseStream = new ByteArrayInputStream(baseStr.getBytes("UTF-8"))
Parse(baseStream)

// TODO: Currently, T2 DSL does not support negative numbers
def assignRandValues(template: Templete, len: Int, bound: Int): Templete =
val l = Random.between(1, len)
val vxs = List.fill(l)(Random.between(0, bound))
val vys = List.fill(l)(Random.between(0, bound))
// TODO: Currently, c is bounded by 10. It can be changed to larger value
val vc = Random.between(0, 10)
return assignValues(template, vxs, vys, vc)
// TODO: current print only 10 values, it can be changed to larger value
def createNewBaseTemplate(): Goal = boilerplate()._1

def assignRandValues(template: Templete, len: Int, bound: Double): Templete =
val l = Random.between(1, len)
val vxs = List.fill(l)(Random.between(0, bound))
val vys = List.fill(l)(Random.between(0, bound))
// TODO: Currently, c is bounded by 10. It can be changed to larger value
val vc = Random.between(0, 10)
return assignValues(template, vxs, vys, vc)
// def assignIntValue(template: Templete, vx: Int, vy: Int): Templete =
// val assignments = List(Assign("x", vx), Assign("y", vy))
// return assignments ++ template

// def assignRandValue(template: Templete, bound: Int): Templete =
// val vx = Random.between(0, bound)
// val vy = Random.between(0, bound)
// return assignIntValue(template, vx, vy)

def parseStmt(stmtStr: String): Statement =
val input_stream: InputStream = new ByteArrayInputStream(
Expand All @@ -156,7 +184,6 @@ case class Generate(encType: ENC_TYPE) {
case Sub(l, r) => "x -= y;"
case Mul(l, r) => "x *= y;"
case Rot(l, r) => "rotate_left(x, c);"
def concretize(s: Stmt) = parseStmt(toString(s))

type Templete = List[Stmt]
def toString(t: Templete): String = t.map(toString).mkString("")
Expand Down

0 comments on commit 881ad9e

Please sign in to comment.