Skip to content

Commit

Permalink
LibConfigGenerator: add handling when LibConfigDomain is empty
Browse files Browse the repository at this point in the history
  • Loading branch information
hyerinshelly committed Mar 25, 2024
1 parent 23d6a5c commit ddad75b
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 155 deletions.
17 changes: 0 additions & 17 deletions src/main/scala/fhetest/Generate/AbsProgram.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,6 @@ case class AbsProgram(
case Mul(_, _) | MulP(_, _) => true; case _ => false
}

// TODO: Change these filters to assertions?
// lazy val isValid: Boolean =
// mulDepthIsSmall(mulDepth, encParams.mulDepth) &&
// firstModSizeIsLargest(libConfig.firstModSize, libConfig.scalingModSize) &&
// modSizeIsUpto60bits(libConfig.firstModSize, libConfig.scalingModSize) &&
// openFHEBFVModuli(
// libConfig.scheme,
// libConfig.firstModSize,
// libConfig.scalingModSize,
// ) &&
// ringDimIsPowerOfTwo(encParams.ringDim) &&
// plainModIsPositive(encParams.plainMod) &&
// plainModEnableBatching(encParams.plainMod, encParams.ringDim) &&
// lenIsLessThanRingDim(len, encParams.ringDim, libConfig.scheme) &&
// boundIsLessThanPowerOfModSize(bound, libConfig.firstModSize) &&
// boundIsLessThanPlainMod(bound, encParams.plainMod)

def stringify: String = absStmts.map(_.stringify()).mkString("")

def assignRandValues(): AbsProgram = {
Expand Down
21 changes: 15 additions & 6 deletions src/main/scala/fhetest/Generate/AbsProgramGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,24 @@ case class ExhaustiveGenerator(encType: ENC_TYPE, validFilter: Boolean)
for {
stmt <- allAbsStmts
libConfigGen <- libConfigGens
stmts = List(stmt)
libConfigOpt = libConfigGen(stmts)
if libConfigOpt.isDefined
} yield {
val stmts = List(stmt)
AbsProgram(stmts, libConfigGen(stmts))
val libConfig = libConfigOpt.get
AbsProgram(stmts, libConfig)
}
case _ =>
for {
stmt <- allAbsStmts
program <- allAbsProgramsOfSize(n - 1)
libConfigGen <- libConfigGens
stmts = stmt :: program.absStmts
libConfigOpt = libConfigGen(stmts)
if libConfigOpt.isDefined
} yield {
val stmts = stmt :: program.absStmts
AbsProgram(stmts, libConfigGen(stmts))
val libConfig = libConfigOpt.get
AbsProgram(stmts, libConfig)
}
}
LazyList.from(1).flatMap(allAbsProgramsOfSize)
Expand All @@ -72,9 +78,12 @@ case class RandomGenerator(encType: ENC_TYPE, validFilter: Boolean)
for {
len <- randomLength
libConfigGen <- libConfigGens
stmts = randomAbsStmtsOfSize(len)
libConfigOpt = libConfigGen(stmts)
if libConfigOpt.isDefined
} yield {
val stmts = randomAbsStmtsOfSize(len)
AbsProgram(stmts, libConfigGen(stmts))
val libConfig = libConfigOpt.get
AbsProgram(stmts, libConfig)
}
}
}
208 changes: 115 additions & 93 deletions src/main/scala/fhetest/Generate/LibConfigGenerator.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package fhetest.Generate

import fhetest.Utils.*
import scala.util.Random
import fhetest.Generate.Utils.combinations
import scala.util.Random
import scala.util.control.Breaks._

val ringDimCandidates: List[Int] = // also in ValidFilter
List(8192, 16384, 32768)
Expand Down Expand Up @@ -38,7 +39,7 @@ def getLibConfigUniverse(scheme: Scheme) = LibConfigDomain(
)

