Skip to content

Commit 075826e

Browse files
author
Jaeho Choi
committed
Update t2 program generator (#11)
1 parent a20077a commit 075826e

File tree

5 files changed

+107
-68
lines changed

5 files changed

+107
-68
lines changed

src/main/scala/fhetest/Command.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package fhetest
22

33
import fhetest.Utils.*
4-
import fhetest.Phase.{Parse, Interp, Print, Execute}
4+
import fhetest.Phase.{Parse, Interp, Print, Execute, Generate}
55

66
sealed abstract class Command(
77
/** command name */
@@ -143,3 +143,19 @@ case object CmdExecute extends Command("execute") {
143143
case Nil => println("No backend given.")
144144
}
145145
}
146+
147+
/** `gen` command */
148+
case object CmdGen extends Command("gen") {
149+
val help = "Generate random T2 programs."
150+
val examples = List(
151+
"fhetest gen",
152+
"fhetest gen -n 10",
153+
)
154+
def apply(args: List[String]): Unit = args match {
155+
case Nil => println("No argument given.")
156+
case n :: _ => {
157+
val num = n.toInt
158+
Generate(List(Backend.SEAL, Backend.OpenFHE), ENC_TYPE.ENC_INT, num)
159+
}
160+
}
161+
}
Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package fhetest.Phase
22

3+
import sys.process.*
4+
import java.io.File
5+
36
import fhetest.Utils.*
47

58
case object Execute {
@@ -8,14 +11,16 @@ case object Execute {
811
val cmakeCommand = "cmake ."
912
val makeCommand = "make -j"
1013
val executeCommand = "./test.out"
14+
15+
// TODO : Add option silent (default true)
1116
val cmakeProcess =
12-
sys.process.Process(cmakeCommand, new java.io.File(workspaceDir))
17+
Process(cmakeCommand, new File(workspaceDir))
1318
val makeProcess =
14-
sys.process.Process(makeCommand, new java.io.File(workspaceDir))
19+
Process(makeCommand, new File(workspaceDir))
1520
val executeProcess =
16-
sys.process.Process(executeCommand, new java.io.File(binPath))
17-
cmakeProcess.!
18-
makeProcess.!
21+
Process(executeCommand, new File(binPath))
22+
cmakeProcess.!(silentLogger)
23+
makeProcess.!(silentLogger)
1924
executeProcess.!!
2025
}
2126
}

src/main/scala/fhetest/Phase/Generate.scala

Lines changed: 75 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,43 @@ import scala.util.Random
1616

1717
case object Generate {
1818

19-
def apply(backend: Backend, encType: ENC_TYPE, n: Int) =
19+
def apply(backends: List[Backend], encType: ENC_TYPE, n: Int) =
2020
for {
21-
program <- allPrograms.take(n)
21+
template <- allTempletes.take(n)
2222
} {
23-
// withBackendTempDir(
24-
// backend,
25-
// { workspaceDir =>
26-
// given DirName = workspaceDir
27-
// val ast = baseAst
28-
// // appendToBaseAst(exampleStmt)
29-
// appendToBaseAst(program)
30-
// Print(ast, symbolTable, encType, backend)
31-
// Execute(backend)
32-
// },
33-
// )
34-
given DirName = getWorkspaceDir(backend)
35-
val template = buildTemplate(program)
36-
val ast = concretizeTemplate(template)
37-
Print(ast, symbolTable, encType, backend)
38-
Execute(backend)
23+
// TODO: currently concretize 5 times for each template
24+
// TODO: store the results into a json file and print it by parsing the json file
25+
print("=" * 80 + "\n")
26+
print("<Program>\n")
27+
print(toString(template) + "\n")
28+
print("=" * 80 + "\n")
29+
for _ <- 0 until 5 do {
30+
val concretized = concretizeTemplate(template)
31+
print(toString(concretized) + "\n")
32+
print("-" * 80 + "\n")
33+
34+
val ast = buildTemplate(concretized)
35+
val result = Interp(ast, 32768, 65537)
36+
print("CLEAR" + " : " + result + "\n")
37+
for {
38+
backend <- backends
39+
} {
40+
withBackendTempDir(
41+
backend,
42+
{ workspaceDir =>
43+
given DirName = workspaceDir
44+
Print(ast, symbolTable, encType, backend)
45+
val result = Execute(backend)
46+
print(backend.toString + " : " + result)
47+
},
48+
)
49+
}
50+
print("-" * 80 + "\n")
51+
}
3952
}
4053

41-
def concretizeTemplate(template: Goal): Goal =
42-
return assignRandIntValues(template, 1000)
54+
def concretizeTemplate(template: Templete): Templete =
55+
return assignRandIntValues(template, 100)
4356

4457
// FIXME: This is just a temporary solution for making symbolTable and encType available
4558
val (_: Goal, symbolTable, encType) =
@@ -65,34 +78,28 @@ case object Generate {
6578
Parse(baseStream)._1
6679
}
6780

68-
def assignIntValue(template: Goal, vx: Int, vy: Int): Goal =
69-
val XStr = s"x = $vx;"
70-
val YStr = s"y = $vy;"
71-
print(XStr)
72-
print(YStr)
73-
val assignments = List(XStr, YStr)
74-
val stmts = assignments.map(parseStmt)
75-
val templateStmts = template.f0.f7.nodes
76-
templateStmts.addAll(0, stmts.asJava)
77-
return template
81+
def assignIntValue(template: Templete, vx: Int, vy: Int): Templete =
82+
val assignments = List(Assign("x", vx), Assign("y", vy))
83+
return assignments ++ template
7884

7985
// vxs = [1, 2, 3], vys = [4, 5, 6] => x = { 1, 2, 3 }; y = { 4, 5, 6 };
80-
def assignIntValues(template: Goal, vxs: List[Int], vys: List[Int]): Goal =
81-
val XStr = s"x = {${vxs.mkString(",")}};"
82-
val YStr = s"y = {${vys.mkString(",")}};"
83-
val assignments = List(XStr, YStr)
84-
val stmts = assignments.map(parseStmt)
85-
val templateStmts = template.f0.f7.nodes
86-
templateStmts.addAll(0, stmts.asJava)
87-
return template
88-
89-
def assignRandIntValue(template: Goal, bound: Int): Goal =
86+
def assignIntValues(
87+
template: Templete,
88+
vxs: List[Int],
89+
vys: List[Int],
90+
): Templete =
91+
val assignments = List(AssignVec("x", vxs), AssignVec("y", vys))
92+
return assignments ++ template
93+
94+
def assignRandIntValue(template: Templete, bound: Int): Templete =
9095
val vx = Random.between(0, bound)
9196
val vy = Random.between(0, bound)
9297
return assignIntValue(template, vx, vy)
9398

9499
// current length = 5
95-
def assignRandIntValues(template: Goal, bound: Int): Goal =
100+
// TODO : Currently, it only supports the length of 5
101+
// TODO : Currently, T2 DSL does not support negative numbers
102+
def assignRandIntValues(template: Templete, bound: Int): Templete =
96103
val vxs = List.fill(5)(Random.between(0, bound))
97104
val vys = List.fill(5)(Random.between(0, bound))
98105
return assignIntValues(template, vxs, vys)
@@ -103,48 +110,54 @@ case object Generate {
103110
)
104111
T2DSLParser(input_stream).Statement()
105112

106-
def buildTemplate(stmts: List[Statement]): Goal =
107-
val base = createNewBaseTemplate()
108-
val baseStmts = base.f0.f7.nodes
109-
baseStmts.addAll(0, stmts.asJava)
110-
return base
111-
112113
trait Stmt
113114
case class Var()
115+
case class Assign(l: String, r: Int) extends Stmt
116+
case class AssignVec(l: String, r: List[Int]) extends Stmt
114117
case class Add(l: Var, r: Var) extends Stmt
115118
case class Sub(l: Var, r: Var) extends Stmt
116119
case class Mul(l: Var, r: Var) extends Stmt
117120
case class Rot(l: Var, r: Var) extends Stmt
118121

119-
def concretize(s: Stmt) = s match
120-
case Add(l, r) => parseStmt("x += y;")
121-
case Sub(l, r) => parseStmt("x -= y;")
122-
case Mul(l, r) => parseStmt("x *= y;")
123-
// case Rot(l, r) => parseStmt("rotate_left(x, c);")
124-
125-
type Program = List[Stmt]
126122
val V = Var()
127123

124+
def toString(s: Stmt) = s match
125+
case Assign(l, r) => s"$l = $r;"
126+
case AssignVec(l, r) => s"$l = {${r.mkString(",")}};"
127+
case Add(l, r) => "x += y;"
128+
case Sub(l, r) => "x -= y;"
129+
case Mul(l, r) => "x *= y;"
130+
// case Rot(l, r) => "rotate_left(x, c);"
131+
def concretize(s: Stmt) = parseStmt(toString(s))
132+
133+
type Templete = List[Stmt]
134+
def toString(t: Templete): String = t.map(toString).mkString("\n")
135+
136+
def buildTemplate(temp: Templete): Goal =
137+
val stmts = temp.map(toString).map(parseStmt)
138+
val base = createNewBaseTemplate()
139+
val baseStmts = base.f0.f7.nodes
140+
baseStmts.addAll(0, stmts.asJava)
141+
return base
142+
128143
// 가능한 모든 Stmt를 생성하는 함수
129144
def allStmts: LazyList[Stmt] = LazyList(
130145
Add(V, V),
131146
Sub(V, V),
132147
Mul(V, V),
133148
)
134149

135-
// 주어진 길이에 대해 가능한 모든 프로그램을 생성하는 함수
136-
def allProgramsOfSize(n: Int): LazyList[Program] = n match {
150+
// 주어진 길이에 대해 가능한 모든 템플릿을 생성하는 함수
151+
def allTempletesOfSize(n: Int): LazyList[Templete] = n match {
137152
case 1 => allStmts.map(stmt => List(stmt))
138153
case _ =>
139154
for {
140155
stmt <- allStmts
141-
program <- allProgramsOfSize(n - 1)
156+
program <- allTempletesOfSize(n - 1)
142157
} yield stmt :: program
143158
}
144159

145-
// 모든 길이에 대해 가능한 모든 프로그램을 생성하는 LazyList
146-
val allPrograms: LazyList[List[Statement]] =
147-
LazyList.from(1).flatMap(allProgramsOfSize).map(_.map(concretize))
148-
149-
def generateTemplate() = ???
160+
// 모든 길이에 대해 가능한 모든 템플릿을 생성하는 LazyList
161+
val allTempletes: LazyList[Templete] =
162+
LazyList.from(1).flatMap(allTempletesOfSize)
150163
}

src/main/scala/fhetest/Utils/Utils.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package fhetest.Utils
22

33
import org.twc.terminator.Main.ENC_TYPE as T2ENC_TYPE
44

5+
import sys.process.*
56
import java.nio.file.{Files, Path, Paths, StandardCopyOption}
67
import java.util.Comparator
78
import scala.util.Try
@@ -101,3 +102,5 @@ def withBackendTempDir[Result](
101102
deleteTemp()
102103
}
103104
}
105+
106+
val silentLogger = ProcessLogger(_ => (), _ => ())

src/main/scala/fhetest/fhetest.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ object FHETest {
2525
// Run the given T2 program
2626
// Compile -> Execute (SEAL, OpenFHE) / Interp (if no backend is given)
2727
CmdRun,
28+
// Generate random T2 programs
29+
CmdGen,
2830
)
2931
val cmdMap = commands.foldLeft[Map[String, Command]](Map()) {
3032
case (map, cmd) => map + (cmd.name -> cmd)

0 commit comments

Comments
 (0)