Skip to content

Commit

Permalink
Add functionality for generator to produce a list of T2 program as st…
Browse files Browse the repository at this point in the history
…ring (#11)
  • Loading branch information
Maokami committed Feb 15, 2024
1 parent c489cb3 commit fd90c75
Showing 1 changed file with 23 additions and 22 deletions.
45 changes: 23 additions & 22 deletions src/main/scala/fhetest/Phase/Generate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,28 @@ import scala.jdk.CollectionConverters._
import scala.util.Random

case class Generate(encType: ENC_TYPE) {
// TODO : This boilerplate code is really ugly. But I cannot find a better way to do this.
val baseStrFront = encType match {
case ENC_TYPE.ENC_INT =>
"int main(void) { EncInt x, y; int c; "
case ENC_TYPE.ENC_DOUBLE =>
"int main(void) { EncDouble x, y; int c; "
case ENC_TYPE.None => throw new Exception("encType is not set")
}
val baseStrBack = " print_batched (x, 10); return 0; } "
val baseStr = baseStrFront + baseStrBack

val symbolTable = boilerplate()._2

def apply(n: Int): List[String] = {
for {
template <- allTempletes.take(n).toList
} yield {
val concretized = concretizeTemplate(template)
toStringWithBaseStr(concretized)
}
}

// This is for testing purpose
def show(backends: List[Backend], n: Int) =
for {
Expand Down Expand Up @@ -69,27 +89,6 @@ case class Generate(encType: ENC_TYPE) {
case _ => throw new Exception("encType is not set")

def boilerplate(): (Goal, SymbolTable, _) =
val baseStr = encType match {
case ENC_TYPE.ENC_INT =>
"""
int main(void) {
EncInt x, y;
int c;
print_batched (x, 10);
return 0;
}
"""
case ENC_TYPE.ENC_DOUBLE =>
"""
int main(void) {
EncDouble x, y;
int c;
print_batched (x, 10);
return 0;
}
"""
case ENC_TYPE.None => throw new Exception("encType is not set")
}
val baseStream = new ByteArrayInputStream(baseStr.getBytes("UTF-8"))
Parse(baseStream)

Expand Down Expand Up @@ -160,7 +159,9 @@ case class Generate(encType: ENC_TYPE) {
def concretize(s: Stmt) = parseStmt(toString(s))

type Templete = List[Stmt]
def toString(t: Templete): String = t.map(toString).mkString("\n")
def toString(t: Templete): String = t.map(toString).mkString("")
def toStringWithBaseStr(t: Templete): String =
baseStrFront + toString(t) + baseStrBack
def getMulDepth(t: Templete): Int = t.count {
case Mul(_, _) => true; case _ => false
}
Expand Down

0 comments on commit fd90c75

Please sign in to comment.