trait LibConfigGenerator(encType: ENC_TYPE) {
def getLibConfigGenerators(): LazyList[List[AbsStmt] => LibConfig]
def getLibConfigGenerators(): LazyList[List[AbsStmt] => Option[LibConfig]]
val validFilters = classOf[ValidFilter].getDeclaredClasses.toList
.filter { cls =>
classOf[ValidFilter]
Expand All @@ -48,7 +49,7 @@ trait LibConfigGenerator(encType: ENC_TYPE) {

case class ValidLibConfigGenerator(encType: ENC_TYPE)
extends LibConfigGenerator(encType) {
def getLibConfigGenerators(): LazyList[List[AbsStmt] => LibConfig] = {
def getLibConfigGenerators(): LazyList[List[AbsStmt] => Option[LibConfig]] = {
val libConfigGeneratorFromAbsStmts = (absStmts: List[AbsStmt]) => {
val randomScheme =
if encType == ENC_TYPE.ENC_INT then Scheme.values(Random.nextInt(2))
Expand Down Expand Up @@ -83,108 +84,129 @@ case class InvalidLibConfigGenerator(encType: ENC_TYPE)
// TODO: currently generate only 1 test case for each class
// val numOfTC = 10
val allCombinations_lazy = LazyList.from(allCombinations)
def getLibConfigGenerators(): LazyList[List[AbsStmt] => LibConfig] = for {
combination <- allCombinations_lazy
} yield {
val libConfigGeneratorFromAbsStmts = (absStmts: List[AbsStmt]) => {
val randomScheme =
if encType == ENC_TYPE.ENC_INT then Scheme.values(Random.nextInt(2))
else Scheme.CKKS
val libConfigUniverse = getLibConfigUniverse(randomScheme)
val filteredLibConfigDomain = validFilters.foldLeft(libConfigUniverse)({
case (curLibConfigDomain, curValidFilter) => {
val curValidFilterIdx = validFilters.indexOf(curValidFilter)
val inInValid = combination.contains(curValidFilterIdx)
val constructor = curValidFilter.getDeclaredConstructors.head
constructor.setAccessible(true)
val f = constructor
.newInstance(curLibConfigDomain, !inInValid)
.asInstanceOf[ValidFilter]
f.getFilteredLibConfigDomain()
def getLibConfigGenerators(): LazyList[List[AbsStmt] => Option[LibConfig]] =
for {
combination <- allCombinations_lazy
} yield {
println(combination)
val libConfigGeneratorFromAbsStmts = (absStmts: List[AbsStmt]) => {
val randomScheme =
if encType == ENC_TYPE.ENC_INT then Scheme.values(Random.nextInt(2))
else Scheme.CKKS
val libConfigUniverse = getLibConfigUniverse(randomScheme)
val filteredLibConfigDomain = validFilters.foldLeft(libConfigUniverse)({
case (curLibConfigDomain, curValidFilter) => {
val curValidFilterIdx = validFilters.indexOf(curValidFilter)
val inInValid = combination.contains(curValidFilterIdx)
val constructor = curValidFilter.getDeclaredConstructors.head
constructor.setAccessible(true)
val f = constructor
.newInstance(curLibConfigDomain, !inInValid)
.asInstanceOf[ValidFilter]
f.getFilteredLibConfigDomain()
}
})
val res = randomLibConfigFromDomain(
false,
absStmts,
randomScheme,
filteredLibConfigDomain,
)
res match {
case None => println("NO DOMAIN")
case Some(_) => ()
}
})
randomLibConfigFromDomain(
false,
absStmts,
randomScheme,
filteredLibConfigDomain,
)
res
}
libConfigGeneratorFromAbsStmts
}
libConfigGeneratorFromAbsStmts
}
}

// TODO: No handling for empty domain
def randomLibConfigFromDomain(
validFilter: Boolean,
absStmts: List[AbsStmt],
randomScheme: Scheme,
filteredLibConfigDomain: LibConfigDomain,
): LibConfig = {
val randomRingDim = Random.shuffle(filteredLibConfigDomain.ringDim).head
val randomMulDepth = {
val realMulDepth: Int = absStmts.count {
case Mul(_, _) | MulP(_, _) => true; case _ => false
): Option[LibConfig] = {
var result: Option[LibConfig] = None
breakable {
def getRandomElementOrBreak[T](list: List[T]): T = {
val elem =
if (list.nonEmpty) Some(Random.shuffle(list).head)
else None
elem getOrElse { break }
}
Random.shuffle((filteredLibConfigDomain.mulDepth)(realMulDepth)).head
}
val randomPlainMod =
Random.shuffle((filteredLibConfigDomain.plainMod)(randomRingDim)).head
val randomFirstModSize =
Random
.shuffle((filteredLibConfigDomain.firstModSize)(randomScheme))
.head
val randomScalingModSize = Random
.shuffle(
(filteredLibConfigDomain.scalingModSize)(randomScheme)(
val randomRingDim = getRandomElementOrBreak(filteredLibConfigDomain.ringDim)
val randomMulDepth = {
val realMulDepth: Int = absStmts.count {
case Mul(_, _) | MulP(_, _) => true; case _ => false
}
println(s"realMulDepth: $realMulDepth")
getRandomElementOrBreak(
(filteredLibConfigDomain.mulDepth)(realMulDepth),
)
}
val randomPlainMod = getRandomElementOrBreak(
(filteredLibConfigDomain.plainMod)(randomRingDim),
)
val randomFirstModSize = getRandomElementOrBreak(
(filteredLibConfigDomain.firstModSize)(randomScheme),
)
val randomScalingModSize = getRandomElementOrBreak(
(filteredLibConfigDomain.scalingModSize)(randomScheme)(randomFirstModSize),
)
val randomSecurityLevel =
getRandomElementOrBreak(filteredLibConfigDomain.securityLevel)
val randomScalingTechnique = getRandomElementOrBreak(
(filteredLibConfigDomain.scalingTechnique)(randomScheme),
)
val randomLenOpt: Option[Int] = {
val upper =
(filteredLibConfigDomain.lenMax)(randomScheme)(randomRingDim)
val lower =
(filteredLibConfigDomain.lenMin)(randomScheme)(randomRingDim)
if (lower > upper) break
else Some(Random.between(lower, upper + 1))
}
val randomBoundOpt: Option[Int | Double] = {
val upper = (filteredLibConfigDomain.boundMax)(randomScheme)(
randomPlainMod,
)(randomFirstModSize)
val lower = (filteredLibConfigDomain.boundMin)(randomScheme)(
randomPlainMod,
)(randomFirstModSize)
lower match {
case li: Int =>
upper match {
case ui: Int =>
if (li > ui) break else Some(Random.between(li, ui + 1))
case _ => Some(Random.between(1, 100000 + 1)) // unreachable
}
case ld: Double =>
upper match {
case ud: Int =>
if (ld > ud) break else Some(Random.between(ld, ud))
case _ => Some(Random.between(1, math.pow(2, 64))) // unreachable
}
}
}
val randomRotateBoundOpt: Option[Int] =
val r = getRandomElementOrBreak(filteredLibConfigDomain.rotateBound)
Some(r)

result = Some(
LibConfig(
randomScheme,
EncParams(randomRingDim, randomMulDepth, randomPlainMod),
randomFirstModSize,
randomScalingModSize,
randomSecurityLevel,
randomScalingTechnique,
randomLenOpt,
randomBoundOpt,
randomRotateBoundOpt,
),
)
.head
val randomSecurityLevel =
Random.shuffle(filteredLibConfigDomain.securityLevel).head
val randomScalingTechnique = Random
.shuffle((filteredLibConfigDomain.scalingTechnique)(randomScheme))
.head
val randomLenOpt: Option[Int] = {
val upper =
(filteredLibConfigDomain.lenMax)(randomScheme)(randomRingDim)
val lower =
(filteredLibConfigDomain.lenMin)(randomScheme)(randomRingDim)
Some(Random.between(lower, upper + 1))
}
val randomBoundOpt: Option[Int | Double] = {
val upper = (filteredLibConfigDomain.boundMax)(randomScheme)(
randomPlainMod,
)(randomFirstModSize)
val lower = (filteredLibConfigDomain.boundMin)(randomScheme)(
randomPlainMod,
)(randomFirstModSize)
lower match {
case li: Int =>
upper match {
case ui: Int => Some(Random.between(li, ui + 1))
case _ => Some(Random.between(1, 100000 + 1)) // unreachable
}
case ld: Double =>
upper match {
case ud: Int => Some(Random.between(ld, ud))
case _ => Some(Random.between(1, math.pow(2, 64))) // unreachable
}
}
}
val randomRotateBoundOpt: Option[Int] =
Some(Random.shuffle(filteredLibConfigDomain.rotateBound).head)

LibConfig(
randomScheme,
EncParams(randomRingDim, randomMulDepth, randomPlainMod),
randomFirstModSize,
randomScalingModSize,
randomSecurityLevel,
randomScalingTechnique,
randomLenOpt,
randomBoundOpt,
randomRotateBoundOpt,
)
result
}
19 changes: 11 additions & 8 deletions src/main/scala/fhetest/Generate/ValidFilter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ import fhetest.Checker.schemeDecoder
// * defined & used in LibConfigGenerator
// * automatically arranged in alphabetical order
// val validFilters = List(
// FilterBoundIsLessThanPlainMod,
// FilterBoundIsLessThanPowerOfModSize,
// FilterFirstModSizeIsLargest,
// FilterLenIsLessThanRingDim,
// FilterModSizeIsBeteween14And60bits,
// FilterMulDepthIsEnough,
// FilterOpenFHEBFVModuli,
// FilterBoundIsLessThanPlainMod, // 0
// FilterBoundIsLessThanPowerOfModSize, // 1
// FilterFirstModSizeIsLargest, // 2
// FilterLenIsLessThanRingDim, // 3
// FilterModSizeIsBeteween14And60bits, // 4
// FilterMulDepthIsEnough, // 5
// FilterOpenFHEBFVModuli, // 6
// FilterPlainModEnableBatching, /* commented */
// FilterPlainModIsPositive, /* commented */
// FilterRingDimIsPowerOfTwo, /* commented */
// FilterScalingTechniqueByScheme
// FilterScalingTechniqueByScheme // 7
// )

trait ValidFilter(prev: LibConfigDomain, validFilter: Boolean) {
Expand Down Expand Up @@ -62,6 +62,9 @@ object ValidFilter {
)
}

// TODO: There are 2 options for this implementation
// * Current implementation filters scalingModSize which is not greater than firstModSize
// * Another option is to filter firstModSize to be not smaller than scalingModeSize
// def firstModSizeIsLargest(firstModSize: Int, scalingModSize: Int): Boolean =
// scalingModSize <= firstModSize
case class FilterFirstModSizeIsLargest(
Expand Down
31 changes: 0 additions & 31 deletions src/main/scala/fhetest/Phase/Generate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,38 +46,7 @@ case class Generate(
val adjusted = assigned.adjustScale(encType)
adjusted
}

val resultAbsPrograms: LazyList[AbsProgram] = adjustedAbsPrograms
// val resultAbsPrograms: LazyList[AbsProgram] = if (validFilter) {
// adjustedAbsPrograms.filter(_.isValid)
// } else {
// // val numOfValidFilter = 10
// // val programsWithEquivClasses: LazyList[(AbsProgram, List[Boolean])] =
// // adjustedAbsPrograms.map({ pgm =>
// // (pgm, pgm.getInvalidEquivClassList())
// // })
// // def filterSequencially(
// // absPrograms: LazyList[(AbsProgram, List[Boolean])],
// // idx: Int,
// // ): LazyList[AbsProgram] =
// // if (absPrograms.isEmpty)
// // LazyList.empty // unreachable
// // else if (idx == numOfValidFilter) filterSequencially(absPrograms, 0)
// // else {
// // val (pgm, equivClassList) = absPrograms.head
// // val equivClass = equivClassList.apply(idx)
// // if (equivClass)
// // pgm #:: filterSequencially(absPrograms.tail, idx + 1)
// // else filterSequencially(absPrograms, idx + 1)
// // }
// // filterSequencially(programsWithEquivClasses, 0)

// val equivClassIdx = LazyList.from(0)
// adjustedAbsPrograms
// .zip(equivClassIdx)
// .filter { case (pgm, idx) => pgm.invalidEquivClass(idx) }
// .map(_._1)
// }
val takenResultAbsPrograms = nOpt match {
case Some(n) => resultAbsPrograms.take(n)
case None => resultAbsPrograms
Expand Down

0 comments on commit ddad75b

Please sign in to comment.