Skip to content

Commit

Permalink
Add Timeout (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaeho committed Mar 12, 2024
1 parent e53ba27 commit 38ae03f
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 63 deletions.
4 changes: 3 additions & 1 deletion src/main/scala/fhetest/Checker/ExecuteResult.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ case class LibraryError(msg: String) extends ExecuteResult {
override def toString: String = s"LibraryError: $msg"
}
case object ParseError extends ExecuteResult
// case object TimeoutError extends ExecuteResult //TODO: development
case object TimeoutError extends ExecuteResult {
override def toString: String = s"timeout"
}
// case object Throw extends ExecuteResult

case class BackendResultPair(backend: String, result: ExecuteResult)
Expand Down
17 changes: 13 additions & 4 deletions src/main/scala/fhetest/Command.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ case object CmdRun extends BackendCommand("run") {
Some(encParams),
libConfigOpt,
)
val result = Execute(backend)
val result = Execute(backend, config.timeLimit)
print(result)
case None =>
val (ast, _, _) = Parse(fname)
Expand Down Expand Up @@ -155,7 +155,7 @@ case object CmdExecute extends BackendCommand("execute") {
def runJob(config: Config): Unit =
val backend = config.backend.getOrElseThrow("No backend given.")
given DirName = getWorkspaceDir(backend)
val output = Execute(backend)
val output = Execute(backend, config.timeLimit)
println(output)
}

Expand Down Expand Up @@ -189,7 +189,15 @@ case object CmdCheck extends BackendCommand("check") {
val sealVersion = config.sealVersion.getOrElse(SEAL_VERSIONS.head)
val openfheVersion = config.openfheVersion.getOrElse(OPENFHE_VERSIONS.head)
val outputs =
Check(dir, backends, encParamsOpt, toJson, sealVersion, openfheVersion)
Check(
dir,
backends,
encParamsOpt,
toJson,
sealVersion,
openfheVersion,
config.timeLimit,
)
for output <- outputs do {
println(output)
}
Expand Down Expand Up @@ -230,6 +238,7 @@ case object CmdTest extends BackendCommand("test") {
openfheVersion,
config.filter,
config.debug,
config.timeLimit,
)
for (program, output) <- outputs do {
println("=" * 80)
Expand Down Expand Up @@ -280,7 +289,7 @@ case object CmdReplay extends Command("replay") {
Some(encParams),
Some(libConfig),
)
val result = Execute(backend)
val result = Execute(backend, config.timeLimit)
print(result)
case None =>
val (ast, _, _) = Parse(t2Program)
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/fhetest/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Config(
var filter: Boolean = true,
var silent: Boolean = false,
var debug: Boolean = false,
var timeLimit: Option[Int] = None,
)

object Config {
Expand Down Expand Up @@ -68,6 +69,7 @@ object Config {
case "filter" => config.filter = value.toBoolean
case "silent" => config.silent = value.toBoolean
case "debug" => config.debug = value.toBoolean
case "timeout" => config.timeLimit = Some(value.toInt)
case _ => throw new Error(s"Unknown option: $key")
}
case _ => // 잘못된 형식의 인자 처리
Expand Down
13 changes: 9 additions & 4 deletions src/main/scala/fhetest/Phase/Check.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ case object Check {
program: T2Program,
backends: List[Backend],
encParamsOpt: Option[EncParams],
timeLimit: Option[Int],
): CheckResult = {
val encParams = encParamsOpt match {
case Some(encParams) => encParams
Expand All @@ -42,7 +43,7 @@ case object Check {
val executeResPairs = backends.map(backend =>
BackendResultPair(
backend.toString,
execute(backend, encParams, parsed, program.libConfig),
execute(backend, encParams, parsed, program.libConfig, timeLimit),
),
)
diffResults(interpResPair, executeResPairs, encType, encParams.plainMod)
Expand All @@ -61,6 +62,7 @@ case object Check {
openfheVersion: String,
validCheck: Boolean,
debug: Boolean,
timeLimit: Option[Int],
): LazyList[(T2Program, CheckResult)] = {
setTestDir()
val checkResults: LazyList[Option[(T2Program, CheckResult)]] = for {
Expand All @@ -82,7 +84,7 @@ case object Check {
val executeResPairs = backends.map(backend =>
BackendResultPair(
backend.toString,
execute(backend, encParams, parsed, program.libConfig),
execute(backend, encParams, parsed, program.libConfig, timeLimit),
),
)
val checkResult =
Expand Down Expand Up @@ -134,6 +136,7 @@ case object Check {
toJson: Boolean,
sealVersion: String,
openfheVersion: String,
timeLimit: Option[Int],
): LazyList[String] = {
val dir = new File(directory)
if (dir.exists() && dir.isDirectory) {
Expand All @@ -146,7 +149,7 @@ case object Check {
val fileStr = Files.readAllLines(filePath).asScala.mkString("")
val libConfig = LibConfig() // default libConfig for dir testing
val program = T2Program(fileStr, libConfig)
val checkResult = apply(program, backends, encParamsOpt)
val checkResult = apply(program, backends, encParamsOpt, timeLimit)
if (toJson)
DumpUtil.dumpResult(
program,
Expand Down Expand Up @@ -204,6 +207,7 @@ case object Check {
encParams: EncParams,
parsed: (Goal, SymbolTable, ENC_TYPE),
libConfig: LibConfig,
timeLimit: Option[Int],
): ExecuteResult = {
val (ast, symbolTable, encType) = parsed
withBackendTempDir(
Expand All @@ -220,9 +224,10 @@ case object Check {
libConfigOpt = Some(libConfig),
)
try {
val res = Execute(backend)
val res = Execute(backend, timeLimit)
Normal(res.trim)
} catch {
case _: java.util.concurrent.TimeoutException => TimeoutError
// TODO?: classify exception related with parmeters?
case ex: Exception => LibraryError(ex.getMessage)
}
Expand Down
108 changes: 55 additions & 53 deletions src/main/scala/fhetest/Phase/Execute.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,64 +6,66 @@ import java.io.File
import fhetest.Utils.*

case object Execute {
def apply(backend: Backend)(using workspaceDir: DirName): String = {
val binPath = s"$workspaceDir/bin"
val cmakeCommand = "cmake ."
val makeCommand = "make -j"
val executeCommand = "./test.out"
def apply(backend: Backend, timeLimit: Option[Int])(using
workspaceDir: DirName,
): String = timeout(
{
val binPath = s"$workspaceDir/bin"
val cmakeCommand = "cmake ."
val makeCommand = "make -j"
val executeCommand = "./test.out"

// remove bin directory if it exists
val binDir = new File(binPath)
if (binDir.exists()) {
deleteDirectoryRecursively(binDir)
}
// TODO : Add option silent (default true)
val cmakeProcess =
Process(cmakeCommand, new File(workspaceDir))
val makeProcess =
Process(makeCommand, new File(workspaceDir))
val executeProcess =
Process(executeCommand, binDir)
// remove bin directory if it exists
val binDir = new File(binPath)
if (binDir.exists()) {
deleteDirectoryRecursively(binDir)
}
// TODO : Add option silent (default true)
val cmakeProcess =
Process(cmakeCommand, new File(workspaceDir))
val makeProcess =
Process(makeCommand, new File(workspaceDir))
val executeProcess =
Process(executeCommand, binDir)

val outputSB = new StringBuilder() // To capture standard output
val errorSB = new StringBuilder() // To capture error output
val outputSB = new StringBuilder() // To capture standard output
val errorSB = new StringBuilder() // To capture error output

// Custom ProcessLogger to append output and errors
val processLogger = ProcessLogger(
(o: String) => outputSB.append(o).append("\n"),
(e: String) => errorSB.append(e).append("\n"),
)
val silentLogger = ProcessLogger(_ => (), errorSB.append(_))
// Custom ProcessLogger to append output and errors
val processLogger = ProcessLogger(
(o: String) => outputSB.append(o).append("\n"),
(e: String) => errorSB.append(e).append("\n"),
)
val silentLogger = ProcessLogger(_ => (), errorSB.append(_))

cmakeProcess.!(silentLogger)
val makeExitCode = makeProcess.!(silentLogger)
cmakeProcess.!(silentLogger)
val makeExitCode = makeProcess.!(silentLogger)

if (makeExitCode == 0) {
// Only proceed if make was successful
val executeExitCode = executeProcess.!(processLogger)
if (executeExitCode == 0) {
// If make and execute were successful, return standard output
return outputSB.toString()
} else if (executeExitCode == 139) {
// If program terminated with segmentation fault, return error message
errorSB.append(
"Program terminated with segmentation fault (exit code 139).\n",
)
return errorSB.toString()
} else if (executeExitCode == 136) {
// If program terminated with segmentation fault, return error message
errorSB.append(
"Program terminated with floating point exception (exit code 136).\n",
)
return errorSB.toString()
var result: String = ""
if (makeExitCode == 0) {
val executeExitCode = executeProcess.!(processLogger)
result = executeExitCode match {
case 0 => outputSB.toString()
case 139 =>
errorSB
.append(
"Program terminated with segmentation fault (exit code 139).\n",
)
.toString
case 136 =>
errorSB
.append(
"Program terminated with floating point exception (exit code 136).\n",
)
.toString
case _ => errorSB.toString()
}
} else {
// If execute failed, append error message
return errorSB.toString()
// If make failed, append error message
result = errorSB.toString()
}
} else {
// If make failed, append error message
return errorSB.toString()
}
outputSB.toString()
}
result
},
timeLimit,
)
}
2 changes: 1 addition & 1 deletion src/main/scala/fhetest/Phase/Generate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ case class Generate(
backend,
encParamsOpt = Some(encParams),
)
val result = Execute(backend)
val result = Execute(backend, None)
print(backend.toString + " : " + result)
},
)
Expand Down
13 changes: 13 additions & 0 deletions src/main/scala/fhetest/Utils/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import java.time.format.DateTimeFormatter
import scala.collection.mutable.StringBuilder
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
import scala.concurrent.Await
import scala.concurrent.duration.*
import scala.io.Source
import scala.util.Try
import scala.util.Using
Expand Down Expand Up @@ -295,4 +297,15 @@ def deleteDirectoryRecursively(file: File): Unit = {
file.delete()
}

/** set timeout with optional limitation */
def timeout[T](f: => T, limit: Option[Int]): T =
limit.fold(f)(l => timeout(f, l.second))

/** set timeout with limitation */
def timeout[T](f: => T, limit: Int): T =
timeout(f, limit.seconds)

/** set timeout with duration */
def timeout[T](f: => T, duration: Duration): T =
Await.result(Future(Try(f)), duration).get
//TODO : move this to somewhere else

0 comments on commit 38ae03f

Please sign in to comment.