From 854e6a43fbc41b7ce3c6b1efbf1d225ef34a4f5b Mon Sep 17 00:00:00 2001 From: Hyerin Park Date: Sat, 23 Mar 2024 17:48:17 +0000 Subject: [PATCH] LibConfigGenerate:re-implement ValidFilter for invalid testing --- src/main/scala/fhetest/Command.scala | 7 +- src/main/scala/fhetest/Config.scala | 4 +- .../scala/fhetest/Generate/AbsProgram.scala | 30 +- .../Generate/AbsProgramGenerator.scala | 64 +- .../fhetest/Generate/LibConfigDomain.scala | 37 ++ .../fhetest/Generate/LibConfigGenerator.scala | 316 ++++----- src/main/scala/fhetest/Generate/Utils.scala | 18 + .../scala/fhetest/Generate/ValidFilter.scala | 601 +++++++++++++++++- src/main/scala/fhetest/Phase/Check.scala | 4 +- src/main/scala/fhetest/Phase/Generate.scala | 40 +- 10 files changed, 886 insertions(+), 235 deletions(-) create mode 100644 src/main/scala/fhetest/Generate/LibConfigDomain.scala diff --git a/src/main/scala/fhetest/Command.scala b/src/main/scala/fhetest/Command.scala index 4df32ae..1cc6dd2 100644 --- a/src/main/scala/fhetest/Command.scala +++ b/src/main/scala/fhetest/Command.scala @@ -171,7 +171,7 @@ case object CmdGen extends Command("gen") { def runJob(config: Config): Unit = val encType = config.encType.getOrElseThrow("No encType given.") val genCount = config.genCount.getOrElse(10) - val generator = Generate(encType, Strategy.Random, config.filter) + val generator = Generate(encType, Strategy.Random, config.validFilter) generator.show(List(Backend.SEAL, Backend.OpenFHE), genCount, encType) } @@ -217,7 +217,8 @@ case object CmdTest extends BackendCommand("test") { val encType = config.encType.getOrElseThrow("No encType given.") val genStrategy = config.genStrategy.getOrElse(Strategy.Random) val genCount = config.genCount - val generator = Generate(encType, genStrategy, config.filter) + val validFilter = config.validFilter + val generator = Generate(encType, genStrategy, validFilter) val programs = generator(genCount) val backendList = List(Backend.SEAL, Backend.OpenFHE) val encParamsOpt = config.libConfigOpt.map(_.encParams) @@ -236,7 +237,7 @@ case object CmdTest extends BackendCommand("test") { toJson, sealVersion, openfheVersion, - config.filter, + validFilter, config.debug, config.timeLimit, ) diff --git a/src/main/scala/fhetest/Config.scala b/src/main/scala/fhetest/Config.scala index e6fa8bd..9f14d36 100644 --- a/src/main/scala/fhetest/Config.scala +++ b/src/main/scala/fhetest/Config.scala @@ -13,12 +13,12 @@ class Config( var encType: Option[ENC_TYPE] = None, var genStrategy: Option[Strategy] = None, var genCount: Option[Int] = None, + var validFilter: Boolean = true, var toJson: Boolean = false, var sealVersion: Option[String] = None, var openfheVersion: Option[String] = None, var libConfigOpt: Option[LibConfig] = None, var fromJson: Option[String] = None, - var filter: Boolean = true, var silent: Boolean = false, var debug: Boolean = false, var timeLimit: Option[Int] = None, @@ -67,7 +67,7 @@ object Config { case "libconfig" => config.libConfigOpt = Some(LibConfig()) case "fromjson" => config.fromJson = Some(value) - case "filter" => config.filter = value.toBoolean + case "filter" => config.validFilter = value.toBoolean case "silent" => config.silent = value.toBoolean case "debug" => config.debug = value.toBoolean case "timeout" => config.timeLimit = Some(value.toInt) diff --git a/src/main/scala/fhetest/Generate/AbsProgram.scala b/src/main/scala/fhetest/Generate/AbsProgram.scala index 8965f81..ccf4395 100644 --- a/src/main/scala/fhetest/Generate/AbsProgram.scala +++ b/src/main/scala/fhetest/Generate/AbsProgram.scala @@ -18,21 +18,21 @@ case class AbsProgram( } // TODO: Change these filters to assertions? - lazy val isValid: Boolean = - mulDepthIsSmall(mulDepth, encParams.mulDepth) && - firstModSizeIsLargest(libConfig.firstModSize, libConfig.scalingModSize) && - modSizeIsUpto60bits(libConfig.firstModSize, libConfig.scalingModSize) && - openFHEBFVModuli( - libConfig.scheme, - libConfig.firstModSize, - libConfig.scalingModSize, - ) && - ringDimIsPowerOfTwo(encParams.ringDim) && - plainModIsPositive(encParams.plainMod) && - plainModEnableBatching(encParams.plainMod, encParams.ringDim) && - lenIsLessThanRingDim(len, encParams.ringDim, libConfig.scheme) && - boundIsLessThanPowerOfModSize(bound, libConfig.firstModSize) && - boundIsLessThanPlainMod(bound, encParams.plainMod) + // lazy val isValid: Boolean = + // mulDepthIsSmall(mulDepth, encParams.mulDepth) && + // firstModSizeIsLargest(libConfig.firstModSize, libConfig.scalingModSize) && + // modSizeIsUpto60bits(libConfig.firstModSize, libConfig.scalingModSize) && + // openFHEBFVModuli( + // libConfig.scheme, + // libConfig.firstModSize, + // libConfig.scalingModSize, + // ) && + // ringDimIsPowerOfTwo(encParams.ringDim) && + // plainModIsPositive(encParams.plainMod) && + // plainModEnableBatching(encParams.plainMod, encParams.ringDim) && + // lenIsLessThanRingDim(len, encParams.ringDim, libConfig.scheme) && + // boundIsLessThanPowerOfModSize(bound, libConfig.firstModSize) && + // boundIsLessThanPlainMod(bound, encParams.plainMod) def stringify: String = absStmts.map(_.stringify()).mkString("") diff --git a/src/main/scala/fhetest/Generate/AbsProgramGenerator.scala b/src/main/scala/fhetest/Generate/AbsProgramGenerator.scala index d78d3e9..e6763ac 100644 --- a/src/main/scala/fhetest/Generate/AbsProgramGenerator.scala +++ b/src/main/scala/fhetest/Generate/AbsProgramGenerator.scala @@ -10,58 +10,58 @@ enum Strategy: extension (s: Strategy) def getGenerator( encType: ENC_TYPE, - validCheck: Boolean, + validFilter: Boolean, ): AbsProgramGenerator = s match { case Strategy.Exhaustive => - ExhaustiveGenerator(encType: ENC_TYPE, validCheck: Boolean) + ExhaustiveGenerator(encType: ENC_TYPE, validFilter: Boolean) case Strategy.Random => - RandomGenerator(encType: ENC_TYPE, validCheck: Boolean) + RandomGenerator(encType: ENC_TYPE, validFilter: Boolean) } // AbsProgram Generator -trait AbsProgramGenerator(encType: ENC_TYPE, validCheck: Boolean) { +trait AbsProgramGenerator(encType: ENC_TYPE, validFilter: Boolean) { def generateAbsPrograms(): LazyList[AbsProgram] val lcGen = - if validCheck then ValidLibConfigGenerator(encType) - else RandomLibConfigGenerator(encType) + if validFilter then ValidLibConfigGenerator(encType) + else InvalidLibConfigGenerator(encType) } -case class ExhaustiveGenerator(encType: ENC_TYPE, validCheck: Boolean) - extends AbsProgramGenerator(encType: ENC_TYPE, validCheck: Boolean) { +case class ExhaustiveGenerator(encType: ENC_TYPE, validFilter: Boolean) + extends AbsProgramGenerator(encType: ENC_TYPE, validFilter: Boolean) { + val libConfigGens = lcGen.getLibConfigGenerators() def generateAbsPrograms(): LazyList[AbsProgram] = { - def allAbsProgramsOfSize(n: Int): LazyList[AbsProgram] = { + def allAbsProgramsOfSize(n: Int): LazyList[AbsProgram] = n match { case 1 => - allAbsStmts - .map(stmt => List(stmt)) - .map( - AbsProgram( - _, - lcGen.generateLibConfig(), - ), - ) + for { + stmt <- allAbsStmts + libConfigGen <- libConfigGens + } yield { + val stmts = List(stmt) + AbsProgram(stmts, libConfigGen(stmts)) + } case _ => for { stmt <- allAbsStmts program <- allAbsProgramsOfSize(n - 1) - } yield AbsProgram( - stmt :: program.absStmts, - lcGen.generateLibConfig(), - ) + libConfigGen <- libConfigGens + } yield { + val stmts = stmt :: program.absStmts + AbsProgram(stmts, libConfigGen(stmts)) + } } - } LazyList.from(1).flatMap(allAbsProgramsOfSize) } } -case class RandomGenerator(encType: ENC_TYPE, validCheck: Boolean) - extends AbsProgramGenerator(encType: ENC_TYPE, validCheck: Boolean) { +case class RandomGenerator(encType: ENC_TYPE, validFilter: Boolean) + extends AbsProgramGenerator(encType: ENC_TYPE, validFilter: Boolean) { def generateAbsPrograms(): LazyList[AbsProgram] = { - def randomAbsProgramOfSize(n: Int): AbsProgram = { - val absStmts = (1 to n).map(_ => Random.shuffle(allAbsStmts).head).toList - AbsProgram(absStmts, lcGen.generateLibConfig()) - } + def randomAbsStmtsOfSize(n: Int): List[AbsStmt] = + (1 to n).map(_ => Random.shuffle(allAbsStmts).head).toList + val libConfigGens = lcGen.getLibConfigGenerators() + // Generate Lengths from 1 to inf // LazyList.from(1) @@ -69,6 +69,12 @@ case class RandomGenerator(encType: ENC_TYPE, validCheck: Boolean) val randomLength: LazyList[Int] = LazyList.continually(Random.nextInt(20) + 1) - randomLength.map(randomAbsProgramOfSize) + for { + len <- randomLength + libConfigGen <- libConfigGens + } yield { + val stmts = randomAbsStmtsOfSize(len) + AbsProgram(stmts, libConfigGen(stmts)) + } } } diff --git a/src/main/scala/fhetest/Generate/LibConfigDomain.scala b/src/main/scala/fhetest/Generate/LibConfigDomain.scala new file mode 100644 index 0000000..43f2f89 --- /dev/null +++ b/src/main/scala/fhetest/Generate/LibConfigDomain.scala @@ -0,0 +1,37 @@ +package fhetest.Generate + +import fhetest.Utils.* + +// import java.util.stream.Stream +// import scala.jdk.CollectionConverters._ + +type RingDim = Int +type FirstModSize = Int +type PlainMod = Int + +case class LibConfigDomain( + scheme: Scheme, + ringDim: List[Int], + mulDepth: Int => List[Int], + plainMod: RingDim => List[Int], + firstModSize: Scheme => List[Int], + scalingModSize: Scheme => FirstModSize => List[Int], + securityLevel: List[SecurityLevel], + scalingTechnique: Scheme => List[ScalingTechnique], + lenMin: Scheme => RingDim => Int, + lenMax: Scheme => RingDim => Int, + boundMin: Scheme => PlainMod => FirstModSize => Int | Double, + boundMax: Scheme => PlainMod => FirstModSize => Int | Double, + rotateBound: List[Int], +) { + def stringify(): String = + s"""{scheme: ${scheme}} +{encParams: EncParams(${ringDim}, ${mulDepth}, ${plainMod})} +{(firstModSize, scalingModSize): (${firstModSize}, ${scalingModSize})} +{securityLevel: ${securityLevel}} +{scalingTechnique: ${scalingTechnique}} +{lenMax: ${lenMax}} +{boundMax: ${boundMax}} +{rotateBoundOpt: ${rotateBound}} +""" +} diff --git a/src/main/scala/fhetest/Generate/LibConfigGenerator.scala b/src/main/scala/fhetest/Generate/LibConfigGenerator.scala index 7e93da1..caaa8da 100644 --- a/src/main/scala/fhetest/Generate/LibConfigGenerator.scala +++ b/src/main/scala/fhetest/Generate/LibConfigGenerator.scala @@ -2,163 +2,191 @@ package fhetest.Generate import fhetest.Utils.* import scala.util.Random +import fhetest.Generate.Utils.combinations + +val ringDimCandidates: List[Int] = // also in ValidFilter + List(8192, 16384, 32768) + // List(8192, 16384, 32768, 65536, 131072) // also in ValidFilter + +def getLibConfigUniverse(scheme: Scheme) = LibConfigDomain( + scheme = scheme, + ringDim = ringDimCandidates, + mulDepth = (realMulDepth: Int) => (-20 to 20).toList, + plainMod = (ringDim: Int) => List(65537), + firstModSize = (scheme: Scheme) => (-100 to 100).toList, + scalingModSize = + (scheme: Scheme) => (firstModSize: Int) => (-100 to 100).toList, + securityLevel = SecurityLevel.values.toList, + scalingTechnique = (scheme: Scheme) => ScalingTechnique.values.toList, + lenMin = (scheme: Scheme) => (ringDim: Int) => 1, + lenMax = (scheme: Scheme) => (ringDim: Int) => 100000, + boundMin = (scheme: Scheme) => + (plainMod: Int) => + (firstModSize: Int) => + scheme match { + case Scheme.CKKS => 1d + case _ => 1 + }, + boundMax = (scheme: Scheme) => + (plainMod: Int) => + (firstModSize: Int) => + scheme match { + case Scheme.CKKS => math.pow(2, 64) + case _ => 1000 + }, + rotateBound = (0 to 40).toList, +) -val ringDimCandidates: List[Int] = - List(8192, 16384, 32768, 65536, 131072) trait LibConfigGenerator(encType: ENC_TYPE) { - def generateLibConfig(): LibConfig + def getLibConfigGenerators(): LazyList[List[AbsStmt] => LibConfig] + val validFilters = classOf[ValidFilter].getDeclaredClasses.toList + .filter { cls => + classOf[ValidFilter] + .isAssignableFrom(cls) && cls != classOf[ValidFilter] + } } case class ValidLibConfigGenerator(encType: ENC_TYPE) extends LibConfigGenerator(encType) { - def generateLibConfig(): LibConfig = { - val randomScheme = - if encType == ENC_TYPE.ENC_INT then Scheme.values(Random.nextInt(2)) - else Scheme.CKKS - val randomEncParams = { - val randomRingDim = Random.shuffle(ringDimCandidates).head - val randomMultDepth = - Random.nextInt(10 + 1) - // TODO: Currently randomPlainMod is fixed - val randomPlainMod = 65537 - EncParams(randomRingDim, randomMultDepth, randomPlainMod) - } - // modSizeIsUpto60bits - val randomFirstModSize: Int = - // https://github.com/openfheorg/openfhe-development/blob/main/src/pke/include/schemerns/rns-parametergeneration.h - Random.between(14, 60 + 1) - // if randomScheme == Scheme.BFV then Random.between(30, 60 + 1) - //// SEAL SEAL_MOD_BIT_COUNT_MIN = 2, SEAL_MOD_BIT_COUNT_MAX = 61 - //// OpenFHE modSize is upto 60 bits - // else Random.between(2, 60 + 1) - // firstModSizeIsLargest - // openFHEBFVModuli - val randomScalingModSize: Int = - // https://github.com/openfheorg/openfhe-development/blob/main/src/pke/include/schemerns/rns-parametergeneration.h - Random.between(14, randomFirstModSize + 1) - // if randomScheme == Scheme.BFV then - // Random.between(30, randomFirstModSize + 1) - //// SEAL SEAL_MOD_BIT_COUNT_MIN = 2, SEAL_MOD_BIT_COUNT_MAX = 61 - //// OpenFHE modSize is upto 60 bits - // else Random.between(2, randomFirstModSize + 1) - val randomSecurityLevel = - SecurityLevel.values(Random.nextInt(SecurityLevel.values.length)) - val randomScalingTechnique = - // Currently, exclude FLEXIBLEAUTOEXT in valid testing because of the following issue - // https://openfhe.discourse.group/t/unexpected-behavior-with-setscalingmodsize-and-flexibleautoext-in-bgv-scheme/1111 - // And, exclude NORESCALE in CKKS scheme because it is not supported - // https://openfhe.discourse.group/t/incorrect-result-when-using-norescale-scaling-technique-in-ckks-scheme/1119 - val scalingTechs = - import ScalingTechnique.* - if randomScheme == Scheme.CKKS then - Array( - FIXEDMANUAL, - FIXEDAUTO, - FLEXIBLEAUTO, - ) - else if randomScheme == Scheme.BFV then - Array( - NORESCALE, - ) - else - Array( - NORESCALE, - FIXEDMANUAL, - FIXEDAUTO, - FLEXIBLEAUTO, - ) - scalingTechs( - Random.nextInt(scalingTechs.length), + def getLibConfigGenerators(): LazyList[List[AbsStmt] => LibConfig] = { + val libConfigGeneratorFromAbsStmts = (absStmts: List[AbsStmt]) => { + val randomScheme = + if encType == ENC_TYPE.ENC_INT then Scheme.values(Random.nextInt(2)) + else Scheme.CKKS + val libConfigUniverse = getLibConfigUniverse(randomScheme) + val filteredLibConfigDomain = validFilters.foldLeft(libConfigUniverse)({ + case (curLibConfigDomain, curValidFilter) => { + val constructor = curValidFilter.getDeclaredConstructors.head + constructor.setAccessible(true) + val f = constructor + .newInstance(curLibConfigDomain, true) + .asInstanceOf[ValidFilter] + f.getFilteredLibConfigDomain() + } + }) + randomLibConfigFromDomain( + true, + absStmts, + randomScheme, + filteredLibConfigDomain, ) - // len must be larger than 0 - // lenIsLessThanRingDim - val randomLenOpt: Option[Int] = - val upper = randomScheme match { - case Scheme.CKKS => (randomEncParams.ringDim / 2) - case _ => randomEncParams.ringDim - } - Some(Random.between(1, upper + 1)) - val randomBoundOpt: Option[Int | Double] = - randomScheme match { - case Scheme.BFV | Scheme.BGV => - // bound ^ (mulDepth + 1) < plainMod = 2^16 + 1 - Some( - Random.between( - 1, - Math.pow(2, 16 % (randomEncParams.mulDepth + 1)).toInt + 1, - ), - ) - case Scheme.CKKS => - Some( - Random.between( - 1, - math.pow( - 2, - randomFirstModSize % (randomEncParams.mulDepth + 1) + 1, - ), - ), - ) - } - val randomRotateBoundOpt: Option[Int] = - Some(Random.between(0, 20 + 1)) - LibConfig( - randomScheme, - randomEncParams, - randomFirstModSize, - randomScalingModSize, - randomSecurityLevel, - randomScalingTechnique, - randomLenOpt, - randomBoundOpt, - randomRotateBoundOpt, - ) + } + LazyList.continually(libConfigGeneratorFromAbsStmts) } } -case class RandomLibConfigGenerator(encType: ENC_TYPE) +case class InvalidLibConfigGenerator(encType: ENC_TYPE) extends LibConfigGenerator(encType) { - def generateLibConfig(): LibConfig = { - val randomScheme = - if encType == ENC_TYPE.ENC_INT then Scheme.values(Random.nextInt(2)) - else Scheme.CKKS + val totalNumOfFilters = validFilters.length + val allCombinations = + (1 to totalNumOfFilters).toList.flatMap(combinations(_, totalNumOfFilters)) + // TODO: currently generate only 1 test case for each class + // val numOfTC = 10 + val allCombinations_lazy = LazyList.from(allCombinations) + def getLibConfigGenerators(): LazyList[List[AbsStmt] => LibConfig] = for { + combination <- allCombinations_lazy + } yield { + println(combination) + val libConfigGeneratorFromAbsStmts = (absStmts: List[AbsStmt]) => { + val randomScheme = + if encType == ENC_TYPE.ENC_INT then Scheme.values(Random.nextInt(2)) + else Scheme.CKKS + val libConfigUniverse = getLibConfigUniverse(randomScheme) + val filteredLibConfigDomain = validFilters.foldLeft(libConfigUniverse)({ + case (curLibConfigDomain, curValidFilter) => { + val curValidFilterIdx = validFilters.indexOf(curValidFilter) + val inInValid = combination.contains(curValidFilterIdx) + val constructor = curValidFilter.getDeclaredConstructors.head + constructor.setAccessible(true) + val f = constructor + .newInstance(curLibConfigDomain, !inInValid) + .asInstanceOf[ValidFilter] + f.getFilteredLibConfigDomain() + } + }) + randomLibConfigFromDomain( + false, + absStmts, + randomScheme, + filteredLibConfigDomain, + ) + } + libConfigGeneratorFromAbsStmts + } +} - val randomEncParams = { - // TODO: Currently only MultDepth is random - val randomRingDim = Random.shuffle(ringDimCandidates).head - val randomMultDepth = - Random.between(-10, 10 + 1) - val randomPlainMod = 65537 - EncParams(randomRingDim, randomMultDepth, randomPlainMod) +// TODO: No handling when +def randomLibConfigFromDomain( + validFilter: Boolean, + absStmts: List[AbsStmt], + randomScheme: Scheme, + filteredLibConfigDomain: LibConfigDomain, +): LibConfig = { + val randomRingDim = Random.shuffle(filteredLibConfigDomain.ringDim).head + val randomMulDepth = { + val realMulDepth: Int = absStmts.count { + case Mul(_, _) | MulP(_, _) => true; case _ => false } - val randomFirstModSize: Int = - Random.between(-100, 100 + 1) - val randomScalingModSize: Int = - Random.between(-100, 100 + 1) - val randomSecurityLevel = - SecurityLevel.values(Random.nextInt(SecurityLevel.values.length)) - val randomScalingTechnique = - ScalingTechnique.values(Random.nextInt(ScalingTechnique.values.length)) - // len must be larger than 0 - val randomLenOpt: Option[Int] = - Some(Random.between(1, 100000 + 1)) - val randomBoundOpt: Option[Int | Double] = - randomScheme match { - case Scheme.BFV | Scheme.BGV => - Some(Random.between(1, 1000 + 1)) - case Scheme.CKKS => Some(Random.between(1, math.pow(2, 64) + 1)) - } - val randomRotateBoundOpt: Option[Int] = - Some(Random.between(0, 40 + 1)) - LibConfig( - randomScheme, - randomEncParams, - randomFirstModSize, - randomScalingModSize, - randomSecurityLevel, - randomScalingTechnique, - randomLenOpt, - randomBoundOpt, - randomRotateBoundOpt, + println(s"realMulDepth: $realMulDepth") + Random.shuffle((filteredLibConfigDomain.mulDepth)(realMulDepth)).head + } + val randomPlainMod = + Random.shuffle((filteredLibConfigDomain.plainMod)(randomRingDim)).head + val randomFirstModSize = + Random + .shuffle((filteredLibConfigDomain.firstModSize)(randomScheme)) + .head + val randomScalingModSize = Random + .shuffle( + (filteredLibConfigDomain.scalingModSize)(randomScheme)( + randomFirstModSize, + ), ) + .head + val randomSecurityLevel = + Random.shuffle(filteredLibConfigDomain.securityLevel).head + val randomScalingTechnique = Random + .shuffle((filteredLibConfigDomain.scalingTechnique)(randomScheme)) + .head + val randomLenOpt: Option[Int] = { + val upper = + (filteredLibConfigDomain.lenMax)(randomScheme)(randomRingDim) + val lower = + (filteredLibConfigDomain.lenMin)(randomScheme)(randomRingDim) + Some(Random.between(lower, upper + 1)) } + val randomBoundOpt: Option[Int | Double] = { + val upper = (filteredLibConfigDomain.boundMax)(randomScheme)( + randomPlainMod, + )(randomFirstModSize) + val lower = (filteredLibConfigDomain.boundMin)(randomScheme)( + randomPlainMod, + )(randomFirstModSize) + lower match { + case li: Int => + upper match { + case ui: Int => Some(Random.between(li, ui + 1)) + case _ => Some(Random.between(1, 100000 + 1)) // unreachable + } + case ld: Double => + upper match { + case ud: Int => Some(Random.between(ld, ud)) + case _ => Some(Random.between(1, math.pow(2, 64))) // unreachable + } + } + } + val randomRotateBoundOpt: Option[Int] = + Some(Random.shuffle(filteredLibConfigDomain.rotateBound).head) + + LibConfig( + randomScheme, + EncParams(randomRingDim, randomMulDepth, randomPlainMod), + randomFirstModSize, + randomScalingModSize, + randomSecurityLevel, + randomScalingTechnique, + randomLenOpt, + randomBoundOpt, + randomRotateBoundOpt, + ) } diff --git a/src/main/scala/fhetest/Generate/Utils.scala b/src/main/scala/fhetest/Generate/Utils.scala index d6ce3c1..565f75b 100644 --- a/src/main/scala/fhetest/Generate/Utils.scala +++ b/src/main/scala/fhetest/Generate/Utils.scala @@ -7,4 +7,22 @@ object Utils { Assign(name, v) def assignValues(name: String, vs: (List[Int] | List[Double])): AbsStmt = AssignVec(name, vs) + + // mCn: all combibation of length n from 0 to (m-1) + def combinations(n: Int, m: Int): List[List[Int]] = { + def combinations_helper( + k: Int, + start: Int, + acc: List[Int], + ): List[List[Int]] = { + if (k == 0) List(acc.reverse) + else { + (start until m).flatMap { i => + combinations_helper(k - 1, i + 1, i :: acc) + }.toList + } + } + + combinations_helper(n, 0, List.empty) + } } diff --git a/src/main/scala/fhetest/Generate/ValidFilter.scala b/src/main/scala/fhetest/Generate/ValidFilter.scala index df9752e..ef54367 100644 --- a/src/main/scala/fhetest/Generate/ValidFilter.scala +++ b/src/main/scala/fhetest/Generate/ValidFilter.scala @@ -1,50 +1,581 @@ package fhetest.Generate -// TODO: This must be removed because currently the Libconfig generation logic handle thees validations import fhetest.Utils.* +import fhetest.Checker.schemeDecoder -def mulDepthIsSmall(realMulDepth: Int, configMulDepth: Int): Boolean = - realMulDepth < configMulDepth +// TODO: currently filters regarding plainMod, ringDim is commented since using fixed plaindMod, ringDim +// TODO: FilterRingDimIsPowerOfTwo just filter into fixed candidates when validFilter = true -def firstModSizeIsLargest(firstModSize: Int, scalingModSize: Int): Boolean = - scalingModSize <= firstModSize +/* Current validFilters */ +// * defined & used in LibConfigGenerator +// * automatically arranged in alphabetical order +// val validFilters = List( +// FilterBoundIsLessThanPlainMod, +// FilterBoundIsLessThanPowerOfModSize, +// FilterFirstModSizeIsLargest, +// FilterLenIsLessThanRingDim, +// FilterModSizeIsBeteween14And60bits, +// FilterMulDepthIsEnough, +// FilterOpenFHEBFVModuli, +// FilterPlainModEnableBatching, /* commented */ +// FilterPlainModIsPositive, /* commented */ +// FilterRingDimIsPowerOfTwo, /* commented */ +// FilterScalingTechniqueByScheme +// ) -def modSizeIsUpto60bits(firstModSize: Int, scalingModSize: Int): Boolean = - (firstModSize <= 60) && (scalingModSize <= 60) +trait ValidFilter(prev: LibConfigDomain, validFilter: Boolean) { + def getFilteredLibConfigDomain(): LibConfigDomain +} + +object ValidFilter { + // def mulDepthIsSmall(realMulDepth: Int, configMulDepth: Int): Boolean = + // realMulDepth < configMulDepth + case class FilterMulDepthIsEnough( + prev: LibConfigDomain, + validFilter: Boolean, + ) extends ValidFilter(prev, validFilter) { + def getFilteredLibConfigDomain(): LibConfigDomain = + LibConfigDomain( + scheme = prev.scheme, + ringDim = prev.ringDim, + mulDepth = + if (validFilter) + ( + realMulDepth => + (prev.mulDepth)(realMulDepth).filter(_ > realMulDepth), + ) + else + ( + realMulDepth => + (prev.mulDepth)(realMulDepth).filterNot(_ > realMulDepth), + ), + plainMod = prev.plainMod, + firstModSize = prev.firstModSize, + scalingModSize = prev.scalingModSize, + securityLevel = prev.securityLevel, + scalingTechnique = prev.scalingTechnique, + lenMin = prev.lenMin, + lenMax = prev.lenMax, + boundMin = prev.boundMin, + boundMax = prev.boundMax, + rotateBound = prev.rotateBound, + ) + } + + // def firstModSizeIsLargest(firstModSize: Int, scalingModSize: Int): Boolean = + // scalingModSize <= firstModSize + case class FilterFirstModSizeIsLargest( + prev: LibConfigDomain, + validFilter: Boolean, + ) extends ValidFilter(prev, validFilter) { + def getFilteredLibConfigDomain(): LibConfigDomain = + LibConfigDomain( + scheme = prev.scheme, + ringDim = prev.ringDim, + mulDepth = prev.mulDepth, + plainMod = prev.plainMod, + firstModSize = prev.firstModSize, + scalingModSize = + if (validFilter) + ( + scheme => + firstModSize => + (prev.scalingModSize)(scheme)(firstModSize) + .filter(_ <= firstModSize), + ) + else + ( + scheme => + firstModSize => + (prev.scalingModSize)(scheme)(firstModSize) + .filterNot(_ <= firstModSize), + ), + securityLevel = prev.securityLevel, + scalingTechnique = prev.scalingTechnique, + lenMin = prev.lenMin, + lenMax = prev.lenMax, + boundMin = prev.boundMin, + boundMax = prev.boundMax, + rotateBound = prev.rotateBound, + ) + } + + // https://github.com/openfheorg/openfhe-development/blob/main/src/pke/include/schemerns/rns-parametergeneration.h + // Both firstModSize and scalingModSize are >= 14 + // def modSizeIsUpto60bits(firstModSize: Int, scalingModSize: Int): Boolean = + // (14 <= firstModSize <= 60) && (14 <= scalingModSize <= 60) + case class FilterModSizeIsBeteween14And60bits( + prev: LibConfigDomain, + validFilter: Boolean, + ) extends ValidFilter(prev, validFilter) { + def getFilteredLibConfigDomain(): LibConfigDomain = + LibConfigDomain( + scheme = prev.scheme, + ringDim = prev.ringDim, + mulDepth = prev.mulDepth, + plainMod = prev.plainMod, + firstModSize = + if (validFilter) + ( + scheme => + (prev + .firstModSize)(scheme) + .filter({ case m => (m >= 14) && (m <= 60) }), + ) + else + ( + scheme => + (prev + .firstModSize)(scheme) + .filterNot({ case m => (m >= 14) && (m <= 60) }), + ), + scalingModSize = + if (validFilter) + ( + scheme => + firstModSize => + (prev.scalingModSize)(scheme)(firstModSize) + .filter({ case m => (m >= 14) && (m <= 60) }), + ) + else + ( + scheme => + firstModSize => + (prev.scalingModSize)(scheme)(firstModSize) + .filterNot({ case m => (m >= 14) && (m <= 60) }), + ), + securityLevel = prev.securityLevel, + scalingTechnique = prev.scalingTechnique, + lenMin = prev.lenMin, + lenMax = prev.lenMax, + boundMin = prev.boundMin, + boundMax = prev.boundMax, + rotateBound = prev.rotateBound, + ) + } -// OpenFHE/src/pke/lib/scheme/bfvrns/bfvrns-parametergeneration.cpp:53 -// BFVrns.ParamsGen: Number of bits in CRT moduli should be in the range from 30 to 60 -def openFHEBFVModuli( - scheme: Scheme, - firstModSize: Int, - scalingModSize: Int, -): Boolean = - !(scheme == Scheme.BFV) || ((30 <= firstModSize) && (30 <= scalingModSize)) + // OpenFHE/src/pke/lib/scheme/bfvrns/bfvrns-parametergeneration.cpp:53 + // BFVrns.ParamsGen: Number of bits in CRT moduli should be in the range from 30 to 60 + // def openFHEBFVModuli( + // scheme: Scheme, + // firstModSize: Int, + // scalingModSize: Int, + // ): Boolean = + // !(scheme == Scheme.BFV) || ((30 <= firstModSize) && (30 <= scalingModSize)) + case class FilterOpenFHEBFVModuli( + prev: LibConfigDomain, + validFilter: Boolean, + ) extends ValidFilter(prev, validFilter) { + def getFilteredLibConfigDomain(): LibConfigDomain = + LibConfigDomain( + scheme = prev.scheme, + ringDim = prev.ringDim, + mulDepth = prev.mulDepth, + plainMod = prev.plainMod, + firstModSize = + if (validFilter) + ( + scheme => + if (scheme == Scheme.BFV) + (prev.firstModSize)(scheme).filter(_ >= 30) + else (prev.firstModSize)(scheme), + ) + else + ( + scheme => + if (scheme == Scheme.BFV) + (prev.firstModSize)(scheme).filterNot(_ >= 30) + else (prev.firstModSize)(scheme), + ), + scalingModSize = + if (validFilter) + ( + scheme => + firstModSize => + if (scheme == Scheme.BFV) + (prev.scalingModSize)(scheme)(firstModSize).filter(_ >= 30) + else (prev.scalingModSize)(scheme)(firstModSize), + ) + else + ( + scheme => + firstModSize => + if (scheme == Scheme.BFV) + (prev.scalingModSize)(scheme)(firstModSize) + .filterNot(_ >= 30) + else (prev.scalingModSize)(scheme)(firstModSize), + ), + securityLevel = prev.securityLevel, + scalingTechnique = prev.scalingTechnique, + lenMin = prev.lenMin, + lenMax = prev.lenMax, + boundMin = prev.boundMin, + boundMax = prev.boundMax, + rotateBound = prev.rotateBound, + ) + } -def ringDimIsPowerOfTwo(n: Int): Boolean = (n > 0) && ((n & (n - 1)) == 0) + // TODO: This filter is not included since using fixed ringDim + // TODO: FilterRingDimIsPowerOfTwo just filter into fixed candidates when validFilter = true + // def ringDimIsPowerOfTwo(n: Int): Boolean = (n > 0) && ((n & (n - 1)) == 0) + // case class FilterRingDimIsPowerOfTwo( + // prev: LibConfigDomain, + // validFilter: Boolean, + // ) extends ValidFilter(prev, validFilter) { + // // TODO: currently use only few candidates in valid testing + // val ringDimCandidates: List[Int] = List(8192, 16384, 32768) + // def getFilteredLibConfigDomain(): LibConfigDomain = + // LibConfigDomain( + // scheme = prev.scheme, + // ringDim = + // if (validFilter) + // ringDimCandidates // TODO: prev.ringDim.filter({ case n => (n > 0) && ((n & (n - 1)) == 0) }) + // else + // prev.ringDim.filterNot({ + // case n => (n > 0) && ((n & (n - 1)) == 0) + // }), + // mulDepth = prev.mulDepth, + // plainMod = prev.plainMod, + // firstModSize = prev.firstModSize, + // scalingModSize = prev.scalingModSize, + // securityLevel = prev.securityLevel, + // scalingTechnique = prev.scalingTechnique, + // lenMin = prev.lenMin, + // lenMax = prev.lenMax, + // boundMin = prev.boundMin, + // boundMax = prev.boundMax, + // rotateBound = prev.rotateBound, + // ) + // } -def plainModIsPositive(m: Int): Boolean = m > 0 + // TODO: This filter is not included since using fixed plainMod + // def plainModIsPositive(m: Int): Boolean = m > 0 + // case class FilterPlainModIsPositive( + // prev: LibConfigDomain, + // validFilter: Boolean, + // ) extends ValidFilter(prev, validFilter) { + // def getFilteredLibConfigDomain(): LibConfigDomain = + // LibConfigDomain( + // scheme = prev.scheme, + // ringDim = prev.ringDim, + // mulDepth = prev.mulDepth, + // plainMod = + // if (validFilter) + // (ringDim => (prev.plainMod)(ringDim).filter(_ > 0)) + // else + // (ringDim => (prev.plainMod)(ringDim).filterNot(_ > 0)), + // firstModSize = prev.firstModSize, + // scalingModSize = prev.scalingModSize, + // securityLevel = prev.securityLevel, + // scalingTechnique = prev.scalingTechnique, + // lenMin = prev.lenMin, + // lenMax = prev.lenMax, + // boundMin = prev.boundMin, + // boundMax = prev.boundMax, + // rotateBound = prev.rotateBound, + // ) + // } -def plainModEnableBatching(m: Int, n: Int): Boolean = (m % (2 * n)) == 1 + // TODO: This filter is not included since using fixed plainMod + // def plainModEnableBatching(m: Int, n: Int): Boolean = (m % (2 * n)) == 1 + // case class FilterPlainModEnableBatching( + // prev: LibConfigDomain, + // validFilter: Boolean, + // ) extends ValidFilter(prev, validFilter) { + // def getFilteredLibConfigDomain(): LibConfigDomain = + // LibConfigDomain( + // scheme = prev.scheme, + // ringDim = prev.ringDim, + // mulDepth = prev.mulDepth, + // plainMod = + // if (validFilter) + // ( + // ringDim => + // (prev + // .plainMod)(ringDim) + // .filter({ case m => (m % (2 * ringDim)) == 1 }), + // ) + // else + // ( + // ringDim => + // (prev + // .plainMod)(ringDim) + // .filterNot({ case m => (m % (2 * ringDim)) == 1 }), + // ), + // firstModSize = prev.firstModSize, + // scalingModSize = prev.scalingModSize, + // securityLevel = prev.securityLevel, + // scalingTechnique = prev.scalingTechnique, + // lenMin = prev.lenMin, + // lenMax = prev.lenMax, + // boundMin = prev.boundMin, + // boundMax = prev.boundMax, + // rotateBound = prev.rotateBound, + // ) + // } -def lenIsLessThanRingDim(len: Int, n: Int, scheme: Scheme): Boolean = - scheme match { - case Scheme.CKKS => len <= (n / 2) - case _ => len <= n + // Currently, exclude FLEXIBLEAUTOEXT in valid testing because of the following issue + // https://openfhe.discourse.group/t/unexpected-behavior-with-setscalingmodsize-and-flexibleautoext-in-bgv-scheme/1111 + // And, exclude NORESCALE in CKKS scheme because it is not supported + // https://openfhe.discourse.group/t/incorrect-result-when-using-norescale-scaling-technique-in-ckks-scheme/1119 + case class FilterScalingTechniqueByScheme( + prev: LibConfigDomain, + validFilter: Boolean, + ) extends ValidFilter(prev, validFilter) { + def getFilteredLibConfigDomain(): LibConfigDomain = + LibConfigDomain( + scheme = prev.scheme, + ringDim = prev.ringDim, + mulDepth = prev.mulDepth, + plainMod = prev.plainMod, + firstModSize = prev.firstModSize, + scalingModSize = prev.scalingModSize, + securityLevel = prev.securityLevel, + scalingTechnique = + if (validFilter) + ( + scheme => + if (scheme == Scheme.CKKS) + List( + ScalingTechnique.FIXEDMANUAL, + ScalingTechnique.FIXEDAUTO, + ScalingTechnique.FLEXIBLEAUTO, + ) + else if (scheme == Scheme.BFV) List(ScalingTechnique.NORESCALE) + else + List( + ScalingTechnique.NORESCALE, + ScalingTechnique.FIXEDMANUAL, + ScalingTechnique.FIXEDAUTO, + ScalingTechnique.FLEXIBLEAUTO, + ), + ) + else + ( + scheme => + if (scheme == Scheme.CKKS) + List( + ScalingTechnique.NORESCALE, + ScalingTechnique.FLEXIBLEAUTOEXT, + ) + else if (scheme == Scheme.BFV) + List( + ScalingTechnique.FIXEDMANUAL, + ScalingTechnique.FIXEDAUTO, + ScalingTechnique.FLEXIBLEAUTO, + ScalingTechnique.FLEXIBLEAUTOEXT, + ) + else List(ScalingTechnique.FLEXIBLEAUTOEXT), + ), + lenMin = prev.lenMin, + lenMax = prev.lenMax, + boundMin = prev.boundMin, + boundMax = prev.boundMax, + rotateBound = prev.rotateBound, + ) } -def boundIsLessThanPowerOfModSize( - bound: Int | Double, - firstModSize: Int, -): Boolean = bound match { - case intBound: Int => intBound < math.pow(2, firstModSize) - case doubleBound: Double => doubleBound < math.pow(2, firstModSize) -} + // def lenIsLessThanRingDim(len: Int, n: Int, scheme: Scheme): Boolean = + // scheme match { + // case Scheme.CKKS => len <= (n / 2) + // case _ => len <= n + // } + case class FilterLenIsLessThanRingDim( + prev: LibConfigDomain, + validFilter: Boolean, + ) extends ValidFilter(prev, validFilter) { + def getFilteredLibConfigDomain(): LibConfigDomain = if (validFilter) + LibConfigDomain( + scheme = prev.scheme, + ringDim = prev.ringDim, + mulDepth = prev.mulDepth, + plainMod = prev.plainMod, + firstModSize = prev.firstModSize, + scalingModSize = prev.scalingModSize, + securityLevel = prev.securityLevel, + scalingTechnique = prev.scalingTechnique, + lenMin = prev.lenMin, + lenMax = ( + scheme => + ringDim => + if (scheme == Scheme.CKKS) + (prev.lenMax)(scheme)(ringDim).min(ringDim / 2) + else (prev.lenMax)(scheme)(ringDim).min(ringDim), + ), + boundMin = prev.boundMin, + boundMax = prev.boundMax, + rotateBound = prev.rotateBound, + ) + else + LibConfigDomain( + scheme = prev.scheme, + ringDim = prev.ringDim, + mulDepth = prev.mulDepth, + plainMod = prev.plainMod, + firstModSize = prev.firstModSize, + scalingModSize = prev.scalingModSize, + securityLevel = prev.securityLevel, + scalingTechnique = prev.scalingTechnique, + lenMin = ( + scheme => + ringDim => + if (scheme == Scheme.CKKS) + (prev.lenMax)(scheme)(ringDim).max(ringDim / 2 + 1) + else (prev.lenMax)(scheme)(ringDim).max(ringDim + 1), + ), + lenMax = prev.lenMax, + boundMin = prev.boundMin, + boundMax = prev.boundMax, + rotateBound = prev.rotateBound, + ) + } + + // def boundIsLessThanPowerOfModSize( + // bound: Int | Double, + // firstModSize: Int, + // ): Boolean = bound match { + // case intBound: Int => intBound < math.pow(2, firstModSize) + // case doubleBound: Double => doubleBound < math.pow(2, firstModSize) + // } + case class FilterBoundIsLessThanPowerOfModSize( + prev: LibConfigDomain, + validFilter: Boolean, + ) extends ValidFilter(prev, validFilter) { + def getFilteredLibConfigDomain(): LibConfigDomain = if (validFilter) + LibConfigDomain( + scheme = prev.scheme, + ringDim = prev.ringDim, + mulDepth = prev.mulDepth, + plainMod = prev.plainMod, + firstModSize = prev.firstModSize, + scalingModSize = prev.scalingModSize, + securityLevel = prev.securityLevel, + scalingTechnique = prev.scalingTechnique, + lenMin = prev.lenMin, + lenMax = prev.lenMax, + boundMin = prev.boundMin, + boundMax = ( + scheme => + plainMod => + firstModSize => + (prev.boundMax)(scheme)(plainMod)(firstModSize) match { + case intBound: Int => + intBound.min(math.pow(2, firstModSize).toInt) + case doubleBound: Double => + doubleBound.min(math.pow(2, firstModSize)) + }, + ), + rotateBound = prev.rotateBound, + ) + else + LibConfigDomain( + scheme = prev.scheme, + ringDim = prev.ringDim, + mulDepth = prev.mulDepth, + plainMod = prev.plainMod, + firstModSize = prev.firstModSize, + scalingModSize = prev.scalingModSize, + securityLevel = prev.securityLevel, + scalingTechnique = prev.scalingTechnique, + lenMin = prev.lenMin, + lenMax = prev.lenMax, + boundMin = ( + scheme => + plainMod => + firstModSize => + (prev.boundMax)(scheme)(plainMod)(firstModSize) match { + case intBound: Int => + intBound.max(math.pow(2, firstModSize).toInt + 1) + case doubleBound: Double => + doubleBound.max(math.pow(2, firstModSize) + 0.1) + }, + ), + boundMax = prev.boundMax, + rotateBound = prev.rotateBound, + ) + } + + // def boundIsLessThanPlainMod( + // bound: Int | Double, + // plainMod: Int, + // ): Boolean = bound match { + // case intBound: Int => intBound < plainMod + // case _: Double => true + // } + case class FilterBoundIsLessThanPlainMod( + prev: LibConfigDomain, + validFilter: Boolean, + ) extends ValidFilter(prev, validFilter) { + def getFilteredLibConfigDomain(): LibConfigDomain = if (validFilter) + LibConfigDomain( + scheme = prev.scheme, + ringDim = prev.ringDim, + mulDepth = prev.mulDepth, + plainMod = prev.plainMod, + firstModSize = prev.firstModSize, + scalingModSize = prev.scalingModSize, + securityLevel = prev.securityLevel, + scalingTechnique = prev.scalingTechnique, + lenMin = prev.lenMin, + lenMax = prev.lenMax, + boundMin = prev.boundMin, + boundMax = ( + scheme => + plainMod => + firstModSize => + (prev.boundMax)(scheme)(plainMod)(firstModSize) match { + case intBound: Int => intBound.min(plainMod) + case doubleBound: Double => doubleBound + }, + ), + rotateBound = prev.rotateBound, + ) + else + LibConfigDomain( + scheme = prev.scheme, + ringDim = prev.ringDim, + mulDepth = prev.mulDepth, + plainMod = prev.plainMod, + firstModSize = prev.firstModSize, + scalingModSize = prev.scalingModSize, + securityLevel = prev.securityLevel, + scalingTechnique = prev.scalingTechnique, + lenMin = prev.lenMin, + lenMax = prev.lenMax, + boundMin = ( + scheme => + plainMod => + firstModSize => + (prev.boundMax)(scheme)(plainMod)(firstModSize) match { + case intBound: Int => intBound.max(plainMod + 1) + case doubleBound: Double => doubleBound + }, + ), + boundMax = prev.boundMax, + rotateBound = prev.rotateBound, + ) + } -def boundIsLessThanPlainMod( - bound: Int | Double, - plainMod: Int, -): Boolean = bound match { - case intBound: Int => intBound < plainMod - case _: Double => true + // TODO: just for test -> remove later? + // case class FilterRotateBoundTest( + // prev: LibConfigDomain, + // validFilter: Boolean, + // ) extends ValidFilter(prev, validFilter) { + // def getFilteredLibConfigDomain(): LibConfigDomain = + // LibConfigDomain( + // scheme = prev.scheme, + // ringDim = prev.ringDim, + // mulDepth = prev.mulDepth, + // plainMod = prev.plainMod, + // firstModSize = prev.firstModSize, + // scalingModSize = prev.scalingModSize, + // securityLevel = prev.securityLevel, + // scalingTechnique = prev.scalingTechnique, + // lenMin = prev.lenMin, + // lenMax = prev.lenMax, + // boundMin = prev.boundMin, + // boundMax = prev.boundMax, + // rotateBound = (1 to 20).toList, + // ) + // } } diff --git a/src/main/scala/fhetest/Phase/Check.scala b/src/main/scala/fhetest/Phase/Check.scala index c690402..fc3f427 100644 --- a/src/main/scala/fhetest/Phase/Check.scala +++ b/src/main/scala/fhetest/Phase/Check.scala @@ -60,7 +60,7 @@ case object Check { toJson: Boolean, sealVersion: String, openfheVersion: String, - validCheck: Boolean, + validFilter: Boolean, debug: Boolean, timeLimit: Option[Int], ): LazyList[(T2Program, CheckResult)] = { @@ -78,7 +78,7 @@ case object Check { if program.libConfig.scheme == Scheme.CKKS then math.pow(2, program.libConfig.firstModSize) else program.libConfig.encParams.plainMod.toDouble - if !validCheck || notOverflow(result, overflowBound) then { + if !validFilter || notOverflow(result, overflowBound) then { val encType = parsed._3 val interpResPair = BackendResultPair("CLEAR", result) val executeResPairs = backends.map(backend => diff --git a/src/main/scala/fhetest/Phase/Generate.scala b/src/main/scala/fhetest/Phase/Generate.scala index 8caf9bd..50a8503 100644 --- a/src/main/scala/fhetest/Phase/Generate.scala +++ b/src/main/scala/fhetest/Phase/Generate.scala @@ -10,6 +10,7 @@ import org.twc.terminator.t2dsl_compiler.T2DSLparser.ParseException; import org.twc.terminator.t2dsl_compiler.T2DSLparser.T2DSLParser; import org.twc.terminator.t2dsl_compiler.T2DSLsyntaxtree.*; import java.nio.file.{Files, Paths}; +import scala.util.Random import java.io.*; import javax.print.attribute.EnumSyntax @@ -18,7 +19,7 @@ import scala.jdk.CollectionConverters._ case class Generate( encType: ENC_TYPE, strategy: Strategy = Strategy.Random, - checkValid: Boolean = true, + validFilter: Boolean = true, ) { // TODO : This boilerplate code is really ugly. But I cannot find a better way to do this. val baseStrFront = encType match { @@ -32,7 +33,7 @@ case class Generate( val symbolTable = boilerplate()._2 - val absProgGen = strategy.getGenerator(encType, checkValid) + val absProgGen = strategy.getGenerator(encType, validFilter) val allAbsPrograms = absProgGen.generateAbsPrograms() @@ -45,9 +46,38 @@ case class Generate( val adjusted = assigned.adjustScale(encType) adjusted } - val resultAbsPrograms = if (checkValid) { - adjustedAbsPrograms.filter(_.isValid) - } else { adjustedAbsPrograms.filterNot(_.isValid) } + + val resultAbsPrograms: LazyList[AbsProgram] = adjustedAbsPrograms + // val resultAbsPrograms: LazyList[AbsProgram] = if (validFilter) { + // adjustedAbsPrograms.filter(_.isValid) + // } else { + // // val numOfValidFilter = 10 + // // val programsWithEquivClasses: LazyList[(AbsProgram, List[Boolean])] = + // // adjustedAbsPrograms.map({ pgm => + // // (pgm, pgm.getInvalidEquivClassList()) + // // }) + // // def filterSequencially( + // // absPrograms: LazyList[(AbsProgram, List[Boolean])], + // // idx: Int, + // // ): LazyList[AbsProgram] = + // // if (absPrograms.isEmpty) + // // LazyList.empty // unreachable + // // else if (idx == numOfValidFilter) filterSequencially(absPrograms, 0) + // // else { + // // val (pgm, equivClassList) = absPrograms.head + // // val equivClass = equivClassList.apply(idx) + // // if (equivClass) + // // pgm #:: filterSequencially(absPrograms.tail, idx + 1) + // // else filterSequencially(absPrograms, idx + 1) + // // } + // // filterSequencially(programsWithEquivClasses, 0) + + // val equivClassIdx = LazyList.from(0) + // adjustedAbsPrograms + // .zip(equivClassIdx) + // .filter { case (pgm, idx) => pgm.invalidEquivClass(idx) } + // .map(_._1) + // } val takenResultAbsPrograms = nOpt match { case Some(n) => resultAbsPrograms.take(n) case None => resultAbsPrograms