From 0461e360d6d1dde89eb4498f974556a6ee2afa13 Mon Sep 17 00:00:00 2001 From: Hyerin Park Date: Wed, 3 Jul 2024 03:05:30 +0000 Subject: [PATCH] add using keyword sets --- .../scala/fhetest/Checker/CheckResult.scala | 17 ++ .../fhetest/Checker/ExceptionMsgKeyword.scala | 101 ++++++++ .../scala/fhetest/Checker/ExecuteResult.scala | 3 + src/main/scala/fhetest/Checker/Utils.scala | 227 ++++++++++++------ src/main/scala/fhetest/Command.scala | 90 +++---- src/main/scala/fhetest/Config.scala | 2 + .../scala/fhetest/Generate/AbsProgram.scala | 3 +- .../Generate/AbsProgramGenerator.scala | 44 +++- .../fhetest/Generate/LibConfigDomain.scala | 3 - .../fhetest/Generate/LibConfigGenerator.scala | 36 ++- .../scala/fhetest/Generate/ValidFilter.scala | 52 ++-- src/main/scala/fhetest/Phase/Check.scala | 87 ++++++- src/main/scala/fhetest/Phase/Generate.scala | 3 +- src/main/scala/fhetest/Utils/Utils.scala | 9 + src/main/scala/fhetest/fhetest.scala | 2 +- 15 files changed, 505 insertions(+), 174 deletions(-) create mode 100644 src/main/scala/fhetest/Checker/ExceptionMsgKeyword.scala diff --git a/src/main/scala/fhetest/Checker/CheckResult.scala b/src/main/scala/fhetest/Checker/CheckResult.scala index 3c1d109..8b28a79 100644 --- a/src/main/scala/fhetest/Checker/CheckResult.scala +++ b/src/main/scala/fhetest/Checker/CheckResult.scala @@ -18,5 +18,22 @@ case class Diff( results: List[BackendResultPair], fails: List[BackendResultPair], ) extends CheckResult +// InvalidResults: only for printing results in debug mode case class InvalidResults(results: List[BackendResultPair]) extends CheckResult +case class InvalidNormalResults( + results: List[BackendResultPair], + normals: List[BackendResultPair], +) extends CheckResult +case class InvalidExpectedExceptions( + results: List[BackendResultPair], + expectedExceptions: List[BackendResultPair], +) extends CheckResult +case class InvalidUnexpectedExceptions( + results: List[BackendResultPair], + unexpectedExceptions: List[BackendResultPair], +) extends CheckResult +case class InvalidErrors( + results: List[BackendResultPair], + errors: List[BackendResultPair], +) extends CheckResult case class ParserError(results: List[BackendResultPair]) extends CheckResult diff --git a/src/main/scala/fhetest/Checker/ExceptionMsgKeyword.scala b/src/main/scala/fhetest/Checker/ExceptionMsgKeyword.scala new file mode 100644 index 0000000..76999f5 --- /dev/null +++ b/src/main/scala/fhetest/Checker/ExceptionMsgKeyword.scala @@ -0,0 +1,101 @@ +package fhetest.Checker + +import fhetest.Generate.ValidFilter +import fhetest.Generate.getValidFilterList //TODO +import fhetest.Generate.Utils.InvalidFilterIdx + +type LibConfigArgumentName = String + +val libConfigArguments: Set[LibConfigArgumentName] = + Set( + "Scheme", + "RingDim", + "MulDepth", + "PlaindMod", + "ModSize", + "FirstModSize", + "ScalingModSize", + "SecurityLevel", + "ScalingTechnique", + "Len", + "Bound", + "RotateBound", + ) + +// TODO: fill out +// Note: Keywords should be in lower cases! +def mapLibConfigArgument2Keywords( + argName: LibConfigArgumentName, +): Set[String] = { + val commonKeywords = + Set(" parameters", "ring dimension", " ringdim", " primes", " overflow") + val modSizeKeywords = Set(" moduli", " bit_sizes", "bit length") + val uniqueKeywords = argName match { + case "Scheme" => Set("scheme") + case "RingDim" => Set("ring dimension", " ringdim") + case "MulDepth" => + Set( + "multiplicative depth", + " towers", + "end of modulus switching chain", + "removing last element of dcrtpoly", + "scale out of bounds", + "encoded values are too large", + ) + case "PlainMod" => Set(" plain_modulus", " plaintext_modulus") + case "ModSize" => modSizeKeywords + case "FirstModSize" => modSizeKeywords ++ Set() + case "ScalingModSize" => modSizeKeywords ++ Set() + case "SecurityLevel" => Set("security level", " SecurityLevel") + case "ScalingTechnique" => Set("security mode") + case "Len" => + Set( + "values_matrix size", + " values_size", + // "should not be greater than ringdim", // Currently, overlap with " ringDim" in commonKeywords + ) + case "Bound" => Set("encoded values are too large") + case "RotateBound" => Set("out_of_range", "evalkey for index") + case s: String => + throw new Exception(s"$s is not defined as LibConfigArgument.") + } + commonKeywords ++ uniqueKeywords +} + +val validFilters = + getValidFilterList().map(filter => filter.getSimpleName.replace("$", "")) + +def mapFilterName2LibConfigArgumentMap( + filterName: String, +): LibConfigArgumentName = + filterName match { + case "FilterBoundIsLessThanPlainMod" => "Bound" + case "FilterBoundIsLessThanPowerOfModSize" => "Bound" + case "FilterFirstModSizeIsLargest" => "FirstModSize" + case "FilterLenIsLessThanRingDim" => "Len" + case "FilterModSizeIsBeteween14And60bits" => "ModSize" + case "FilterMulDepthIsEnough" => "MulDepth" + case "FilterMulDepthIsNotNegative" => "MulDepth" + case "FilterOpenFHEBFVModuli" => "ModSize" + case "FilterPlainModEnableBatching" => "PlainMod" + case "FilterPlainModIsPositive" => "PlainMod" + case "FilterRingDimIsPowerOfTwo" => "RingDim" + case "FilterRotateBoundTest" => "RotateBound" + case "FilterScalingTechniqueByScheme" => "ScalingTechnique" + case s: String => throw new Exception(s"Keyword for $s is undifined.") + } + +def getKeywordsFromFilters( + invalidFilterIdxList: List[InvalidFilterIdx], +): Set[String] = { + val invalidFilterNameList = invalidFilterIdxList.map { + case i => validFilters.apply(i) + } + invalidFilterNameList.foldLeft(Set[String]()) { (acc, filterName) => + { + val argName = mapFilterName2LibConfigArgumentMap(filterName) + val keywords = mapLibConfigArgument2Keywords(argName) + acc ++ keywords + } + } +} diff --git a/src/main/scala/fhetest/Checker/ExecuteResult.scala b/src/main/scala/fhetest/Checker/ExecuteResult.scala index 12b2f83..873396c 100644 --- a/src/main/scala/fhetest/Checker/ExecuteResult.scala +++ b/src/main/scala/fhetest/Checker/ExecuteResult.scala @@ -15,6 +15,9 @@ case object PrintError extends ExecuteResult { case class LibraryError(msg: String) extends ExecuteResult { override def toString: String = s"LibraryError: $msg" } +case class LibraryException(msg: String) extends ExecuteResult { + override def toString: String = s"LibraryException: $msg" +} case object ParseError extends ExecuteResult case object TimeoutError extends ExecuteResult { override def toString: String = s"timeout" diff --git a/src/main/scala/fhetest/Checker/Utils.scala b/src/main/scala/fhetest/Checker/Utils.scala index 97757b8..40d5c27 100644 --- a/src/main/scala/fhetest/Checker/Utils.scala +++ b/src/main/scala/fhetest/Checker/Utils.scala @@ -1,7 +1,7 @@ package fhetest.Checker import fhetest.Utils.* -import fhetest.Generate.{T2Program, ValidFilter} +import fhetest.Generate.{T2Program, ValidFilter, getValidFilterList} import fhetest.TEST_DIR import fhetest.Generate.LibConfig @@ -15,20 +15,20 @@ import scala.io.Source import java.io.{File, PrintWriter} import java.nio.file.{Files, Path, Paths, StandardCopyOption} -case class Failure(library: String, failedResult: String) +case class ResultReport(library: String, failedResult: String) trait ResultInfo { val programId: Int val program: T2Program - val result: String + val resultType: String val SEAL: String val OpenFHE: String } case class ResultValidInfo( programId: Int, program: T2Program, - result: String, - failures: List[Failure], + resultType: String, + failures: List[ResultReport], expected: String, SEAL: String, OpenFHE: String, @@ -36,8 +36,9 @@ case class ResultValidInfo( case class ResultInvalidInfo( programId: Int, program: T2Program, - result: String, - outputs: List[Failure], + resultType: String, + results: List[ResultReport], + others: List[ResultReport], invalidFilters: List[String], SEAL: String, OpenFHE: String, @@ -134,14 +135,14 @@ implicit val libConfigDecoder: Decoder[LibConfig] = Decoder.instance { cursor => } implicit val t2ProgramEncoder: Encoder[T2Program] = deriveEncoder implicit val t2ProgramDecoder: Decoder[T2Program] = deriveDecoder -implicit val failureEncoder: Encoder[Failure] = deriveEncoder -implicit val failureDecoder: Decoder[Failure] = deriveDecoder +implicit val resultReportEncoder: Encoder[ResultReport] = deriveEncoder +implicit val resultReportDecoder: Decoder[ResultReport] = deriveDecoder implicit val resultValidInfoEncoder: Encoder[ResultValidInfo] = Encoder.forProduct7( "programId", "program", - "result", + "resultType", "failures", "expected", "SEAL", @@ -150,7 +151,7 @@ implicit val resultValidInfoEncoder: Encoder[ResultValidInfo] = ( ri.programId, ri.program, - ri.result, + ri.resultType, ri.failures, ri.expected, ri.SEAL, @@ -161,7 +162,7 @@ implicit val resultValidInfoDecoder: Decoder[ResultValidInfo] = Decoder.forProduct7( "programId", "program", - "result", + "resultType", "failures", "expected", "SEAL", @@ -169,11 +170,12 @@ implicit val resultValidInfoDecoder: Decoder[ResultValidInfo] = )(ResultValidInfo.apply) implicit val resultInvalidInfoEncoder: Encoder[ResultInvalidInfo] = - Encoder.forProduct7( + Encoder.forProduct8( "programId", "program", - "result", - "outputs", + "resultType", + "results", + "others", "invalidFilters", "SEAL", "OpenFHE", @@ -181,19 +183,21 @@ implicit val resultInvalidInfoEncoder: Encoder[ResultInvalidInfo] = ( ri.programId, ri.program, - ri.result, - ri.outputs, + ri.resultType, + ri.results, + ri.others, ri.invalidFilters, ri.SEAL, ri.OpenFHE, ), ) implicit val resultInvalidInfoDecoder: Decoder[ResultInvalidInfo] = - Decoder.forProduct7( + Decoder.forProduct8( "programId", "program", - "result", - "outputs", + "resultType", + "results", + "others", "invalidFilters", "SEAL", "OpenFHE", @@ -231,69 +235,117 @@ object DumpUtil { res: CheckResult, sealVersion: String, openfheVersion: String, - ): Unit = res match { - case InvalidResults(results) => { - val resultString = "InvalidResults" - val outputs = results.map(backendResultPair => - Failure(backendResultPair.backend, backendResultPair.result.toString()), - ) - val validFilters = classOf[ValidFilter].getDeclaredClasses.toList - .filter { cls => - classOf[ValidFilter] - .isAssignableFrom(cls) && cls != classOf[ValidFilter] - } - val invalidFilters = program.invalidFilterIdxList.map { - case i => validFilters.apply(i).getSimpleName.replace("$", "") + ): Unit = { + val resultString = res match { + case Same(_) => "Success" + case Diff(_, _) => "Fail" + case ParserError(_) => "ParseError" + } + val failures = res match { + case Diff(_, fails) => + fails.map(fail => ResultReport(fail.backend, fail.result.toString())) + case _ => List.empty[ResultReport] + } + val expected = res match { + case Same(results) => + results.headOption.map(_.result.toString()).getOrElse("") + case Diff(results, _) => + results + .partition(_.backend == "CLEAR") + ._1 + .headOption + .map(_.result.toString()) + .getOrElse("") + } + val resultValidInfo = ResultValidInfo( + i, + program, + resultString, + failures, + expected, + sealVersion, + openfheVersion, + ) + val filename = res match + case Same(_) => s"$succDir/$i.json" + case Diff(_, fails) => s"$failDir/$i.json" + case ParserError(_) => s"$psrErrDir/$i.json" + dumpFile(resultValidInfo.asJson.spaces2, filename) + } + + def dumpInvalidResult( + program: T2Program, + i: Int, + resLst: List[CheckResult], + sealVersion: String, + openfheVersion: String, + invalidFilters: List[String] = List[String](), + ): Unit = resLst match { + case head :: next => { + val (resultString, results, interested, filename) = head match { + case InvalidNormalResults(results, interested) => + ( + "InvalidNormal", + results, + interested, + s"$invalidNormalDir/$i.json", + ) + case InvalidExpectedExceptions(results, interested) => + ( + "InvalidExpectedException", + results, + interested, + s"$invalidExpectedExceptionDir/$i.json", + ) + case InvalidUnexpectedExceptions(results, interested) => + ( + "InvalidUnexpectedException", + results, + interested, + s"$invalidUnexpectedExceptionDir/$i.json", + ) + case InvalidErrors(results, interested) => + ("InvalidError", results, interested, s"$invalidErrorDir/$i.json") + } + val interestedResults = interested.map { backendResultPair => + ResultReport( + backendResultPair.backend, + backendResultPair.result.toString(), + ) } - val resultValidInfo = ResultInvalidInfo( + val interestedLibrariesNames = interested.map(_.backend) + val others = results.foldLeft(List[ResultReport]()) { + (acc, backendResultPair) => + if (interestedLibrariesNames.contains(backendResultPair.backend)) + acc + else { + acc :+ ResultReport( + backendResultPair.backend, + backendResultPair.result.toString(), + ) + } + } + val resultInvalidInfo = ResultInvalidInfo( i, program, resultString, - outputs, + interestedResults, + others, invalidFilters, sealVersion, openfheVersion, ) - val filename = s"$testInvalidDir/$i.json" - dumpFile(resultValidInfo.asJson.spaces2, filename) - } - case _ => { - val resultString = res match { - case Same(_) => "Success" - case Diff(_, _) => "Fail" - case ParserError(_) => "ParseError" - } - val failures = res match { - case Diff(_, fails) => - fails.map(fail => Failure(fail.backend, fail.result.toString())) - case _ => List.empty[Failure] - } - val expected = res match { - case Same(results) => - results.headOption.map(_.result.toString()).getOrElse("") - case Diff(results, _) => - results - .partition(_.backend == "CLEAR") - ._1 - .headOption - .map(_.result.toString()) - .getOrElse("") - } - val resultValidInfo = ResultValidInfo( - i, + dumpFile(resultInvalidInfo.asJson.spaces2, filename) + dumpInvalidResult( program, - resultString, - failures, - expected, + i, + next, sealVersion, openfheVersion, + invalidFilters, ) - val filename = res match - case Same(_) => s"$succDir/$i.json" - case Diff(_, fails) => s"$failDir/$i.json" - case ParserError(_) => s"$psrErrDir/$i.json" - dumpFile(resultValidInfo.asJson.spaces2, filename) } + case Nil => () } def dumpCount(dir: String, countMap: Map[List[Int], Int]): Unit = { @@ -320,21 +372,40 @@ val testInvalidDir = s"$TEST_DIR-invalid-$formattedDateTime" val succDir = s"$testDir/succ" val failDir = s"$testDir/fail" val psrErrDir = s"$testDir/psr_err" +val invalidNormalDir = s"$testInvalidDir/normal" +val invalidExceptionDir = s"$testInvalidDir/exception" +val invalidExpectedExceptionDir = s"$invalidExceptionDir/expected" +val invalidUnexpectedExceptionDir = s"$invalidExceptionDir/unexpected" +val invalidErrorDir = s"$testInvalidDir/error" + val testDirPath = Paths.get(testDir) val testInvalidDirPath = Paths.get(testInvalidDir) val succDirPath = Paths.get(succDir) val failDirPath = Paths.get(failDir) val psrErrDirPath = Paths.get(psrErrDir) +val invalidNormalDirPath = Paths.get(invalidNormalDir) +val invalidExceptionDirPath = Paths.get(invalidExceptionDir) +val invalidExpectedExceptionDirPath = Paths.get(invalidExpectedExceptionDir) +val invalidUnexpectedExceptionDirPath = Paths.get(invalidUnexpectedExceptionDir) +val invalidErrorDirPath = Paths.get(invalidErrorDir) -def setValidTestDir(): Unit = { - val pathLst = List(testDirPath, succDirPath, failDirPath, psrErrDirPath) - pathLst.foreach(path => +def setValidTestDir(): Unit = setTestDirs( + List(testDirPath, succDirPath, failDirPath, psrErrDirPath), +) +def setInvalidTestDir(): Unit = setTestDirs( + List( + testInvalidDirPath, + invalidNormalDirPath, + invalidExceptionDirPath, + invalidExpectedExceptionDirPath, + invalidUnexpectedExceptionDirPath, + invalidErrorDirPath, + ), +) + +def setTestDirs(paths: List[Path]): Unit = + paths.foreach(path => if (!Files.exists(path)) { Files.createDirectories(path) }, ) -} -def setInvalidTestDir(): Unit = - if (!Files.exists(testInvalidDirPath)) { - Files.createDirectories(testInvalidDirPath) - } diff --git a/src/main/scala/fhetest/Command.scala b/src/main/scala/fhetest/Command.scala index ee1d9bf..a175318 100644 --- a/src/main/scala/fhetest/Command.scala +++ b/src/main/scala/fhetest/Command.scala @@ -223,9 +223,11 @@ case object CmdTest extends BackendCommand("test") { val genStrategy = config.genStrategy.getOrElse(Strategy.Random) val genCount = config.genCount val validFilter = config.validFilter - val generator = Generate(encType, genStrategy, validFilter) + val noFilterOpt = config.noFilterOpt + val generator = Generate(encType, genStrategy, validFilter, noFilterOpt) val programs = generator(genCount) - val backendList = List(Backend.SEAL, Backend.OpenFHE) + // val backendList = List(Backend.SEAL, Backend.OpenFHE) // commented out for evaluation + val backendList = List(Backend.OpenFHE) // tmp: for evaluation val encParamsOpt = config.libConfigOpt.map(_.encParams) val toJson = config.toJson val sealVersion = config.sealVersion.getOrElse(SEAL_VERSIONS.head) @@ -305,45 +307,45 @@ case object CmdReplay extends Command("replay") { } } -case object CmdCount extends Command("count") { - val help = - "Count the number of programs tested for each combination of valid filters" - val examples = List( - "fhetest count -dir:logs/test-invalid", - ) - def runJob(config: Config): Unit = - val dirString = config.dirName.getOrElseThrow("No directory given.") - if (dirString contains "invalid") { - val dir = new File(dirString) - if (dir.exists() && dir.isDirectory) { - val numOfValidFilters = - classOf[ValidFilter].getDeclaredClasses.toList.filter { cls => - classOf[ValidFilter] - .isAssignableFrom(cls) && cls != classOf[ValidFilter] - }.length - val allCombinations = (1 to numOfValidFilters).toList.flatMap( - combinations(_, numOfValidFilters), - ) - var countMap: Map[List[Int], Int] = - allCombinations.foldLeft(Map.empty[List[Int], Int]) { - case (acc, comb) => - acc + (comb -> 0) - } - val files = Files.list(Paths.get(dirString)) - val fileList = files.iterator().asScala.toList - for { - filePath <- fileList - fileName = filePath.toString() - } yield { - val resultInfo = DumpUtil.readResult(fileName) - val t2Program = resultInfo.program - val invalidFilterIdxList = t2Program.invalidFilterIdxList - countMap = countMap.updatedWith(invalidFilterIdxList) { - case Some(cnt) => Some(cnt + 1) - case None => Some(1) // unreachable - } - } - DumpUtil.dumpCount(dirString, countMap) - } - } else println("Directrory contains test cases of VALID programs") -} +// Deprecated: as reimplementing Checker for invalid program testing +// case object CmdCount extends Command("count") { +// val help = +// "Count the number of programs tested for each combination of valid filters" +// val examples = List( +// "fhetest count -dir:logs/test-invalid", +// ) +// def runJob(config: Config): Unit = +// val dirString = config.dirName.getOrElseThrow("No directory given.") +// if (dirString contains "invalid") { +// val dir = new File(dirString) +// if (dir.exists() && dir.isDirectory) { +// val numOfValidFilters = +// classOf[ValidFilter].getDeclaredClasses.toList.filter { cls => +// classOf[ValidFilter] +// .isAssignableFrom(cls) && cls != classOf[ValidFilter] +// }.length +// val allCombinations = (1 to numOfValidFilters).toList.flatMap( +// combinations(_, numOfValidFilters), +// ) +// var countMap: Map[List[Int], Int] = +// allCombinations.foldLeft(Map.empty[List[Int], Int]) { +// case (acc, comb) => acc + (comb -> 0) +// } +// val files = Files.list(Paths.get(dirString)) +// val fileList = files.iterator().asScala.toList +// for { +// filePath <- fileList +// fileName = filePath.toString() +// } yield { +// val resultInfo = DumpUtil.readResult(fileName) +// val t2Program = resultInfo.program +// val invalidFilterIdxList = t2Program.invalidFilterIdxList +// countMap = countMap.updatedWith(invalidFilterIdxList) { +// case Some(cnt) => Some(cnt + 1) +// case None => Some(1) // unreachable +// } +// } +// DumpUtil.dumpCount(dirString, countMap) +// } +// } else println("Directrory contains test cases of VALID programs") +// } diff --git a/src/main/scala/fhetest/Config.scala b/src/main/scala/fhetest/Config.scala index 9f14d36..c5dc4cf 100644 --- a/src/main/scala/fhetest/Config.scala +++ b/src/main/scala/fhetest/Config.scala @@ -14,6 +14,7 @@ class Config( var genStrategy: Option[Strategy] = None, var genCount: Option[Int] = None, var validFilter: Boolean = true, + var noFilterOpt: Boolean = false, var toJson: Boolean = false, var sealVersion: Option[String] = None, var openfheVersion: Option[String] = None, @@ -68,6 +69,7 @@ object Config { config.libConfigOpt = Some(LibConfig()) case "fromjson" => config.fromJson = Some(value) case "filter" => config.validFilter = value.toBoolean + case "nofilter" => config.noFilterOpt = value.toBoolean case "silent" => config.silent = value.toBoolean case "debug" => config.debug = value.toBoolean case "timeout" => config.timeLimit = Some(value.toInt) diff --git a/src/main/scala/fhetest/Generate/AbsProgram.scala b/src/main/scala/fhetest/Generate/AbsProgram.scala index ea269a8..9788748 100644 --- a/src/main/scala/fhetest/Generate/AbsProgram.scala +++ b/src/main/scala/fhetest/Generate/AbsProgram.scala @@ -23,7 +23,8 @@ case class AbsProgram( def assignRandValues(): AbsProgram = { def lx() = Random.between(1, len + 1) def ly() = Random.between(1, len + 1) - def vc() = Random.between(0, rotateBound + 1) + def vc() = if (rotateBound < 21) Random.between(0, rotateBound + 1) + else Random.between(21, rotateBound + 1) // Generate Random Values def vxs(): (List[Int] | List[Double]) = bound match { diff --git a/src/main/scala/fhetest/Generate/AbsProgramGenerator.scala b/src/main/scala/fhetest/Generate/AbsProgramGenerator.scala index 7465d08..a1fa34b 100644 --- a/src/main/scala/fhetest/Generate/AbsProgramGenerator.scala +++ b/src/main/scala/fhetest/Generate/AbsProgramGenerator.scala @@ -11,23 +11,44 @@ extension (s: Strategy) def getGenerator( encType: ENC_TYPE, validFilter: Boolean, + noFilterOpt: Boolean, ): AbsProgramGenerator = s match { case Strategy.Exhaustive => - ExhaustiveGenerator(encType: ENC_TYPE, validFilter: Boolean) + ExhaustiveGenerator( + encType: ENC_TYPE, + validFilter: Boolean, + noFilterOpt: Boolean, + ) case Strategy.Random => - RandomGenerator(encType: ENC_TYPE, validFilter: Boolean) + RandomGenerator( + encType: ENC_TYPE, + validFilter: Boolean, + noFilterOpt: Boolean, + ) } // AbsProgram Generator -trait AbsProgramGenerator(encType: ENC_TYPE, validFilter: Boolean) { +trait AbsProgramGenerator( + encType: ENC_TYPE, + validFilter: Boolean, + noFilterOpt: Boolean, +) { def generateAbsPrograms(): LazyList[AbsProgram] val lcGen = - if validFilter then ValidLibConfigGenerator(encType) + if noFilterOpt then RandomLibConfigGenerator(encType) + else if validFilter then ValidLibConfigGenerator(encType) else InvalidLibConfigGenerator(encType) } -case class ExhaustiveGenerator(encType: ENC_TYPE, validFilter: Boolean) - extends AbsProgramGenerator(encType: ENC_TYPE, validFilter: Boolean) { +case class ExhaustiveGenerator( + encType: ENC_TYPE, + validFilter: Boolean, + noFilterOpt: Boolean, +) extends AbsProgramGenerator( + encType: ENC_TYPE, + validFilter: Boolean, + noFilterOpt: Boolean, + ) { val libConfigGens = lcGen.getLibConfigGenerators() def generateAbsPrograms(): LazyList[AbsProgram] = { def allAbsProgramsOfSize(n: Int): LazyList[AbsProgram] = @@ -61,8 +82,15 @@ case class ExhaustiveGenerator(encType: ENC_TYPE, validFilter: Boolean) } -case class RandomGenerator(encType: ENC_TYPE, validFilter: Boolean) - extends AbsProgramGenerator(encType: ENC_TYPE, validFilter: Boolean) { +case class RandomGenerator( + encType: ENC_TYPE, + validFilter: Boolean, + noFilterOpt: Boolean, +) extends AbsProgramGenerator( + encType: ENC_TYPE, + validFilter: Boolean, + noFilterOpt: Boolean, + ) { def generateAbsPrograms(): LazyList[AbsProgram] = { def randomAbsStmtsOfSize(n: Int): List[AbsStmt] = (1 to n).map(_ => Random.shuffle(allAbsStmts).head).toList diff --git a/src/main/scala/fhetest/Generate/LibConfigDomain.scala b/src/main/scala/fhetest/Generate/LibConfigDomain.scala index 43f2f89..8094e48 100644 --- a/src/main/scala/fhetest/Generate/LibConfigDomain.scala +++ b/src/main/scala/fhetest/Generate/LibConfigDomain.scala @@ -2,9 +2,6 @@ package fhetest.Generate import fhetest.Utils.* -// import java.util.stream.Stream -// import scala.jdk.CollectionConverters._ - type RingDim = Int type FirstModSize = Int type PlainMod = Int diff --git a/src/main/scala/fhetest/Generate/LibConfigGenerator.scala b/src/main/scala/fhetest/Generate/LibConfigGenerator.scala index 0edd720..265dfb7 100644 --- a/src/main/scala/fhetest/Generate/LibConfigGenerator.scala +++ b/src/main/scala/fhetest/Generate/LibConfigGenerator.scala @@ -42,11 +42,37 @@ def getLibConfigUniverse(scheme: Scheme) = LibConfigDomain( trait LibConfigGenerator(encType: ENC_TYPE) { def getLibConfigGenerators() : LazyList[List[AbsStmt] => Option[(LibConfig, List[InvalidFilterIdx])]] - val validFilters = classOf[ValidFilter].getDeclaredClasses.toList - .filter { cls => - classOf[ValidFilter] - .isAssignableFrom(cls) && cls != classOf[ValidFilter] + val validFilters = getValidFilterList() + // classOf[ValidFilter].getDeclaredClasses.toList + // .filter { cls => + // classOf[ValidFilter] + // .isAssignableFrom(cls) && cls != classOf[ValidFilter] + // } +} + +// for evaluation (no filters) +case class RandomLibConfigGenerator(encType: ENC_TYPE) + extends LibConfigGenerator(encType) { + def getLibConfigGenerators() + : LazyList[List[AbsStmt] => Option[(LibConfig, List[InvalidFilterIdx])]] = { + val libConfigGeneratorFromAbsStmts = (absStmts: List[AbsStmt]) => { + val randomScheme = + if encType == ENC_TYPE.ENC_INT then Scheme.values(Random.nextInt(2)) + else Scheme.CKKS + val libConfigUniverse = getLibConfigUniverse(randomScheme) + val libConfigOpt = randomLibConfigFromDomain( + true, + absStmts, + randomScheme, + libConfigUniverse, + ) + libConfigOpt match { + case None => None + case Some(libConfig) => Some((libConfig, List[InvalidFilterIdx]())) + } } + LazyList.continually(libConfigGeneratorFromAbsStmts) + } } case class ValidLibConfigGenerator(encType: ENC_TYPE) @@ -89,7 +115,7 @@ case class InvalidLibConfigGenerator(encType: ENC_TYPE) val allCombinations = (1 to totalNumOfFilters).toList.flatMap(combinations(_, totalNumOfFilters)) // TODO: currently generate only 1 test case for each class in each iteration - val numOfTC = 1 + val numOfTC = 20 val allCombinationsNtimes = allCombinations.flatMap { List.fill(numOfTC)(_) } val allCombinations_lazy = LazyList.from(allCombinationsNtimes) def getLibConfigGenerators() diff --git a/src/main/scala/fhetest/Generate/ValidFilter.scala b/src/main/scala/fhetest/Generate/ValidFilter.scala index b62968e..c6dace3 100644 --- a/src/main/scala/fhetest/Generate/ValidFilter.scala +++ b/src/main/scala/fhetest/Generate/ValidFilter.scala @@ -20,6 +20,7 @@ import fhetest.Utils.* // FilterPlainModEnableBatching, /* commented */ // FilterPlainModIsPositive, /* commented */ // FilterRingDimIsPowerOfTwo, /* commented */ +// FilterRotateBoundTest // FilterScalingTechniqueByScheme // ) @@ -27,6 +28,12 @@ trait ValidFilter(prev: LibConfigDomain, validFilter: Boolean) { def getFilteredLibConfigDomain(): LibConfigDomain } +def getValidFilterList() = classOf[ValidFilter].getDeclaredClasses.toList + .filter { cls => + classOf[ValidFilter] + .isAssignableFrom(cls) && cls != classOf[ValidFilter] + } + object ValidFilter { // def mulDepthIsSmall(realMulDepth: Int, configMulDepth: Int): Boolean = // realMulDepth < configMulDepth @@ -585,26 +592,27 @@ object ValidFilter { ) } - // TODO: just for test -> remove later? - // case class FilterRotateBoundTest( - // prev: LibConfigDomain, - // validFilter: Boolean, - // ) extends ValidFilter(prev, validFilter) { - // def getFilteredLibConfigDomain(): LibConfigDomain = - // LibConfigDomain( - // scheme = prev.scheme, - // ringDim = prev.ringDim, - // mulDepth = prev.mulDepth, - // plainMod = prev.plainMod, - // firstModSize = prev.firstModSize, - // scalingModSize = prev.scalingModSize, - // securityLevel = prev.securityLevel, - // scalingTechnique = prev.scalingTechnique, - // lenMin = prev.lenMin, - // lenMax = prev.lenMax, - // boundMin = prev.boundMin, - // boundMax = prev.boundMax, - // rotateBound = (1 to 20).toList, - // ) - // } + // TODO: change name + case class FilterRotateBoundTest( + prev: LibConfigDomain, + validFilter: Boolean, + ) extends ValidFilter(prev, validFilter) { + def getFilteredLibConfigDomain(): LibConfigDomain = + LibConfigDomain( + scheme = prev.scheme, + ringDim = prev.ringDim, + mulDepth = prev.mulDepth, + plainMod = prev.plainMod, + firstModSize = prev.firstModSize, + scalingModSize = prev.scalingModSize, + securityLevel = prev.securityLevel, + scalingTechnique = prev.scalingTechnique, + lenMin = prev.lenMin, + lenMax = prev.lenMax, + boundMin = prev.boundMin, + boundMax = prev.boundMax, + rotateBound = if (validFilter) { (1 to 20).toList } + else prev.rotateBound.filter(_ > 20), + ) + } } diff --git a/src/main/scala/fhetest/Phase/Check.scala b/src/main/scala/fhetest/Phase/Check.scala index 8ce513e..ace065a 100644 --- a/src/main/scala/fhetest/Phase/Check.scala +++ b/src/main/scala/fhetest/Phase/Check.scala @@ -42,7 +42,12 @@ case object Check { execute(backend, encParams, parsed, program.libConfig, timeLimit), ), ) - diffResults(interpResPair, executeResPairs, encType, encParams.plainMod) + diffValidResults( + interpResPair, + executeResPairs, + encType, + encParams.plainMod, + ) } result.getOrElse( ParserError(List(BackendResultPair("Parser", ParseError))), @@ -67,7 +72,8 @@ case object Check { (filePath, i) <- fileList.to(LazyList).zipWithIndex } yield { val fileStr = Files.readAllLines(filePath).asScala.mkString("") - val libConfig = LibConfig() // default libConfig for dir testing + val libConfig = + LibConfig() // TODO: now, default libConfig for dir testing val program = T2Program(fileStr, libConfig, List[InvalidFilterIdx]()) val checkResult = apply(program, backends, encParamsOpt, timeLimit) if (toJson) @@ -128,7 +134,7 @@ case object Check { ), ) val checkResult = - diffResults( + diffValidResults( interpResPair, executeResPairs, encType, @@ -170,19 +176,21 @@ case object Check { execute(backend, encParams, parsed, program.libConfig, timeLimit), ), ) - val checkResult: CheckResult = InvalidResults(executeResPairs) + // val checkResult: CheckResult = InvalidResults(executeResPairs) + val (topCheckResult, checkResultLst) = + classifyInvalidResults(executeResPairs, program.invalidFilterIdxList) if (toJson) - DumpUtil.dumpResult( + DumpUtil.dumpInvalidResult( program, i, - checkResult, + checkResultLst, sealVersion, openfheVersion, ) if (debug) { println(s"Program $i:") } - Some(program, checkResult) + Some(program, topCheckResult) } checkResults.flatten } @@ -199,7 +207,7 @@ case object Check { } // Get CheckResult from results of valid programs - def diffResults( + def diffValidResults( expected: BackendResultPair, obtained: List[BackendResultPair], encType: ENC_TYPE, @@ -212,6 +220,56 @@ case object Check { else Diff(results, fails) } + // Get a list of CheckResult from results of invalid programs + def classifyInvalidResults( + obtained: List[BackendResultPair], + invalidFilterIdxList: List[InvalidFilterIdx], + ): (CheckResult, List[CheckResult]) = { + var normals = List[BackendResultPair]() + var expectedExceptions = List[BackendResultPair]() + var unexpectedExceptions = List[BackendResultPair]() + var errors = List[BackendResultPair]() + obtained.map(backendResultPair => + backendResultPair.result match { + case Normal(_) => normals = normals :+ backendResultPair + case LibraryException(msg) => { + val relatedKeywords: Set[String] = + getKeywordsFromFilters(invalidFilterIdxList) + val expected: Boolean = + relatedKeywords.foldLeft(false) { (acc, keyword) => + if (acc) true + else (msg.toLowerCase().contains(keyword)) + } + if (expected) { + expectedExceptions = expectedExceptions :+ backendResultPair + } else { + unexpectedExceptions = unexpectedExceptions :+ backendResultPair + } + } + case _ => + errors = errors :+ backendResultPair // LibraryError & TimeoutError + }, + ) + var checkResultLst = List[CheckResult]() + if (!normals.isEmpty) + checkResultLst = checkResultLst :+ InvalidNormalResults(obtained, normals) + if (!expectedExceptions.isEmpty) + checkResultLst = checkResultLst :+ InvalidExpectedExceptions( + obtained, + expectedExceptions, + ) + if (!unexpectedExceptions.isEmpty) + checkResultLst = checkResultLst :+ InvalidUnexpectedExceptions( + obtained, + unexpectedExceptions, + ) + if (!errors.isEmpty) + checkResultLst = checkResultLst :+ InvalidErrors(obtained, errors) + + val topCheckResult = InvalidResults(obtained) + (topCheckResult, checkResultLst) + } + // parse wrapper def parse(program: T2Program) = { val pgm = program.content @@ -251,12 +309,19 @@ case object Check { libConfigOpt = Some(libConfig), ) try { - val res = Execute(backend, timeLimit) - Normal(res.trim) + val res = Execute(backend, timeLimit).trim + res match { + case _ if res.startsWith("Program terminated") => + LibraryError(res) + case _ if res.split(" ").forall(isNumber) => + Normal(res) + case _ => + LibraryException(res) + } } catch { case _: java.util.concurrent.TimeoutException => TimeoutError // TODO?: classify exception related with parmeters? - case ex: Exception => LibraryError(ex.getMessage) + // case ex: Exception => LibraryError(ex.getMessage) } } catch { case _ => PrintError diff --git a/src/main/scala/fhetest/Phase/Generate.scala b/src/main/scala/fhetest/Phase/Generate.scala index 272c4f8..85944a2 100644 --- a/src/main/scala/fhetest/Phase/Generate.scala +++ b/src/main/scala/fhetest/Phase/Generate.scala @@ -20,6 +20,7 @@ case class Generate( encType: ENC_TYPE, strategy: Strategy = Strategy.Random, validFilter: Boolean = true, + noFilterOpt: Boolean = false, ) { // TODO : This boilerplate code is really ugly. But I cannot find a better way to do this. val baseStrFront = encType match { @@ -33,7 +34,7 @@ case class Generate( val symbolTable = boilerplate()._2 - val absProgGen = strategy.getGenerator(encType, validFilter) + val absProgGen = strategy.getGenerator(encType, validFilter, noFilterOpt) val allAbsPrograms = absProgGen.generateAbsPrograms() diff --git a/src/main/scala/fhetest/Utils/Utils.scala b/src/main/scala/fhetest/Utils/Utils.scala index aca5619..6c09908 100644 --- a/src/main/scala/fhetest/Utils/Utils.scala +++ b/src/main/scala/fhetest/Utils/Utils.scala @@ -139,6 +139,15 @@ def withBackendTempDir[Result]( } } +def isNumber(s: String): Boolean = { + try { + s.toDouble + true + } catch { + case _: NumberFormatException => false + } +} + def compare( obtained: String, expected: String, diff --git a/src/main/scala/fhetest/fhetest.scala b/src/main/scala/fhetest/fhetest.scala index ad029c2..abaac2f 100644 --- a/src/main/scala/fhetest/fhetest.scala +++ b/src/main/scala/fhetest/fhetest.scala @@ -35,7 +35,7 @@ object FHETest { CmdReplay, // Make a json report of invalid program testing // Count the number of programs tested for each combination of valid filters - CmdCount, + // CmdCount, ) val cmdMap = commands.foldLeft[Map[String, Command]](Map()) { case (map, cmd) => map + (cmd.name -> cmd)