diff --git a/hail/hail/src/is/hail/backend/Backend.scala b/hail/hail/src/is/hail/backend/Backend.scala index 34a85bfde93..71a0ef4e75e 100644 --- a/hail/hail/src/is/hail/backend/Backend.scala +++ b/hail/hail/src/is/hail/backend/Backend.scala @@ -9,7 +9,6 @@ import is.hail.io.fs.FS import is.hail.types.RTable import is.hail.types.encoded.EType import is.hail.types.physical.PTuple -import is.hail.utils.ExecutionTimer.Timings import is.hail.utils.fatal import scala.reflect.ClassTag @@ -105,7 +104,7 @@ abstract class Backend extends Closeable { def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses) : TableStage - def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) - def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)] + + def backendContext(ctx: ExecuteContext): BackendContext } diff --git a/hail/hail/src/is/hail/backend/BackendRpc.scala b/hail/hail/src/is/hail/backend/BackendRpc.scala index bb2df117ad1..31eca3df798 100644 --- a/hail/hail/src/is/hail/backend/BackendRpc.scala +++ b/hail/hail/src/is/hail/backend/BackendRpc.scala @@ -249,8 +249,4 @@ trait HttpLikeBackendRpc extends BackendRpc { ) } } - - implicit protected def Reader: Routing - implicit protected def Writer: Writer - implicit protected def Context: Context } diff --git a/hail/hail/src/is/hail/backend/BackendServer.scala b/hail/hail/src/is/hail/backend/BackendServer.scala deleted file mode 100644 index 984d24f4709..00000000000 --- a/hail/hail/src/is/hail/backend/BackendServer.scala +++ /dev/null @@ -1,124 +0,0 @@ -package is.hail.backend - -import is.hail.types.virtual.Kinds.{BlockMatrix, Matrix, Table, Value} -import is.hail.utils._ -import is.hail.utils.ExecutionTimer.Timings - -import java.io.Closeable -import java.net.InetSocketAddress -import java.util.concurrent._ - -import com.google.api.client.http.HttpStatusCodes -import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer} -import org.json4s._ -import org.json4s.jackson.{JsonMethods, Serialization} - -class BackendServer(backend: Backend) extends Closeable { - // 0 => let the OS pick an available port - private[this] val httpServer = HttpServer.create(new InetSocketAddress(0), 10) - - private[this] val thread = { - // This HTTP server *must not* start non-daemon threads because such threads keep the JVM - // alive. A living JVM indicates to Spark that the job is incomplete. This does not manifest - // when you run jobs in a local pyspark (because you'll Ctrl-C out of Python regardless of the - // JVM's state) nor does it manifest in a Notebook (again, you'll kill the Notebook kernel - // explicitly regardless of the JVM). It *does* manifest when submitting jobs with - // - // gcloud dataproc submit ... - // - // or - // - // spark-submit - // - // setExecutor(null) ensures the server creates no new threads: - // - // > If this method is not called (before start()) or if it is called with a null Executor, then - // > a default implementation is used, which uses the thread which was created by the start() - // > method. - // - /* Source: - * https://docs.oracle.com/en/java/javase/11/docs/api/jdk.httpserver/com/sun/net/httpserver/HttpServer.html#setExecutor(java.util.concurrent.Executor) */ - // - httpServer.createContext("/", Handler) - httpServer.setExecutor(null) - val t = Executors.defaultThreadFactory().newThread(new Runnable() { - def run(): Unit = - httpServer.start() - }) - t.setDaemon(true) - t - } - - def port: Int = httpServer.getAddress.getPort - - def start(): Unit = - thread.start() - - override def close(): Unit = - httpServer.stop(10) - - private[this] object Handler extends HttpHandler with HttpLikeBackendRpc { - - case class Env(exchange: HttpExchange, payload: JValue) - - override def handle(exchange: HttpExchange): Unit = { - val payload = using(exchange.getRequestBody)(JsonMethods.parse(_)) - runRpc(Env(exchange, payload)) - } - - implicit override object Reader extends Routing { - - import Routes._ - - override def route(env: Env): Route = - env.exchange.getRequestURI.getPath match { - case "/value/type" => TypeOf(Value) - case "/table/type" => TypeOf(Table) - case "/matrixtable/type" => TypeOf(Matrix) - case "/blockmatrix/type" => TypeOf(BlockMatrix) - case "/execute" => Execute - case "/vcf/metadata/parse" => ParseVcfMetadata - case "/fam/import" => ImportFam - case "/references/load" => LoadReferencesFromDataset - case "/references/from_fasta" => LoadReferencesFromFASTA - } - - override def payload(env: Env): JValue = env.payload - } - - implicit override object Writer extends Writer with ErrorHandling { - - override def timings(env: Env, t: Timings): Unit = { - val ts = Serialization.write(Map("timings" -> t)) - env.exchange.getResponseHeaders.add("X-Hail-Timings", ts) - } - - override def result(env: Env, result: Array[Byte]): Unit = - respond(env, HttpStatusCodes.STATUS_CODE_OK, result) - - override def error(env: Env, t: Throwable): Unit = - respond( - env, - HttpStatusCodes.STATUS_CODE_SERVER_ERROR, - jsonToBytes { - val (shortMessage, expandedMessage, errorId) = handleForPython(t) - JObject( - "short" -> JString(shortMessage), - "expanded" -> JString(expandedMessage), - "error_id" -> JInt(errorId), - ) - }, - ) - - private[this] def respond(env: Env, code: Int, payload: Array[Byte]): Unit = { - env.exchange.sendResponseHeaders(code, payload.length) - using(env.exchange.getResponseBody)(_.write(payload)) - } - } - - implicit override protected object Context extends Context { - override def scoped[A](env: Env)(f: ExecuteContext => A): (A, Timings) = - backend.withExecuteContext(f) - } - } -} diff --git a/hail/hail/src/is/hail/backend/ExecuteContext.scala b/hail/hail/src/is/hail/backend/ExecuteContext.scala index ca37f5752e4..b307e16aa75 100644 --- a/hail/hail/src/is/hail/backend/ExecuteContext.scala +++ b/hail/hail/src/is/hail/backend/ExecuteContext.scala @@ -1,6 +1,6 @@ package is.hail.backend -import is.hail.{HailContext, HailFeatureFlags} +import is.hail.HailFeatureFlags import is.hail.annotations.{Region, RegionPool} import is.hail.asm4s.HailClassLoader import is.hail.backend.local.LocalTaskContext @@ -23,7 +23,7 @@ trait TempFileManager extends AutoCloseable { def newTmpPath(tmpdir: String, prefix: String, extension: String = null): String } -class OwningTempFileManager(fs: FS) extends TempFileManager { +class OwningTempFileManager(val fs: FS) extends TempFileManager { private[this] val tmpPaths = mutable.ArrayBuffer[String]() override def newTmpPath(tmpdir: String, prefix: String, extension: String): String = { @@ -55,11 +55,6 @@ object NonOwningTempFileManager { } object ExecuteContext { - def scoped[T](f: ExecuteContext => T)(implicit E: Enclosing): T = - HailContext.sparkBackend.withExecuteContext( - selfContainedExecution = false - )(f) - def scoped[T]( tmpdir: String, localTmpdir: String, @@ -70,7 +65,6 @@ object ExecuteContext { tempFileManager: TempFileManager, theHailClassLoader: HailClassLoader, flags: HailFeatureFlags, - backendContext: BackendContext, irMetadata: IrMetadata, blockMatrixCache: mutable.Map[String, BlockMatrix], codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], @@ -92,7 +86,6 @@ object ExecuteContext { tempFileManager, theHailClassLoader, flags, - backendContext, irMetadata, blockMatrixCache, codeCache, @@ -125,7 +118,6 @@ class ExecuteContext( _tempFileManager: TempFileManager, val theHailClassLoader: HailClassLoader, val flags: HailFeatureFlags, - val backendContext: BackendContext, val irMetadata: IrMetadata, val BlockMatrixCache: mutable.Map[String, BlockMatrix], val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], @@ -196,7 +188,6 @@ class ExecuteContext( tempFileManager: TempFileManager = NonOwningTempFileManager(this.tempFileManager), theHailClassLoader: HailClassLoader = this.theHailClassLoader, flags: HailFeatureFlags = this.flags, - backendContext: BackendContext = this.backendContext, irMetadata: IrMetadata = this.irMetadata, blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache, codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache, @@ -216,7 +207,6 @@ class ExecuteContext( tempFileManager, theHailClassLoader, flags, - backendContext, irMetadata, blockMatrixCache, codeCache, diff --git a/hail/hail/src/is/hail/backend/api/Py4JBackendApi.scala b/hail/hail/src/is/hail/backend/api/Py4JBackendApi.scala new file mode 100644 index 00000000000..584d7c72931 --- /dev/null +++ b/hail/hail/src/is/hail/backend/api/Py4JBackendApi.scala @@ -0,0 +1,417 @@ +package is.hail.backend.api + +import is.hail.{linalg, HailFeatureFlags} +import is.hail.asm4s.HailClassLoader +import is.hail.backend._ +import is.hail.backend.spark.SparkBackend +import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex} +import is.hail.expr.ir._ +import is.hail.expr.ir.IRParser.parseType +import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer +import is.hail.expr.ir.defs.{EncodedLiteral, GetFieldByIdx} +import is.hail.expr.ir.functions.IRFunctionRegistry +import is.hail.expr.ir.lowering.IrMetadata +import is.hail.io.fs._ +import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} +import is.hail.types.physical.PStruct +import is.hail.types.virtual.{TArray, TInterval} +import is.hail.types.virtual.Kinds.{BlockMatrix, Matrix, Table, Value} +import is.hail.utils._ +import is.hail.utils.ExecutionTimer.Timings +import is.hail.variant.ReferenceGenome + +import scala.annotation.nowarn +import scala.collection.mutable +import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConverter} + +import java.io.Closeable +import java.net.InetSocketAddress +import java.util + +import com.google.api.client.http.HttpStatusCodes +import com.sun.net.httpserver.{HttpExchange, HttpServer} +import org.apache.hadoop.conf.Configuration +import org.apache.spark.sql.DataFrame +import org.json4s._ +import org.json4s.jackson.{JsonMethods, Serialization} +import sourcecode.Enclosing + +final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandling { + + private[this] val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() + private[this] val hcl = new HailClassLoader(getClass.getClassLoader) + private[this] val references = mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*) + private[this] val bmCache = mutable.Map[String, linalg.BlockMatrix]() + private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) + private[this] val persistedIr = mutable.Map[Int, BaseIR]() + private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32) + private[this] var irID: Int = 0 + + private[this] var tmpdir: String = _ + private[this] var localTmpdir: String = _ + + private[this] var tmpFileManager = new OwningTempFileManager( + newFs(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) + ) + + def pyFs: FS = + synchronized(tmpFileManager.fs) + + def pyGetFlag(name: String): String = + synchronized(flags.get(name)) + + def pySetFlag(name: String, value: String): Unit = + synchronized(flags.set(name, value)) + + def pyAvailableFlags: java.util.ArrayList[String] = + flags.available + + def pySetRemoteTmp(tmp: String): Unit = + synchronized { tmpdir = tmp } + + def pySetLocalTmp(tmp: String): Unit = + synchronized { localTmpdir = tmp } + + def pySetGcsRequesterPaysConfig(project: String, buckets: util.List[String]): Unit = + synchronized { + tmpFileManager.close() + + val cloudfsConf = CloudStorageFSConfig.fromFlagsAndEnv(None, flags) + + val rpConfig: Option[RequesterPaysConfig] = + ( + Option(project).filter(_.nonEmpty), + Option(buckets).map(_.asScala.toSet.filterNot(_.isBlank)).filter(_.nonEmpty), + ) match { + case (Some(project), buckets) => Some(RequesterPaysConfig(project, buckets)) + case (None, Some(_)) => fatal( + "A non-empty, non-null requester pays google project is required to configure requester pays buckets." + ) + case (None, None) => None + } + + val fs = newFs( + cloudfsConf.copy( + google = (cloudfsConf.google, rpConfig) match { + case (Some(gconf), _) => Some(gconf.copy(requester_pays_config = rpConfig)) + case (None, Some(_)) => Some(GoogleStorageFSConfig(None, rpConfig)) + case _ => None + } + ) + ) + + tmpFileManager = new OwningTempFileManager(fs) + } + + def pyRemoveJavaIR(id: Int): Unit = + synchronized(persistedIr.remove(id)) + + def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = + synchronized { + val seq = IndexedFastaSequenceFile(tmpFileManager.fs, fastaFile, indexFile) + references(name).addSequence(seq) + } + + def pyRemoveSequence(name: String): Unit = + synchronized(references(name).removeSequence()) + + def pyExportBlockMatrix( + pathIn: String, + pathOut: String, + delimiter: String, + header: String, + addIndex: Boolean, + exportType: String, + partitionSize: java.lang.Integer, + entries: String, + ): Unit = + withExecuteContext() { ctx => + val rm = linalg.RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize) + entries match { + case "full" => + rm.export(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + case "lower" => + rm.exportLowerTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + case "strict_lower" => + rm.exportStrictLowerTriangle( + ctx, + pathOut, + delimiter, + Option(header), + addIndex, + exportType, + ) + case "upper" => + rm.exportUpperTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + case "strict_upper" => + rm.exportStrictUpperTriangle( + ctx, + pathOut, + delimiter, + Option(header), + addIndex, + exportType, + ) + } + } + + def pyRegisterIR( + name: String, + typeParamStrs: java.util.ArrayList[String], + argNameStrs: java.util.ArrayList[String], + argTypeStrs: java.util.ArrayList[String], + returnType: String, + bodyStr: String, + ): Unit = + withExecuteContext() { ctx => + IRFunctionRegistry.registerIR( + ctx, + name, + typeParamStrs.asScala.toArray, + argNameStrs.asScala.toArray, + argTypeStrs.asScala.toArray, + returnType, + bodyStr, + ) + } + + def pyExecuteLiteral(irStr: String): Int = + withExecuteContext() { ctx => + val ir = IRParser.parse_value_ir(ctx, irStr) + assert(ir.typ.isRealizable) + backend.execute(ctx, ir) match { + case Left(_) => throw new HailException("Can't create literal") + case Right((pt, addr)) => + val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) + addJavaIR(ctx, field) + } + }._1 + + def pyFromDF(df: DataFrame, jKey: java.util.List[String]): (Int, String) = + withExecuteContext(selfContainedExecution = false) { ctx => + val key = jKey.asScala.toArray.toFastSeq + val signature = + SparkAnnotationImpex.importType(df.schema).setRequired(true).asInstanceOf[PStruct] + val tir = TableLiteral( + TableValue( + ctx, + signature.virtualType, + key, + df.rdd, + Some(signature), + ), + ctx.theHailClassLoader, + ) + val id = addJavaIR(ctx, tir) + (id, JsonMethods.compact(tir.typ.toJSON)) + }._1 + + def pyToDF(s: String): DataFrame = + withExecuteContext() { ctx => + val tir = IRParser.parse_table_ir(ctx, s) + Interpret(tir, ctx).toDF() + }._1 + + def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] = + withExecuteContext() { ctx => + implicit val fmts: Formats = DefaultFormats + log.info("pyReadMultipleMatrixTables: got query") + + val kvs = JsonMethods.parse(jsonQuery).extract[Map[String, JValue]] + val paths = kvs("paths").extract[IndexedSeq[String]] + val intervalPointType = parseType(kvs("intervalPointType").extract[String]) + val intervalObjects = + JSONAnnotationImpex.importAnnotation(kvs("intervals"), TArray(TInterval(intervalPointType))) + .asInstanceOf[IndexedSeq[Interval]] + + val opts = NativeReaderOptions(intervalObjects, intervalPointType) + val matrixReaders: util.List[MatrixIR] = + paths.map { p => + log.info(s"creating MatrixRead node for $p") + val mnr = MatrixNativeReader(ctx.fs, p, Some(opts)) + MatrixRead(mnr.fullMatrixTypeWithoutUIDs, false, false, mnr): MatrixIR + }.asJava + + log.info("pyReadMultipleMatrixTables: returning N matrix tables") + matrixReaders + }._1 + + def pyAddReference(jsonConfig: String): Unit = + synchronized(addReference(ReferenceGenome.fromJSON(jsonConfig))) + + def pyRemoveReference(name: String): Unit = + synchronized(removeReference(name)) + + def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = + synchronized { + references(name).addLiftover( + references(destRGName), + LiftOver(tmpFileManager.fs, chainFile), + ) + } + + def pyRemoveLiftover(name: String, destRGName: String): Unit = + synchronized(references(name).removeLiftover(destRGName)) + + def parse_blockmatrix_ir(s: String): BlockMatrixIR = + withExecuteContext(selfContainedExecution = false) { ctx => + IRParser.parse_blockmatrix_ir(ctx, s) + }._1 + + private[this] def addReference(rg: ReferenceGenome): Unit = + ReferenceGenome.addFatalOnCollision(references, FastSeq(rg)) + + override def close(): Unit = + synchronized { + bmCache.values.foreach(_.unpersist()) + codeCache.clear() + persistedIr.clear() + coercerCache.clear() + backend.close() + } + + private[this] def removeReference(name: String): Unit = + references -= name + + private[this] def withExecuteContext[T]( + selfContainedExecution: Boolean = true + )( + f: ExecuteContext => T + )(implicit E: Enclosing + ): (T, Timings) = + synchronized { + ExecutionTimer.time { timer => + ExecuteContext.scoped( + tmpdir = tmpdir, + localTmpdir = localTmpdir, + backend = backend, + references = references.toMap, + fs = tmpFileManager.fs, + timer = timer, + tempFileManager = + if (selfContainedExecution) null + else NonOwningTempFileManager(tmpFileManager), + theHailClassLoader = hcl, + flags = flags, + irMetadata = new IrMetadata(), + blockMatrixCache = bmCache, + codeCache = codeCache, + irCache = persistedIr, + coercerCache = coercerCache, + )(f) + } + } + + private[this] def newFs(cloudfsConfig: CloudStorageFSConfig): FS = + backend match { + case s: SparkBackend => + val conf = new Configuration(s.sc.hadoopConfiguration) + cloudfsConfig.google.flatMap(_.requester_pays_config).foreach { + case RequesterPaysConfig(prj, bkts) => + bkts + .map { buckets => + conf.set("fs.gs.requester.pays.mode", "CUSTOM") + conf.set("fs.gs.requester.pays.project.id", prj) + conf.set("fs.gs.requester.pays.buckets", buckets.mkString(",")) + } + .getOrElse { + conf.set("fs.gs.requester.pays.mode", "AUTO") + conf.set("fs.gs.requester.pays.project.id", prj) + } + } + new HadoopFS(new SerializableHadoopConfiguration(conf)) + + case _ => + RouterFS.buildRoutes(cloudfsConfig) + } + + private[this] def nextIRID(): Int = { + irID += 1 + irID + } + + private[this] def addJavaIR(ctx: ExecuteContext, ir: BaseIR): Int = { + val id = nextIRID() + ctx.IrCache += (id -> ir) + id + } + + def pyHttpServer: HttpLikeBackendRpc with Closeable = + new HttpLikeBackendRpc with Closeable { + + override type Env = HttpExchange + + implicit object Reader extends Routing { + override def route(req: HttpExchange): Route = + req.getRequestURI.getPath match { + case "/value/type" => Routes.TypeOf(Value) + case "/table/type" => Routes.TypeOf(Table) + case "/matrixtable/type" => Routes.TypeOf(Matrix) + case "/blockmatrix/type" => Routes.TypeOf(BlockMatrix) + case "/execute" => Routes.Execute + case "/vcf/metadata/parse" => Routes.ParseVcfMetadata + case "/fam/import" => Routes.ImportFam + case "/references/load" => Routes.LoadReferencesFromDataset + case "/references/from_fasta" => Routes.LoadReferencesFromFASTA + } + + override def payload(req: HttpExchange): JValue = + using(req.getRequestBody)(JsonMethods.parse(_)) + } + + implicit object Writer extends Writer { + override def timings(req: HttpExchange, t: Timings): Unit = { + val ts = Serialization.write(Map("timings" -> t)) + req.getResponseHeaders.add("X-Hail-Timings", ts) + } + + override def result(req: HttpExchange, result: Array[Byte]): Unit = + respond(req, HttpStatusCodes.STATUS_CODE_OK, result) + + override def error(req: HttpExchange, t: Throwable): Unit = + respond( + req, + HttpStatusCodes.STATUS_CODE_SERVER_ERROR, + jsonToBytes { + val (shortMessage, expandedMessage, errorId) = handleForPython(t) + JObject( + "short" -> JString(shortMessage), + "expanded" -> JString(expandedMessage), + "error_id" -> JInt(errorId), + ) + }, + ) + + private[this] def respond(req: HttpExchange, code: Int, payload: Array[Byte]): Unit = { + req.sendResponseHeaders(code, payload.length) + using(req.getResponseBody)(_.write(payload)) + } + } + + implicit object Context extends Context { + override def scoped[A](req: HttpExchange)(f: ExecuteContext => A): (A, Timings) = + withExecuteContext()(f) + } + + // 0 => let the OS pick an available port + private[this] val httpServer = HttpServer.create(new InetSocketAddress(0), 10) + + @nowarn def port: Int = httpServer.getAddress.getPort + override def close(): Unit = httpServer.stop(10) + + // This HTTP server *must not* start non-daemon threads because such threads keep the JVM + // alive. A living JVM indicates to Spark that the job is incomplete. This does not manifest + // when you run jobs in a local pyspark (because you'll Ctrl-C out of Python regardless of + // the JVM's state) nor does it manifest in a Notebook (again, you'll kill the Notebook kernel + // explicitly regardless of the JVM). It *does* manifest when submitting jobs with + // + // gcloud dataproc submit ... + // + // or + // + // spark-submit + httpServer.createContext("/", runRpc(_: HttpExchange)) + httpServer.setExecutor(null) // ensures the server creates no new threads + httpServer.start() + } +} diff --git a/hail/hail/src/is/hail/backend/api/ServiceBackendApi.scala b/hail/hail/src/is/hail/backend/api/ServiceBackendApi.scala new file mode 100644 index 00000000000..d663b84ab26 --- /dev/null +++ b/hail/hail/src/is/hail/backend/api/ServiceBackendApi.scala @@ -0,0 +1,271 @@ +package is.hail.backend.api + +import is.hail.{HailContext, HailFeatureFlags} +import is.hail.annotations.Memory +import is.hail.asm4s.HailClassLoader +import is.hail.backend.{Backend, ExecuteContext, HttpLikeBackendRpc} +import is.hail.backend.service._ +import is.hail.expr.ir.lowering.IrMetadata +import is.hail.io.fs.{CloudStorageFSConfig, FS, RouterFS} +import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} +import is.hail.services._ +import is.hail.types.virtual.Kinds._ +import is.hail.utils.{ + toRichIterable, using, ErrorHandling, ExecutionTimer, HailWorkerException, ImmutableMap, Logging, +} +import is.hail.utils.ExecutionTimer.Timings +import is.hail.variant.ReferenceGenome + +import scala.annotation.switch + +import java.io.OutputStream +import java.nio.charset.StandardCharsets +import java.nio.file.Path + +import org.json4s.JsonAST.JValue +import org.json4s.jackson.JsonMethods + +object ServiceBackendApi extends HttpLikeBackendRpc with Logging { + + case class Env( + backend: Backend, + flags: HailFeatureFlags, + hcl: HailClassLoader, + rpcConfig: ServiceBackendRPCPayload, + fs: FS, + references: Map[String, ReferenceGenome], + outputUrl: String, + action: Int, + payload: JValue, + ) + + implicit object Reader extends Routing { + import Routes._ + + override def route(env: Env): Route = + (env.action: @switch) match { + case 1 => TypeOf(Value) + case 2 => TypeOf(Table) + case 3 => TypeOf(Matrix) + case 4 => TypeOf(BlockMatrix) + case 5 => Execute + case 6 => ParseVcfMetadata + case 7 => ImportFam + case 8 => LoadReferencesFromDataset + case 9 => LoadReferencesFromFASTA + } + + override def payload(env: Env): JValue = env.payload + } + + implicit object Writer extends Writer with ErrorHandling { + // service backend doesn't support sending timings back to the python client + override def timings(env: Env, t: Timings): Unit = + () + + override def result(env: Env, result: Array[Byte]): Unit = + retryTransientErrors { + using(env.fs.createNoCompression(env.outputUrl)) { outputStream => + val output = new HailSocketAPIOutputStream(outputStream) + output.writeBool(true) + output.writeBytes(result) + } + } + + override def error(env: Env, t: Throwable): Unit = + retryTransientErrors { + val (shortMessage, expandedMessage, errorId) = + t match { + case t: HailWorkerException => + log.error( + "A worker failed. The exception was written for Python but we will also throw an exception to fail this driver job.", + t, + ) + (t.shortMessage, t.expandedMessage, t.errorId) + case _ => + log.error( + "An exception occurred in the driver. The exception was written for Python but we will re-throw to fail this driver job.", + t, + ) + handleForPython(t) + } + + using(env.fs.createNoCompression(env.outputUrl)) { outputStream => + val output = new HailSocketAPIOutputStream(outputStream) + output.writeBool(false) + output.writeString(shortMessage) + output.writeString(expandedMessage) + output.writeInt(errorId) + } + + throw t + } + } + + implicit object Context extends Context { + override def scoped[A](env: Env)(f: ExecuteContext => A): (A, Timings) = + ExecutionTimer.time { timer => + ExecuteContext.scoped( + env.rpcConfig.tmp_dir, + env.rpcConfig.tmp_dir, + env.backend, + env.references, + env.fs, + timer, + null, + env.hcl, + env.flags, + new IrMetadata(), + ImmutableMap.empty, + ImmutableMap.empty, + ImmutableMap.empty, + ImmutableMap.empty, + )(f) + } + } + + def main(argv: Array[String]): Unit = { + assert(argv.length == 7, argv.toFastSeq) + + val scratchDir = argv(0) + // val logFile = argv(1) + val jarLocation = argv(2) + val kind = argv(3) + assert(kind == Main.DRIVER) + val name = argv(4) + val inputURL = argv(5) + val outputURL = argv(6) + + val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json") + DeployConfig.set(deployConfig) + sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir) + + var fs = RouterFS.buildRoutes( + CloudStorageFSConfig.fromFlagsAndEnv( + Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), + HailFeatureFlags.fromEnv(), + ) + ) + + val (rpcConfig, jobConfig, action, payload) = + using(fs.openNoCompression(inputURL)) { is => + val input = JsonMethods.parse(is) + ( + (input \ "rpc_config").extract[ServiceBackendRPCPayload], + (input \ "job_config").extract[BatchJobConfig], + (input \ "action").extract[Int], + input \ "payload", + ) + } + + // requester pays config is conveyed in feature flags currently + val featureFlags = HailFeatureFlags.fromEnv(rpcConfig.flags) + fs = RouterFS.buildRoutes( + CloudStorageFSConfig.fromFlagsAndEnv( + Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), + featureFlags, + ) + ) + + val references = + ReferenceGenome.addFatalOnCollision( + ReferenceGenome.builtinReferences(), + rpcConfig.custom_references.map(ReferenceGenome.fromJSON), + ) + + rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => + liftoversForSource.foreach { case (destGenome, chainFile) => + references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile)) + } + } + + rpcConfig.sequences.foreach { case (rg, seq) => + references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index)) + } + + val backend = new ServiceBackend( + name, + BatchClient(deployConfig, Path.of(scratchDir, "secrets/gsa-key/key.json")), + JarUrl(jarLocation), + BatchConfig.fromConfigFile(Path.of(scratchDir, "batch-config/batch-config.json")), + jobConfig, + ) + + log.info("ServiceBackend allocated.") + if (HailContext.isInitialized) { + HailContext.get.backend = backend + log.info("Default references added to already initialized HailContexet.") + } else { + HailContext(backend, 50, 3) + log.info("HailContexet initialized.") + } + + // FIXME: when can the classloader be shared? (optimizer benefits!) + runRpc( + Env( + backend, + featureFlags, + new HailClassLoader(getClass.getClassLoader), + rpcConfig, + fs, + references, + outputURL, + action, + payload, + ) + ) + } +} + +private class HailSocketAPIOutputStream( + private[this] val out: OutputStream +) extends AutoCloseable { + private[this] var closed: Boolean = false + private[this] val dummy = new Array[Byte](8) + + def writeBool(b: Boolean): Unit = + out.write(if (b) 1 else 0) + + def writeInt(v: Int): Unit = { + Memory.storeInt(dummy, 0, v) + out.write(dummy, 0, 4) + } + + def writeLong(v: Long): Unit = { + Memory.storeLong(dummy, 0, v) + out.write(dummy) + } + + def writeBytes(bytes: Array[Byte]): Unit = { + writeInt(bytes.length) + out.write(bytes) + } + + def writeString(s: String): Unit = writeBytes(s.getBytes(StandardCharsets.UTF_8)) + + def close(): Unit = + if (!closed) { + out.close() + closed = true + } +} + +case class SequenceConfig(fasta: String, index: String) + +case class ServiceBackendRPCPayload( + tmp_dir: String, + flags: Map[String, String], + custom_references: Array[String], + liftovers: Map[String, Map[String, String]], + sequences: Map[String, SequenceConfig], +) + +case class BatchJobConfig( + token: String, + billing_project: String, + worker_cores: String, + worker_memory: String, + storage: String, + cloudfuse_configs: Array[CloudfuseConfig], + regions: Array[String], +) diff --git a/hail/hail/src/is/hail/backend/local/LocalBackend.scala b/hail/hail/src/is/hail/backend/local/LocalBackend.scala index 67bdfe350ee..fcc40daab83 100644 --- a/hail/hail/src/is/hail/backend/local/LocalBackend.scala +++ b/hail/hail/src/is/hail/backend/local/LocalBackend.scala @@ -1,13 +1,11 @@ package is.hail.backend.local -import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags} +import is.hail.{CancellingExecutorService, HailContext} import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend._ -import is.hail.backend.py4j.Py4JBackendExtensions import is.hail.expr.Validate import is.hail.expr.ir._ -import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs.MakeTuple @@ -18,18 +16,13 @@ import is.hail.types.physical.PTuple import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual.TVoid import is.hail.utils._ -import is.hail.utils.ExecutionTimer.Timings -import is.hail.variant.ReferenceGenome -import scala.collection.mutable import scala.reflect.ClassTag import java.io.PrintWriter import com.fasterxml.jackson.core.StreamReadConstraints import com.google.common.util.concurrent.MoreExecutors -import org.apache.hadoop -import sourcecode.Enclosing class LocalBroadcastValue[T](val value: T) extends BroadcastValue[T] with Serializable @@ -37,8 +30,7 @@ class LocalTaskContext(val partitionId: Int, val stageId: Int) extends HailTaskC override def attemptNumber(): Int = 0 } -object LocalBackend { - private var theLocalBackend: LocalBackend = _ +object LocalBackend extends Backend { // From https://github.com/hail-is/hail/issues/14580 : // IR can get quite big, especially as it can contain an arbitrary @@ -53,91 +45,32 @@ object LocalBackend { ) def apply( - tmpdir: String, logFile: String = "hail.log", quiet: Boolean = false, append: Boolean = false, skipLoggingConfiguration: Boolean = false, - ): LocalBackend = synchronized { - require(theLocalBackend == null) - - if (!skipLoggingConfiguration) - HailContext.configureLogging(logFile, quiet, append) - - theLocalBackend = new LocalBackend( - tmpdir, - mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*), - ) - - theLocalBackend - } - - def stop(): Unit = synchronized { - if (theLocalBackend != null) { - theLocalBackend = null - // Hadoop does not honor the hadoop configuration as a component of the cache key for file - // systems, so we blow away the cache so that a new configuration can successfully take - // effect. - // https://github.com/hail-is/hail/pull/12133#issuecomment-1241322443 - hadoop.fs.FileSystem.closeAll() + ): LocalBackend.type = + synchronized { + if (!skipLoggingConfiguration) HailContext.configureLogging(logFile, quiet, append) + this } - } -} -class LocalBackend( - val tmpdir: String, - override val references: mutable.Map[String, ReferenceGenome], -) extends Backend with Py4JBackendExtensions { - - override def backend: Backend = this - override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() - override def longLifeTempFileManager: TempFileManager = null - - private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader) - private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) - private[this] val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map() - private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32) - - // flags can be set after construction from python - def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) - - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) = - ExecutionTimer.time { timer => - val fs = this.fs - ExecuteContext.scoped( - tmpdir, - tmpdir, - this, - references.toMap, - fs, - timer, - null, - theHailClassLoader, - flags, - new BackendContext { - override val executionCache: ExecutionCache = - ExecutionCache.fromFlags(flags, fs, tmpdir) - }, - new IrMetadata(), - ImmutableMap.empty, - codeCache, - persistedIR, - coercerCache, - )(f) - } + private case class Context(hcl: HailClassLoader, override val executionCache: ExecutionCache) + extends BackendContext def broadcast[T: ClassTag](value: T): BroadcastValue[T] = new LocalBroadcastValue[T](value) private[this] var stageIdx: Int = 0 - private[this] def nextStageId(): Int = { - val current = stageIdx - stageIdx += 1 - current - } + private[this] def nextStageId(): Int = + synchronized { + val current = stageIdx + stageIdx += 1 + current + } override def parallelizeAndComputeWithIndex( - backendContext: BackendContext, + ctx: BackendContext, fs: FS, contexts: IndexedSeq[Array[Byte]], stageIdentifier: String, @@ -148,22 +81,24 @@ class LocalBackend( ): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = { val stageId = nextStageId() + val hcl = ctx.asInstanceOf[Context].hcl runAllKeepFirstError(new CancellingExecutorService(MoreExecutors.newDirectExecutorService())) { partitions.getOrElse(contexts.indices).map { i => ( - () => - using(new LocalTaskContext(i, stageId)) { - f(contexts(i), _, theHailClassLoader, fs) - }, + () => using(new LocalTaskContext(i, stageId))(f(contexts(i), _, hcl, fs)), i, ) } } } + override def backendContext(ctx: ExecuteContext): BackendContext = + Context(ctx.theHailClassLoader, ExecutionCache.fromFlags(ctx.flags, ctx.fs, ctx.tmpdir)) + def defaultParallelism: Int = 1 - def close(): Unit = LocalBackend.stop() + def close(): Unit = + synchronized { stageIdx = 0 } private[this] def _jvmLowerAndExecute( ctx: ExecuteContext, @@ -171,8 +106,7 @@ class LocalBackend( print: Option[PrintWriter] = None, ): Either[Unit, (PTuple, Long)] = ctx.time { - val ir = - LoweringPipeline.darrayLowerer(true)(DArrayLowering.All).apply(ctx, ir0).asInstanceOf[IR] + val ir = LoweringPipeline.darrayLowerer(true)(DArrayLowering.All)(ctx, ir0).asInstanceOf[IR] if (!Compilable(ir)) throw new LowererUnsupportedOperation(s"lowered to uncompilable IR: ${Pretty(ctx, ir)}") diff --git a/hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala b/hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala deleted file mode 100644 index 16b570cb9df..00000000000 --- a/hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala +++ /dev/null @@ -1,229 +0,0 @@ -package is.hail.backend.py4j - -import is.hail.HailFeatureFlags -import is.hail.backend.{Backend, ExecuteContext, NonOwningTempFileManager, TempFileManager} -import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex} -import is.hail.expr.ir.{ - BaseIR, BlockMatrixIR, IRParser, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, - NativeReaderOptions, TableLiteral, TableValue, -} -import is.hail.expr.ir.IRParser.parseType -import is.hail.expr.ir.defs.{EncodedLiteral, GetFieldByIdx} -import is.hail.expr.ir.functions.IRFunctionRegistry -import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} -import is.hail.linalg.RowMatrix -import is.hail.types.physical.PStruct -import is.hail.types.virtual.{TArray, TInterval} -import is.hail.utils.{log, toRichIterable, FastSeq, HailException, Interval} -import is.hail.variant.ReferenceGenome - -import scala.collection.mutable -import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConverter} - -import java.util - -import org.apache.spark.sql.DataFrame -import org.json4s -import org.json4s.jackson.JsonMethods -import sourcecode.Enclosing - -trait Py4JBackendExtensions { - def backend: Backend - def references: mutable.Map[String, ReferenceGenome] - def flags: HailFeatureFlags - def longLifeTempFileManager: TempFileManager - - def pyGetFlag(name: String): String = - flags.get(name) - - def pySetFlag(name: String, value: String): Unit = - flags.set(name, value) - - def pyAvailableFlags: java.util.ArrayList[String] = - flags.available - - private[this] var irID: Int = 0 - - private[this] def nextIRID(): Int = { - irID += 1 - irID - } - - private[this] def addJavaIR(ctx: ExecuteContext, ir: BaseIR): Int = { - val id = nextIRID() - ctx.IrCache += (id -> ir) - id - } - - def pyRemoveJavaIR(id: Int): Unit = - backend.withExecuteContext(_.IrCache.remove(id)) - - def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = - backend.withExecuteContext { ctx => - references(name).addSequence(IndexedFastaSequenceFile(ctx.fs, fastaFile, indexFile)) - } - - def pyRemoveSequence(name: String): Unit = - references(name).removeSequence() - - def pyExportBlockMatrix( - pathIn: String, - pathOut: String, - delimiter: String, - header: String, - addIndex: Boolean, - exportType: String, - partitionSize: java.lang.Integer, - entries: String, - ): Unit = - backend.withExecuteContext { ctx => - val rm = RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize) - entries match { - case "full" => - rm.export(ctx, pathOut, delimiter, Option(header), addIndex, exportType) - case "lower" => - rm.exportLowerTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) - case "strict_lower" => - rm.exportStrictLowerTriangle( - ctx, - pathOut, - delimiter, - Option(header), - addIndex, - exportType, - ) - case "upper" => - rm.exportUpperTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) - case "strict_upper" => - rm.exportStrictUpperTriangle( - ctx, - pathOut, - delimiter, - Option(header), - addIndex, - exportType, - ) - } - } - - def pyRegisterIR( - name: String, - typeParamStrs: java.util.ArrayList[String], - argNameStrs: java.util.ArrayList[String], - argTypeStrs: java.util.ArrayList[String], - returnType: String, - bodyStr: String, - ): Unit = - backend.withExecuteContext { ctx => - IRFunctionRegistry.registerIR( - ctx, - name, - typeParamStrs.asScala.toArray, - argNameStrs.asScala.toArray, - argTypeStrs.asScala.toArray, - returnType, - bodyStr, - ) - } - - def pyExecuteLiteral(irStr: String): Int = - backend.withExecuteContext { ctx => - val ir = IRParser.parse_value_ir(ctx, irStr) - assert(ir.typ.isRealizable) - backend.execute(ctx, ir) match { - case Left(_) => throw new HailException("Can't create literal") - case Right((pt, addr)) => - val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) - addJavaIR(ctx, field) - } - }._1 - - def pyFromDF(df: DataFrame, jKey: java.util.List[String]): (Int, String) = { - val key = jKey.asScala.toArray.toFastSeq - val signature = - SparkAnnotationImpex.importType(df.schema).setRequired(true).asInstanceOf[PStruct] - withExecuteContext(selfContainedExecution = false) { ctx => - val tir = TableLiteral( - TableValue( - ctx, - signature.virtualType, - key, - df.rdd, - Some(signature), - ), - ctx.theHailClassLoader, - ) - val id = addJavaIR(ctx, tir) - (id, JsonMethods.compact(tir.typ.toJSON)) - } - } - - def pyToDF(s: String): DataFrame = - backend.withExecuteContext { ctx => - val tir = IRParser.parse_table_ir(ctx, s) - Interpret(tir, ctx).toDF() - }._1 - - def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] = - backend.withExecuteContext { ctx => - log.info("pyReadMultipleMatrixTables: got query") - val kvs = JsonMethods.parse(jsonQuery) match { - case json4s.JObject(values) => values.toMap - } - - val paths = kvs("paths").asInstanceOf[json4s.JArray].arr.toArray.map { - case json4s.JString(s) => s - } - - val intervalPointType = parseType(kvs("intervalPointType").asInstanceOf[json4s.JString].s) - val intervalObjects = - JSONAnnotationImpex.importAnnotation(kvs("intervals"), TArray(TInterval(intervalPointType))) - .asInstanceOf[IndexedSeq[Interval]] - - val opts = NativeReaderOptions(intervalObjects, intervalPointType) - val matrixReaders: IndexedSeq[MatrixIR] = paths.map { p => - log.info(s"creating MatrixRead node for $p") - val mnr = MatrixNativeReader(ctx.fs, p, Some(opts)) - MatrixRead(mnr.fullMatrixTypeWithoutUIDs, false, false, mnr): MatrixIR - } - log.info("pyReadMultipleMatrixTables: returning N matrix tables") - matrixReaders.asJava - }._1 - - def pyAddReference(jsonConfig: String): Unit = - addReference(ReferenceGenome.fromJSON(jsonConfig)) - - def pyRemoveReference(name: String): Unit = - removeReference(name) - - def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = - backend.withExecuteContext { ctx => - references(name).addLiftover(references(destRGName), LiftOver(ctx.fs, chainFile)) - } - - def pyRemoveLiftover(name: String, destRGName: String): Unit = - references(name).removeLiftover(destRGName) - - private[this] def addReference(rg: ReferenceGenome): Unit = - ReferenceGenome.addFatalOnCollision(references, FastSeq(rg)) - - private[this] def removeReference(name: String): Unit = - references -= name - - def parse_blockmatrix_ir(s: String): BlockMatrixIR = - withExecuteContext(selfContainedExecution = false) { ctx => - IRParser.parse_blockmatrix_ir(ctx, s) - } - - def withExecuteContext[T]( - selfContainedExecution: Boolean = true - )( - f: ExecuteContext => T - )(implicit E: Enclosing - ): T = - backend.withExecuteContext { ctx => - val tempFileManager = longLifeTempFileManager - if (selfContainedExecution && tempFileManager != null) f(ctx) - else ctx.local(tempFileManager = NonOwningTempFileManager(tempFileManager))(f) - }._1 -} diff --git a/hail/hail/src/is/hail/backend/service/Main.scala b/hail/hail/src/is/hail/backend/service/Main.scala index a31f99e7946..391fefcdaad 100644 --- a/hail/hail/src/is/hail/backend/service/Main.scala +++ b/hail/hail/src/is/hail/backend/service/Main.scala @@ -1,5 +1,7 @@ package is.hail.backend.service +import is.hail.backend.api.ServiceBackendApi + import com.fasterxml.jackson.core.StreamReadConstraints object Main { @@ -18,8 +20,7 @@ object Main { def main(argv: Array[String]): Unit = argv(3) match { case WORKER => Worker.main(argv) - case DRIVER => ServiceBackendAPI.main(argv) - + case DRIVER => ServiceBackendApi.main(argv) // Batch's "JvmJob" is a special kind of job that can only call `Main.main`. // TEST is used for integration testing the `BatchClient` to verify that we // can create JvmJobs without having to mock the payload to a `Worker` job. diff --git a/hail/hail/src/is/hail/backend/service/ServiceBackend.scala b/hail/hail/src/is/hail/backend/service/ServiceBackend.scala index a1283c1fc6c..91f19ec3ea5 100644 --- a/hail/hail/src/is/hail/backend/service/ServiceBackend.scala +++ b/hail/hail/src/is/hail/backend/service/ServiceBackend.scala @@ -1,9 +1,10 @@ package is.hail.backend.service -import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags} +import is.hail.{CancellingExecutorService, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ +import is.hail.backend.api.BatchJobConfig import is.hail.backend.service.ServiceBackend.MaxAvailableGcsConnections import is.hail.expr.Validate import is.hail.expr.ir.{IR, IRSize, LoweringAnalyses, SortField, TableIR, TableReader, TypeCheck} @@ -12,37 +13,20 @@ import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs.MakeTuple import is.hail.expr.ir.lowering._ import is.hail.io.fs._ -import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} import is.hail.services.{BatchClient, JobGroupRequest, _} import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success} import is.hail.types._ import is.hail.types.physical._ import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType -import is.hail.types.virtual.{Kinds, TVoid} +import is.hail.types.virtual.TVoid import is.hail.utils._ -import is.hail.utils.ExecutionTimer.Timings -import is.hail.variant.ReferenceGenome -import scala.annotation.switch -import scala.collection.mutable import scala.reflect.ClassTag import java.io._ import java.nio.charset.StandardCharsets -import java.nio.file.Path import java.util.concurrent._ -import org.json4s.{DefaultFormats, Formats} -import org.json4s.JsonAST._ -import org.json4s.jackson.JsonMethods -import sourcecode.Enclosing - -case class ServiceBackendContext( - remoteTmpDir: String, - jobConfig: BatchJobConfig, - override val executionCache: ExecutionCache, -) extends BackendContext with Serializable - object ServiceBackend { val MaxAvailableGcsConnections = 1000 } @@ -51,17 +35,23 @@ class ServiceBackend( val name: String, batchClient: BatchClient, jarSpec: JarSpec, - theHailClassLoader: HailClassLoader, val batchConfig: BatchConfig, - rpcConfig: ServiceBackendRPCPayload, jobConfig: BatchJobConfig, - flags: HailFeatureFlags, - val fs: FS, - references: mutable.Map[String, ReferenceGenome], ) extends Backend with Logging { + case class Context( + remoteTmpDir: String, + batchConfig: BatchConfig, + jobConfig: BatchJobConfig, + flags: HailFeatureFlags, + override val executionCache: ExecutionCache, + ) extends BackendContext + private[this] var stageCount = 0 - private[this] val executor = Executors.newFixedThreadPool(MaxAvailableGcsConnections) + + private[this] val executor = lazily { + Executors.newFixedThreadPool(MaxAvailableGcsConnections) + } def defaultParallelism: Int = 4 @@ -79,6 +69,15 @@ class ServiceBackend( } } + override def backendContext(ctx: ExecuteContext): BackendContext = + Context( + remoteTmpDir = ctx.tmpdir, + batchConfig = batchConfig, + jobConfig = jobConfig, + flags = ctx.flags, + executionCache = ExecutionCache.fromFlags(ctx.flags, ctx.fs, ctx.tmpdir), + ) + private[this] def readString(in: DataInputStream): String = { val n = in.readInt() val bytes = new Array[Byte](n) @@ -87,6 +86,7 @@ class ServiceBackend( } private[this] def submitJobGroupAndWait( + ctx: Context, collection: IndexedSeq[Array[Byte]], token: String, root: String, @@ -96,7 +96,7 @@ class ServiceBackend( JvmJob( command = null, spec = jarSpec, - profile = flags.get("profile") != null, + profile = ctx.flags.get("profile") != null, ) val defaultJob = @@ -142,7 +142,7 @@ class ServiceBackend( batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId) } - private[this] def readResult(root: String, i: Int): Array[Byte] = { + private[this] def readResult(fs: FS, root: String, i: Int): Array[Byte] = { val bytes = fs.readNoCompression(s"$root/result.$i") if (bytes(0) != 0) { bytes.slice(1, bytes.length) @@ -167,7 +167,7 @@ class ServiceBackend( f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] ): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = { - val backendContext = _backendContext.asInstanceOf[ServiceBackendContext] + val backendContext = _backendContext.asInstanceOf[Context] val token = tokenUrlSafe val root = s"${backendContext.remoteTmpDir}/parallelizeAndComputeWithIndex/$token" @@ -206,13 +206,13 @@ class ServiceBackend( uploadFunction.get() uploadContexts.get() - val jobGroup = submitJobGroupAndWait(parts, token, root, stageIdentifier) + val jobGroup = submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier) log.info(s"parallelizeAndComputeWithIndex: $token: reading results") val startTime = System.nanoTime() var r @ (err, results) = runAllKeepFirstError(new CancellingExecutorService(executor)) { (partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) => - (() => readResult(root, jobIndex), partIdx) + (() => readResult(fs, root, jobIndex), partIdx) } } @@ -241,7 +241,7 @@ class ServiceBackend( } override def close(): Unit = { - executor.shutdownNow() + if (executor.isEvaluated) executor.shutdownNow() batchClient.close() } @@ -300,253 +300,6 @@ class ServiceBackend( def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses) : TableStage = LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses) - - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) = - ExecutionTimer.time { timer => - ExecuteContext.scoped( - rpcConfig.tmp_dir, - rpcConfig.tmp_dir, - this, - references.toMap, - fs, - timer, - null, - theHailClassLoader, - flags, - ServiceBackendContext( - rpcConfig.tmp_dir, - jobConfig, - ExecutionCache.fromFlags(flags, fs, rpcConfig.tmp_dir), - ), - new IrMetadata(), - ImmutableMap.empty, - ImmutableMap.empty, - ImmutableMap.empty, - ImmutableMap.empty, - )(f) - } } -class EndOfInputException extends RuntimeException class HailBatchFailure(message: String) extends RuntimeException(message) - -object ServiceBackendAPI extends HttpLikeBackendRpc with Logging { - - def main(argv: Array[String]): Unit = { - assert(argv.length == 7, argv.toFastSeq) - - val scratchDir = argv(0) - // val logFile = argv(1) - val jarLocation = argv(2) - val kind = argv(3) - assert(kind == Main.DRIVER) - val name = argv(4) - val inputURL = argv(5) - val outputURL = argv(6) - - implicit val fmts: Formats = DefaultFormats - - val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json") - DeployConfig.set(deployConfig) - sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir) - - var fs = RouterFS.buildRoutes( - CloudStorageFSConfig.fromFlagsAndEnv( - Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), - HailFeatureFlags.fromEnv(), - ) - ) - - val (rpcConfig, jobConfig, action, payload) = - using(fs.openNoCompression(inputURL)) { is => - val input = JsonMethods.parse(is) - ( - (input \ "rpc_config").extract[ServiceBackendRPCPayload], - (input \ "job_config").extract[BatchJobConfig], - (input \ "action").extract[Int], - input \ "payload", - ) - } - - // requester pays config is conveyed in feature flags currently - val featureFlags = HailFeatureFlags.fromEnv(rpcConfig.flags) - fs = RouterFS.buildRoutes( - CloudStorageFSConfig.fromFlagsAndEnv( - Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), - featureFlags, - ) - ) - - val references = mutable.Map[String, ReferenceGenome]() - references ++= ReferenceGenome.builtinReferences() - ReferenceGenome.addFatalOnCollision( - references, - rpcConfig.custom_references.map(ReferenceGenome.fromJSON), - ) - - rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => - liftoversForSource.foreach { case (destGenome, chainFile) => - references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile)) - } - } - - rpcConfig.sequences.foreach { case (rg, seq) => - references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index)) - } - - // FIXME: when can the classloader be shared? (optimizer benefits!) - val backend = new ServiceBackend( - name, - BatchClient(deployConfig, Path.of(scratchDir, "secrets/gsa-key/key.json")), - JarUrl(jarLocation), - new HailClassLoader(getClass.getClassLoader), - BatchConfig.fromConfigFile(Path.of(scratchDir, "batch-config/batch-config.json")), - rpcConfig, - jobConfig, - featureFlags, - fs, - references, - ) - - log.info("ServiceBackend allocated.") - if (HailContext.isInitialized) { - HailContext.get.backend = backend - log.info("Default references added to already initialized HailContexet.") - } else { - HailContext(backend, 50, 3) - log.info("HailContexet initialized.") - } - - runRpc(Env(backend, fs, outputURL, action, payload)) - } - - case class Env( - backend: ServiceBackend, - fs: FS, - outputUrl: String, - action: Int, - payload: JValue, - ) - - implicit override object Reader extends Routing { - import Routes._ - - override def route(a: Env): Route = - (a.action: @switch) match { - case 1 => TypeOf(Kinds.Value) - case 2 => TypeOf(Kinds.Table) - case 3 => TypeOf(Kinds.Matrix) - case 4 => TypeOf(Kinds.BlockMatrix) - case 5 => Execute - case 6 => ParseVcfMetadata - case 7 => ImportFam - case 8 => LoadReferencesFromDataset - case 9 => LoadReferencesFromFASTA - } - - override def payload(a: Env): JValue = a.payload - } - - implicit override object Writer extends Writer { - - // service backend doesn't support sending timings back to the python client - override def timings(env: Env, t: Timings): Unit = - () - - override def result(env: Env, result: Array[Byte]): Unit = - retryTransientErrors { - using(env.fs.createNoCompression(env.outputUrl)) { outputStream => - val output = new HailSocketAPIOutputStream(outputStream) - output.writeBool(true) - output.writeBytes(result) - } - } - - override def error(env: Env, t: Throwable): Unit = - retryTransientErrors { - val (shortMessage, expandedMessage, errorId) = - t match { - case t: HailWorkerException => - log.error( - "A worker failed. The exception was written for Python but we will also throw an exception to fail this driver job.", - t, - ) - (t.shortMessage, t.expandedMessage, t.errorId) - case _ => - log.error( - "An exception occurred in the driver. The exception was written for Python but we will re-throw to fail this driver job.", - t, - ) - handleForPython(t) - } - - using(env.fs.createNoCompression(env.outputUrl)) { outputStream => - val output = new HailSocketAPIOutputStream(outputStream) - output.writeBool(false) - output.writeString(shortMessage) - output.writeString(expandedMessage) - output.writeInt(errorId) - } - - throw t - } - } - - implicit override object Context extends Context { - override def scoped[A](env: Env)(f: ExecuteContext => A): (A, Timings) = - env.backend.withExecuteContext(f) - } -} - -private class HailSocketAPIOutputStream( - private[this] val out: OutputStream -) extends AutoCloseable { - private[this] var closed: Boolean = false - private[this] val dummy = new Array[Byte](8) - - def writeBool(b: Boolean): Unit = - out.write(if (b) 1 else 0) - - def writeInt(v: Int): Unit = { - Memory.storeInt(dummy, 0, v) - out.write(dummy, 0, 4) - } - - def writeLong(v: Long): Unit = { - Memory.storeLong(dummy, 0, v) - out.write(dummy) - } - - def writeBytes(bytes: Array[Byte]): Unit = { - writeInt(bytes.length) - out.write(bytes) - } - - def writeString(s: String): Unit = writeBytes(s.getBytes(StandardCharsets.UTF_8)) - - def close(): Unit = - if (!closed) { - out.close() - closed = true - } -} - -case class SequenceConfig(fasta: String, index: String) - -case class ServiceBackendRPCPayload( - tmp_dir: String, - flags: Map[String, String], - custom_references: Array[String], - liftovers: Map[String, Map[String, String]], - sequences: Map[String, SequenceConfig], -) - -case class BatchJobConfig( - token: String, - billing_project: String, - worker_cores: String, - worker_memory: String, - storage: String, - cloudfuse_configs: Array[CloudfuseConfig], - regions: Array[String], -) diff --git a/hail/hail/src/is/hail/backend/service/Worker.scala b/hail/hail/src/is/hail/backend/service/Worker.scala index 9c4a459591a..0b3ce24cdc9 100644 --- a/hail/hail/src/is/hail/backend/service/Worker.scala +++ b/hail/hail/src/is/hail/backend/service/Worker.scala @@ -168,19 +168,7 @@ object Worker { timer.start("executeFunction") // FIXME: workers should not have backends, but some things do need hail contexts - val backend = new ServiceBackend( - null, - null, - null, - new HailClassLoader(getClass().getClassLoader()), - null, - null, - null, - null, - null, - null, - ) - + val backend = new ServiceBackend(null, null, null, null, null) if (HailContext.isInitialized) HailContext.get.backend = backend else HailContext(backend) diff --git a/hail/hail/src/is/hail/backend/spark/SparkBackend.scala b/hail/hail/src/is/hail/backend/spark/SparkBackend.scala index 305a7809640..8ea2b4a2b58 100644 --- a/hail/hail/src/is/hail/backend/spark/SparkBackend.scala +++ b/hail/hail/src/is/hail/backend/spark/SparkBackend.scala @@ -1,28 +1,23 @@ package is.hail.backend.spark -import is.hail.{HailContext, HailFeatureFlags} +import is.hail.HailContext import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ -import is.hail.backend.py4j.Py4JBackendExtensions import is.hail.expr.Validate import is.hail.expr.ir._ -import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs.MakeTuple import is.hail.expr.ir.lowering._ import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.fs._ -import is.hail.linalg.BlockMatrix import is.hail.rvd.RVD import is.hail.types._ import is.hail.types.physical.{PStruct, PTuple} import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual._ import is.hail.utils._ -import is.hail.utils.ExecutionTimer.Timings -import is.hail.variant.ReferenceGenome import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -34,7 +29,6 @@ import java.io.PrintWriter import com.fasterxml.jackson.core.StreamReadConstraints import org.apache.hadoop -import org.apache.hadoop.conf.Configuration import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -219,15 +213,11 @@ object SparkBackend { append: Boolean = false, skipLoggingConfiguration: Boolean = false, minBlockSize: Long = 1L, - tmpdir: String = "/tmp", - localTmpdir: String = "file:///tmp", - gcsRequesterPaysProject: String = null, - gcsRequesterPaysBuckets: String = null, ): SparkBackend = synchronized { if (theSparkBackend == null) return SparkBackend(sc, appName, master, local, logFile, quiet, append, skipLoggingConfiguration, - minBlockSize, tmpdir, localTmpdir, gcsRequesterPaysProject, gcsRequesterPaysBuckets) + minBlockSize) // there should be only one SparkContext assert(sc == null || (sc eq theSparkBackend.sc)) @@ -263,10 +253,6 @@ object SparkBackend { append: Boolean = false, skipLoggingConfiguration: Boolean = false, minBlockSize: Long = 1L, - tmpdir: String, - localTmpdir: String, - gcsRequesterPaysProject: String = null, - gcsRequesterPaysBuckets: String = null, ): SparkBackend = synchronized { require(theSparkBackend == null) @@ -281,32 +267,18 @@ object SparkBackend { checkSparkConfiguration(sc1) - if (!quiet) - ProgressBarBuilder.build(sc1) + if (!quiet) ProgressBarBuilder.build(sc1) sc1.uiWebUrl.foreach(ui => info(s"SparkUI: $ui")) - theSparkBackend = - new SparkBackend( - tmpdir, - localTmpdir, - sc1, - mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*), - gcsRequesterPaysProject, - gcsRequesterPaysBuckets, - ) + theSparkBackend = new SparkBackend(sc1) theSparkBackend } def stop(): Unit = synchronized { if (theSparkBackend != null) { - theSparkBackend.sc.stop() + theSparkBackend.close() theSparkBackend = null - // Hadoop does not honor the hadoop configuration as a component of the cache key for file - // systems, so we blow away the cache so that a new configuration can successfully take - // effect. - // https://github.com/hail-is/hail/pull/12133#issuecomment-1241322443 - hadoop.fs.FileSystem.closeAll() } } } @@ -317,104 +289,31 @@ class AnonymousDependency[T](val _rdd: RDD[T]) extends NarrowDependency[T](_rdd) override def getParents(partitionId: Int): Seq[Int] = Seq.empty } -class SparkBackend( - val tmpdir: String, - val localTmpdir: String, - val sc: SparkContext, - override val references: mutable.Map[String, ReferenceGenome], - gcsRequesterPaysProject: String, - gcsRequesterPaysBuckets: String, -) extends Backend with Py4JBackendExtensions { - - assert(gcsRequesterPaysProject != null || gcsRequesterPaysBuckets == null) - lazy val sparkSession: SparkSession = SparkSession.builder().config(sc.getConf).getOrCreate() +class SparkBackend(val sc: SparkContext) extends Backend { - private[this] val theHailClassLoader: HailClassLoader = - new HailClassLoader(getClass().getClassLoader()) + private case class Context( + maxStageParallelism: Int, + override val executionCache: ExecutionCache, + ) extends BackendContext - override def canExecuteParallelTasksOnDriver: Boolean = false - - val fs: HadoopFS = { - val conf = new Configuration(sc.hadoopConfiguration) - if (gcsRequesterPaysProject != null) { - if (gcsRequesterPaysBuckets == null) { - conf.set("fs.gs.requester.pays.mode", "AUTO") - conf.set("fs.gs.requester.pays.project.id", gcsRequesterPaysProject) - } else { - conf.set("fs.gs.requester.pays.mode", "CUSTOM") - conf.set("fs.gs.requester.pays.project.id", gcsRequesterPaysProject) - conf.set("fs.gs.requester.pays.buckets", gcsRequesterPaysBuckets) - } + val sparkSession: Lazy[SparkSession] = + lazily { + SparkSession.builder().config(sc.getConf).getOrCreate() } - new HadoopFS(new SerializableHadoopConfiguration(conf)) - } - - override def backend: Backend = this - override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() - - override val longLifeTempFileManager: TempFileManager = - new OwningTempFileManager(fs) - - private[this] val bmCache = mutable.Map.empty[String, BlockMatrix] - private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) - private[this] val persistedIr = mutable.Map.empty[Int, BaseIR] - private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32) - - def createExecuteContextForTests( - timer: ExecutionTimer, - region: Region, - ): ExecuteContext = - new ExecuteContext( - tmpdir, - localTmpdir, - this, - references.toMap, - fs, - region, - timer, - null, - theHailClassLoader, - flags, - new BackendContext { - override val executionCache: ExecutionCache = - ExecutionCache.forTesting - }, - new IrMetadata(), - ImmutableMap.empty, - ImmutableMap.empty, - ImmutableMap.empty, - ImmutableMap.empty, - ) - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) = - ExecutionTimer.time { timer => - ExecuteContext.scoped( - tmpdir, - localTmpdir, - this, - references.toMap, - fs, - timer, - null, - theHailClassLoader, - flags, - new BackendContext { - override val executionCache: ExecutionCache = - ExecutionCache.fromFlags(flags, fs, tmpdir) - }, - new IrMetadata(), - bmCache, - codeCache, - persistedIr, - coercerCache, - )(f) - } + override def canExecuteParallelTasksOnDriver: Boolean = false def broadcast[T: ClassTag](value: T): BroadcastValue[T] = new SparkBroadcastValue[T](sc.broadcast(value)) + override def backendContext(ctx: ExecuteContext): BackendContext = + Context( + ctx.flags.get(SparkBackend.Flags.MaxStageParallelism).toInt, + ExecutionCache.fromFlags(ctx.flags, ctx.fs, ctx.tmpdir), + ) + override def parallelizeAndComputeWithIndex( - backendContext: BackendContext, + ctx: BackendContext, fs: FS, contexts: IndexedSeq[Array[Byte]], stageIdentifier: String, @@ -444,13 +343,12 @@ class SparkBackend( } } - val chunkSize = flags.get(SparkBackend.Flags.MaxStageParallelism).toInt val partsToRun = partitions.getOrElse(contexts.indices) val buffer = new ArrayBuffer[(Array[Byte], Int)](partsToRun.length) var failure: Option[Throwable] = None try { - for (subparts <- partsToRun.grouped(chunkSize)) { + for (subparts <- partsToRun.grouped(ctx.asInstanceOf[Context].maxStageParallelism)) { sc.runJob( rdd, (_: TaskContext, it: Iterator[Array[Byte]]) => it.next(), @@ -459,8 +357,11 @@ class SparkBackend( ) } } catch { - case e: ExecutionException => failure = failure.orElse(Some(e.getCause)) case NonFatal(t) => failure = failure.orElse(Some(t)) + case e: ExecutionException => failure = failure.orElse(Some(e.getCause)) + case _: InterruptedException => + sc.cancelAllJobs() + Thread.currentThread().interrupt() } (failure, buffer.sortBy(_._2)) @@ -470,11 +371,16 @@ class SparkBackend( override def asSpark(implicit E: Enclosing): SparkBackend = this - def close(): Unit = { - bmCache.values.foreach(_.unpersist()) - SparkBackend.stop() - longLifeTempFileManager.close() - } + def close(): Unit = + synchronized { + if (sparkSession.isEvaluated) sparkSession.close() + sc.stop() + // Hadoop does not honor the hadoop configuration as a component of the cache key for file + // systems, so we blow away the cache so that a new configuration can successfully take + // effect. + // https://github.com/hail-is/hail/pull/12133#issuecomment-1241322443 + hadoop.fs.FileSystem.closeAll() + } def startProgressBar(): Unit = ProgressBarBuilder.build(sc) @@ -507,8 +413,7 @@ class SparkBackend( case (false, true) => DArrayLowering.BMOnly case (false, false) => throw new LowererUnsupportedOperation("no lowering enabled") } - val ir = - LoweringPipeline.darrayLowerer(optimize)(typesToLower).apply(ctx, ir0).asInstanceOf[IR] + val ir = LoweringPipeline.darrayLowerer(optimize)(typesToLower)(ctx, ir0).asInstanceOf[IR] if (!Compilable(ir)) throw new LowererUnsupportedOperation(s"lowered to uncompilable IR: ${Pretty(ctx, ir)}") @@ -546,11 +451,11 @@ class SparkBackend( Validate(ir) ctx.irMetadata.semhash = SemanticHash(ctx)(ir) try { - val lowerTable = flags.get("lower") != null - val lowerBM = flags.get("lower_bm") != null + val lowerTable = ctx.flags.get("lower") != null + val lowerBM = ctx.flags.get("lower_bm") != null _jvmLowerAndExecute(ctx, ir, optimize = true, lowerTable, lowerBM) } catch { - case e: LowererUnsupportedOperation if flags.get("lower_only") != null => throw e + case e: LowererUnsupportedOperation if ctx.flags.get("lower_only") != null => throw e case _: LowererUnsupportedOperation => CompileAndEvaluate._apply(ctx, ir, optimize = true) } @@ -563,7 +468,7 @@ class SparkBackend( rt: RTable, nPartitions: Option[Int], ): TableReader = { - if (flags.get("use_new_shuffle") != null) + if (ctx.flags.get("use_new_shuffle") != null) return LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt) val (globals, rvd) = TableStageToRVD(ctx, stage) diff --git a/hail/hail/src/is/hail/expr/ir/Emit.scala b/hail/hail/src/is/hail/expr/ir/Emit.scala index 07cabb71695..3b472e48bb8 100644 --- a/hail/hail/src/is/hail/expr/ir/Emit.scala +++ b/hail/hail/src/is/hail/expr/ir/Emit.scala @@ -3381,7 +3381,7 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { Array[Array[Byte]], ]( "collectDArray", - mb.getObject(ctx.executeContext.backendContext), + mb.getObject(ctx.executeContext.backend.backendContext(ctx.executeContext)), mb.getHailClassLoader, mb.getFS, functionID, diff --git a/hail/hail/src/is/hail/io/vcf/LoadVCF.scala b/hail/hail/src/is/hail/io/vcf/LoadVCF.scala index e516052ce58..f40ae888796 100644 --- a/hail/hail/src/is/hail/io/vcf/LoadVCF.scala +++ b/hail/hail/src/is/hail/io/vcf/LoadVCF.scala @@ -1806,7 +1806,7 @@ object MatrixVCFReader { val fsConfigBC = backend.broadcast(fs.getConfiguration()) val (failureOpt, _) = backend.parallelizeAndComputeWithIndex( - ctx.backendContext, + ctx.backend.backendContext(ctx), fs, files.tail.map(_.getBytes), "load_vcf_parse_header", diff --git a/hail/hail/src/is/hail/utils/package.scala b/hail/hail/src/is/hail/utils/package.scala index 8b6affd9801..5b5d2b33022 100644 --- a/hail/hail/src/is/hail/utils/package.scala +++ b/hail/hail/src/is/hail/utils/package.scala @@ -8,7 +8,7 @@ import scala.collection.{mutable, GenTraversableOnce, TraversableOnce} import scala.collection.generic.CanBuildFrom import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionException -import scala.language.higherKinds +import scala.language.{higherKinds, implicitConversions} import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -90,6 +90,23 @@ package utils { b.result() } } + + class Lazy[A] private[utils] (f: => A) { + private[this] var option: Option[A] = None + + def apply(): A = + synchronized { + option match { + case Some(a) => a + case None => val a = f; option = Some(a); a + } + } + + def isEvaluated: Boolean = + synchronized { + option.isDefined + } + } } package object utils @@ -1027,7 +1044,7 @@ package object utils val buffer = new mutable.ArrayBuffer[(A, Int)](tasks.length) val completer = new ExecutorCompletionService[(A, Int)](executor) - tasks.foreach { case (t, k) => completer.submit(() => t() -> k) } + val futures = tasks.map { case (t, k) => completer.submit(() => t() -> k) } tasks.foreach { _ => try buffer += completer.take().get() catch { @@ -1035,6 +1052,9 @@ package object utils err = accum(err, e.getCause) case NonFatal(ex) => err = accum(err, ex) + case _: InterruptedException => + futures.foreach(_.cancel(true)) + Thread.currentThread().interrupt() } } @@ -1044,6 +1064,12 @@ package object utils def runAllKeepFirstError[A](executor: ExecutorService) : IndexedSeq[(() => A, Int)] => (Option[Throwable], IndexedSeq[(A, Int)]) = runAll[Option, A](executor) { case (opt, e) => opt.orElse(Some(e)) }(None) + + def lazily[A](f: => A): Lazy[A] = + new Lazy(f) + + implicit def evalLazy[A](f: Lazy[A]): A = + f() } class CancellingExecutorService(delegate: ExecutorService) extends AbstractExecutorService { diff --git a/hail/hail/src/is/hail/variant/ReferenceGenome.scala b/hail/hail/src/is/hail/variant/ReferenceGenome.scala index 88e07e04073..20549122427 100644 --- a/hail/hail/src/is/hail/variant/ReferenceGenome.scala +++ b/hail/hail/src/is/hail/variant/ReferenceGenome.scala @@ -662,4 +662,14 @@ object ReferenceGenome { existing += (rg.name -> rg) } + + def addFatalOnCollision( + existing: Map[String, ReferenceGenome], + newReferences: IndexedSeq[ReferenceGenome], + ): Map[String, ReferenceGenome] = { + val mut = mutable.Map(existing.toSeq: _*) + addFatalOnCollision(mut, newReferences) + mut.toMap + } + } diff --git a/hail/hail/test/src/is/hail/HailSuite.scala b/hail/hail/test/src/is/hail/HailSuite.scala index bff392dcd6a..08f544388f8 100644 --- a/hail/hail/test/src/is/hail/HailSuite.scala +++ b/hail/hail/test/src/is/hail/HailSuite.scala @@ -1,9 +1,9 @@ package is.hail import is.hail.ExecStrategy.ExecStrategy -import is.hail.TestUtils._ import is.hail.annotations._ -import is.hail.backend.{BroadcastValue, ExecuteContext} +import is.hail.asm4s.HailClassLoader +import is.hail.backend.{Backend, ExecuteContext} import is.hail.backend.spark.SparkBackend import is.hail.expr.ir.{ bindIR, freshName, rangeIR, BindingEnv, BlockMatrixIR, Compilable, Env, Forall, IR, Interpret, @@ -13,22 +13,33 @@ import is.hail.expr.ir.defs.{ BlockMatrixCollect, Cast, MakeTuple, NDArrayRef, Ref, StreamMap, ToArray, } import is.hail.io.fs.FS +import is.hail.expr.ir.{TestUtils => _, _} +import is.hail.expr.ir.lowering.IrMetadata +import is.hail.io.fs.{FS, HadoopFS} import is.hail.types.virtual._ import is.hail.utils._ +import is.hail.variant.ReferenceGenome import java.io.{File, PrintWriter} import breeze.linalg.DenseMatrix +import org.apache.hadoop +import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext import org.apache.spark.sql.Row import org.scalatestplus.testng.TestNGSuite import org.testng.ITestContext -import org.testng.annotations.{AfterMethod, BeforeClass, BeforeMethod} +import org.testng.annotations.{AfterClass, AfterMethod, BeforeClass, BeforeMethod} object HailSuite { - val theHailClassLoader = TestUtils.theHailClassLoader - def withSparkBackend(): HailContext = { + val theHailClassLoader: HailClassLoader = + new HailClassLoader(getClass.getClassLoader) + + val flags: HailFeatureFlags = + HailFeatureFlags.fromEnv(sys.env + ("lower" -> "1")) + + lazy val hc: HailContext = { HailContext.configureLogging("/tmp/hail.log", quiet = false, append = false) val backend = SparkBackend( sc = new SparkContext( @@ -40,63 +51,71 @@ object HailSuite { ) .set("spark.unsafe.exceptionOnMemoryLeak", "true") ), - tmpdir = "/tmp", - localTmpdir = "file:///tmp", skipLoggingConfiguration = true, ) - HailContext(backend) - } - - lazy val hc: HailContext = { - val hc = withSparkBackend() - hc.backend.asSpark.flags.set("lower", "1") + val hc = HailContext(backend) hc.checkRVDKeys = true hc } } -class HailSuite extends TestNGSuite { - val theHailClassLoader = HailSuite.theHailClassLoader +class HailSuite extends TestNGSuite with TestUtils { def hc: HailContext = HailSuite.hc - @BeforeClass def ensureHailContextInitialized(): Unit = hc - - def backend: SparkBackend = hc.backend.asSpark - - def sc: SparkContext = backend.sc - - def fs: FS = backend.fs - - def fsBc: BroadcastValue[FS] = fs.broadcast - - var timer: ExecutionTimer = _ + @BeforeClass + def initFs(): Unit = { + val conf = new Configuration(sc.hadoopConfiguration) + fs = new HadoopFS(new SerializableHadoopConfiguration(conf)) + } - var ctx: ExecuteContext = _ + @AfterClass + def closeFS(): Unit = + hadoop.fs.FileSystem.closeAll() + var fs: FS = _ var pool: RegionPool = _ + private[this] var ctx_ : ExecuteContext = _ + + def backend: Backend = hc.backend + def sc: SparkContext = hc.backend.asSpark.sc + def timer: ExecutionTimer = ctx.timer + def theHailClassLoader: HailClassLoader = ctx.theHailClassLoader + override def ctx: ExecuteContext = ctx_ def getTestResource(localPath: String): String = s"hail/test/resources/$localPath" @BeforeMethod def setupContext(context: ITestContext): Unit = { - assert(timer == null) - timer = new ExecutionTimer("HailSuite") - assert(ctx == null) pool = RegionPool() - ctx = backend.createExecuteContextForTests(timer, Region(pool = pool)) + ctx_ = new ExecuteContext( + tmpdir = "/tmp", + localTmpdir = "file:///tmp", + backend = hc.backend, + references = ReferenceGenome.builtinReferences(), + fs = fs, + r = Region(pool = pool), + timer = new ExecutionTimer(context.getName), + _tempFileManager = null, + theHailClassLoader = HailSuite.theHailClassLoader, + flags = HailSuite.flags, + irMetadata = new IrMetadata(), + BlockMatrixCache = ImmutableMap.empty, + CodeCache = ImmutableMap.empty, + IrCache = ImmutableMap.empty, + CoercerCache = ImmutableMap.empty, + ) } @AfterMethod def tearDownContext(context: ITestContext): Unit = { - ctx.close() - ctx = null - timer.finish() - timer = null + ctx_.timer.finish() + ctx_.close() + ctx_ = null pool.close() - if (backend.sc.isStopped) - throw new RuntimeException(s"method stopped spark context!") + if (sc.isStopped) + throw new RuntimeException(s"'${context.getName}' stopped spark context!") } def assertEvalsTo( @@ -113,73 +132,71 @@ class HailSuite extends TestNGSuite { val t = x.typ assert(t == TVoid || t.typeCheck(expected), s"$t, $expected") - ExecuteContext.scoped { ctx => - val filteredExecStrats: Set[ExecStrategy] = - if (HailContext.backend.isInstanceOf[SparkBackend]) - execStrats - else { - info("skipping interpret and non-lowering compile steps on non-spark backend") - execStrats.intersect(ExecStrategy.backendOnly) - } + val filteredExecStrats: Set[ExecStrategy] = + if (HailContext.backend.isInstanceOf[SparkBackend]) + execStrats + else { + info("skipping interpret and non-lowering compile steps on non-spark backend") + execStrats.intersect(ExecStrategy.backendOnly) + } - filteredExecStrats.foreach { strat => - try { - val res = strat match { - case ExecStrategy.Interpret => - assert(agg.isEmpty) - Interpret[Any](ctx, x, env, args) - case ExecStrategy.InterpretUnoptimized => - assert(agg.isEmpty) - Interpret[Any](ctx, x, env, args, optimize = false) - case ExecStrategy.JvmCompile => - assert(Forall(x, node => Compilable(node))) - eval( - x, - env, - args, - agg, - bytecodePrinter = - Option(ctx.getFlag("jvm_bytecode_dump")) - .map { path => - val pw = new PrintWriter(new File(path)) - pw.print(s"/* JVM bytecode dump for IR:\n${Pretty(ctx, x)}\n */\n\n") - pw - }, - true, - ctx, - ) - case ExecStrategy.JvmCompileUnoptimized => - assert(Forall(x, node => Compilable(node))) - eval( - x, - env, - args, - agg, - bytecodePrinter = - Option(ctx.getFlag("jvm_bytecode_dump")) - .map { path => - val pw = new PrintWriter(new File(path)) - pw.print(s"/* JVM bytecode dump for IR:\n${Pretty(ctx, x)}\n */\n\n") - pw - }, - optimize = false, - ctx, - ) - case ExecStrategy.LoweredJVMCompile => - loweredExecute(ctx, x, env, args, agg) - } - if (t != TVoid) { - assert(t.typeCheck(res), s"\n t=$t\n result=$res\n strategy=$strat") - assert( - t.valuesSimilar(res, expected), - s"\n result=$res\n expect=$expected\n strategy=$strat)", + filteredExecStrats.foreach { strat => + try { + val res = strat match { + case ExecStrategy.Interpret => + assert(agg.isEmpty) + Interpret[Any](ctx, x, env, args) + case ExecStrategy.InterpretUnoptimized => + assert(agg.isEmpty) + Interpret[Any](ctx, x, env, args, optimize = false) + case ExecStrategy.JvmCompile => + assert(Forall(x, node => Compilable(node))) + eval( + x, + env, + args, + agg, + bytecodePrinter = + Option(ctx.getFlag("jvm_bytecode_dump")) + .map { path => + val pw = new PrintWriter(new File(path)) + pw.print(s"/* JVM bytecode dump for IR:\n${Pretty(ctx, x)}\n */\n\n") + pw + }, + true, + ctx, + ) + case ExecStrategy.JvmCompileUnoptimized => + assert(Forall(x, node => Compilable(node))) + eval( + x, + env, + args, + agg, + bytecodePrinter = + Option(ctx.getFlag("jvm_bytecode_dump")) + .map { path => + val pw = new PrintWriter(new File(path)) + pw.print(s"/* JVM bytecode dump for IR:\n${Pretty(ctx, x)}\n */\n\n") + pw + }, + optimize = false, + ctx, ) - } - } catch { - case e: Exception => - error(s"error from strategy $strat") - if (execStrats.contains(strat)) throw e + case ExecStrategy.LoweredJVMCompile => + loweredExecute(ctx, x, env, args, agg) + } + if (t != TVoid) { + assert(t.typeCheck(res), s"\n t=$t\n result=$res\n strategy=$strat") + assert( + t.valuesSimilar(res, expected), + s"\n result=$res\n expect=$expected\n strategy=$strat)", + ) } + } catch { + case e: Exception => + error(s"error from strategy $strat") + if (execStrats.contains(strat)) throw e } } } @@ -258,35 +275,33 @@ class HailSuite extends TestNGSuite { expected: DenseMatrix[Double], )(implicit execStrats: Set[ExecStrategy] ): Unit = { - ExecuteContext.scoped { ctx => - val filteredExecStrats: Set[ExecStrategy] = - if (HailContext.backend.isInstanceOf[SparkBackend]) execStrats - else { - info("skipping interpret and non-lowering compile steps on non-spark backend") - execStrats.intersect(ExecStrategy.backendOnly) - } - filteredExecStrats.filter(ExecStrategy.interpretOnly).foreach { strat => - try { - val res = strat match { - case ExecStrategy.Interpret => - Interpret(bm, ctx, optimize = true) - case ExecStrategy.InterpretUnoptimized => - Interpret(bm, ctx, optimize = false) - } - assert(res.toBreezeMatrix() == expected) - } catch { - case e: Exception => - error(s"error from strategy $strat") - if (execStrats.contains(strat)) throw e + val filteredExecStrats: Set[ExecStrategy] = + if (HailContext.backend.isInstanceOf[SparkBackend]) execStrats + else { + info("skipping interpret and non-lowering compile steps on non-spark backend") + execStrats.intersect(ExecStrategy.backendOnly) + } + filteredExecStrats.filter(ExecStrategy.interpretOnly).foreach { strat => + try { + val res = strat match { + case ExecStrategy.Interpret => + Interpret(bm, ctx, optimize = true) + case ExecStrategy.InterpretUnoptimized => + Interpret(bm, ctx, optimize = false) } + assert(res.toBreezeMatrix() == expected) + } catch { + case e: Exception => + error(s"error from strategy $strat") + if (execStrats.contains(strat)) throw e } - val expectedArray = Array.tabulate(expected.rows)(i => - Array.tabulate(expected.cols)(j => expected(i, j)).toFastSeq - ).toFastSeq - assertNDEvals(BlockMatrixCollect(bm), expectedArray)( - filteredExecStrats.filterNot(ExecStrategy.interpretOnly) - ) } + val expectedArray = Array.tabulate(expected.rows)(i => + Array.tabulate(expected.cols)(j => expected(i, j)).toFastSeq + ).toFastSeq + assertNDEvals(BlockMatrixCollect(bm), expectedArray)( + filteredExecStrats.filterNot(ExecStrategy.interpretOnly) + ) } def assertAllEvalTo( diff --git a/hail/hail/test/src/is/hail/TestUtils.scala b/hail/hail/test/src/is/hail/TestUtils.scala index 550be49398a..cd8965d514a 100644 --- a/hail/hail/test/src/is/hail/TestUtils.scala +++ b/hail/hail/test/src/is/hail/TestUtils.scala @@ -45,8 +45,9 @@ object ExecStrategy extends Enumeration { val allRelational: Set[ExecStrategy] = interpretOnly.union(lowering) } -object TestUtils { - val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) +trait TestUtils { + + def ctx: ExecuteContext = ??? import org.scalatest.Assertions._ @@ -99,7 +100,7 @@ object TestUtils { def removeConstantCols(A: DenseMatrix[Int]): DenseMatrix[Int] = { val data = (0 until A.cols).flatMap { j => val col = A(::, j) - if (TestUtils.isConstant(col)) + if (isConstant(col)) Array[Int]() else col.toArray @@ -128,18 +129,14 @@ object TestUtils { print = bytecodePrinter) } - def eval(x: IR): Any = ExecuteContext.scoped { ctx => - eval(x, Env.empty, FastSeq(), None, None, true, ctx) - } - def eval( x: IR, - env: Env[(Any, Type)], - args: IndexedSeq[(Any, Type)], - agg: Option[(IndexedSeq[Row], TStruct)], + env: Env[(Any, Type)] = Env.empty, + args: IndexedSeq[(Any, Type)] = FastSeq(), + agg: Option[(IndexedSeq[Row], TStruct)] = None, bytecodePrinter: Option[PrintWriter] = None, optimize: Boolean = true, - ctx: ExecuteContext, + ctx: ExecuteContext = ctx, ): Any = { val inputTypesB = new BoxedArrayBuilder[Type]() val inputsB = new mutable.ArrayBuffer[Any]() @@ -211,26 +208,29 @@ object TestUtils { assert(resultType2.virtualType == resultType) ctx.r.pool.scopedRegion { region => - val rvb = new RegionValueBuilder(ctx.stateManager, region) - rvb.start(argsPType) - rvb.startTuple() - var i = 0 - while (i < inputsB.length) { - rvb.addAnnotation(inputTypesB(i), inputsB(i)) - i += 1 + ctx.local(r = region) { ctx => + val rvb = new RegionValueBuilder(ctx.stateManager, ctx.r) + rvb.start(argsPType) + rvb.startTuple() + var i = 0 + while (i < inputsB.length) { + rvb.addAnnotation(inputTypesB(i), inputsB(i)) + i += 1 + } + rvb.endTuple() + val argsOff = rvb.end() + + rvb.start(aggArrayPType) + rvb.startArray(aggElements.length) + aggElements.foreach(r => rvb.addAnnotation(aggType, r)) + rvb.endArray() + val aggOff = rvb.end() + + ctx.scopedExecution { (hcl, fs, tc, r) => + val off = f(hcl, fs, tc, r)(r, argsOff, aggOff) + SafeRow(resultType2.asInstanceOf[PBaseStruct], off).get(0) + } } - rvb.endTuple() - val argsOff = rvb.end() - - rvb.start(aggArrayPType) - rvb.startArray(aggElements.length) - aggElements.foreach(r => rvb.addAnnotation(aggType, r)) - rvb.endArray() - val aggOff = rvb.end() - - val resultOff = - f(theHailClassLoader, ctx.fs, ctx.taskContext, region)(region, argsOff, aggOff) - SafeRow(resultType2.asInstanceOf[PBaseStruct], resultOff).get(0) } case None => @@ -250,19 +250,22 @@ object TestUtils { assert(resultType2.virtualType == resultType) ctx.r.pool.scopedRegion { region => - val rvb = new RegionValueBuilder(ctx.stateManager, region) - rvb.start(argsPType) - rvb.startTuple() - var i = 0 - while (i < inputsB.length) { - rvb.addAnnotation(inputTypesB(i), inputsB(i)) - i += 1 + ctx.local(r = region) { ctx => + val rvb = new RegionValueBuilder(ctx.stateManager, ctx.r) + rvb.start(argsPType) + rvb.startTuple() + var i = 0 + while (i < inputsB.length) { + rvb.addAnnotation(inputTypesB(i), inputsB(i)) + i += 1 + } + rvb.endTuple() + val argsOff = rvb.end() + ctx.scopedExecution { (hcl, fs, tc, r) => + val resultOff = f(hcl, fs, tc, r)(r, argsOff) + SafeRow(resultType2.asInstanceOf[PBaseStruct], resultOff).get(0) + } } - rvb.endTuple() - val argsOff = rvb.end() - - val resultOff = f(theHailClassLoader, ctx.fs, ctx.taskContext, region)(region, argsOff) - SafeRow(resultType2.asInstanceOf[PBaseStruct], resultOff).get(0) } } } @@ -276,7 +279,7 @@ object TestUtils { def assertEvalSame(x: IR, env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)]): Unit = { val t = x.typ - val (i, i2, c) = ExecuteContext.scoped { ctx => + val (i, i2, c) = { val i = Interpret[Any](ctx, x, env, args) val i2 = Interpret[Any](ctx, x, env, args, optimize = false) val c = eval(x, env, args, None, None, true, ctx) @@ -299,12 +302,11 @@ object TestUtils { env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)], regex: String, - ): Unit = - ExecuteContext.scoped { ctx => - interceptException[E](regex)(Interpret[Any](ctx, x, env, args)) - interceptException[E](regex)(Interpret[Any](ctx, x, env, args, optimize = false)) - interceptException[E](regex)(eval(x, env, args, None, None, true, ctx)) - } + ): Unit = { + interceptException[E](regex)(Interpret[Any](ctx, x, env, args)) + interceptException[E](regex)(Interpret[Any](ctx, x, env, args, optimize = false)) + interceptException[E](regex)(eval(x, env, args, None, None, true, ctx)) + } def assertFatal(x: IR, regex: String): Unit = assertThrows[HailException](x, regex) @@ -322,9 +324,7 @@ object TestUtils { args: IndexedSeq[(Any, Type)], regex: String, ): Unit = - ExecuteContext.scoped { ctx => - interceptException[E](regex)(eval(x, env, args, None, None, true, ctx)) - } + interceptException[E](regex)(eval(x, env, args, None, None, true, ctx)) def assertCompiledThrows[E <: Throwable: Manifest](x: IR, regex: String): Unit = assertCompiledThrows[E](x, Env.empty[(Any, Type)], FastSeq.empty[(Any, Type)], regex) diff --git a/hail/hail/test/src/is/hail/TestUtilsSuite.scala b/hail/hail/test/src/is/hail/TestUtilsSuite.scala index a6c91c5aef5..2c2b0c08406 100644 --- a/hail/hail/test/src/is/hail/TestUtilsSuite.scala +++ b/hail/hail/test/src/is/hail/TestUtilsSuite.scala @@ -11,21 +11,21 @@ class TestUtilsSuite extends HailSuite { val V = DenseVector(0d, 1d) val V1 = DenseVector(0d, 0.5d) - TestUtils.assertMatrixEqualityDouble(M, DenseMatrix.eye(2)) - TestUtils.assertMatrixEqualityDouble(M, M1, 0.001) - TestUtils.assertVectorEqualityDouble(V, 2d * V1) + assertMatrixEqualityDouble(M, DenseMatrix.eye(2)) + assertMatrixEqualityDouble(M, M1, 0.001) + assertVectorEqualityDouble(V, 2d * V1) - intercept[Exception](TestUtils.assertVectorEqualityDouble(V, V1)) - intercept[Exception](TestUtils.assertMatrixEqualityDouble(M, M1)) + intercept[Exception](assertVectorEqualityDouble(V, V1)) + intercept[Exception](assertMatrixEqualityDouble(M, M1)) } @Test def constantVectorTest(): Unit = { - assert(TestUtils.isConstant(DenseVector())) - assert(TestUtils.isConstant(DenseVector(0))) - assert(TestUtils.isConstant(DenseVector(0, 0))) - assert(TestUtils.isConstant(DenseVector(0, 0, 0))) - assert(!TestUtils.isConstant(DenseVector(0, 1))) - assert(!TestUtils.isConstant(DenseVector(0, 0, 1))) + assert(isConstant(DenseVector())) + assert(isConstant(DenseVector(0))) + assert(isConstant(DenseVector(0, 0))) + assert(isConstant(DenseVector(0, 0, 0))) + assert(!isConstant(DenseVector(0, 1))) + assert(!isConstant(DenseVector(0, 0, 1))) } @Test def removeConstantColsTest(): Unit = { @@ -33,6 +33,6 @@ class TestUtilsSuite extends HailSuite { val M1 = DenseMatrix((0, 1, 0), (1, 0, 1)) - assert(TestUtils.removeConstantCols(M) == M1) + assert(removeConstantCols(M) == M1) } } diff --git a/hail/hail/test/src/is/hail/annotations/UnsafeSuite.scala b/hail/hail/test/src/is/hail/annotations/UnsafeSuite.scala index dd5d5b4e1e6..8c14e2bb7e7 100644 --- a/hail/hail/test/src/is/hail/annotations/UnsafeSuite.scala +++ b/hail/hail/test/src/is/hail/annotations/UnsafeSuite.scala @@ -54,7 +54,7 @@ class UnsafeSuite extends HailSuite { @DataProvider(name = "codecs") def codecs(): Array[Array[Any]] = - ExecuteContext.scoped(ctx => codecs(ctx)) + codecs(ctx) def codecs(ctx: ExecuteContext): Array[Array[Any]] = (BufferSpec.specs ++ Array(TypedCodecSpec( diff --git a/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala b/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala index 3139849a54f..9b9c2a2c67e 100644 --- a/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala +++ b/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala @@ -1,17 +1,13 @@ package is.hail.backend -import is.hail.HailFeatureFlags -import is.hail.asm4s.HailClassLoader -import is.hail.backend.service.{ - BatchJobConfig, ServiceBackend, ServiceBackendContext, ServiceBackendRPCPayload, -} -import is.hail.io.fs.{CloudStorageFSConfig, RouterFS} +import is.hail.HailSuite +import is.hail.backend.api.BatchJobConfig +import is.hail.backend.service.ServiceBackend import is.hail.services._ import is.hail.services.JobGroupStates.Success import is.hail.utils.{tokenUrlSafe, using} -import scala.collection.mutable -import scala.reflect.io.{Directory, Path} +import scala.reflect.io.Directory import scala.util.Random import java.io.Closeable @@ -21,135 +17,105 @@ import org.mockito.IdiomaticMockito import org.mockito.MockitoSugar.when import org.scalatest.OptionValues import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper -import org.scalatestplus.testng.TestNGSuite import org.testng.annotations.Test -class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionValues { +class ServiceBackendSuite extends HailSuite with IdiomaticMockito with OptionValues { @Test def testCreateJobPayload(): Unit = - withMockDriverContext { case (rpcConfig, jobConfig) => + withObjectSpied[is.hail.utils.UtilsType] { + // not obvious how to pull out `tokenUrlSafe` and inject this directory + // using a spy is a hack and i don't particularly like it. + when(is.hail.utils.tokenUrlSafe) thenAnswer "TOKEN" + + val jobConfig = BatchJobConfig( + token = tokenUrlSafe, + billing_project = "fancy", + worker_cores = "128", + worker_memory = "a lot.", + storage = "a big ssd?", + cloudfuse_configs = Array(), + regions = Array("lunar1"), + ) + val batchClient = mock[BatchClient] - using(ServiceBackend(batchClient, rpcConfig, jobConfig)) { backend => - val contexts = Array.tabulate(1)(_.toString.getBytes) - - // verify that - // - the number of jobs matches the number of partitions, and - // - each job is created in the specified region, and - // - each job's resource configuration matches the rpc config - - when(batchClient.newJobGroup(any[JobGroupRequest])) thenAnswer { - jobGroup: JobGroupRequest => - jobGroup.batch_id shouldBe backend.batchConfig.batchId - jobGroup.absolute_parent_id shouldBe backend.batchConfig.jobGroupId - val jobs = jobGroup.jobs - jobs.length shouldEqual contexts.length - jobs.foreach { payload => - payload.regions.value shouldBe jobConfig.regions - payload.resources.value shouldBe JobResources( - preemptible = true, - cpu = Some(jobConfig.worker_cores), - memory = Some(jobConfig.worker_memory), - storage = Some(jobConfig.storage), + using(ServiceBackend(batchClient, jobConfig)) { backend => + using(LocalTmpFolder) { tmp => + val contexts = Array.tabulate(1)(_.toString.getBytes) + + // verify that + // - the number of jobs matches the number of partitions, and + // - each job is created in the specified region, and + // - each job's resource configuration matches the rpc config + + when(batchClient.newJobGroup(any[JobGroupRequest])) thenAnswer { + jobGroup: JobGroupRequest => + jobGroup.batch_id shouldBe backend.batchConfig.batchId + jobGroup.absolute_parent_id shouldBe backend.batchConfig.jobGroupId + val jobs = jobGroup.jobs + jobs.length shouldEqual contexts.length + jobs.foreach { payload => + payload.regions.value shouldBe jobConfig.regions + payload.resources.value shouldBe JobResources( + preemptible = true, + cpu = Some(jobConfig.worker_cores), + memory = Some(jobConfig.worker_memory), + storage = Some(jobConfig.storage), + ) + } + + backend.batchConfig.jobGroupId + 1 + } + + // the service backend expects that each job write its output to a well-known + // location when it finishes. + when(batchClient.waitForJobGroup(any[Int], any[Int])) thenAnswer { + (id: Int, jobGroupId: Int) => + id shouldEqual backend.batchConfig.batchId + jobGroupId shouldEqual backend.batchConfig.jobGroupId + 1 + + val resultsDir = tmp / "parallelizeAndComputeWithIndex" / tokenUrlSafe + + resultsDir.createDirectory() + for (i <- contexts.indices) (resultsDir / f"result.$i").toFile.writeAll("11") + JobGroupResponse( + batch_id = id, + job_group_id = jobGroupId, + state = Success, + complete = true, + n_jobs = contexts.length, + n_completed = contexts.length, + n_succeeded = contexts.length, + n_failed = 0, + n_cancelled = 0, ) - } - - backend.batchConfig.jobGroupId + 1 - } + } + + ctx.local(tmpdir = tmp.toString()) { ctx => + val (failure, _) = + backend.parallelizeAndComputeWithIndex( + backend.backendContext(ctx), + ctx.fs, + contexts, + "stage1", + )((bytes, _, _, _) => bytes) + + failure.foreach(throw _) + } + batchClient.newJobGroup(any) wasCalled once + batchClient.waitForJobGroup(any, any) wasCalled once - // the service backend expects that each job write its output to a well-known - // location when it finishes. - when(batchClient.waitForJobGroup(any[Int], any[Int])) thenAnswer { - (id: Int, jobGroupId: Int) => - id shouldEqual backend.batchConfig.batchId - jobGroupId shouldEqual backend.batchConfig.jobGroupId + 1 - - val resultsDir = - Path(rpcConfig.tmp_dir) / - "parallelizeAndComputeWithIndex" / - tokenUrlSafe - - resultsDir.createDirectory() - for (i <- contexts.indices) (resultsDir / f"result.$i").toFile.writeAll("11") - JobGroupResponse( - batch_id = id, - job_group_id = jobGroupId, - state = Success, - complete = true, - n_jobs = contexts.length, - n_completed = contexts.length, - n_succeeded = contexts.length, - n_failed = 0, - n_cancelled = 0, - ) } - - val (failure, _) = - backend.parallelizeAndComputeWithIndex( - ServiceBackendContext( - remoteTmpDir = rpcConfig.tmp_dir, - jobConfig = jobConfig, - executionCache = ExecutionCache.noCache, - ), - backend.fs, - contexts, - "stage1", - )((bytes, _, _, _) => bytes) - - failure.foreach(throw _) - - batchClient.newJobGroup(any) wasCalled once - batchClient.waitForJobGroup(any, any) wasCalled once } } - def ServiceBackend( - client: BatchClient, - rpcConfig: ServiceBackendRPCPayload, - jobConfig: BatchJobConfig, - ): ServiceBackend = { - val flags = HailFeatureFlags.fromEnv() - val fs = RouterFS.buildRoutes(CloudStorageFSConfig()) + def ServiceBackend(client: BatchClient, jobConfig: BatchJobConfig): ServiceBackend = new ServiceBackend( name = "name", batchClient = client, jarSpec = GitRevision("123"), - theHailClassLoader = new HailClassLoader(getClass.getClassLoader), batchConfig = BatchConfig(batchId = Random.nextInt(), jobGroupId = Random.nextInt()), - rpcConfig = rpcConfig, jobConfig = jobConfig, - flags = flags, - fs = fs, - references = mutable.Map.empty, ) - } - - def withMockDriverContext(test: (ServiceBackendRPCPayload, BatchJobConfig) => Any): Any = - using(LocalTmpFolder) { tmp => - withObjectSpied[is.hail.utils.UtilsType] { - // not obvious how to pull out `tokenUrlSafe` and inject this directory - // using a spy is a hack and i don't particularly like it. - when(is.hail.utils.tokenUrlSafe) thenAnswer "TOKEN" - - test( - ServiceBackendRPCPayload( - tmp_dir = tmp.path, - flags = Map(), - custom_references = Array(), - liftovers = Map(), - sequences = Map(), - ), - BatchJobConfig( - token = tokenUrlSafe, - billing_project = "fancy", - worker_cores = "128", - worker_memory = "a lot.", - storage = "a big ssd?", - cloudfuse_configs = Array(), - regions = Array("lunar1"), - ), - ) - } - } def LocalTmpFolder: Directory with Closeable = new Directory(Directory.makeTemp("hail-testing-tmp").jfile) with Closeable { diff --git a/hail/hail/test/src/is/hail/expr/ir/ArrayFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/ArrayFunctionsSuite.scala index 4273f6e2dd6..e74fae67136 100644 --- a/hail/hail/test/src/is/hail/expr/ir/ArrayFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/ArrayFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.defs.{ArraySlice, F32, F64, I32, In, MakeArray, NA, Str} import is.hail.types.virtual._ diff --git a/hail/hail/test/src/is/hail/expr/ir/CallFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/CallFunctionsSuite.scala index e63415f6ddf..bd0dae68f6a 100644 --- a/hail/hail/test/src/is/hail/expr/ir/CallFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/CallFunctionsSuite.scala @@ -3,6 +3,7 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} import is.hail.expr.ir.TestUtils.IRCall import is.hail.expr.ir.defs.{False, I32, Str, True} +import is.hail.expr.ir.TestUtils.{IRArray, IRCall} import is.hail.types.virtual.{TArray, TBoolean, TCall, TInt32} import is.hail.variant._ @@ -61,7 +62,7 @@ class CallFunctionsSuite extends HailSuite { assertEvalsTo(invoke("Call", TCall, I32(1), False()), Call1(1, false)) assertEvalsTo(invoke("Call", TCall, I32(0), I32(0), False()), Call2(0, 0, false)) assertEvalsTo( - invoke("Call", TCall, TestUtils.IRArray(0, 1), False()), + invoke("Call", TCall, IRArray(0, 1), False()), CallN(Array(0, 1), false), ) assertEvalsTo(invoke("Call", TCall, Str("0|1")), Call2(0, 1, true)) diff --git a/hail/hail/test/src/is/hail/expr/ir/DictFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/DictFunctionsSuite.scala index b211701f942..b780c67cc7d 100644 --- a/hail/hail/test/src/is/hail/expr/ir/DictFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/DictFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.defs.{NA, ToSet, ToStream} import is.hail.types.virtual._ diff --git a/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala b/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala index 7ca4f22567f..acf32fc65ee 100644 --- a/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala @@ -1,10 +1,8 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ -import is.hail.annotations.{Region, SafeRow, ScalaToRegionValue} +import is.hail.annotations.{Region, RegionPool, SafeRow, ScalaToRegionValue} import is.hail.asm4s._ -import is.hail.backend.ExecuteContext import is.hail.expr.ir.agg.{CollectStateSig, PhysicalAggSig, TypedStateSig} import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs._ @@ -1064,14 +1062,14 @@ class EmitStreamSuite extends HailSuite { def assertMemoryDoesNotScaleWithStreamSize(lowSize: Int = 50, highSize: Int = 2500)(f: IR => IR) : Unit = { - val memUsed1 = ExecuteContext.scoped { ctx => - eval(f(lowSize), Env.empty, FastSeq(), None, None, false, ctx) - ctx.r.pool.getHighestTotalUsage + val memUsed1 = RegionPool.scoped { pool => + ctx.local(r = Region(pool = pool))(ctx => eval(f(lowSize), optimize = false, ctx = ctx)) + pool.getHighestTotalUsage } - val memUsed2 = ExecuteContext.scoped { ctx => - eval(f(highSize), Env.empty, FastSeq(), None, None, false, ctx) - ctx.r.pool.getHighestTotalUsage + val memUsed2 = RegionPool.scoped { pool => + ctx.local(r = Region(pool = pool))(ctx => eval(f(highSize), optimize = false, ctx = ctx)) + pool.getHighestTotalUsage } if (memUsed1 != memUsed2) diff --git a/hail/hail/test/src/is/hail/expr/ir/ForwardLetsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/ForwardLetsSuite.scala index b158fb2a506..5f39f4bd01f 100644 --- a/hail/hail/test/src/is/hail/expr/ir/ForwardLetsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/ForwardLetsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.HailSuite -import is.hail.TestUtils._ import is.hail.expr.Nat import is.hail.expr.ir.defs._ import is.hail.types.virtual._ diff --git a/hail/hail/test/src/is/hail/expr/ir/GenotypeFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/GenotypeFunctionsSuite.scala index 521922af0e7..f6ba5f123a3 100644 --- a/hail/hail/test/src/is/hail/expr/ir/GenotypeFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/GenotypeFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.types.virtual.TFloat64 import is.hail.utils.FastSeq diff --git a/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala b/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala index 9258c69cb85..0709c4a39e3 100644 --- a/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala @@ -2,10 +2,10 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} import is.hail.ExecStrategy.ExecStrategy -import is.hail.TestUtils._ -import is.hail.annotations.{BroadcastRow, ExtendedOrdering, SafeNDArray} +import is.hail.annotations.{BroadcastRow, ExtendedOrdering, Region, RegionPool, SafeNDArray} import is.hail.backend.ExecuteContext import is.hail.expr.Nat +import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.agg._ import is.hail.expr.ir.defs._ import is.hail.expr.ir.defs.ArrayZipBehavior.ArrayZipBehavior @@ -1403,11 +1403,11 @@ class IRSuite extends HailSuite { assertEvalsTo(StreamLen(zip(ArrayZipBehavior.AssumeSameLength, range8, range8)), 8) // https://github.com/hail-is/hail/issues/8359 - is.hail.TestUtils.assertThrows[HailException]( + assertThrows[HailException]( zipToTuple(ArrayZipBehavior.AssertSameLength, range6, range8): IR, "zip: length mismatch": String, ) - is.hail.TestUtils.assertThrows[HailException]( + assertThrows[HailException]( zipToTuple(ArrayZipBehavior.AssertSameLength, range12, lit6): IR, "zip: length mismatch": String, ) @@ -1562,7 +1562,7 @@ class IRSuite extends HailSuite { val na = NA(TDict(TInt32, TString)) assertEvalsTo(LowerBoundOnOrderedCollection(na, I32(0), onKey = true), null) - val dwna = TestUtils.IRDict((1, 3), (3, null), (null, 5)) + val dwna = IRDict((1, 3), (3, null), (null, 5)) assertEvalsTo(LowerBoundOnOrderedCollection(dwna, I32(-1), onKey = true), 0) assertEvalsTo(LowerBoundOnOrderedCollection(dwna, I32(1), onKey = true), 0) assertEvalsTo(LowerBoundOnOrderedCollection(dwna, I32(2), onKey = true), 1) @@ -1570,7 +1570,7 @@ class IRSuite extends HailSuite { assertEvalsTo(LowerBoundOnOrderedCollection(dwna, I32(5), onKey = true), 2) assertEvalsTo(LowerBoundOnOrderedCollection(dwna, NA(TInt32), onKey = true), 2) - val dwoutna = TestUtils.IRDict((1, 3), (3, null)) + val dwoutna = IRDict((1, 3), (3, null)) assertEvalsTo(LowerBoundOnOrderedCollection(dwoutna, I32(-1), onKey = true), 0) assertEvalsTo(LowerBoundOnOrderedCollection(dwoutna, I32(4), onKey = true), 2) assertEvalsTo(LowerBoundOnOrderedCollection(dwoutna, NA(TInt32), onKey = true), 2) @@ -1801,18 +1801,18 @@ class IRSuite extends HailSuite { @Test def testStreamFold(): Unit = { assertEvalsTo(foldIR(StreamRange(1, 2, 1), NA(TBoolean))((accum, elt) => IsNA(accum)), true) - assertEvalsTo(foldIR(TestUtils.IRStream(1, 2, 3), 0)((accum, elt) => accum + elt), 6) + assertEvalsTo(foldIR(IRStream(1, 2, 3), 0)((accum, elt) => accum + elt), 6) assertEvalsTo( - foldIR(TestUtils.IRStream(1, 2, 3), NA(TInt32))((accum, elt) => accum + elt), + foldIR(IRStream(1, 2, 3), NA(TInt32))((accum, elt) => accum + elt), null, ) assertEvalsTo( - foldIR(TestUtils.IRStream(1, null, 3), NA(TInt32))((accum, elt) => accum + elt), + foldIR(IRStream(1, null, 3), NA(TInt32))((accum, elt) => accum + elt), null, ) - assertEvalsTo(foldIR(TestUtils.IRStream(1, null, 3), 0)((accum, elt) => accum + elt), null) + assertEvalsTo(foldIR(IRStream(1, null, 3), 0)((accum, elt) => accum + elt), null) assertEvalsTo( - foldIR(TestUtils.IRStream(1, null, 3), NA(TInt32))((accum, elt) => I32(5) + I32(5)), + foldIR(IRStream(1, null, 3), NA(TInt32))((accum, elt) => I32(5) + I32(5)), 10, ) } @@ -1846,15 +1846,15 @@ class IRSuite extends HailSuite { FastSeq(null, true, false, false), ) assertEvalsTo( - ToArray(streamScanIR(TestUtils.IRStream(1, 2, 3), 0)((accum, elt) => accum + elt)), + ToArray(streamScanIR(IRStream(1, 2, 3), 0)((accum, elt) => accum + elt)), FastSeq(0, 1, 3, 6), ) assertEvalsTo( - ToArray(streamScanIR(TestUtils.IRStream(1, 2, 3), NA(TInt32))((accum, elt) => accum + elt)), + ToArray(streamScanIR(IRStream(1, 2, 3), NA(TInt32))((accum, elt) => accum + elt)), FastSeq(null, null, null, null), ) assertEvalsTo( - ToArray(streamScanIR(TestUtils.IRStream(1, null, 3), NA(TInt32))((accum, elt) => + ToArray(streamScanIR(IRStream(1, null, 3), NA(TInt32))((accum, elt) => accum + elt )), FastSeq(null, null, null, null), @@ -3252,7 +3252,7 @@ class IRSuite extends HailSuite { @DataProvider(name = "valueIRs") def valueIRs(): Array[Array[Object]] = - ExecuteContext.scoped(ctx => valueIRs(ctx)) + valueIRs(ctx) def valueIRs(ctx: ExecuteContext): Array[Array[Object]] = { val fs = ctx.fs @@ -3314,8 +3314,7 @@ class IRSuite extends HailSuite { val table = TableRange(100, 10) val mt = MatrixIR.range(20, 2, Some(3)) - val vcf = is.hail.TestUtils.importVCF(ctx, getTestResource("sample.vcf")) - + val vcf = importVCF(ctx, getTestResource("sample.vcf")) val bgenReader = MatrixBGENReader( ctx, FastSeq(getTestResource("example.8bits.bgen")), @@ -3597,7 +3596,7 @@ class IRSuite extends HailSuite { @DataProvider(name = "tableIRs") def tableIRs(): Array[Array[TableIR]] = - ExecuteContext.scoped(ctx => tableIRs(ctx)) + tableIRs(ctx) def tableIRs(ctx: ExecuteContext): Array[Array[TableIR]] = { try { @@ -3706,7 +3705,7 @@ class IRSuite extends HailSuite { @DataProvider(name = "matrixIRs") def matrixIRs(): Array[Array[MatrixIR]] = - ExecuteContext.scoped(ctx => matrixIRs(ctx)) + matrixIRs(ctx) def matrixIRs(ctx: ExecuteContext): Array[Array[MatrixIR]] = { try { @@ -3730,7 +3729,7 @@ class IRSuite extends HailSuite { val read = MatrixIR.read(fs, getTestResource("backward_compatability/1.0.0/matrix_table/0.hmt")) val range = MatrixIR.range(3, 7, None) - val vcf = is.hail.TestUtils.importVCF(ctx, getTestResource("sample.vcf")) + val vcf = importVCF(ctx, getTestResource("sample.vcf")) val bgenReader = MatrixBGENReader( ctx, @@ -3909,20 +3908,13 @@ class IRSuite extends HailSuite { try { val bm = BlockMatrixRandom(0, gaussian = true, shape = Array(5L, 6L), blockSize = 3) - backend.withExecuteContext { ctx => - ctx.local(blockMatrixCache = cache) { ctx => - backend.execute(ctx, BlockMatrixWrite(bm, BlockMatrixPersistWriter("x", "MEMORY_ONLY"))) - } - } + ctx.local(blockMatrixCache = cache) { ctx => + backend.execute(ctx, BlockMatrixWrite(bm, BlockMatrixPersistWriter("x", "MEMORY_ONLY"))) - backend.withExecuteContext { ctx => - ctx.local(blockMatrixCache = cache) { ctx => - val persist = BlockMatrixRead(BlockMatrixPersistReader("x", bm.typ)) - - val s = Pretty.sexprStyle(persist, elideLiterals = false) - val x2 = IRParser.parse_blockmatrix_ir(ctx, s) - assert(x2 == persist) - } + val persist = BlockMatrixRead(BlockMatrixPersistReader("x", bm.typ)) + val s = Pretty.sexprStyle(persist, elideLiterals = false) + val x2 = IRParser.parse_blockmatrix_ir(ctx, s) + assert(x2 == persist) } } finally cache.values.foreach(_.unpersist()) @@ -3931,18 +3923,14 @@ class IRSuite extends HailSuite { @Test def testCachedIR(): Unit = { val cached = Literal(TSet(TInt32), Set(1)) val s = s"(JavaIR 1)" - val x2 = ExecuteContext.scoped { ctx => - ctx.local(irCache = mutable.Map(1 -> cached))(ctx => IRParser.parse_value_ir(ctx, s)) - } + val x2 = ctx.local(irCache = mutable.Map(1 -> cached))(ctx => IRParser.parse_value_ir(ctx, s)) assert(x2 eq cached) } @Test def testCachedTableIR(): Unit = { val cached = TableRange(1, 1) val s = s"(JavaTable 1)" - val x2 = ExecuteContext.scoped { ctx => - ctx.local(irCache = mutable.Map(1 -> cached))(ctx => IRParser.parse_table_ir(ctx, s)) - } + val x2 = ctx.local(irCache = mutable.Map(1 -> cached))(ctx => IRParser.parse_table_ir(ctx, s)) assert(x2 eq cached) } @@ -4224,16 +4212,18 @@ class IRSuite extends HailSuite { val startingArg = SafeNDArray(IndexedSeq[Long](4L, 4L), (0 until 16).toFastSeq) - var memUsed = 0L - - ExecuteContext.scoped { ctx => - eval(ndSum, Env.empty, FastSeq(2 -> TInt32, startingArg -> ndType), None, None, true, ctx) - memUsed = ctx.r.pool.getHighestTotalUsage + val memUsed = RegionPool.scoped { pool => + ctx.local(r = Region(pool = pool)) { ctx => + eval(ndSum, Env.empty, FastSeq(2 -> TInt32, startingArg -> ndType), None, None, true, ctx) + pool.getHighestTotalUsage + } } - ExecuteContext.scoped { ctx => - eval(ndSum, Env.empty, FastSeq(100 -> TInt32, startingArg -> ndType), None, None, true, ctx) - assert(memUsed == ctx.r.pool.getHighestTotalUsage) + RegionPool.scoped { pool => + ctx.local(r = Region(pool = pool)) { ctx => + eval(ndSum, Env.empty, FastSeq(100 -> TInt32, startingArg -> ndType), None, None, true, ctx) + assert(memUsed == pool.getHighestTotalUsage) + } } } diff --git a/hail/hail/test/src/is/hail/expr/ir/IntervalSuite.scala b/hail/hail/test/src/is/hail/expr/ir/IntervalSuite.scala index 243c6a92a32..f868039fd29 100644 --- a/hail/hail/test/src/is/hail/expr/ir/IntervalSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/IntervalSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.defs.{ ErrorIDs, False, GetTupleElement, I32, If, Literal, MakeTuple, NA, True, } diff --git a/hail/hail/test/src/is/hail/expr/ir/MathFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/MathFunctionsSuite.scala index 6ec3d983de6..309ae863d59 100644 --- a/hail/hail/test/src/is/hail/expr/ir/MathFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/MathFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{stats, ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.defs.{ErrorIDs, F32, F64, False, I32, I64, Str, True} import is.hail.types.virtual._ import is.hail.utils._ diff --git a/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala b/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala index 5e6b03bb47c..9ae90fa1ebc 100644 --- a/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala @@ -2,7 +2,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} import is.hail.ExecStrategy.ExecStrategy -import is.hail.TestUtils._ import is.hail.annotations.BroadcastRow import is.hail.expr.JSONAnnotationImpex import is.hail.expr.ir.TestUtils._ @@ -381,7 +380,7 @@ class MatrixIRSuite extends HailSuite { } @Test def testMatrixMultiWriteDifferentTypesRaisesError(): Unit = { - val vcf = is.hail.TestUtils.importVCF(ctx, getTestResource("sample.vcf")) + val vcf = importVCF(ctx, getTestResource("sample.vcf")) val range = rangeMatrix(10, 2, None) val path1 = ctx.createTmpPath("test1") val path2 = ctx.createTmpPath("test2") diff --git a/hail/hail/test/src/is/hail/expr/ir/MemoryLeakSuite.scala b/hail/hail/test/src/is/hail/expr/ir/MemoryLeakSuite.scala index 90cb2d4713f..5fa628f8312 100644 --- a/hail/hail/test/src/is/hail/expr/ir/MemoryLeakSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/MemoryLeakSuite.scala @@ -1,9 +1,10 @@ package is.hail.expr.ir import is.hail.HailSuite -import is.hail.TestUtils.eval import is.hail.backend.ExecuteContext import is.hail.expr.ir.defs.{Literal, ToArray, ToStream} +import is.hail.annotations.{Region, RegionPool} +import is.hail.expr.ir import is.hail.types.virtual.{TArray, TBoolean, TSet, TString} import is.hail.utils._ @@ -17,19 +18,22 @@ class MemoryLeakSuite extends HailSuite { def run(size: Int): Long = { val lit = Literal(TSet(TString), (0 until litSize).map(_.toString).toSet) val queries = Literal(TArray(TString), (0 until size).map(_.toString).toFastSeq) - ExecuteContext.scoped { ctx => - eval( - ToArray( - mapIR(ToStream(queries))(r => invoke("contains", TBoolean, lit, r)) - ), - Env.empty, - FastSeq(), - None, - None, - false, - ctx, - ) - ctx.r.pool.getHighestTotalUsage + RegionPool.scoped { pool => + ctx.local(r = Region(pool = pool)) { ctx => + eval( + ToArray( + mapIR(ToStream(queries))(r => ir.invoke("contains", TBoolean, lit, r)) + ), + Env.empty, + FastSeq(), + None, + None, + false, + ctx, + ) + } + + pool.getHighestTotalUsage } } diff --git a/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala b/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala index 1f41d54260c..c7827fa214e 100644 --- a/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.annotations._ import is.hail.asm4s._ import is.hail.check.{Gen, Prop} diff --git a/hail/hail/test/src/is/hail/expr/ir/RequirednessSuite.scala b/hail/hail/test/src/is/hail/expr/ir/RequirednessSuite.scala index e62319bc968..a4e02032bb8 100644 --- a/hail/hail/test/src/is/hail/expr/ir/RequirednessSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/RequirednessSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.HailSuite -import is.hail.backend.ExecuteContext import is.hail.expr.Nat import is.hail.expr.ir.agg.CallStatsState import is.hail.expr.ir.defs._ @@ -104,7 +103,7 @@ class RequirednessSuite extends HailSuite { def pinterval(point: PType, r: Boolean): PInterval = PCanonicalInterval(point, r) @DataProvider(name = "valueIR") - def valueIR(): Array[Array[Any]] = ExecuteContext.scoped { ctx => + def valueIR(): Array[Array[Any]] = { val nodes = new BoxedArrayBuilder[Array[Any]](50) val allRequired = Array( diff --git a/hail/hail/test/src/is/hail/expr/ir/SetFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/SetFunctionsSuite.scala index dd3ec81340e..e1bfbbddc7c 100644 --- a/hail/hail/test/src/is/hail/expr/ir/SetFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/SetFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.defs.{I32, NA, ToSet, ToStream} import is.hail.types.virtual._ diff --git a/hail/hail/test/src/is/hail/expr/ir/SimplifySuite.scala b/hail/hail/test/src/is/hail/expr/ir/SimplifySuite.scala index a169e56dccc..8d9c78bd747 100644 --- a/hail/hail/test/src/is/hail/expr/ir/SimplifySuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/SimplifySuite.scala @@ -3,6 +3,7 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} import is.hail.expr.ir.TestUtils.IRAggCount import is.hail.expr.ir.defs._ +import is.hail.expr.ir.TestUtils._ import is.hail.types.virtual._ import is.hail.utils.{FastSeq, Interval} import is.hail.variant.Locus diff --git a/hail/hail/test/src/is/hail/expr/ir/StringFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/StringFunctionsSuite.scala index 09211f5cd83..c5ad27273c2 100644 --- a/hail/hail/test/src/is/hail/expr/ir/StringFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/StringFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.defs.{F32, I32, I64, MakeTuple, NA, Str} import is.hail.types.virtual._ diff --git a/hail/hail/test/src/is/hail/expr/ir/StringSliceSuite.scala b/hail/hail/test/src/is/hail/expr/ir/StringSliceSuite.scala index d532f9f9ed9..697747dfbf7 100644 --- a/hail/hail/test/src/is/hail/expr/ir/StringSliceSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/StringSliceSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.defs.{I32, In, NA, Str} import is.hail.types.virtual.TString import is.hail.utils._ diff --git a/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala b/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala index 62901f74263..39974a502a5 100644 --- a/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala @@ -2,7 +2,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} import is.hail.ExecStrategy.ExecStrategy -import is.hail.TestUtils._ import is.hail.annotations.SafeNDArray import is.hail.expr.Nat import is.hail.expr.ir.TestUtils._ diff --git a/hail/hail/test/src/is/hail/expr/ir/TrapNodeSuite.scala b/hail/hail/test/src/is/hail/expr/ir/TrapNodeSuite.scala index dc6c42b31cf..f08a632e4e1 100644 --- a/hail/hail/test/src/is/hail/expr/ir/TrapNodeSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/TrapNodeSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.defs.{ArrayRef, Die, GetTupleElement, I32, If, IsNA, Literal, Str, Trap} import is.hail.types.virtual._ import is.hail.utils._ diff --git a/hail/hail/test/src/is/hail/expr/ir/UtilFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/UtilFunctionsSuite.scala index 3c7eb8636ae..cc02aa909a4 100644 --- a/hail/hail/test/src/is/hail/expr/ir/UtilFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/UtilFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.defs.{Die, False, MakeStream, NA, Str, True} import is.hail.types.virtual.{TBoolean, TInt32, TStream} diff --git a/hail/hail/test/src/is/hail/expr/ir/analyses/SemanticHashSuite.scala b/hail/hail/test/src/is/hail/expr/ir/analyses/SemanticHashSuite.scala index 68ba43506f2..e31431c6e2f 100644 --- a/hail/hail/test/src/is/hail/expr/ir/analyses/SemanticHashSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/analyses/SemanticHashSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir.analyses import is.hail.{HAIL_PRETTY_VERSION, HailSuite} -import is.hail.backend.ExecuteContext import is.hail.expr.ir._ import is.hail.expr.ir.defs._ import is.hail.io.fs.{FS, FakeFS, FakeURL, FileListEntry} @@ -292,12 +291,14 @@ class SemanticHashSuite extends HailSuite { @Test(dataProvider = "isBaseIRSemanticallyEquivalent") def testSemanticEquivalence(a: BaseIR, b: BaseIR, isEqual: Boolean, comment: String): Unit = - assertResult( - isEqual, - s"expected semhash($a) ${if (isEqual) "==" else "!="} semhash($b), $comment", - )( - semhash(fakeFs)(a) == semhash(fakeFs)(b) - ) + ctx.local(fs = fakeFs) { ctx => + assertResult( + isEqual, + s"expected semhash($a) ${if (isEqual) "==" else "!="} semhash($b), $comment", + )( + SemanticHash(ctx)(a) == SemanticHash(ctx)(b) + ) + } @Test def testFileNotFoundExceptions(): Unit = { @@ -309,14 +310,13 @@ class SemanticHashSuite extends HailSuite { val ir = importMatrix("gs://fake-bucket/fake-matrix") - assertResult(None, "SemHash should be resilient to FileNotFoundExceptions.")( - semhash(fs)(ir) - ) + ctx.local(fs = fs) { ctx => + assertResult(None, "SemHash should be resilient to FileNotFoundExceptions.")( + SemanticHash(ctx)(ir) + ) + } } - def semhash(fs: FS)(ir: BaseIR): Option[SemanticHash.Type] = - ExecuteContext.scoped(_.local(fs = fs)(SemanticHash(_)(ir))) - val fakeFs: FS = new FakeFS { override def eTag(url: FakeURL): Option[String] = diff --git a/hail/hail/test/src/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala b/hail/hail/test/src/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala index 6bd31b6d4b4..72d5c29a580 100644 --- a/hail/hail/test/src/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala @@ -7,6 +7,7 @@ import is.hail.expr.ir.{ import is.hail.expr.ir.defs.{ Apply, ErrorIDs, GetField, I32, Literal, MakeStruct, Ref, SelectFields, ToArray, ToStream, } +import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.lowering.LowerDistributedSort.samplePartition import is.hail.types.RTable import is.hail.types.virtual.{TArray, TInt32, TStruct} @@ -15,7 +16,7 @@ import is.hail.utils.FastSeq import org.apache.spark.sql.Row import org.testng.annotations.Test -class LowerDistributedSortSuite extends HailSuite { +class LowerDistributedSortSuite extends HailSuite with TestUtils { implicit val execStrats = ExecStrategy.compileOnly @Test def testSamplePartition(): Unit = { @@ -71,14 +72,14 @@ class LowerDistributedSortSuite extends HailSuite { val sortedTs = LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt) .lower(ctx, myTable.typ.copy(key = FastSeq())) val res = - TestUtils.eval(sortedTs.mapCollect("test")(x => ToArray(x))).asInstanceOf[IndexedSeq[ + eval(sortedTs.mapCollect("test")(x => ToArray(x))).asInstanceOf[IndexedSeq[ IndexedSeq[Row] ]].flatten val rowFunc = myTable.typ.rowType.select(sortFields.map(_.field))._2 - val unsortedCollect = is.hail.expr.ir.TestUtils.collect(myTable) + val unsortedCollect = collect(myTable) val unsortedAnalyses = LoweringAnalyses.apply(unsortedCollect, ctx) - val unsorted = TestUtils.eval(LowerTableIR.apply( + val unsorted = eval(LowerTableIR.apply( unsortedCollect, DArrayLowering.All, ctx, diff --git a/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala b/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala index df686e1b7ed..25dca918bf2 100644 --- a/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala @@ -1,14 +1,13 @@ package is.hail.expr.ir.table import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils.loweredExecute -import is.hail.backend.ExecuteContext import is.hail.expr.ir._ import is.hail.expr.ir.TestUtils.IRAggCollect import is.hail.expr.ir.defs.{ ApplyBinaryPrimOp, ErrorIDs, GetField, MakeStream, MakeStruct, Ref, Str, StreamRange, TableAggregate, TableGetGlobals, } +import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.lowering.{DArrayLowering, LowerTableIR} import is.hail.rvd.RVDPartitioner import is.hail.types.virtual._ @@ -99,7 +98,7 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("lowering")) def testLowering(): Unit = { - val table = TestUtils.collect(mkTableGen()) + val table = collect(mkTableGen()) val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) assertEvalsTo(lowered, Row(FastSeq(0, 0).map(Row(_)), Row(0))) } @@ -107,13 +106,13 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("lowering")) def testNumberOfContextsMatchesPartitions(): Unit = { val errorId = 42 - val table = TestUtils.collect(mkTableGen( + val table = collect(mkTableGen( partitioner = Some(RVDPartitioner.unkeyed(ctx.stateManager, 0)), errorId = Some(errorId), )) val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) val ex = intercept[HailException] { - ExecuteContext.scoped(ctx => loweredExecute(ctx, lowered, Env.empty, FastSeq(), None)) + loweredExecute(ctx, lowered, Env.empty, FastSeq(), None) } ex.errorId shouldBe errorId ex.getMessage should include("partitioner contains 0 partitions, got 2 contexts.") @@ -122,7 +121,7 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("lowering")) def testRowsAreCorrectlyKeyed(): Unit = { val errorId = 56 - val table = TestUtils.collect(mkTableGen( + val table = collect(mkTableGen( partitioner = Some(new RVDPartitioner( ctx.stateManager, TStruct("a" -> TInt32), @@ -135,7 +134,7 @@ class TableGenSuite extends HailSuite { )) val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) val ex = intercept[SparkException] { - ExecuteContext.scoped(ctx => loweredExecute(ctx, lowered, Env.empty, FastSeq(), None)) + loweredExecute(ctx, lowered, Env.empty, FastSeq(), None) }.getCause.asInstanceOf[HailException] ex.errorId shouldBe errorId diff --git a/hail/hail/test/src/is/hail/io/compress/BGzipCodecSuite.scala b/hail/hail/test/src/is/hail/io/compress/BGzipCodecSuite.scala index 7f85df6d302..a310abdb5bc 100644 --- a/hail/hail/test/src/is/hail/io/compress/BGzipCodecSuite.scala +++ b/hail/hail/test/src/is/hail/io/compress/BGzipCodecSuite.scala @@ -1,7 +1,6 @@ package is.hail.io.compress import is.hail.HailSuite -import is.hail.TestUtils._ import is.hail.check.Gen import is.hail.check.Prop.forAll import is.hail.expr.ir.GenericLines diff --git a/hail/hail/test/src/is/hail/io/fs/FSSuite.scala b/hail/hail/test/src/is/hail/io/fs/FSSuite.scala index 3173b646b06..fa834fd5de4 100644 --- a/hail/hail/test/src/is/hail/io/fs/FSSuite.scala +++ b/hail/hail/test/src/is/hail/io/fs/FSSuite.scala @@ -12,13 +12,10 @@ import org.apache.hadoop.fs.FileAlreadyExistsException import org.scalatestplus.testng.TestNGSuite import org.testng.annotations.Test -trait FSSuite extends TestNGSuite { +trait FSSuite extends TestNGSuite with TestUtils { val root: String = System.getenv("HAIL_TEST_STORAGE_URI") - def fsResourcesRoot: String = System.getenv("HAIL_FS_TEST_CLOUD_RESOURCES_URI") - def tmpdir: String = System.getenv("HAIL_TEST_STORAGE_URI") - def fs: FS /* Structure of src/test/resources/fs: @@ -73,7 +70,7 @@ trait FSSuite extends TestNGSuite { @Test def testFileStatusOnDirIsFailure(): Unit = { val f = r("/dir") - TestUtils.interceptException[FileNotFoundException](f)( + interceptException[FileNotFoundException](f)( fs.fileStatus(f) ) } diff --git a/hail/hail/test/src/is/hail/linalg/BlockMatrixSuite.scala b/hail/hail/test/src/is/hail/linalg/BlockMatrixSuite.scala index 04a3a8387f4..da85781d8b2 100644 --- a/hail/hail/test/src/is/hail/linalg/BlockMatrixSuite.scala +++ b/hail/hail/test/src/is/hail/linalg/BlockMatrixSuite.scala @@ -1,6 +1,7 @@ package is.hail.linalg import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.check._ import is.hail.check.Arbitrary._ import is.hail.check.Gen._ @@ -926,9 +927,9 @@ class BlockMatrixSuite extends HailSuite { val bm = BlockMatrix.fromBreezeMatrix(lm, blockSize = 2) val expected = new BDM[Double](2, 3, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) - TestUtils.assertMatrixEqualityDouble(bm.pow(0.0).toBreezeMatrix(), BDM.fill(2, 3)(1.0)) - TestUtils.assertMatrixEqualityDouble(bm.pow(0.5).toBreezeMatrix(), expected) - TestUtils.assertMatrixEqualityDouble(bm.sqrt().toBreezeMatrix(), expected) + assertMatrixEqualityDouble(bm.pow(0.0).toBreezeMatrix(), BDM.fill(2, 3)(1.0)) + assertMatrixEqualityDouble(bm.pow(0.5).toBreezeMatrix(), expected) + assertMatrixEqualityDouble(bm.sqrt().toBreezeMatrix(), expected) } def filteredEquals(bm1: BlockMatrix, bm2: BlockMatrix): Boolean = @@ -1219,17 +1220,17 @@ class BlockMatrixSuite extends HailSuite { val v0 = Array(0.0, Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity) - TestUtils.interceptFatal(notSupported)(bm0 / bm0) - TestUtils.interceptFatal(notSupported)(bm0.reverseRowVectorDiv(v)) - TestUtils.interceptFatal(notSupported)(bm0.reverseColVectorDiv(v)) - TestUtils.interceptFatal(notSupported)(1 / bm0) + interceptFatal(notSupported)(bm0 / bm0) + interceptFatal(notSupported)(bm0.reverseRowVectorDiv(v)) + interceptFatal(notSupported)(bm0.reverseColVectorDiv(v)) + interceptFatal(notSupported)(1 / bm0) - TestUtils.interceptFatal(notSupported)(bm0.rowVectorDiv(v0)) - TestUtils.interceptFatal(notSupported)(bm0.colVectorDiv(v0)) - TestUtils.interceptFatal("multiplication by scalar NaN")(bm0 * Double.NaN) - TestUtils.interceptFatal("division by scalar 0.0")(bm0 / 0) + interceptFatal(notSupported)(bm0.rowVectorDiv(v0)) + interceptFatal(notSupported)(bm0.colVectorDiv(v0)) + interceptFatal("multiplication by scalar NaN")(bm0 * Double.NaN) + interceptFatal("division by scalar 0.0")(bm0 / 0) - TestUtils.interceptFatal(notSupported)(bm0.pow(-1)) + interceptFatal(notSupported)(bm0.pow(-1)) } @Test diff --git a/hail/hail/test/src/is/hail/methods/LocalLDPruneSuite.scala b/hail/hail/test/src/is/hail/methods/LocalLDPruneSuite.scala index 63ce2cd71c1..85d007d753f 100644 --- a/hail/hail/test/src/is/hail/methods/LocalLDPruneSuite.scala +++ b/hail/hail/test/src/is/hail/methods/LocalLDPruneSuite.scala @@ -1,6 +1,6 @@ package is.hail.methods -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.annotations.Annotation import is.hail.check.{Gen, Properties} import is.hail.expr.ir.{Interpret, MatrixValue, TableValue} @@ -110,7 +110,7 @@ class LocalLDPruneSuite extends HailSuite { val nCores = 4 lazy val mt = Interpret( - TestUtils.importVCF(ctx, getTestResource("sample.vcf.bgz"), nPartitions = Option(10)), + importVCF(ctx, getTestResource("sample.vcf.bgz"), nPartitions = Option(10)), ctx, false, ).toMatrixValue(Array("s")) diff --git a/hail/hail/test/src/is/hail/methods/SkatSuite.scala b/hail/hail/test/src/is/hail/methods/SkatSuite.scala index a9b5e723f21..59ca348b593 100644 --- a/hail/hail/test/src/is/hail/methods/SkatSuite.scala +++ b/hail/hail/test/src/is/hail/methods/SkatSuite.scala @@ -1,6 +1,6 @@ package is.hail.methods -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.expr.ir.DoubleArrayBuilder import is.hail.utils._ @@ -31,6 +31,6 @@ class SkatSuite extends HailSuite { val (qLarge, gramianLarge) = Skat.computeGramianLargeN(st) assert(D_==(qSmall, qLarge)) - TestUtils.assertMatrixEqualityDouble(gramianSmall, gramianLarge) + assertMatrixEqualityDouble(gramianSmall, gramianLarge) } } diff --git a/hail/hail/test/src/is/hail/stats/eigSymDSuite.scala b/hail/hail/test/src/is/hail/stats/eigSymDSuite.scala index 73bcb32e9e4..3a51cb32f4a 100644 --- a/hail/hail/test/src/is/hail/stats/eigSymDSuite.scala +++ b/hail/hail/test/src/is/hail/stats/eigSymDSuite.scala @@ -1,6 +1,6 @@ package is.hail.stats -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.utils._ import breeze.linalg.{eigSym, svd, DenseMatrix, DenseVector} @@ -109,7 +109,7 @@ class eigSymDSuite extends HailSuite { val x = DenseVector.fill[Double](n)(rand.nextGaussian()) - TestUtils.assertVectorEqualityDouble(x, TriSolve(A, A * x)) + assertVectorEqualityDouble(x, TriSolve(A, A * x)) } } } diff --git a/hail/hail/test/src/is/hail/utils/RichArraySuite.scala b/hail/hail/test/src/is/hail/utils/RichArraySuite.scala index df1a5809f51..b7793450052 100644 --- a/hail/hail/test/src/is/hail/utils/RichArraySuite.scala +++ b/hail/hail/test/src/is/hail/utils/RichArraySuite.scala @@ -1,6 +1,6 @@ package is.hail.utils -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.utils.richUtils.RichArray import org.testng.annotations.Test @@ -15,7 +15,7 @@ class RichArraySuite extends HailSuite { RichArray.importFromDoubles(fs, file, a2, bufSize = 16) assert(a === a2) - TestUtils.interceptFatal("Premature") { + interceptFatal("Premature") { RichArray.importFromDoubles(fs, file, new Array[Double](101), bufSize = 64) } } diff --git a/hail/hail/test/src/is/hail/utils/RichDenseMatrixDoubleSuite.scala b/hail/hail/test/src/is/hail/utils/RichDenseMatrixDoubleSuite.scala index a8d05321bf4..282e71d77e0 100644 --- a/hail/hail/test/src/is/hail/utils/RichDenseMatrixDoubleSuite.scala +++ b/hail/hail/test/src/is/hail/utils/RichDenseMatrixDoubleSuite.scala @@ -1,6 +1,6 @@ package is.hail.utils -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.linalg.BlockMatrix import is.hail.utils.richUtils.RichDenseMatrixDouble @@ -33,7 +33,7 @@ class RichDenseMatrixDoubleSuite extends HailSuite { val lmT2 = RichDenseMatrixDouble.importFromDoubles(fs, fileT, 100, 50, rowMajor = true) assert(mT === lmT2) - TestUtils.interceptFatal("Premature") { + interceptFatal("Premature") { RichDenseMatrixDouble.importFromDoubles(fs, fileT, 100, 100, rowMajor = true) } } diff --git a/hail/hail/test/src/is/hail/variant/GenotypeSuite.scala b/hail/hail/test/src/is/hail/variant/GenotypeSuite.scala index a2a057b30fc..1f7374d180d 100644 --- a/hail/hail/test/src/is/hail/variant/GenotypeSuite.scala +++ b/hail/hail/test/src/is/hail/variant/GenotypeSuite.scala @@ -8,8 +8,7 @@ import is.hail.utils._ import org.scalatestplus.testng.TestNGSuite import org.testng.annotations.Test -class GenotypeSuite extends TestNGSuite { - +class GenotypeSuite extends TestNGSuite with TestUtils { val v = Variant("1", 1, "A", "T") @Test def gtPairGtIndexIsId(): Unit = @@ -118,6 +117,6 @@ class GenotypeSuite extends TestNGSuite { assert(Call.parse("0|1") == Call2(0, 1, phased = true)) intercept[UnsupportedOperationException](Call.parse("1/1/1")) intercept[UnsupportedOperationException](Call.parse("1|1|1")) - TestUtils.interceptFatal("invalid call expression:")(Call.parse("0/")) + interceptFatal("invalid call expression:")(Call.parse("0/")) } } diff --git a/hail/hail/test/src/is/hail/variant/LocusIntervalSuite.scala b/hail/hail/test/src/is/hail/variant/LocusIntervalSuite.scala index 0816b4afa89..7371b2d7939 100644 --- a/hail/hail/test/src/is/hail/variant/LocusIntervalSuite.scala +++ b/hail/hail/test/src/is/hail/variant/LocusIntervalSuite.scala @@ -1,6 +1,6 @@ package is.hail.variant -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.utils._ import org.testng.annotations.Test @@ -161,11 +161,11 @@ class LocusIntervalSuite extends HailSuite { true, true, )) - TestUtils.interceptFatal("Start 'X:0' is not within the range")(Locus.parseInterval( + interceptFatal("Start 'X:0' is not within the range")(Locus.parseInterval( "[X:0-5)", rg, )) - TestUtils.interceptFatal(s"End 'X:${xMax + 1}' is not within the range")(Locus.parseInterval( + interceptFatal(s"End 'X:${xMax + 1}' is not within the range")(Locus.parseInterval( s"[X:1-${xMax + 1}]", rg, )) @@ -208,19 +208,19 @@ class LocusIntervalSuite extends HailSuite { false, )) - TestUtils.interceptFatal("invalid interval expression") { + interceptFatal("invalid interval expression") { Locus.parseInterval("4::start-5:end", rg) } - TestUtils.interceptFatal("invalid interval expression") { + interceptFatal("invalid interval expression") { Locus.parseInterval("4:start-", rg) } - TestUtils.interceptFatal("invalid interval expression") { + interceptFatal("invalid interval expression") { Locus.parseInterval("1:1.1111K-2k", rg) } - TestUtils.interceptFatal("invalid interval expression") { + interceptFatal("invalid interval expression") { Locus.parseInterval("1:1.1111111M-2M", rg) } diff --git a/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala b/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala index c579894a642..823fc3c8a1a 100644 --- a/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala +++ b/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala @@ -1,6 +1,7 @@ package is.hail.variant -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite +import is.hail.backend.HailStateManager import is.hail.check.Prop._ import is.hail.check.Properties import is.hail.expr.ir.EmitFunctionBuilder @@ -46,18 +47,18 @@ class ReferenceGenomeSuite extends HailSuite { } @Test def testAssertions(): Unit = { - TestUtils.interceptFatal("Must have at least one contig in the reference genome.")( + interceptFatal("Must have at least one contig in the reference genome.")( ReferenceGenome("test", Array.empty[String], Map.empty[String, Int]) ) - TestUtils.interceptFatal("No lengths given for the following contigs:")(ReferenceGenome( + interceptFatal("No lengths given for the following contigs:")(ReferenceGenome( "test", Array("1", "2", "3"), Map("1" -> 5), )) - TestUtils.interceptFatal("Contigs found in 'lengths' that are not present in 'contigs'")( + interceptFatal("Contigs found in 'lengths' that are not present in 'contigs'")( ReferenceGenome("test", Array("1", "2", "3"), Map("1" -> 5, "2" -> 5, "3" -> 5, "4" -> 100)) ) - TestUtils.interceptFatal("The following X contig names are absent from the reference:")( + interceptFatal("The following X contig names are absent from the reference:")( ReferenceGenome( "test", Array("1", "2", "3"), @@ -65,7 +66,7 @@ class ReferenceGenomeSuite extends HailSuite { xContigs = Set("X"), ) ) - TestUtils.interceptFatal("The following Y contig names are absent from the reference:")( + interceptFatal("The following Y contig names are absent from the reference:")( ReferenceGenome( "test", Array("1", "2", "3"), @@ -73,7 +74,7 @@ class ReferenceGenomeSuite extends HailSuite { yContigs = Set("Y"), ) ) - TestUtils.interceptFatal( + interceptFatal( "The following mitochondrial contig names are absent from the reference:" )(ReferenceGenome( "test", @@ -81,13 +82,13 @@ class ReferenceGenomeSuite extends HailSuite { Map("1" -> 5, "2" -> 5, "3" -> 5), mtContigs = Set("MT"), )) - TestUtils.interceptFatal("The contig name for PAR interval")(ReferenceGenome( + interceptFatal("The contig name for PAR interval")(ReferenceGenome( "test", Array("1", "2", "3"), Map("1" -> 5, "2" -> 5, "3" -> 5), parInput = Array((Locus("X", 1), Locus("X", 5))), )) - TestUtils.interceptFatal("in both X and Y contigs.")(ReferenceGenome( + interceptFatal("in both X and Y contigs.")(ReferenceGenome( "test", Array("1", "2", "3"), Map("1" -> 5, "2" -> 5, "3" -> 5), @@ -98,7 +99,7 @@ class ReferenceGenomeSuite extends HailSuite { @Test def testContigRemap(): Unit = { val mapping = Map("23" -> "foo") - TestUtils.interceptFatal("have remapped contigs in reference genome")( + interceptFatal("have remapped contigs in reference genome")( ctx.references(ReferenceGenome.GRCh37).validateContigRemap(mapping) ) } diff --git a/hail/python/hail/backend/local_backend.py b/hail/python/hail/backend/local_backend.py index 7bcb1145259..311a5c7e9bb 100644 --- a/hail/python/hail/backend/local_backend.py +++ b/hail/python/hail/backend/local_backend.py @@ -74,7 +74,6 @@ def __init__( hail_package = getattr(self._gateway.jvm, 'is').hail jbackend = hail_package.backend.local.LocalBackend.apply( - tmpdir, log, True, append, @@ -82,23 +81,14 @@ def __init__( ) jhc = hail_package.HailContext.apply(jbackend, branching_factor, optimizer_iterations) - super(LocalBackend, self).__init__(self._gateway.jvm, jbackend, jhc) + super().__init__(self._gateway.jvm, jbackend, jhc, tmpdir, tmpdir) + self.gcs_requester_pays_configuration = gcs_requester_pays_configuration self._fs = self._exit_stack.enter_context( RouterFS(gcs_kwargs={'gcs_requester_pays_configuration': gcs_requester_pays_configuration}) ) self._logger = None - - flags = {} - if gcs_requester_pays_configuration is not None: - if isinstance(gcs_requester_pays_configuration, str): - flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration - else: - assert isinstance(gcs_requester_pays_configuration, tuple) - flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration[0] - flags['gcs_requester_pays_buckets'] = ','.join(gcs_requester_pays_configuration[1]) - - self._initialize_flags(flags) + self._initialize_flags({}) def validate_file(self, uri: str) -> None: async_to_blocking(validate_file(uri, self._fs.afs)) diff --git a/hail/python/hail/backend/py4j_backend.py b/hail/python/hail/backend/py4j_backend.py index 2d787a0f339..e3d33e97c17 100644 --- a/hail/python/hail/backend/py4j_backend.py +++ b/hail/python/hail/backend/py4j_backend.py @@ -13,9 +13,11 @@ import hail from hail.expr import construct_expr +from hail.fs.hadoop_fs import HadoopFS from hail.ir import JavaIR from hail.utils.java import Env, FatalError, scala_package_object from hail.version import __version__ +from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration from ..hail_logging import Logger from .backend import ActionTag, Backend, fatal_error_from_java_error_triplet @@ -171,8 +173,15 @@ def parse(node): class Py4JBackend(Backend): @abc.abstractmethod - def __init__(self, jvm: JVMView, jbackend: JavaObject, jhc: JavaObject): - super(Py4JBackend, self).__init__() + def __init__( + self, + jvm: JVMView, + jbackend: JavaObject, + jhc: JavaObject, + tmpdir: str, + remote_tmpdir: str, + ): + super().__init__() import base64 def decode_bytearray(encoded): @@ -185,14 +194,19 @@ def decode_bytearray(encoded): self._jvm = jvm self._hail_package = getattr(self._jvm, 'is').hail self._utils_package_object = scala_package_object(self._hail_package.utils) - self._jbackend = jbackend self._jhc = jhc - self._backend_server = self._hail_package.backend.BackendServer(self._jbackend) - self._backend_server_port: int = self._backend_server.port() - self._backend_server.start() + self._jbackend = self._hail_package.backend.api.Py4JBackendApi(jbackend) + self._jbackend.pySetLocalTmp(tmpdir) + self._jbackend.pySetRemoteTmp(remote_tmpdir) + + self._jhttp_server = self._jbackend.pyHttpServer() + self._backend_server_port: int = self._jhttp_server.port() self._requests_session = requests.Session() + self._gcs_requester_pays_config = None + self._fs = None + # This has to go after creating the SparkSession. Unclear why. # Maybe it does its own patch? install_exception_handler() @@ -214,6 +228,23 @@ def hail_package(self): def utils_package_object(self): return self._utils_package_object + @property + def gcs_requester_pays_configuration(self) -> Optional[GCSRequesterPaysConfiguration]: + return self._gcs_requester_pays_config + + @gcs_requester_pays_configuration.setter + def gcs_requester_pays_configuration(self, config: Optional[GCSRequesterPaysConfiguration]): + self._gcs_requester_pays_config = config + project, buckets = (None, None) if config is None else (config, None) if isinstance(config, str) else config + self._jbackend.pySetGcsRequesterPaysConfig(project, buckets) + self._fs = None # stale + + @property + def fs(self): + if self._fs is None: + self._fs = HadoopFS(self._utils_package_object, self._jbackend.pyFs()) + return self._fs + @property def logger(self): if self._logger is None: @@ -288,7 +319,7 @@ def _to_java_blockmatrix_ir(self, ir): return self._parse_blockmatrix_ir(self._render_ir(ir)) def stop(self): - self._backend_server.close() + self._jhttp_server.close() self._jbackend.close() self._jhc.stop() self._jhc = None diff --git a/hail/python/hail/backend/spark_backend.py b/hail/python/hail/backend/spark_backend.py index 69f53292443..192ca57eb75 100644 --- a/hail/python/hail/backend/spark_backend.py +++ b/hail/python/hail/backend/spark_backend.py @@ -7,11 +7,11 @@ import pyspark.sql from hail.expr.table_type import ttable -from hail.fs.hadoop_fs import HadoopFS from hail.ir import BaseIR from hail.ir.renderer import CSERenderer from hail.table import Table from hail.utils import copy_log +from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration from hailtop.aiotools.router_fs import RouterAsyncFS from hailtop.aiotools.validators import validate_file from hailtop.utils import async_to_blocking @@ -47,12 +47,9 @@ def __init__( skip_logging_configuration, optimizer_iterations, *, - gcs_requester_pays_project: Optional[str] = None, - gcs_requester_pays_buckets: Optional[str] = None, + gcs_requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None, copy_log_on_error: bool = False, ): - assert gcs_requester_pays_project is not None or gcs_requester_pays_buckets is None - try: local_jar_info = local_jar_information() except ValueError: @@ -120,10 +117,6 @@ def __init__( append, skip_logging_configuration, min_block_size, - tmpdir, - local_tmpdir, - gcs_requester_pays_project, - gcs_requester_pays_buckets, ) jhc = hail_package.HailContext.getOrCreate(jbackend, branching_factor, optimizer_iterations) else: @@ -137,10 +130,6 @@ def __init__( append, skip_logging_configuration, min_block_size, - tmpdir, - local_tmpdir, - gcs_requester_pays_project, - gcs_requester_pays_buckets, ) jhc = hail_package.HailContext.apply(jbackend, branching_factor, optimizer_iterations) @@ -149,12 +138,12 @@ def __init__( self.sc = sc else: self.sc = pyspark.SparkContext(gateway=self._gateway, jsc=jvm.JavaSparkContext(self._jsc)) - self._jspark_session = jbackend.sparkSession() + self._jspark_session = jbackend.sparkSession().apply() self._spark_session = pyspark.sql.SparkSession(self.sc, self._jspark_session) - super(SparkBackend, self).__init__(jvm, jbackend, jhc) + super().__init__(jvm, jbackend, jhc, local_tmpdir, tmpdir) + self.gcs_requester_pays_configuration = gcs_requester_pays_config - self._fs = None self._logger = None if not quiet: @@ -167,7 +156,7 @@ def __init__( self._initialize_flags({}) self._router_async_fs = RouterAsyncFS( - gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_project} + gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_config} ) self._tmpdir = tmpdir @@ -181,12 +170,6 @@ def stop(self): self.sc.stop() self.sc = None - @property - def fs(self): - if self._fs is None: - self._fs = HadoopFS(self._utils_package_object, self._jbackend.fs()) - return self._fs - def from_spark(self, df, key): result_tuple = self._jbackend.pyFromDF(df._jdf, key) tir_id, type_json = result_tuple._1(), result_tuple._2() diff --git a/hail/python/hail/context.py b/hail/python/hail/context.py index 6c96599a24b..b3d72e2cef4 100644 --- a/hail/python/hail/context.py +++ b/hail/python/hail/context.py @@ -472,14 +472,10 @@ def init_spark( optimizer_iterations = get_env_or_default(_optimizer_iterations, 'HAIL_OPTIMIZER_ITERATIONS', 3) app_name = app_name or 'Hail' - ( - gcs_requester_pays_project, - gcs_requester_pays_buckets, - ) = convert_gcs_requester_pays_configuration_to_hadoop_conf_style( - get_gcs_requester_pays_configuration( - gcs_requester_pays_configuration=gcs_requester_pays_configuration, - ) + gcs_requester_pays_configuration = get_gcs_requester_pays_configuration( + gcs_requester_pays_configuration=gcs_requester_pays_configuration, ) + backend = SparkBackend( idempotent, sc, @@ -496,8 +492,7 @@ def init_spark( local_tmpdir, skip_logging_configuration, optimizer_iterations, - gcs_requester_pays_project=gcs_requester_pays_project, - gcs_requester_pays_buckets=gcs_requester_pays_buckets, + gcs_requester_pays_config=gcs_requester_pays_configuration, copy_log_on_error=copy_log_on_error, ) if not backend.fs.exists(tmpdir):