diff --git a/src/main/scala/fhetest/Checker/Utils.scala b/src/main/scala/fhetest/Checker/Utils.scala index 26762ef..4d6f183 100644 --- a/src/main/scala/fhetest/Checker/Utils.scala +++ b/src/main/scala/fhetest/Checker/Utils.scala @@ -16,6 +16,14 @@ import java.io.{File, PrintWriter} import java.nio.file.{Files, Path, Paths, StandardCopyOption} case class Failure(library: String, failedResult: String) + +trait ResultInfo { + val programId: Int + val program: T2Program + val result: String + val SEAL: String + val OpenFHE: String +} case class ResultValidInfo( programId: Int, program: T2Program, @@ -24,7 +32,7 @@ case class ResultValidInfo( expected: String, SEAL: String, OpenFHE: String, -) +) extends ResultInfo case class ResultInvalidInfo( programId: Int, program: T2Program, @@ -33,7 +41,7 @@ case class ResultInvalidInfo( invalidFilters: List[String], SEAL: String, OpenFHE: String, -) +) extends ResultInfo // Define en/decoders using Circe // Scheme @@ -277,10 +285,12 @@ object DumpUtil { } } - def readResult(filePath: String): ResultValidInfo = { + def readResult(filePath: String): ResultInfo = { val fileContents = readFile(filePath) - val resultValidInfo = decode[ResultValidInfo](fileContents) - resultValidInfo match { + val resultInfo = + if (filePath contains "invalid") decode[ResultInvalidInfo](fileContents) + else decode[ResultValidInfo](fileContents) + resultInfo match { case Right(info) => info case Left(error) => throw new Exception(s"Error: $error") } diff --git a/src/main/scala/fhetest/Command.scala b/src/main/scala/fhetest/Command.scala index 1cc6dd2..a531680 100644 --- a/src/main/scala/fhetest/Command.scala +++ b/src/main/scala/fhetest/Command.scala @@ -257,7 +257,8 @@ case object CmdTest extends BackendCommand("test") { } case object CmdReplay extends Command("replay") { - val help = "Replay the given json." + val help = + "Replay the given json with specified backend (default: interpreter)." val examples = List( "fhetest replay -fromjson:logs/test/success/2.json", "fhetest replay -fromjson:logs/test/success/2.json -b:OpenFHE", diff --git a/src/main/scala/fhetest/Generate/ValidFilter.scala b/src/main/scala/fhetest/Generate/ValidFilter.scala index 6edb326..b62968e 100644 --- a/src/main/scala/fhetest/Generate/ValidFilter.scala +++ b/src/main/scala/fhetest/Generate/ValidFilter.scala @@ -15,6 +15,7 @@ import fhetest.Utils.* // FilterLenIsLessThanRingDim, // FilterModSizeIsBeteween14And60bits, // FilterMulDepthIsEnough, +// FilterMulDepthIsNotNegative, // FilterOpenFHEBFVModuli, // FilterPlainModEnableBatching, /* commented */ // FilterPlainModIsPositive, /* commented */ @@ -61,6 +62,32 @@ object ValidFilter { ) } + case class FilterMulDepthIsNotNegative( + prev: LibConfigDomain, + validFilter: Boolean, + ) extends ValidFilter(prev, validFilter) { + def getFilteredLibConfigDomain(): LibConfigDomain = + LibConfigDomain( + scheme = prev.scheme, + ringDim = prev.ringDim, + mulDepth = + if (validFilter) + (realMulDepth => (prev.mulDepth)(realMulDepth).filter(_ >= 0)) + else + (realMulDepth => (prev.mulDepth)(realMulDepth).filterNot(_ >= 0)), + plainMod = prev.plainMod, + firstModSize = prev.firstModSize, + scalingModSize = prev.scalingModSize, + securityLevel = prev.securityLevel, + scalingTechnique = prev.scalingTechnique, + lenMin = prev.lenMin, + lenMax = prev.lenMax, + boundMin = prev.boundMin, + boundMax = prev.boundMax, + rotateBound = prev.rotateBound, + ) + } + // TODO: There are 2 options for this implementation // * Current implementation filters scalingModSize which is not greater than firstModSize // * Another option is to filter firstModSize to be not smaller than scalingModeSize