Skip to content

Commit

Permalink
Add debug option and fix several bugs in Check
Browse files Browse the repository at this point in the history
  • Loading branch information
jaeho committed Mar 5, 2024
1 parent 918eb95 commit ac2ba61
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 33 deletions.
23 changes: 18 additions & 5 deletions src/main/scala/fhetest/Command.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,21 @@ case object CmdCheck extends BackendCommand("check") {
case object CmdTest extends BackendCommand("test") {
val help = "Check after Generate random T2 programs."
val examples = List(
"fhetest test -type:int -stg:random",
"fhetest test -type:int -stg:random -count:10",
"fhetest test -type:double -stg:exhaust -count:10",
"fhetest test -type:double -stg:random -json:true -seal:4.0.0 -openfhe:1.0.4",
"fhetest test -stg:random",
"fhetest test -stg:random -count:10",
"fhetest test -stg:exhaust -count:10",
"fhetest test -stg:random -json:true -seal:4.0.0 -openfhe:1.0.4",
)

def runJob(config: Config): Unit =
val encType = config.encType.getOrElseThrow("No encType given.")
val encType = config.libConfigOpt match {
case Some(libConfig) =>
libConfig.scheme match {
case Scheme.CKKS => ENC_TYPE.ENC_DOUBLE
case _ => ENC_TYPE.ENC_INT
}
case None => config.encType.getOrElseThrow("No encType given.")
}
val genStrategy = config.genStrategy.getOrElse(Strategy.Random)
val genCount = config.genCount
val generator = Generate(encType, genStrategy, config.filter)
Expand All @@ -221,6 +228,11 @@ case object CmdTest extends BackendCommand("test") {
val toJson = config.toJson
val sealVersion = config.sealVersion
val openfheVersion = config.openfheVersion
if (config.debug) {
println(s"EncType : $encType")
println(s"SEAL version : $sealVersion")
println(s"OpenFHE version : $openfheVersion")
}
val outputs = Check(
programs,
backendList,
Expand All @@ -229,6 +241,7 @@ case object CmdTest extends BackendCommand("test") {
sealVersion,
openfheVersion,
config.filter,
config.debug,
)
for (program, output) <- outputs do {
println("=" * 80)
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/fhetest/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Config(
var libConfigOpt: Option[LibConfig] = None,
var filter: Boolean = true,
var silent: Boolean = false,
var debug: Boolean = false,
)

object Config {
Expand Down Expand Up @@ -57,6 +58,7 @@ object Config {
config.libConfigOpt = Some(LibConfig())
case "filter" => config.filter = value.toBoolean
case "silent" => config.silent = value.toBoolean
case "debug" => config.debug = value.toBoolean
case _ => throw new Error(s"Unknown option: $key")
}
case _ => // 잘못된 형식의 인자 처리
Expand Down
14 changes: 12 additions & 2 deletions src/main/scala/fhetest/LibConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ Ciphertext<DCRTPoly> tmp_;"""

lazy val sealStr: String =
lazy val scaleModsStr = s", ${scalingModSize}" * encParams.mulDepth
lazy val encoderName =
if (scheme == Scheme.CKKS) "encoder"
else "batch_encoder"
lazy val encoderType =
if (scheme == Scheme.CKKS) s"CKKSEncoder"
else "BatchEncoder"
lazy val slotStr =
if (scheme == Scheme.CKKS) s"slot_count"
else "slots"

lazy val moduliStr = s"vector<int> { $firstModSize$scaleModsStr, 60 }"
s"""EncryptionParameters parms(scheme_type::${scheme
.toString()
Expand All @@ -95,8 +105,8 @@ keygen.create_galois_keys(gal_keys);
Encryptor encryptor(context, public_key);
Evaluator evaluator(context);
Decryptor decryptor(context, secret_key);
CKKSEncoder encoder(context);
size_t slot_count = encoder.slot_count();
$encoderType $encoderName(context);
size_t $slotStr = $encoderName.slot_count();
Plaintext tmp;
Ciphertext tmp_;
"""
Expand Down
65 changes: 41 additions & 24 deletions src/main/scala/fhetest/Phase/Check.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ case object Check {
val executeResPairs = backends.map(backend =>
BackendResultPair(
backend.toString,
execute(backend, encParams, parsed),
execute(backend, encParams, parsed, program.libConfig),
),
)
diffResults(interpResPair, executeResPairs, encType, encParams.plainMod)
Expand All @@ -60,40 +60,55 @@ case object Check {
sealVersion: String,
openfheVersion: String,
validCheck: Boolean,
debug: Boolean,
): LazyList[(T2Program, CheckResult)] = {
setTestDir()
val checkResults = for {
val checkResults: LazyList[Option[(T2Program, CheckResult)]] = for {
(program, i) <- programs.zipWithIndex
encParams = encParamsOpt.getOrElse(program.libConfig.encParams)
parsed <- parse(program).toOption
interpResult <- interp(parsed, encParams).toOption
overflowBound = program.libConfig.firstModSize
if !validCheck || notOverflow(interpResult, overflowBound)
} yield {
val encType = parsed._3
val interpResPair = BackendResultPair("CLEAR", interpResult)
val executeResPairs = backends.map(backend =>
BackendResultPair(
backend.toString,
execute(backend, encParams, parsed),
),
)
val checkResult =
diffResults(interpResPair, executeResPairs, encType, encParams.plainMod)
if (toJson)
dumpResult(program, i, checkResult, sealVersion, openfheVersion)
(program, checkResult)
val result: ExecuteResult = interp(parsed, encParams) match {
case Success(interpValue) => interpValue
case Failure(_) => InterpError
}
val overflowBound = program.libConfig.firstModSize
if !validCheck || notOverflow(result, overflowBound) then {
val encType = parsed._3
val interpResPair = BackendResultPair("CLEAR", result)
val executeResPairs = backends.map(backend =>
BackendResultPair(
backend.toString,
execute(backend, encParams, parsed, program.libConfig),
),
)
val checkResult =
diffResults(
interpResPair,
executeResPairs,
encType,
encParams.plainMod,
)
if (toJson)
dumpResult(program, i, checkResult, sealVersion, openfheVersion)
Some(program, checkResult)
} else {
None
}
}
checkResults
checkResults.flatten
}

// TODO: Need to be revised
def notOverflow(interpResult: Normal, overflowBound: Int): Boolean =
def notOverflow(interpResult: ExecuteResult, overflowBound: Int): Boolean =
val limit = math.pow(2, overflowBound)
val lines = interpResult.res.split("\n")
lines.forall { line =>
val max = line.split(" ").map(_.toDouble).max
max < limit
interpResult match {
case Normal(res) =>
res.split("\n").forall { line =>
val max = line.split(" ").map(_.toDouble).max
max < limit
}
case _ => true
}

def apply(
Expand Down Expand Up @@ -166,6 +181,7 @@ case object Check {
backend: Backend,
encParams: EncParams,
parsed: (Goal, SymbolTable, ENC_TYPE),
libConfig: LibConfig,
): ExecuteResult = {
val (ast, symbolTable, encType) = parsed
withBackendTempDir(
Expand All @@ -179,6 +195,7 @@ case object Check {
encType,
backend,
encParamsOpt = Some(encParams),
libConfigOpt = Some(libConfig),
)
try {
val res = Execute(backend)
Expand Down
3 changes: 2 additions & 1 deletion workspace/OpenFHE/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ else()
link_libraries( ${OpenFHE_SHARED_LIBRARIES} )
endif()

add_compile_options(-Wno-unused-variable)
#turn off any warnings
add_compile_options(-w)
add_executable(test.out)
target_sources(test.out PRIVATE
${CMAKE_CURRENT_LIST_DIR}/compiled/test.cpp
Expand Down
3 changes: 2 additions & 1 deletion workspace/SEAL/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ else()
message(FATAL_ERROR "Cannot find target SEAL::seal or SEAL::seal_shared")
endif()

add_compile_options(-Wno-unused-variable)
#turn off any warnings
add_compile_options(-w)
add_executable(test.out)
target_sources(test.out PRIVATE
${CMAKE_CURRENT_LIST_DIR}/compiled/test.cpp
Expand Down

0 comments on commit ac2ba61

Please sign in to comment.