Skip to content

Commit

Permalink
Generate: Add LibConfigGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
hyerinshelly committed Mar 5, 2024
1 parent 99e4205 commit 918eb95
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 37 deletions.
18 changes: 18 additions & 0 deletions src/main/scala/fhetest/Checker/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,27 @@ def dumpResult(
("SEAL" -> JsString(sealVersion)),
("OpenFHE" -> JsString(openfheVersion)),
)
val libConfig = program.libConfig
val encParams = libConfig.encParams
val encParams_info = JsObject(
("ringDim" -> JsString(encParams.ringDim.toString)),
("multDepth" -> JsString(encParams.mulDepth.toString)),
("plainMod" -> JsString(encParams.plainMod.toString)),
)
val libConfig_info = JsObject(
("scheme" -> JsString(libConfig.scheme.toString)),
("encParams" -> encParams_info.toJson),
("firstModSize" -> JsString(libConfig.firstModSize.toString)),
("scalingModSize" -> JsString(libConfig.scalingModSize.toString)),
("securityLevel" -> JsString(libConfig.securityLevel.toString)),
("scalingTechnique" -> JsString(libConfig.scalingTechnique.toString)),
("lenOpt" -> JsString(libConfig.lenOpt.getOrElse(0).toString)),
("boundOpt" -> JsString(libConfig.boundOpt.getOrElse(0).toString)),
)
val pgm_info = Map(
("programId" -> JsString(i.toString)),
("program" -> JsString(program.content)),
("libConfig" -> libConfig_info),
)
val info = pgm_info ++ backend_info
res match {
Expand Down
23 changes: 12 additions & 11 deletions src/main/scala/fhetest/Command.scala
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ case object CmdExecute extends BackendCommand("execute") {
case object CmdGen extends Command("gen") {
val help = "Generate random T2 programs."
val examples = List(
"fhetest gen -type:int -c:10",
"fhetest gen -type:double -c:10",
"fhetest gen -type:int -count:10",
"fhetest gen -type:double -count:10",
"fhetest gen -type:int -stg:exhaust -c:10",
"fhetest gen -type:double -stg:random -c:10",
)
Expand All @@ -180,19 +180,18 @@ case object CmdCheck extends BackendCommand("check") {
val examples = List(
"fhetest check -dir:tmp -json:true",
)
// TODO: json option 추가
def runJob(config: Config): Unit =
val dir = config.dirName.getOrElseThrow("No directory given.")
val encParams = config.libConfigOpt match {
case Some(libConfig) => libConfig.encParams
case None => config.encParams
val encParamsOpt = config.libConfigOpt match {
case Some(libConfig) => Some(libConfig.encParams)
case None => None
}
val backends = List(Backend.SEAL, Backend.OpenFHE)
val toJson = config.toJson
val sealVersion = config.sealVersion
val openfheVersion = config.openfheVersion
val outputs =
Check(dir, backends, encParams, toJson, sealVersion, openfheVersion)
Check(dir, backends, encParamsOpt, toJson, sealVersion, openfheVersion)
for output <- outputs do {
println(output)
}
Expand All @@ -215,17 +214,17 @@ case object CmdTest extends BackendCommand("test") {
val generator = Generate(encType, genStrategy, config.filter)
val programs = generator(genCount)
val backendList = List(Backend.SEAL, Backend.OpenFHE)
val encParams = config.libConfigOpt match {
case Some(libConfig) => libConfig.encParams
case None => config.encParams
val encParamsOpt = config.libConfigOpt match {
case Some(libConfig) => Some(libConfig.encParams)
case None => None
}
val toJson = config.toJson
val sealVersion = config.sealVersion
val openfheVersion = config.openfheVersion
val outputs = Check(
programs,
backendList,
encParams,
encParamsOpt,
toJson,
sealVersion,
openfheVersion,
Expand All @@ -236,6 +235,8 @@ case object CmdTest extends BackendCommand("test") {
if !config.silent then {
println("Program : " + program.content)
println("-" * 80)
println("LibConfig : " + program.libConfig.stringify())
println("-" * 80)
}
println(output)
println("=" * 80)
Expand Down
22 changes: 13 additions & 9 deletions src/main/scala/fhetest/Generate/AbsProgramGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,22 @@ extension (s: Strategy)
// AbsProgram Generator
trait AbsProgramGenerator {
def generateAbsPrograms(): LazyList[AbsProgram]
val libConfig: LibConfig = LibConfig()
// val libConfig: LibConfig = LibConfig()
}

object ExhaustiveGenerator extends AbsProgramGenerator {
def generateAbsPrograms(): LazyList[AbsProgram] = {
def allAbsProgramsOfSize(n: Int): LazyList[AbsProgram] = n match {
case 1 =>
allAbsStmts.map(stmt => List(stmt)).map(AbsProgram(_, libConfig))
case _ =>
for {
stmt <- allAbsStmts
program <- allAbsProgramsOfSize(n - 1)
} yield AbsProgram(stmt :: program.absStmts, libConfig)
def allAbsProgramsOfSize(n: Int): LazyList[AbsProgram] = {
val libConfig = generateLibConfig()
n match {
case 1 =>
allAbsStmts.map(stmt => List(stmt)).map(AbsProgram(_, libConfig))
case _ =>
for {
stmt <- allAbsStmts
program <- allAbsProgramsOfSize(n - 1)
} yield AbsProgram(stmt :: program.absStmts, libConfig)
}
}
LazyList.from(1).flatMap(allAbsProgramsOfSize)
}
Expand All @@ -39,6 +42,7 @@ object RandomGenerator extends AbsProgramGenerator {
def generateAbsPrograms(): LazyList[AbsProgram] = {
def randomAbsProgramOfSize(n: Int): AbsProgram = {
val absStmts = (1 to n).map(_ => Random.shuffle(allAbsStmts).head).toList
val libConfig = generateLibConfig()
AbsProgram(absStmts, libConfig)
}
// Generate Lengths from 1 to inf
Expand Down
39 changes: 39 additions & 0 deletions src/main/scala/fhetest/Generate/LibConfigGenerator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package fhetest.Generate

import fhetest.LibConfig
import fhetest.Utils.*
import scala.util.Random

def generateLibConfig(): LibConfig = {
val randomScheme = Scheme.CKKS
// val randomScheme = Scheme.values(Random.nextInt(Scheme.values.length))
val randomEncParams = {
// TODO: Currently only MultDepth is random
val randomRingDim = 32768
val randomMultDepth = Random.nextInt(10)
val randomPlainMod = 65537
EncParams(randomRingDim, randomMultDepth, randomPlainMod)
}
val randomFirstModSize: Int = Random.nextInt(100)
val randomScalingModSize: Int = Random.nextInt(100)
val randomSecurityLevel =
SecurityLevel.values(Random.nextInt(SecurityLevel.values.length))
val randomScalingTechnique =
ScalingTechnique.values(Random.nextInt(ScalingTechnique.values.length))
val randomLenOpt: Option[Int] = Some(Random.nextInt(100000))
val randomBoundOpt: Option[Int | Double] = randomScheme match {
case Scheme.BFV | Scheme.BGV =>
Some(Random.nextInt(BigInt(2).pow(100).toInt))
case Scheme.CKKS => Some(Random.nextDouble() * math.pow(2, 100))
}
LibConfig(
randomScheme,
randomEncParams,
randomFirstModSize,
randomScalingModSize,
randomSecurityLevel,
randomScalingTechnique,
randomLenOpt,
randomBoundOpt,
)
}
9 changes: 9 additions & 0 deletions src/main/scala/fhetest/LibConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ case class LibConfig(
boundOpt: Option[Int | Double] = None,
) {

def stringify(): String =
s"""{scheme: ${scheme}}
{encParams: EncParams(${encParams.ringDim}, ${encParams.mulDepth}, ${encParams.plainMod})}
{(firstModSize, scalingModSize): (${firstModSize}, ${scalingModSize})}
{securityLevel: ${securityLevel}}
{scalingTechnique: ${scalingTechnique}}
{lenOpt: ${lenOpt}}
{boundOpt: ${boundOpt}}"""

val sealConfigs: List[String] = sealStr.split("\n").toList
val openfheConfigs: List[String] = openfheStr.split("\n").toList

Expand Down
26 changes: 17 additions & 9 deletions src/main/scala/fhetest/Phase/Check.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ case object Check {
def apply(
program: T2Program,
backends: List[Backend],
encParams: EncParams,
encParamsOpt: Option[EncParams],
): CheckResult = {
val encParams = encParamsOpt match {
case Some(encParams) => encParams
case None => program.libConfig.encParams
}
val result = for {
parsed <- parse(program)
} yield {
Expand Down Expand Up @@ -51,7 +55,7 @@ case object Check {
def apply(
programs: LazyList[T2Program],
backends: List[Backend],
encParams: EncParams,
encParamsOpt: Option[EncParams],
toJson: Boolean,
sealVersion: String,
openfheVersion: String,
Expand All @@ -60,6 +64,7 @@ case object Check {
setTestDir()
val checkResults = for {
(program, i) <- programs.zipWithIndex
encParams = encParamsOpt.getOrElse(program.libConfig.encParams)
parsed <- parse(program).toOption
interpResult <- interp(parsed, encParams).toOption
overflowBound = program.libConfig.firstModSize
Expand Down Expand Up @@ -91,11 +96,10 @@ case object Check {
max < limit
}

// TODO: Do we need this function?
def apply(
directory: String,
backends: List[Backend],
encParams: EncParams,
encParamsOpt: Option[EncParams],
toJson: Boolean,
sealVersion: String,
openfheVersion: String,
Expand All @@ -109,14 +113,18 @@ case object Check {
(filePath, i) <- fileList.to(LazyList).zipWithIndex
} yield {
val fileStr = Files.readAllLines(filePath).asScala.mkString("")
val program = T2Program(fileStr, LibConfig())
val checkResult = apply(program, backends, encParams)
val libConfig = LibConfig() // default libConfig for dir testing
val program = T2Program(fileStr, libConfig)
val checkResult = apply(program, backends, encParamsOpt)
if (toJson)
dumpResult(program, i, checkResult, sealVersion, openfheVersion)
val pgmStr = "-" * 10 + " Program " + "-" * 10 + "\n" + fileStr + "\n"
val reportStr = checkResult.toString + "\n"
pgmStr + reportStr

val libConfigStr =
"-" * 10 + " LibConfig " + "-" * 10 + "\n" + libConfig
.stringify() + "\n"
val reportStr =
"-" * 10 + " CheckResult " + "-" * 10 + "\n" + checkResult.toString + "\n"
pgmStr + libConfigStr + reportStr
}
checkResults
} else {
Expand Down
14 changes: 6 additions & 8 deletions src/main/scala/fhetest/Phase/Generate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,16 @@ case class Generate(
} {
print("=" * 80 + "\n")
print("<Program>\n")
print(s"$absProgram\n")
print(s"${absProgram.absStmts}\n")
print("=" * 80 + "\n")
val assigned = absProgram.assignRandValues()
val adjusted = assigned.adjustScale(encType)
print(s"$adjusted\n")
print("-" * 80 + "\n")

val ast: Goal = buildAbsProgram(adjusted)
val result = Interp(ast, 32768, 65537)
val encParams = absProgram.libConfig.encParams
val result = Interp(ast, encParams.ringDim, encParams.plainMod)
print("CLEAR" + " : " + result + "\n")
for {
backend <- backends
Expand All @@ -80,8 +81,6 @@ case class Generate(
{ workspaceDir =>
given DirName = workspaceDir
val mulDepth = adjusted.mulDepth
// Default RingDim, PlainModulus with MulDepth
val encParams = EncParams(32768, mulDepth, 65537)
Print(
ast,
symbolTable,
Expand All @@ -95,7 +94,6 @@ case class Generate(
)
}
print("-" * 80 + "\n")
// }
}

def boilerplate(): (Goal, SymbolTable, _) =
Expand All @@ -111,11 +109,11 @@ case class Generate(
)
T2DSLParser(input_stream).Statement()

def toT2Program(t: AbsProgram): T2Program =
val programStr = baseStrFront + t.absStmts
def toT2Program(absProg: AbsProgram): T2Program =
val programStr = baseStrFront + absProg.absStmts
.map(_.stringify())
.foldLeft("")(_ + _) + baseStrBack
T2Program(programStr, t.libConfig)
T2Program(programStr, absProg.libConfig)

def buildAbsProgram(absProg: AbsProgram): Goal =
val stmts = absProg.absStmts.map(_.stringify()).map(parseStmt)
Expand Down

0 comments on commit 918eb95

Please sign in to comment.