diff --git a/hail/hail/src/is/hail/HailContext.scala b/hail/hail/src/is/hail/HailContext.scala index 3de89ce13cd..f52efe60ee7 100644 --- a/hail/hail/src/is/hail/HailContext.scala +++ b/hail/hail/src/is/hail/HailContext.scala @@ -19,6 +19,7 @@ import org.apache.spark.executor.InputMetrics import org.apache.spark.rdd.RDD import org.json4s.Extraction import org.json4s.jackson.JsonMethods +import sourcecode.Enclosing case class FilePartition(index: Int, file: String) extends Partition @@ -41,7 +42,7 @@ object HailContext { def backend: Backend = get.backend - def sparkBackend(op: String): SparkBackend = get.sparkBackend(op) + def sparkBackend(implicit E: Enclosing): SparkBackend = get.backend.asSpark def configureLogging(logFile: String, quiet: Boolean, append: Boolean): Unit = { org.apache.log4j.helpers.LogLog.setInternalDebugging(true) @@ -152,7 +153,7 @@ object HailContext { val fsBc = fs.broadcast - new RDD[T](SparkBackend.sparkContext("readPartition"), Nil) { + new RDD[T](SparkBackend.sparkContext, Nil) { def getPartitions: Array[Partition] = Array.tabulate(nPartitions)(i => FilePartition(i, partFiles(i))) @@ -175,8 +176,6 @@ class HailContext private ( ) { def stop(): Unit = HailContext.stop() - def sparkBackend(op: String): SparkBackend = backend.asSpark(op) - var checkRVDKeys: Boolean = false def version: String = is.hail.HAIL_PRETTY_VERSION @@ -188,7 +187,7 @@ class HailContext private ( maxLines: Int, ): Map[String, Array[WithContext[String]]] = { val regexp = regex.r - SparkBackend.sparkContext("fileAndLineCounts").textFilesLines(fs.globAll(files).map(_.getPath)) + SparkBackend.sparkContext.textFilesLines(fs.globAll(files).map(_.getPath)) .filter(line => regexp.findFirstIn(line.value).isDefined) .take(maxLines) .groupBy(_.source.file) diff --git a/hail/hail/src/is/hail/backend/Backend.scala b/hail/hail/src/is/hail/backend/Backend.scala index 666b3c0d2b6..71a0ef4e75e 100644 --- a/hail/hail/src/is/hail/backend/Backend.scala +++ b/hail/hail/src/is/hail/backend/Backend.scala @@ -1,28 +1,20 @@ package is.hail.backend -import is.hail.asm4s._ -import is.hail.backend.Backend.jsonToBytes +import is.hail.asm4s.HailClassLoader import is.hail.backend.spark.SparkBackend -import is.hail.expr.ir.{IR, IRParser, LoweringAnalyses, SortField, TableIR, TableReader} +import is.hail.expr.ir.{IR, LoweringAnalyses, SortField, TableIR, TableReader} import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} import is.hail.io.{BufferSpec, TypedCodecSpec} -import is.hail.io.fs._ -import is.hail.io.plink.LoadPlink -import is.hail.io.vcf.LoadVCF -import is.hail.types._ +import is.hail.io.fs.FS +import is.hail.types.RTable import is.hail.types.encoded.EType import is.hail.types.physical.PTuple -import is.hail.types.virtual.TFloat64 -import is.hail.utils._ -import is.hail.variant.ReferenceGenome +import is.hail.utils.fatal import scala.reflect.ClassTag -import java.io._ -import java.nio.charset.StandardCharsets +import java.io.{Closeable, OutputStream} -import org.json4s._ -import org.json4s.jackson.JsonMethods import sourcecode.Enclosing object Backend { @@ -38,23 +30,19 @@ object Backend { ctx: ExecuteContext, t: PTuple, off: Long, - bufferSpecString: String, + bufferSpec: BufferSpec, os: OutputStream, ): Unit = { - val bs = BufferSpec.parseOrDefault(bufferSpecString) assert(t.size == 1) val elementType = t.fields(0).typ val codec = TypedCodecSpec( EType.fromPythonTypeEncoding(elementType.virtualType), elementType.virtualType, - bs, + bufferSpec, ) assert(t.isFieldDefined(off, 0)) codec.encode(ctx, elementType, t.loadField(off, 0), os) } - - def jsonToBytes(f: => JValue): Array[Byte] = - JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8) } abstract class BroadcastValue[T] { def value: T } @@ -82,8 +70,8 @@ abstract class Backend extends Closeable { f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] ): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) - def asSpark(op: String): SparkBackend = - fatal(s"${getClass.getSimpleName}: $op requires SparkBackend") + def asSpark(implicit E: Enclosing): SparkBackend = + fatal(s"${getClass.getSimpleName}: ${E.value} requires SparkBackend") def lowerDistributedSort( ctx: ExecuteContext, @@ -116,70 +104,7 @@ abstract class Backend extends Closeable { def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses) : TableStage - def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T - - final def valueType(s: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - IRParser.parse_value_ir(ctx, s).typ.toJSON - } - } - - final def tableType(s: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - IRParser.parse_table_ir(ctx, s).typ.toJSON - } - } - - final def matrixTableType(s: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - IRParser.parse_matrix_ir(ctx, s).typ.toJSON - } - } - - final def blockMatrixType(s: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - IRParser.parse_blockmatrix_ir(ctx, s).typ.toJSON - } - } - - def loadReferencesFromDataset(path: String): Array[Byte] - - def fromFASTAFile( - name: String, - fastaFile: String, - indexFile: String, - xContigs: Array[String], - yContigs: Array[String], - mtContigs: Array[String], - parInput: Array[String], - ): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, - xContigs, yContigs, mtContigs, parInput).toJSON - } - } - - def parseVCFMetadata(path: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - Extraction.decompose { - LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path) - }(defaultJSONFormats) - } - } - - def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String) - : Array[Byte] = - withExecuteContext { ctx => - LoadPlink.importFamJSON(ctx.fs, path, isQuantPheno, delimiter, missingValue).getBytes( - StandardCharsets.UTF_8 - ) - } - def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)] + + def backendContext(ctx: ExecuteContext): BackendContext } diff --git a/hail/hail/src/is/hail/backend/BackendRpc.scala b/hail/hail/src/is/hail/backend/BackendRpc.scala new file mode 100644 index 00000000000..834185cd0b7 --- /dev/null +++ b/hail/hail/src/is/hail/backend/BackendRpc.scala @@ -0,0 +1,259 @@ +package is.hail.backend + +import is.hail.expr.ir.IRParser +import is.hail.expr.ir.functions.IRFunctionRegistry +import is.hail.expr.ir.functions.IRFunctionRegistry.UserDefinedFnKey +import is.hail.io.BufferSpec +import is.hail.io.plink.LoadPlink +import is.hail.io.vcf.LoadVCF +import is.hail.services.retryTransientErrors +import is.hail.types.virtual.{Kind, TFloat64, VType} +import is.hail.types.virtual.Kinds._ +import is.hail.utils.{using, BoxedArrayBuilder, ExecutionTimer, FastSeq} +import is.hail.utils.ExecutionTimer.Timings +import is.hail.variant.ReferenceGenome + +import scala.util.control.NonFatal + +import java.io.ByteArrayOutputStream +import java.nio.charset.StandardCharsets + +import org.json4s.{DefaultFormats, Extraction, Formats, JArray, JValue} +import org.json4s.jackson.JsonMethods + +case class SerializedIRFunction( + name: String, + type_parameters: Array[String], + value_parameter_names: Array[String], + value_parameter_types: Array[String], + return_type: String, + rendered_body: String, +) + +trait BackendRpc { + + sealed trait Command extends Product with Serializable + + object Commands { + case class TypeOf(k: Kind[_ <: VType], ir: String) extends Command + case class Execute(ir: String, fs: Array[SerializedIRFunction], codec: String) extends Command + case class ParseVcfMetadata(path: String) extends Command + + case class ImportFam(path: String, isQuantPheno: Boolean, delimiter: String, missing: String) + extends Command + + case class LoadReferencesFromDataset(path: String) extends Command + + case class LoadReferencesFromFASTA( + name: String, + fasta_file: String, + index_file: String, + x_contigs: Array[String], + y_contigs: Array[String], + mt_contigs: Array[String], + par: Array[String], + ) extends Command + } + + type Env + + trait Reader { + def command(env: Env): Command + } + + trait Context { + def scoped[A](env: Env)(f: ExecuteContext => A): (A, Timings) + def putReferences(env: Env)(refs: Iterable[ReferenceGenome]): Unit + } + + trait Writer { + def timings(env: Env, t: ExecutionTimer.Timings): Unit + def result(env: Env, r: Array[Byte]): Unit + def error(env: Env, t: Throwable): Unit + } + + implicit val fmts: Formats = DefaultFormats + + import Commands._ + + final def runRpc(env: Env)(implicit Reader: Reader, Context: Context, Write: Writer): Unit = + try { + val command = Reader.command(env) + val (result, timings) = retryTransientErrors { + Context.scoped(env) { ctx => + command match { + case TypeOf(kind, s) => + jsonToBytes { + (kind match { + case Value => IRParser.parse_value_ir(ctx, s) + case Table => IRParser.parse_table_ir(ctx, s) + case Matrix => IRParser.parse_matrix_ir(ctx, s) + case BlockMatrix => IRParser.parse_blockmatrix_ir(ctx, s) + }).typ.toJSON + } + + case Execute(s, fns, codec) => + val bufferSpec = BufferSpec.parseOrDefault(codec) + withIRFunctionsReadFromInput(ctx, fns) { + val ir = IRParser.parse_value_ir(ctx, s) + val res = ctx.backend.execute(ctx, ir) + res match { + case Left(_) => Array.empty[Byte] + case Right((pt, off)) => + using(new ByteArrayOutputStream()) { os => + Backend.encodeToOutputStream(ctx, pt, off, bufferSpec, os) + os.toByteArray + } + } + } + + case ParseVcfMetadata(path) => + jsonToBytes { + val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path) + Extraction.decompose(metadata) + } + + case ImportFam(path, isQuantPheno, delimiter, missing) => + jsonToBytes { + LoadPlink.importFamJSON(ctx.fs, path, isQuantPheno, delimiter, missing) + } + + case LoadReferencesFromDataset(path) => + jsonToBytes { + val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) + Context.putReferences(env)(rgs) + JArray(rgs.map(_.toJSON).toList) + } + + case LoadReferencesFromFASTA(name, fasta, index, xContigs, yContigs, mtContigs, par) => + jsonToBytes { + val rg = ReferenceGenome.fromFASTAFile( + ctx, + name, + fasta, + index, + xContigs, + yContigs, + mtContigs, + par, + ) + Context.putReferences(env)(FastSeq(rg)) + rg.toJSON + } + } + } + } + + Write.result(env, result) + Write.timings(env, timings) + } catch { + case NonFatal(error) => Write.error(env, error) + } + + def jsonToBytes(v: JValue): Array[Byte] = + JsonMethods.compact(v).getBytes(StandardCharsets.UTF_8) + + private[this] def withIRFunctionsReadFromInput[A]( + ctx: ExecuteContext, + serializedFunctions: Array[SerializedIRFunction], + )( + body: => A + ): A = { + val fns = new BoxedArrayBuilder[UserDefinedFnKey](serializedFunctions.length) + try { + for (func <- serializedFunctions) { + fns += IRFunctionRegistry.registerIR( + ctx, + func.name, + func.type_parameters, + func.value_parameter_names, + func.value_parameter_types, + func.return_type, + func.rendered_body, + ) + } + + body + } finally + for (i <- 0 until fns.length) + IRFunctionRegistry.unregisterIr(fns(i)) + } +} + +trait HttpLikeRpc extends BackendRpc { + + import Commands._ + + trait Routing extends Reader { + + sealed trait Route extends Product with Serializable + + object Routes { + case class TypeOf(kind: Kind[_ <: VType]) extends Route + case object Execute extends Route + case object ParseVcfMetadata extends Route + case object ImportFam extends Route + case object LoadReferencesFromDataset extends Route + case object LoadReferencesFromFASTA extends Route + } + + def route(a: Env): Route + def payload(a: Env): JValue + + final override def command(a: Env): Command = + route(a) match { + case Routes.TypeOf(k) => + TypeOf(k, payload(a).extract[IRTypePayload].ir) + case Routes.Execute => + val ExecutePayload(ir, fns, codec) = payload(a).extract[ExecutePayload] + Execute(ir, fns, codec) + case Routes.ParseVcfMetadata => + ParseVcfMetadata(payload(a).extract[ParseVCFMetadataPayload].path) + case Routes.ImportFam => + val config = payload(a).extract[ImportFamPayload] + ImportFam(config.path, config.quant_pheno, config.delimiter, config.missing) + case Routes.LoadReferencesFromDataset => + val path = payload(a).extract[LoadReferencesFromDatasetPayload].path + LoadReferencesFromDataset(path) + case Routes.LoadReferencesFromFASTA => + val config = payload(a).extract[FromFASTAFilePayload] + LoadReferencesFromFASTA( + config.name, + config.fasta_file, + config.index_file, + config.x_contigs, + config.y_contigs, + config.mt_contigs, + config.par, + ) + } + } + + private[this] case class IRTypePayload(ir: String) + private[this] case class LoadReferencesFromDatasetPayload(path: String) + + private[this] case class FromFASTAFilePayload( + name: String, + fasta_file: String, + index_file: String, + x_contigs: Array[String], + y_contigs: Array[String], + mt_contigs: Array[String], + par: Array[String], + ) + + private[this] case class ParseVCFMetadataPayload(path: String) + + private[this] case class ImportFamPayload( + path: String, + quant_pheno: Boolean, + delimiter: String, + missing: String, + ) + + private[this] case class ExecutePayload( + ir: String, + fns: Array[SerializedIRFunction], + stream_codec: String, + ) +} diff --git a/hail/hail/src/is/hail/backend/BackendServer.scala b/hail/hail/src/is/hail/backend/BackendServer.scala deleted file mode 100644 index 24e7f7b9875..00000000000 --- a/hail/hail/src/is/hail/backend/BackendServer.scala +++ /dev/null @@ -1,153 +0,0 @@ -package is.hail.backend - -import is.hail.expr.ir.IRParser -import is.hail.utils._ - -import scala.util.control.NonFatal - -import java.io.Closeable -import java.net.InetSocketAddress -import java.nio.charset.StandardCharsets -import java.util.concurrent._ - -import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer} -import org.json4s._ -import org.json4s.jackson.JsonMethods -import org.json4s.jackson.JsonMethods.compact - -case class IRTypePayload(ir: String) -case class LoadReferencesFromDatasetPayload(path: String) - -case class FromFASTAFilePayload( - name: String, - fasta_file: String, - index_file: String, - x_contigs: Array[String], - y_contigs: Array[String], - mt_contigs: Array[String], - par: Array[String], -) - -case class ParseVCFMetadataPayload(path: String) -case class ImportFamPayload(path: String, quant_pheno: Boolean, delimiter: String, missing: String) -case class ExecutePayload(ir: String, stream_codec: String, timed: Boolean) - -class BackendServer(backend: Backend) extends Closeable { - // 0 => let the OS pick an available port - private[this] val httpServer = HttpServer.create(new InetSocketAddress(0), 10) - private[this] val handler = new BackendHttpHandler(backend) - - private[this] val thread = { - // This HTTP server *must not* start non-daemon threads because such threads keep the JVM - // alive. A living JVM indicates to Spark that the job is incomplete. This does not manifest - // when you run jobs in a local pyspark (because you'll Ctrl-C out of Python regardless of the - // JVM's state) nor does it manifest in a Notebook (again, you'll kill the Notebook kernel - // explicitly regardless of the JVM). It *does* manifest when submitting jobs with - // - // gcloud dataproc submit ... - // - // or - // - // spark-submit - // - // setExecutor(null) ensures the server creates no new threads: - // - // > If this method is not called (before start()) or if it is called with a null Executor, then - // > a default implementation is used, which uses the thread which was created by the start() - // > method. - // - /* Source: - * https://docs.oracle.com/en/java/javase/11/docs/api/jdk.httpserver/com/sun/net/httpserver/HttpServer.html#setExecutor(java.util.concurrent.Executor) */ - // - httpServer.createContext("/", handler) - httpServer.setExecutor(null) - val t = Executors.defaultThreadFactory().newThread(new Runnable() { - def run(): Unit = - httpServer.start() - }) - t.setDaemon(true) - t - } - - def port = httpServer.getAddress.getPort - - def start(): Unit = - thread.start() - - override def close(): Unit = - httpServer.stop(10) -} - -class BackendHttpHandler(backend: Backend) extends HttpHandler { - def handle(exchange: HttpExchange): Unit = { - implicit val formats: Formats = DefaultFormats - - try { - val body = using(exchange.getRequestBody)(JsonMethods.parse(_)) - if (exchange.getRequestURI.getPath == "/execute") { - val ExecutePayload(irStr, streamCodec, timed) = body.extract[ExecutePayload] - backend.withExecuteContext { ctx => - val (res, timings) = ExecutionTimer.time { timer => - ctx.local(timer = timer) { ctx => - val irData = IRParser.parse_value_ir(ctx, irStr) - backend.execute(ctx, irData) - } - } - - if (timed) { - exchange.getResponseHeaders.add("X-Hail-Timings", compact(timings.toJSON)) - } - - res match { - case Left(_) => exchange.sendResponseHeaders(200, -1L) - case Right((t, off)) => - exchange.sendResponseHeaders(200, 0L) // 0 => an arbitrarily long response body - using(exchange.getResponseBody) { os => - Backend.encodeToOutputStream(ctx, t, off, streamCodec, os) - } - } - } - return - } - - val response: Array[Byte] = exchange.getRequestURI.getPath match { - case "/value/type" => backend.valueType(body.extract[IRTypePayload].ir) - case "/table/type" => backend.tableType(body.extract[IRTypePayload].ir) - case "/matrixtable/type" => backend.matrixTableType(body.extract[IRTypePayload].ir) - case "/blockmatrix/type" => backend.blockMatrixType(body.extract[IRTypePayload].ir) - case "/references/load" => - backend.loadReferencesFromDataset(body.extract[LoadReferencesFromDatasetPayload].path) - case "/references/from_fasta" => - val config = body.extract[FromFASTAFilePayload] - backend.fromFASTAFile( - config.name, - config.fasta_file, - config.index_file, - config.x_contigs, - config.y_contigs, - config.mt_contigs, - config.par, - ) - case "/vcf/metadata/parse" => - backend.parseVCFMetadata(body.extract[ParseVCFMetadataPayload].path) - case "/fam/import" => - val config = body.extract[ImportFamPayload] - backend.importFam(config.path, config.quant_pheno, config.delimiter, config.missing) - } - - exchange.sendResponseHeaders(200, response.length) - using(exchange.getResponseBody())(_.write(response)) - } catch { - case NonFatal(t) => - val (shortMessage, expandedMessage, errorId) = handleForPython(t) - val errorJson = JObject( - "short" -> JString(shortMessage), - "expanded" -> JString(expandedMessage), - "error_id" -> JInt(errorId), - ) - val errorBytes = JsonMethods.compact(errorJson).getBytes(StandardCharsets.UTF_8) - exchange.sendResponseHeaders(500, errorBytes.length) - using(exchange.getResponseBody())(_.write(errorBytes)) - } - } -} diff --git a/hail/hail/src/is/hail/backend/ExecuteContext.scala b/hail/hail/src/is/hail/backend/ExecuteContext.scala index a9217359af1..3eb03b9bfb3 100644 --- a/hail/hail/src/is/hail/backend/ExecuteContext.scala +++ b/hail/hail/src/is/hail/backend/ExecuteContext.scala @@ -1,6 +1,6 @@ package is.hail.backend -import is.hail.{HailContext, HailFeatureFlags} +import is.hail.HailFeatureFlags import is.hail.annotations.{Region, RegionPool} import is.hail.asm4s.HailClassLoader import is.hail.backend.local.LocalTaskContext @@ -23,7 +23,7 @@ trait TempFileManager extends AutoCloseable { def newTmpPath(tmpdir: String, prefix: String, extension: String = null): String } -class OwningTempFileManager(fs: FS) extends TempFileManager { +class OwningTempFileManager(val fs: FS) extends TempFileManager { private[this] val tmpPaths = mutable.ArrayBuffer[String]() override def newTmpPath(tmpdir: String, prefix: String, extension: String): String = { @@ -39,7 +39,7 @@ class OwningTempFileManager(fs: FS) extends TempFileManager { } } -class NonOwningTempFileManager private (owner: TempFileManager) extends TempFileManager { +private class NonOwningTempFileManager private (owner: TempFileManager) extends TempFileManager { override def newTmpPath(tmpdir: String, prefix: String, extension: String): String = owner.newTmpPath(tmpdir, prefix, extension) @@ -55,13 +55,6 @@ object NonOwningTempFileManager { } object ExecuteContext { - def scoped[T](f: ExecuteContext => T)(implicit E: Enclosing): T = { - val result = HailContext.sparkBackend("ExecuteContext.scoped").withExecuteContext( - selfContainedExecution = false - )(f) - result - } - def scoped[T]( tmpdir: String, localTmpdir: String, @@ -72,7 +65,6 @@ object ExecuteContext { tempFileManager: TempFileManager, theHailClassLoader: HailClassLoader, flags: HailFeatureFlags, - backendContext: BackendContext, irMetadata: IrMetadata, blockMatrixCache: mutable.Map[String, BlockMatrix], codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], @@ -94,7 +86,6 @@ object ExecuteContext { tempFileManager, theHailClassLoader, flags, - backendContext, irMetadata, blockMatrixCache, codeCache, @@ -124,10 +115,9 @@ class ExecuteContext( val fs: FS, val r: Region, val timer: ExecutionTimer, - _tempFileManager: TempFileManager, + val tempFileManager: TempFileManager, val theHailClassLoader: HailClassLoader, val flags: HailFeatureFlags, - val backendContext: BackendContext, val irMetadata: IrMetadata, val BlockMatrixCache: mutable.Map[String, BlockMatrix], val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], @@ -146,10 +136,8 @@ class ExecuteContext( ) } - val stateManager = HailStateManager(references) - - val tempFileManager: TempFileManager = - if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs) + def stateManager: HailStateManager = + HailStateManager(references) def fsBc: BroadcastValue[FS] = fs.broadcast @@ -197,7 +185,6 @@ class ExecuteContext( tempFileManager: TempFileManager = NonOwningTempFileManager(this.tempFileManager), theHailClassLoader: HailClassLoader = this.theHailClassLoader, flags: HailFeatureFlags = this.flags, - backendContext: BackendContext = this.backendContext, irMetadata: IrMetadata = this.irMetadata, blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache, codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache, @@ -217,7 +204,6 @@ class ExecuteContext( tempFileManager, theHailClassLoader, flags, - backendContext, irMetadata, blockMatrixCache, codeCache, diff --git a/hail/hail/src/is/hail/backend/api/Py4JBackendApi.scala b/hail/hail/src/is/hail/backend/api/Py4JBackendApi.scala new file mode 100644 index 00000000000..9634edaf6b6 --- /dev/null +++ b/hail/hail/src/is/hail/backend/api/Py4JBackendApi.scala @@ -0,0 +1,420 @@ +package is.hail.backend.api + +import is.hail.{linalg, HailFeatureFlags} +import is.hail.asm4s.HailClassLoader +import is.hail.backend._ +import is.hail.backend.spark.SparkBackend +import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex} +import is.hail.expr.ir._ +import is.hail.expr.ir.IRParser.parseType +import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer +import is.hail.expr.ir.defs.{EncodedLiteral, GetFieldByIdx} +import is.hail.expr.ir.functions.IRFunctionRegistry +import is.hail.expr.ir.lowering.IrMetadata +import is.hail.io.fs._ +import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} +import is.hail.types.physical.PStruct +import is.hail.types.virtual.{TArray, TInterval} +import is.hail.types.virtual.Kinds.{BlockMatrix, Matrix, Table, Value} +import is.hail.utils._ +import is.hail.utils.ExecutionTimer.Timings +import is.hail.variant.ReferenceGenome + +import scala.annotation.nowarn +import scala.collection.mutable +import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConverter} + +import java.io.Closeable +import java.net.InetSocketAddress +import java.util + +import com.google.api.client.http.HttpStatusCodes +import com.sun.net.httpserver.{HttpExchange, HttpServer} +import org.apache.hadoop.conf.Configuration +import org.apache.spark.sql.DataFrame +import org.json4s._ +import org.json4s.jackson.{JsonMethods, Serialization} +import sourcecode.Enclosing + +final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandling { + + private[this] val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() + private[this] val hcl = new HailClassLoader(getClass.getClassLoader) + private[this] val references = mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*) + private[this] val bmCache = mutable.Map[String, linalg.BlockMatrix]() + private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) + private[this] val persistedIr = mutable.Map[Int, BaseIR]() + private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32) + private[this] var irID: Int = 0 + + private[this] var tmpdir: String = _ + private[this] var localTmpdir: String = _ + + private[this] var tmpFileManager = new OwningTempFileManager( + newFs(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) + ) + + def pyFs: FS = + synchronized(tmpFileManager.fs) + + def pyGetFlag(name: String): String = + synchronized(flags.get(name)) + + def pySetFlag(name: String, value: String): Unit = + synchronized(flags.set(name, value)) + + def pyAvailableFlags: java.util.ArrayList[String] = + flags.available + + def pySetRemoteTmp(tmp: String): Unit = + synchronized { tmpdir = tmp } + + def pySetLocalTmp(tmp: String): Unit = + synchronized { localTmpdir = tmp } + + def pySetGcsRequesterPaysConfig(project: String, buckets: util.List[String]): Unit = + synchronized { + tmpFileManager.close() + + val cloudfsConf = CloudStorageFSConfig.fromFlagsAndEnv(None, flags) + + val rpConfig: Option[RequesterPaysConfig] = + ( + Option(project).filter(_.nonEmpty), + Option(buckets).map(_.asScala.toSet.filterNot(_.isBlank)).filter(_.nonEmpty), + ) match { + case (Some(project), buckets) => Some(RequesterPaysConfig(project, buckets)) + case (None, Some(_)) => fatal( + "A non-empty, non-null requester pays google project is required to configure requester pays buckets." + ) + case (None, None) => None + } + + val fs = newFs( + cloudfsConf.copy( + google = (cloudfsConf.google, rpConfig) match { + case (Some(gconf), _) => Some(gconf.copy(requester_pays_config = rpConfig)) + case (None, Some(_)) => Some(GoogleStorageFSConfig(None, rpConfig)) + case _ => None + } + ) + ) + + tmpFileManager = new OwningTempFileManager(fs) + } + + def pyRemoveJavaIR(id: Int): Unit = + synchronized(persistedIr.remove(id)) + + def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = + synchronized { + val seq = IndexedFastaSequenceFile(tmpFileManager.fs, fastaFile, indexFile) + references(name).addSequence(seq) + } + + def pyRemoveSequence(name: String): Unit = + synchronized(references(name).removeSequence()) + + def pyExportBlockMatrix( + pathIn: String, + pathOut: String, + delimiter: String, + header: String, + addIndex: Boolean, + exportType: String, + partitionSize: java.lang.Integer, + entries: String, + ): Unit = + withExecuteContext() { ctx => + val rm = linalg.RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize) + entries match { + case "full" => + rm.export(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + case "lower" => + rm.exportLowerTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + case "strict_lower" => + rm.exportStrictLowerTriangle( + ctx, + pathOut, + delimiter, + Option(header), + addIndex, + exportType, + ) + case "upper" => + rm.exportUpperTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + case "strict_upper" => + rm.exportStrictUpperTriangle( + ctx, + pathOut, + delimiter, + Option(header), + addIndex, + exportType, + ) + } + } + + def pyRegisterIR( + name: String, + typeParamStrs: java.util.ArrayList[String], + argNameStrs: java.util.ArrayList[String], + argTypeStrs: java.util.ArrayList[String], + returnType: String, + bodyStr: String, + ): Unit = + withExecuteContext() { ctx => + IRFunctionRegistry.registerIR( + ctx, + name, + typeParamStrs.asScala.toArray, + argNameStrs.asScala.toArray, + argTypeStrs.asScala.toArray, + returnType, + bodyStr, + ) + } + + def pyExecuteLiteral(irStr: String): Int = + withExecuteContext() { ctx => + val ir = IRParser.parse_value_ir(ctx, irStr) + assert(ir.typ.isRealizable) + backend.execute(ctx, ir) match { + case Left(_) => throw new HailException("Can't create literal") + case Right((pt, addr)) => + val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) + addJavaIR(ctx, field) + } + }._1 + + def pyFromDF(df: DataFrame, jKey: java.util.List[String]): (Int, String) = + withExecuteContext(selfContainedExecution = false) { ctx => + val key = jKey.asScala.toArray.toFastSeq + val signature = + SparkAnnotationImpex.importType(df.schema).setRequired(true).asInstanceOf[PStruct] + val tir = TableLiteral( + TableValue( + ctx, + signature.virtualType, + key, + df.rdd, + Some(signature), + ), + ctx.theHailClassLoader, + ) + val id = addJavaIR(ctx, tir) + (id, JsonMethods.compact(tir.typ.toJSON)) + }._1 + + def pyToDF(s: String): DataFrame = + withExecuteContext() { ctx => + val tir = IRParser.parse_table_ir(ctx, s) + Interpret(tir, ctx).toDF() + }._1 + + def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] = + withExecuteContext() { ctx => + implicit val fmts: Formats = DefaultFormats + log.info("pyReadMultipleMatrixTables: got query") + + val kvs = JsonMethods.parse(jsonQuery).extract[Map[String, JValue]] + val paths = kvs("paths").extract[IndexedSeq[String]] + val intervalPointType = parseType(kvs("intervalPointType").extract[String]) + val intervalObjects = + JSONAnnotationImpex.importAnnotation(kvs("intervals"), TArray(TInterval(intervalPointType))) + .asInstanceOf[IndexedSeq[Interval]] + + val opts = NativeReaderOptions(intervalObjects, intervalPointType) + val matrixReaders: util.List[MatrixIR] = + paths.map { p => + log.info(s"creating MatrixRead node for $p") + val mnr = MatrixNativeReader(ctx.fs, p, Some(opts)) + MatrixRead(mnr.fullMatrixTypeWithoutUIDs, false, false, mnr): MatrixIR + }.asJava + + log.info("pyReadMultipleMatrixTables: returning N matrix tables") + matrixReaders + }._1 + + def pyAddReference(jsonConfig: String): Unit = + synchronized(addReference(ReferenceGenome.fromJSON(jsonConfig))) + + def pyRemoveReference(name: String): Unit = + synchronized(removeReference(name)) + + def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = + synchronized { + references(name).addLiftover( + references(destRGName), + LiftOver(tmpFileManager.fs, chainFile), + ) + } + + def pyRemoveLiftover(name: String, destRGName: String): Unit = + synchronized(references(name).removeLiftover(destRGName)) + + def parse_blockmatrix_ir(s: String): BlockMatrixIR = + withExecuteContext(selfContainedExecution = false) { ctx => + IRParser.parse_blockmatrix_ir(ctx, s) + }._1 + + private[this] def addReference(rg: ReferenceGenome): Unit = + ReferenceGenome.addFatalOnCollision(references, FastSeq(rg)) + + override def close(): Unit = + synchronized { + bmCache.values.foreach(_.unpersist()) + codeCache.clear() + persistedIr.clear() + coercerCache.clear() + backend.close() + } + + private[this] def removeReference(name: String): Unit = + references -= name + + private[this] def withExecuteContext[T]( + selfContainedExecution: Boolean = true + )( + f: ExecuteContext => T + )(implicit E: Enclosing + ): (T, Timings) = + synchronized { + ExecutionTimer.time { timer => + ExecuteContext.scoped( + tmpdir = tmpdir, + localTmpdir = localTmpdir, + backend = backend, + references = references.toMap, + fs = tmpFileManager.fs, + timer = timer, + tempFileManager = + if (!selfContainedExecution) NonOwningTempFileManager(tmpFileManager) + else new OwningTempFileManager(tmpFileManager.fs), + theHailClassLoader = hcl, + flags = flags, + irMetadata = new IrMetadata(), + blockMatrixCache = bmCache, + codeCache = codeCache, + irCache = persistedIr, + coercerCache = coercerCache, + )(f) + } + } + + private[this] def newFs(cloudfsConfig: CloudStorageFSConfig): FS = + backend match { + case s: SparkBackend => + val conf = new Configuration(s.sc.hadoopConfiguration) + cloudfsConfig.google.flatMap(_.requester_pays_config).foreach { + case RequesterPaysConfig(prj, bkts) => + bkts + .map { buckets => + conf.set("fs.gs.requester.pays.mode", "CUSTOM") + conf.set("fs.gs.requester.pays.project.id", prj) + conf.set("fs.gs.requester.pays.buckets", buckets.mkString(",")) + } + .getOrElse { + conf.set("fs.gs.requester.pays.mode", "AUTO") + conf.set("fs.gs.requester.pays.project.id", prj) + } + } + new HadoopFS(new SerializableHadoopConfiguration(conf)) + + case _ => + RouterFS.buildRoutes(cloudfsConfig) + } + + private[this] def nextIRID(): Int = { + irID += 1 + irID + } + + private[this] def addJavaIR(ctx: ExecuteContext, ir: BaseIR): Int = { + val id = nextIRID() + ctx.IrCache += (id -> ir) + id + } + + def pyHttpServer: HttpLikeRpc with Closeable = + new HttpLikeRpc with Closeable { + + override type Env = HttpExchange + + implicit object Reader extends Routing { + override def route(req: HttpExchange): Route = + req.getRequestURI.getPath match { + case "/value/type" => Routes.TypeOf(Value) + case "/table/type" => Routes.TypeOf(Table) + case "/matrixtable/type" => Routes.TypeOf(Matrix) + case "/blockmatrix/type" => Routes.TypeOf(BlockMatrix) + case "/execute" => Routes.Execute + case "/vcf/metadata/parse" => Routes.ParseVcfMetadata + case "/fam/import" => Routes.ImportFam + case "/references/load" => Routes.LoadReferencesFromDataset + case "/references/from_fasta" => Routes.LoadReferencesFromFASTA + } + + override def payload(req: HttpExchange): JValue = + using(req.getRequestBody)(JsonMethods.parse(_)) + } + + implicit object Writer extends Writer { + override def timings(req: HttpExchange, t: Timings): Unit = { + val ts = Serialization.write(Map("timings" -> t)) + req.getResponseHeaders.add("X-Hail-Timings", ts) + } + + override def result(req: HttpExchange, result: Array[Byte]): Unit = + respond(req, HttpStatusCodes.STATUS_CODE_OK, result) + + override def error(req: HttpExchange, t: Throwable): Unit = + respond( + req, + HttpStatusCodes.STATUS_CODE_SERVER_ERROR, + jsonToBytes { + val (shortMessage, expandedMessage, errorId) = handleForPython(t) + JObject( + "short" -> JString(shortMessage), + "expanded" -> JString(expandedMessage), + "error_id" -> JInt(errorId), + ) + }, + ) + + private[this] def respond(req: HttpExchange, code: Int, payload: Array[Byte]): Unit = { + req.sendResponseHeaders(code, payload.length) + using(req.getResponseBody)(_.write(payload)) + } + } + + implicit object Context extends Context { + override def scoped[A](req: HttpExchange)(f: ExecuteContext => A): (A, Timings) = + withExecuteContext()(f) + + override def putReferences(req: HttpExchange)(refs: Iterable[ReferenceGenome]): Unit = + refs.foreach(addReference) + } + + // 0 => let the OS pick an available port + private[this] val httpServer = HttpServer.create(new InetSocketAddress(0), 10) + + @nowarn def port: Int = httpServer.getAddress.getPort + override def close(): Unit = httpServer.stop(10) + + // This HTTP server *must not* start non-daemon threads because such threads keep the JVM + // alive. A living JVM indicates to Spark that the job is incomplete. This does not manifest + // when you run jobs in a local pyspark (because you'll Ctrl-C out of Python regardless of + // the JVM's state) nor does it manifest in a Notebook (again, you'll kill the Notebook kernel + // explicitly regardless of the JVM). It *does* manifest when submitting jobs with + // + // gcloud dataproc submit ... + // + // or + // + // spark-submit + httpServer.createContext("/", runRpc(_: HttpExchange)) + httpServer.setExecutor(null) // ensures the server creates no new threads + httpServer.start() + } +} diff --git a/hail/hail/src/is/hail/backend/api/ServiceBackendApi.scala b/hail/hail/src/is/hail/backend/api/ServiceBackendApi.scala new file mode 100644 index 00000000000..4a573f2b798 --- /dev/null +++ b/hail/hail/src/is/hail/backend/api/ServiceBackendApi.scala @@ -0,0 +1,274 @@ +package is.hail.backend.api + +import is.hail.{HailContext, HailFeatureFlags} +import is.hail.annotations.Memory +import is.hail.asm4s.HailClassLoader +import is.hail.backend.{Backend, ExecuteContext, HttpLikeRpc, OwningTempFileManager} +import is.hail.backend.service._ +import is.hail.expr.ir.lowering.IrMetadata +import is.hail.io.fs.{CloudStorageFSConfig, FS, RouterFS} +import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} +import is.hail.services._ +import is.hail.types.virtual.Kinds._ +import is.hail.utils.{ + toRichIterable, using, ErrorHandling, ExecutionTimer, HailWorkerException, ImmutableMap, Logging, +} +import is.hail.utils.ExecutionTimer.Timings +import is.hail.variant.ReferenceGenome + +import scala.annotation.switch + +import java.io.OutputStream +import java.nio.charset.StandardCharsets +import java.nio.file.Path + +import org.json4s.JsonAST.JValue +import org.json4s.jackson.JsonMethods + +object ServiceBackendApi extends HttpLikeRpc with Logging { + + case class Env( + backend: Backend, + flags: HailFeatureFlags, + hcl: HailClassLoader, + rpcConfig: ServiceBackendRPCPayload, + fs: FS, + references: Map[String, ReferenceGenome], + outputUrl: String, + action: Int, + payload: JValue, + ) + + implicit object Reader extends Routing { + import Routes._ + + override def route(env: Env): Route = + (env.action: @switch) match { + case 1 => TypeOf(Value) + case 2 => TypeOf(Table) + case 3 => TypeOf(Matrix) + case 4 => TypeOf(BlockMatrix) + case 5 => Execute + case 6 => ParseVcfMetadata + case 7 => ImportFam + case 8 => LoadReferencesFromDataset + case 9 => LoadReferencesFromFASTA + } + + override def payload(env: Env): JValue = env.payload + } + + implicit object Writer extends Writer with ErrorHandling { + // service backend doesn't support sending timings back to the python client + override def timings(env: Env, t: Timings): Unit = + () + + override def result(env: Env, result: Array[Byte]): Unit = + retryTransientErrors { + using(env.fs.createNoCompression(env.outputUrl)) { outputStream => + val output = new HailSocketAPIOutputStream(outputStream) + output.writeBool(true) + output.writeBytes(result) + } + } + + override def error(env: Env, t: Throwable): Unit = + retryTransientErrors { + val (shortMessage, expandedMessage, errorId) = + t match { + case t: HailWorkerException => + log.error( + "A worker failed. The exception was written for Python but we will also throw an exception to fail this driver job.", + t, + ) + (t.shortMessage, t.expandedMessage, t.errorId) + case _ => + log.error( + "An exception occurred in the driver. The exception was written for Python but we will re-throw to fail this driver job.", + t, + ) + handleForPython(t) + } + + using(env.fs.createNoCompression(env.outputUrl)) { outputStream => + val output = new HailSocketAPIOutputStream(outputStream) + output.writeBool(false) + output.writeString(shortMessage) + output.writeString(expandedMessage) + output.writeInt(errorId) + } + + throw t + } + } + + implicit object Context extends Context { + override def scoped[A](env: Env)(f: ExecuteContext => A): (A, Timings) = + ExecutionTimer.time { timer => + ExecuteContext.scoped( + env.rpcConfig.tmp_dir, + env.rpcConfig.tmp_dir, + env.backend, + env.references, + env.fs, + timer, + new OwningTempFileManager(env.fs), + env.hcl, + env.flags, + new IrMetadata(), + ImmutableMap.empty, + ImmutableMap.empty, + ImmutableMap.empty, + ImmutableMap.empty, + )(f) + } + + override def putReferences(env: Env)(refs: Iterable[ReferenceGenome]): Unit = + ReferenceGenome.addFatalOnCollision(env.references, refs) + } + + def main(argv: Array[String]): Unit = { + assert(argv.length == 7, argv.toFastSeq) + + val scratchDir = argv(0) + // val logFile = argv(1) + val jarLocation = argv(2) + val kind = argv(3) + assert(kind == Main.DRIVER) + val name = argv(4) + val inputURL = argv(5) + val outputURL = argv(6) + + val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json") + DeployConfig.set(deployConfig) + sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir) + + var fs = RouterFS.buildRoutes( + CloudStorageFSConfig.fromFlagsAndEnv( + Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), + HailFeatureFlags.fromEnv(), + ) + ) + + val (rpcConfig, jobConfig, action, payload) = + using(fs.openNoCompression(inputURL)) { is => + val input = JsonMethods.parse(is) + ( + (input \ "rpc_config").extract[ServiceBackendRPCPayload], + (input \ "job_config").extract[BatchJobConfig], + (input \ "action").extract[Int], + input \ "payload", + ) + } + + // requester pays config is conveyed in feature flags currently + val featureFlags = HailFeatureFlags.fromEnv(rpcConfig.flags) + fs = RouterFS.buildRoutes( + CloudStorageFSConfig.fromFlagsAndEnv( + Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), + featureFlags, + ) + ) + + val references = + ReferenceGenome.addFatalOnCollision( + ReferenceGenome.builtinReferences(), + rpcConfig.custom_references.map(ReferenceGenome.fromJSON), + ) + + rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => + liftoversForSource.foreach { case (destGenome, chainFile) => + references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile)) + } + } + + rpcConfig.sequences.foreach { case (rg, seq) => + references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index)) + } + + val backend = new ServiceBackend( + name, + BatchClient(deployConfig, Path.of(scratchDir, "secrets/gsa-key/key.json")), + JarUrl(jarLocation), + BatchConfig.fromConfigFile(Path.of(scratchDir, "batch-config/batch-config.json")), + jobConfig, + ) + + log.info("ServiceBackend allocated.") + if (HailContext.isInitialized) { + HailContext.get.backend = backend + log.info("Default references added to already initialized HailContexet.") + } else { + HailContext(backend, 50, 3) + log.info("HailContexet initialized.") + } + + // FIXME: when can the classloader be shared? (optimizer benefits!) + runRpc( + Env( + backend, + featureFlags, + new HailClassLoader(getClass.getClassLoader), + rpcConfig, + fs, + references, + outputURL, + action, + payload, + ) + ) + } +} + +private class HailSocketAPIOutputStream( + private[this] val out: OutputStream +) extends AutoCloseable { + private[this] var closed: Boolean = false + private[this] val dummy = new Array[Byte](8) + + def writeBool(b: Boolean): Unit = + out.write(if (b) 1 else 0) + + def writeInt(v: Int): Unit = { + Memory.storeInt(dummy, 0, v) + out.write(dummy, 0, 4) + } + + def writeLong(v: Long): Unit = { + Memory.storeLong(dummy, 0, v) + out.write(dummy) + } + + def writeBytes(bytes: Array[Byte]): Unit = { + writeInt(bytes.length) + out.write(bytes) + } + + def writeString(s: String): Unit = writeBytes(s.getBytes(StandardCharsets.UTF_8)) + + def close(): Unit = + if (!closed) { + out.close() + closed = true + } +} + +case class SequenceConfig(fasta: String, index: String) + +case class ServiceBackendRPCPayload( + tmp_dir: String, + flags: Map[String, String], + custom_references: Array[String], + liftovers: Map[String, Map[String, String]], + sequences: Map[String, SequenceConfig], +) + +case class BatchJobConfig( + token: String, + billing_project: String, + worker_cores: String, + worker_memory: String, + storage: String, + cloudfuse_configs: Array[CloudfuseConfig], + regions: Array[String], +) diff --git a/hail/hail/src/is/hail/backend/local/LocalBackend.scala b/hail/hail/src/is/hail/backend/local/LocalBackend.scala index 89847e5af85..fcc40daab83 100644 --- a/hail/hail/src/is/hail/backend/local/LocalBackend.scala +++ b/hail/hail/src/is/hail/backend/local/LocalBackend.scala @@ -1,13 +1,11 @@ package is.hail.backend.local -import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags} +import is.hail.{CancellingExecutorService, HailContext} import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend._ -import is.hail.backend.py4j.Py4JBackendExtensions import is.hail.expr.Validate import is.hail.expr.ir._ -import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs.MakeTuple @@ -18,17 +16,13 @@ import is.hail.types.physical.PTuple import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual.TVoid import is.hail.utils._ -import is.hail.variant.ReferenceGenome -import scala.collection.mutable import scala.reflect.ClassTag import java.io.PrintWriter import com.fasterxml.jackson.core.StreamReadConstraints import com.google.common.util.concurrent.MoreExecutors -import org.apache.hadoop -import sourcecode.Enclosing class LocalBroadcastValue[T](val value: T) extends BroadcastValue[T] with Serializable @@ -36,8 +30,7 @@ class LocalTaskContext(val partitionId: Int, val stageId: Int) extends HailTaskC override def attemptNumber(): Int = 0 } -object LocalBackend { - private var theLocalBackend: LocalBackend = _ +object LocalBackend extends Backend { // From https://github.com/hail-is/hail/issues/14580 : // IR can get quite big, especially as it can contain an arbitrary @@ -52,91 +45,32 @@ object LocalBackend { ) def apply( - tmpdir: String, logFile: String = "hail.log", quiet: Boolean = false, append: Boolean = false, skipLoggingConfiguration: Boolean = false, - ): LocalBackend = synchronized { - require(theLocalBackend == null) - - if (!skipLoggingConfiguration) - HailContext.configureLogging(logFile, quiet, append) - - theLocalBackend = new LocalBackend( - tmpdir, - mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*), - ) - - theLocalBackend - } - - def stop(): Unit = synchronized { - if (theLocalBackend != null) { - theLocalBackend = null - // Hadoop does not honor the hadoop configuration as a component of the cache key for file - // systems, so we blow away the cache so that a new configuration can successfully take - // effect. - // https://github.com/hail-is/hail/pull/12133#issuecomment-1241322443 - hadoop.fs.FileSystem.closeAll() + ): LocalBackend.type = + synchronized { + if (!skipLoggingConfiguration) HailContext.configureLogging(logFile, quiet, append) + this } - } -} -class LocalBackend( - val tmpdir: String, - override val references: mutable.Map[String, ReferenceGenome], -) extends Backend with Py4JBackendExtensions { - - override def backend: Backend = this - override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() - override def longLifeTempFileManager: TempFileManager = null - - private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader) - private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) - private[this] val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map() - private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32) - - // flags can be set after construction from python - def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) - - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = - ExecutionTimer.logTime { timer => - val fs = this.fs - ExecuteContext.scoped( - tmpdir, - tmpdir, - this, - references.toMap, - fs, - timer, - null, - theHailClassLoader, - flags, - new BackendContext { - override val executionCache: ExecutionCache = - ExecutionCache.fromFlags(flags, fs, tmpdir) - }, - new IrMetadata(), - ImmutableMap.empty, - codeCache, - persistedIR, - coercerCache, - )(f) - } + private case class Context(hcl: HailClassLoader, override val executionCache: ExecutionCache) + extends BackendContext def broadcast[T: ClassTag](value: T): BroadcastValue[T] = new LocalBroadcastValue[T](value) private[this] var stageIdx: Int = 0 - private[this] def nextStageId(): Int = { - val current = stageIdx - stageIdx += 1 - current - } + private[this] def nextStageId(): Int = + synchronized { + val current = stageIdx + stageIdx += 1 + current + } override def parallelizeAndComputeWithIndex( - backendContext: BackendContext, + ctx: BackendContext, fs: FS, contexts: IndexedSeq[Array[Byte]], stageIdentifier: String, @@ -147,22 +81,24 @@ class LocalBackend( ): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = { val stageId = nextStageId() + val hcl = ctx.asInstanceOf[Context].hcl runAllKeepFirstError(new CancellingExecutorService(MoreExecutors.newDirectExecutorService())) { partitions.getOrElse(contexts.indices).map { i => ( - () => - using(new LocalTaskContext(i, stageId)) { - f(contexts(i), _, theHailClassLoader, fs) - }, + () => using(new LocalTaskContext(i, stageId))(f(contexts(i), _, hcl, fs)), i, ) } } } + override def backendContext(ctx: ExecuteContext): BackendContext = + Context(ctx.theHailClassLoader, ExecutionCache.fromFlags(ctx.flags, ctx.fs, ctx.tmpdir)) + def defaultParallelism: Int = 1 - def close(): Unit = LocalBackend.stop() + def close(): Unit = + synchronized { stageIdx = 0 } private[this] def _jvmLowerAndExecute( ctx: ExecuteContext, @@ -170,8 +106,7 @@ class LocalBackend( print: Option[PrintWriter] = None, ): Either[Unit, (PTuple, Long)] = ctx.time { - val ir = - LoweringPipeline.darrayLowerer(true)(DArrayLowering.All).apply(ctx, ir0).asInstanceOf[IR] + val ir = LoweringPipeline.darrayLowerer(true)(DArrayLowering.All)(ctx, ir0).asInstanceOf[IR] if (!Compilable(ir)) throw new LowererUnsupportedOperation(s"lowered to uncompilable IR: ${Pretty(ctx, ir)}") diff --git a/hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala b/hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala deleted file mode 100644 index aa084dcf46f..00000000000 --- a/hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala +++ /dev/null @@ -1,259 +0,0 @@ -package is.hail.backend.py4j - -import is.hail.HailFeatureFlags -import is.hail.backend.{Backend, ExecuteContext, NonOwningTempFileManager, TempFileManager} -import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex} -import is.hail.expr.ir.{ - BaseIR, BindingEnv, BlockMatrixIR, IR, IRParser, Interpret, MatrixIR, MatrixNativeReader, - MatrixRead, Name, NativeReaderOptions, TableIR, TableLiteral, TableValue, -} -import is.hail.expr.ir.IRParser.parseType -import is.hail.expr.ir.defs.{EncodedLiteral, GetFieldByIdx} -import is.hail.expr.ir.functions.IRFunctionRegistry -import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} -import is.hail.linalg.RowMatrix -import is.hail.types.physical.PStruct -import is.hail.types.virtual.{TArray, TInterval} -import is.hail.utils.{defaultJSONFormats, log, toRichIterable, FastSeq, HailException, Interval} -import is.hail.variant.ReferenceGenome - -import scala.collection.mutable -import scala.jdk.CollectionConverters.{ - asScalaBufferConverter, mapAsScalaMapConverter, seqAsJavaListConverter, -} - -import java.nio.charset.StandardCharsets -import java.util - -import org.apache.spark.sql.DataFrame -import org.json4s -import org.json4s.Formats -import org.json4s.jackson.{JsonMethods, Serialization} -import sourcecode.Enclosing - -trait Py4JBackendExtensions { - def backend: Backend - def references: mutable.Map[String, ReferenceGenome] - def flags: HailFeatureFlags - def longLifeTempFileManager: TempFileManager - - def pyGetFlag(name: String): String = - flags.get(name) - - def pySetFlag(name: String, value: String): Unit = - flags.set(name, value) - - def pyAvailableFlags: java.util.ArrayList[String] = - flags.available - - private[this] var irID: Int = 0 - - private[this] def nextIRID(): Int = { - irID += 1 - irID - } - - private[this] def addJavaIR(ctx: ExecuteContext, ir: BaseIR): Int = { - val id = nextIRID() - ctx.IrCache += (id -> ir) - id - } - - def pyRemoveJavaIR(id: Int): Unit = - backend.withExecuteContext(_.IrCache.remove(id)) - - def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = - backend.withExecuteContext { ctx => - references(name).addSequence(IndexedFastaSequenceFile(ctx.fs, fastaFile, indexFile)) - } - - def pyRemoveSequence(name: String): Unit = - references(name).removeSequence() - - def pyExportBlockMatrix( - pathIn: String, - pathOut: String, - delimiter: String, - header: String, - addIndex: Boolean, - exportType: String, - partitionSize: java.lang.Integer, - entries: String, - ): Unit = - backend.withExecuteContext { ctx => - val rm = RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize) - entries match { - case "full" => - rm.export(ctx, pathOut, delimiter, Option(header), addIndex, exportType) - case "lower" => - rm.exportLowerTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) - case "strict_lower" => - rm.exportStrictLowerTriangle( - ctx, - pathOut, - delimiter, - Option(header), - addIndex, - exportType, - ) - case "upper" => - rm.exportUpperTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) - case "strict_upper" => - rm.exportStrictUpperTriangle( - ctx, - pathOut, - delimiter, - Option(header), - addIndex, - exportType, - ) - } - } - - def pyRegisterIR( - name: String, - typeParamStrs: java.util.ArrayList[String], - argNameStrs: java.util.ArrayList[String], - argTypeStrs: java.util.ArrayList[String], - returnType: String, - bodyStr: String, - ): Unit = - backend.withExecuteContext { ctx => - IRFunctionRegistry.registerIR( - ctx, - name, - typeParamStrs.asScala.toArray, - argNameStrs.asScala.toArray, - argTypeStrs.asScala.toArray, - returnType, - bodyStr, - ) - } - - def pyExecuteLiteral(irStr: String): Int = - backend.withExecuteContext { ctx => - val ir = IRParser.parse_value_ir(ctx, irStr) - assert(ir.typ.isRealizable) - backend.execute(ctx, ir) match { - case Left(_) => throw new HailException("Can't create literal") - case Right((pt, addr)) => - val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) - addJavaIR(ctx, field) - } - } - - def pyFromDF(df: DataFrame, jKey: java.util.List[String]): (Int, String) = { - val key = jKey.asScala.toArray.toFastSeq - val signature = - SparkAnnotationImpex.importType(df.schema).setRequired(true).asInstanceOf[PStruct] - withExecuteContext(selfContainedExecution = false) { ctx => - val tir = TableLiteral( - TableValue( - ctx, - signature.virtualType, - key, - df.rdd, - Some(signature), - ), - ctx.theHailClassLoader, - ) - val id = addJavaIR(ctx, tir) - (id, JsonMethods.compact(tir.typ.toJSON)) - } - } - - def pyToDF(s: String): DataFrame = - backend.withExecuteContext { ctx => - val tir = IRParser.parse_table_ir(ctx, s) - Interpret(tir, ctx).toDF() - } - - def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] = - backend.withExecuteContext { ctx => - log.info("pyReadMultipleMatrixTables: got query") - val kvs = JsonMethods.parse(jsonQuery) match { - case json4s.JObject(values) => values.toMap - } - - val paths = kvs("paths").asInstanceOf[json4s.JArray].arr.toArray.map { - case json4s.JString(s) => s - } - - val intervalPointType = parseType(kvs("intervalPointType").asInstanceOf[json4s.JString].s) - val intervalObjects = - JSONAnnotationImpex.importAnnotation(kvs("intervals"), TArray(TInterval(intervalPointType))) - .asInstanceOf[IndexedSeq[Interval]] - - val opts = NativeReaderOptions(intervalObjects, intervalPointType) - val matrixReaders: IndexedSeq[MatrixIR] = paths.map { p => - log.info(s"creating MatrixRead node for $p") - val mnr = MatrixNativeReader(ctx.fs, p, Some(opts)) - MatrixRead(mnr.fullMatrixTypeWithoutUIDs, false, false, mnr): MatrixIR - } - log.info("pyReadMultipleMatrixTables: returning N matrix tables") - matrixReaders.asJava - } - - def pyAddReference(jsonConfig: String): Unit = - addReference(ReferenceGenome.fromJSON(jsonConfig)) - - def pyRemoveReference(name: String): Unit = - removeReference(name) - - def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = - backend.withExecuteContext { ctx => - references(name).addLiftover(references(destRGName), LiftOver(ctx.fs, chainFile)) - } - - def pyRemoveLiftover(name: String, destRGName: String): Unit = - references(name).removeLiftover(destRGName) - - private[this] def addReference(rg: ReferenceGenome): Unit = - ReferenceGenome.addFatalOnCollision(references, FastSeq(rg)) - - private[this] def removeReference(name: String): Unit = - references -= name - - def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = - backend.withExecuteContext { ctx => - IRParser.parse_value_ir( - ctx, - s, - BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) => - Name(n) -> IRParser.parseType(t) - }.toSeq: _*), - ) - } - - def parse_table_ir(s: String): TableIR = - withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_table_ir(ctx, s)) - - def parse_matrix_ir(s: String): MatrixIR = - withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_matrix_ir(ctx, s)) - - def parse_blockmatrix_ir(s: String): BlockMatrixIR = - withExecuteContext(selfContainedExecution = false) { ctx => - IRParser.parse_blockmatrix_ir(ctx, s) - } - - def loadReferencesFromDataset(path: String): Array[Byte] = - backend.withExecuteContext { ctx => - val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) - ReferenceGenome.addFatalOnCollision(references, rgs) - - implicit val formats: Formats = defaultJSONFormats - Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8) - } - - def withExecuteContext[T]( - selfContainedExecution: Boolean = true - )( - f: ExecuteContext => T - )(implicit E: Enclosing - ): T = - backend.withExecuteContext { ctx => - val tempFileManager = longLifeTempFileManager - if (selfContainedExecution && tempFileManager != null) f(ctx) - else ctx.local(tempFileManager = NonOwningTempFileManager(tempFileManager))(f) - } -} diff --git a/hail/hail/src/is/hail/backend/service/Main.scala b/hail/hail/src/is/hail/backend/service/Main.scala index a31f99e7946..391fefcdaad 100644 --- a/hail/hail/src/is/hail/backend/service/Main.scala +++ b/hail/hail/src/is/hail/backend/service/Main.scala @@ -1,5 +1,7 @@ package is.hail.backend.service +import is.hail.backend.api.ServiceBackendApi + import com.fasterxml.jackson.core.StreamReadConstraints object Main { @@ -18,8 +20,7 @@ object Main { def main(argv: Array[String]): Unit = argv(3) match { case WORKER => Worker.main(argv) - case DRIVER => ServiceBackendAPI.main(argv) - + case DRIVER => ServiceBackendApi.main(argv) // Batch's "JvmJob" is a special kind of job that can only call `Main.main`. // TEST is used for integration testing the `BatchClient` to verify that we // can create JvmJobs without having to mock the payload to a `Worker` job. diff --git a/hail/hail/src/is/hail/backend/service/ServiceBackend.scala b/hail/hail/src/is/hail/backend/service/ServiceBackend.scala index 6e4309056b2..56df04ed276 100644 --- a/hail/hail/src/is/hail/backend/service/ServiceBackend.scala +++ b/hail/hail/src/is/hail/backend/service/ServiceBackend.scala @@ -1,140 +1,55 @@ package is.hail.backend.service -import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags} +import is.hail.{CancellingExecutorService, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ +import is.hail.backend.api.BatchJobConfig +import is.hail.backend.service.ServiceBackend.MaxAvailableGcsConnections import is.hail.expr.Validate -import is.hail.expr.ir.{ - IR, IRParser, IRSize, LoweringAnalyses, SortField, TableIR, TableReader, TypeCheck, -} +import is.hail.expr.ir.{IR, IRSize, LoweringAnalyses, SortField, TableIR, TableReader, TypeCheck} import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs.MakeTuple -import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering._ import is.hail.io.fs._ -import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} import is.hail.services.{BatchClient, JobGroupRequest, _} import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success} import is.hail.types._ import is.hail.types.physical._ import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType -import is.hail.types.virtual._ +import is.hail.types.virtual.TVoid import is.hail.utils._ -import is.hail.variant.ReferenceGenome -import scala.annotation.switch -import scala.collection.mutable import scala.reflect.ClassTag import java.io._ import java.nio.charset.StandardCharsets -import java.nio.file.Path import java.util.concurrent._ -import org.apache.log4j.Logger -import org.json4s.{DefaultFormats, Formats} -import org.json4s.JsonAST._ -import org.json4s.jackson.{JsonMethods, Serialization} -import sourcecode.Enclosing - -class ServiceBackendContext( - val billingProject: String, - val remoteTmpDir: String, - val workerCores: String, - val workerMemory: String, - val storageRequirement: String, - val regions: Array[String], - val cloudfuseConfig: Array[CloudfuseConfig], - val profile: Boolean, - val executionCache: ExecutionCache, -) extends BackendContext with Serializable {} - object ServiceBackend { - - def apply( - jarLocation: String, - name: String, - theHailClassLoader: HailClassLoader, - batchClient: BatchClient, - batchConfig: BatchConfig, - scratchDir: String = sys.env.getOrElse("HAIL_WORKER_SCRATCH_DIR", ""), - rpcConfig: ServiceBackendRPCPayload, - env: Map[String, String], - ): ServiceBackend = { - - val flags = HailFeatureFlags.fromEnv(rpcConfig.flags) - val shouldProfile = flags.get("profile") != null - val fs = RouterFS.buildRoutes( - CloudStorageFSConfig.fromFlagsAndEnv( - Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), - flags, - env, - ) - ) - - val backendContext = new ServiceBackendContext( - rpcConfig.billing_project, - rpcConfig.remote_tmpdir, - rpcConfig.worker_cores, - rpcConfig.worker_memory, - rpcConfig.storage, - rpcConfig.regions, - rpcConfig.cloudfuse_configs, - shouldProfile, - ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir), - ) - - val references = mutable.Map.empty[String, ReferenceGenome] - references ++= ReferenceGenome.builtinReferences() - ReferenceGenome.addFatalOnCollision( - references, - rpcConfig.custom_references.map(ReferenceGenome.fromJSON), - ) - - rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => - liftoversForSource.foreach { case (destGenome, chainFile) => - references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile)) - } - } - rpcConfig.sequences.foreach { case (rg, seq) => - references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index)) - } - - new ServiceBackend( - JarUrl(jarLocation), - name, - theHailClassLoader, - references, - batchClient, - batchConfig, - flags, - rpcConfig.tmp_dir, - fs, - backendContext, - scratchDir, - ) - } + val MaxAvailableGcsConnections = 1000 } class ServiceBackend( - val jarSpec: JarSpec, - var name: String, - val theHailClassLoader: HailClassLoader, - val references: mutable.Map[String, ReferenceGenome], - val batchClient: BatchClient, + val name: String, + batchClient: BatchClient, + jarSpec: JarSpec, val batchConfig: BatchConfig, - val flags: HailFeatureFlags, - val tmpdir: String, - val fs: FS, - val serviceBackendContext: ServiceBackendContext, - val scratchDir: String, + jobConfig: BatchJobConfig, ) extends Backend with Logging { + case class Context( + remoteTmpDir: String, + flags: HailFeatureFlags, + override val executionCache: ExecutionCache, + ) extends BackendContext + private[this] var stageCount = 0 - private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000 - private[this] val executor = Executors.newFixedThreadPool(MAX_AVAILABLE_GCS_CONNECTIONS) + + private[this] val executor = lazily { + Executors.newFixedThreadPool(MaxAvailableGcsConnections) + } def defaultParallelism: Int = 4 @@ -152,6 +67,13 @@ class ServiceBackend( } } + override def backendContext(ctx: ExecuteContext): BackendContext = + Context( + remoteTmpDir = ctx.tmpdir, + flags = ctx.flags, + executionCache = ExecutionCache.fromFlags(ctx.flags, ctx.fs, ctx.tmpdir), + ) + private[this] def readString(in: DataInputStream): String = { val n = in.readInt() val bytes = new Array[Byte](n) @@ -160,7 +82,7 @@ class ServiceBackend( } private[this] def submitJobGroupAndWait( - backendContext: ServiceBackendContext, + ctx: Context, collection: IndexedSeq[Array[Byte]], token: String, root: String, @@ -170,7 +92,7 @@ class ServiceBackend( JvmJob( command = null, spec = jarSpec, - profile = flags.get("profile") != null, + profile = ctx.flags.get("profile") != null, ) val defaultJob = @@ -180,13 +102,13 @@ class ServiceBackend( resources = Some( JobResources( preemptible = true, - cpu = Some(backendContext.workerCores).filter(_ != "None"), - memory = Some(backendContext.workerMemory).filter(_ != "None"), - storage = Some(backendContext.storageRequirement).filter(_ != "0Gi"), + cpu = Some(jobConfig.worker_cores).filter(_ != "None"), + memory = Some(jobConfig.worker_memory).filter(_ != "None"), + storage = Some(jobConfig.storage).filter(_ != "0Gi"), ) ), - regions = Some(backendContext.regions).filter(_.nonEmpty), - cloudfuse = Some(backendContext.cloudfuseConfig).filter(_.nonEmpty), + regions = Some(jobConfig.regions).filter(_.nonEmpty), + cloudfuse = Some(jobConfig.cloudfuse_configs).filter(_.nonEmpty), ) val jobs = @@ -216,7 +138,7 @@ class ServiceBackend( batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId) } - private[this] def readResult(root: String, i: Int): Array[Byte] = { + private[this] def readResult(fs: FS, root: String, i: Int): Array[Byte] = { val bytes = fs.readNoCompression(s"$root/result.$i") if (bytes(0) != 0) { bytes.slice(1, bytes.length) @@ -241,7 +163,7 @@ class ServiceBackend( f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] ): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = { - val backendContext = _backendContext.asInstanceOf[ServiceBackendContext] + val backendContext = _backendContext.asInstanceOf[Context] val token = tokenUrlSafe val root = s"${backendContext.remoteTmpDir}/parallelizeAndComputeWithIndex/$token" @@ -286,7 +208,7 @@ class ServiceBackend( val startTime = System.nanoTime() var r @ (err, results) = runAllKeepFirstError(new CancellingExecutorService(executor)) { (partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) => - (() => readResult(root, jobIndex), partIdx) + (() => readResult(fs, root, jobIndex), partIdx) } } @@ -315,7 +237,7 @@ class ServiceBackend( } override def close(): Unit = { - executor.shutdownNow() + if (executor.isEvaluated) executor.shutdownNow() batchClient.close() } @@ -374,308 +296,6 @@ class ServiceBackend( def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses) : TableStage = LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses) - - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = - ExecutionTimer.logTime { timer => - ExecuteContext.scoped( - tmpdir, - "file:///tmp", - this, - references.toMap, - fs, - timer, - null, - theHailClassLoader, - flags, - serviceBackendContext, - new IrMetadata(), - ImmutableMap.empty, - ImmutableMap.empty, - ImmutableMap.empty, - ImmutableMap.empty, - )(f) - } - - override def loadReferencesFromDataset(path: String): Array[Byte] = - withExecuteContext { ctx => - val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) - ReferenceGenome.addFatalOnCollision(references, rgs) - implicit val formats: Formats = defaultJSONFormats - Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8) - } } -class EndOfInputException extends RuntimeException class HailBatchFailure(message: String) extends RuntimeException(message) - -object ServiceBackendAPI { - private[this] val log = Logger.getLogger(getClass.getName()) - - def main(argv: Array[String]): Unit = { - assert(argv.length == 7, argv.toFastSeq) - - val scratchDir = argv(0) - // val logFile = argv(1) - val jarLocation = argv(2) - val kind = argv(3) - assert(kind == Main.DRIVER) - val name = argv(4) - val inputURL = argv(5) - val outputURL = argv(6) - - implicit val formats: Formats = DefaultFormats - - val fs = RouterFS.buildRoutes( - CloudStorageFSConfig.fromFlagsAndEnv( - Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), - HailFeatureFlags.fromEnv(), - ) - ) - val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json") - DeployConfig.set(deployConfig) - sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir) - - val batchClient = BatchClient(deployConfig, Path.of(scratchDir, "secrets/gsa-key/key.json")) - log.info("BatchClient allocated.") - - val batchConfig = - BatchConfig.fromConfigFile(Path.of(scratchDir, "batch-config/batch-config.json")) - log.info("BatchConfig parsed.") - - val input = using(fs.openNoCompression(inputURL))(JsonMethods.parse(_)) - val rpcConfig = (input \ "config").extract[ServiceBackendRPCPayload] - - // FIXME: when can the classloader be shared? (optimizer benefits!) - val backend = ServiceBackend( - jarLocation, - name, - new HailClassLoader(getClass().getClassLoader()), - batchClient, - batchConfig, - scratchDir, - rpcConfig, - sys.env, - ) - log.info("ServiceBackend allocated.") - if (HailContext.isInitialized) { - HailContext.get.backend = backend - log.info("Default references added to already initialized HailContexet.") - } else { - HailContext(backend, 50, 3) - log.info("HailContexet initialized.") - } - - val action = (input \ "action").extract[Int] - val payload = (input \ "payload") - new ServiceBackendAPI(backend, fs, outputURL).executeOneCommand(action, payload) - } -} - -private class HailSocketAPIOutputStream( - private[this] val out: OutputStream -) extends AutoCloseable { - private[this] var closed: Boolean = false - private[this] val dummy = new Array[Byte](8) - - def writeBool(b: Boolean): Unit = - out.write(if (b) 1 else 0) - - def writeInt(v: Int): Unit = { - Memory.storeInt(dummy, 0, v) - out.write(dummy, 0, 4) - } - - def writeLong(v: Long): Unit = { - Memory.storeLong(dummy, 0, v) - out.write(dummy) - } - - def writeBytes(bytes: Array[Byte]): Unit = { - writeInt(bytes.length) - out.write(bytes) - } - - def writeString(s: String): Unit = writeBytes(s.getBytes(StandardCharsets.UTF_8)) - - def close(): Unit = - if (!closed) { - out.close() - closed = true - } -} - -case class SequenceConfig(fasta: String, index: String) - -case class ServiceBackendRPCPayload( - tmp_dir: String, - remote_tmpdir: String, - billing_project: String, - worker_cores: String, - worker_memory: String, - storage: String, - cloudfuse_configs: Array[CloudfuseConfig], - regions: Array[String], - flags: Map[String, String], - custom_references: Array[String], - liftovers: Map[String, Map[String, String]], - sequences: Map[String, SequenceConfig], -) - -case class ServiceBackendExecutePayload( - functions: Array[SerializedIRFunction], - idempotency_token: String, - payload: ExecutePayload, -) - -case class SerializedIRFunction( - name: String, - type_parameters: Array[String], - value_parameter_names: Array[String], - value_parameter_types: Array[String], - return_type: String, - rendered_body: String, -) - -class ServiceBackendAPI( - private[this] val backend: ServiceBackend, - private[this] val fs: FS, - private[this] val outputURL: String, -) extends Thread { - private[this] val LOAD_REFERENCES_FROM_DATASET = 1 - private[this] val VALUE_TYPE = 2 - private[this] val TABLE_TYPE = 3 - private[this] val MATRIX_TABLE_TYPE = 4 - private[this] val BLOCK_MATRIX_TYPE = 5 - private[this] val EXECUTE = 6 - private[this] val PARSE_VCF_METADATA = 7 - private[this] val IMPORT_FAM = 8 - private[this] val FROM_FASTA_FILE = 9 - - private[this] val log = Logger.getLogger(getClass.getName()) - - private[this] def doAction(action: Int, payload: JValue): Array[Byte] = retryTransientErrors { - implicit val formats: Formats = DefaultFormats - (action: @switch) match { - case LOAD_REFERENCES_FROM_DATASET => - val path = payload.extract[LoadReferencesFromDatasetPayload].path - backend.loadReferencesFromDataset(path) - case VALUE_TYPE => - val ir = payload.extract[IRTypePayload].ir - backend.valueType(ir) - case TABLE_TYPE => - val ir = payload.extract[IRTypePayload].ir - backend.tableType(ir) - case MATRIX_TABLE_TYPE => - val ir = payload.extract[IRTypePayload].ir - backend.matrixTableType(ir) - case BLOCK_MATRIX_TYPE => - val ir = payload.extract[IRTypePayload].ir - backend.blockMatrixType(ir) - case EXECUTE => - val qobExecutePayload = payload.extract[ServiceBackendExecutePayload] - val bufferSpecString = qobExecutePayload.payload.stream_codec - val code = qobExecutePayload.payload.ir - backend.withExecuteContext { ctx => - withIRFunctionsReadFromInput(qobExecutePayload.functions, ctx) { () => - val ir = IRParser.parse_value_ir(ctx, code) - backend.execute(ctx, ir) match { - case Left(()) => - Array() - case Right((pt, off)) => - using(new ByteArrayOutputStream()) { os => - Backend.encodeToOutputStream(ctx, pt, off, bufferSpecString, os) - os.toByteArray - } - } - } - } - case PARSE_VCF_METADATA => - val path = payload.extract[ParseVCFMetadataPayload].path - backend.parseVCFMetadata(path) - case IMPORT_FAM => - val famPayload = payload.extract[ImportFamPayload] - val path = famPayload.path - val quantPheno = famPayload.quant_pheno - val delimiter = famPayload.delimiter - val missing = famPayload.missing - backend.importFam(path, quantPheno, delimiter, missing) - case FROM_FASTA_FILE => - val fastaPayload = payload.extract[FromFASTAFilePayload] - backend.fromFASTAFile( - fastaPayload.name, - fastaPayload.fasta_file, - fastaPayload.index_file, - fastaPayload.x_contigs, - fastaPayload.y_contigs, - fastaPayload.mt_contigs, - fastaPayload.par, - ) - } - } - - private[this] def withIRFunctionsReadFromInput( - serializedFunctions: Array[SerializedIRFunction], - ctx: ExecuteContext, - )( - body: () => Array[Byte] - ): Array[Byte] = { - try { - serializedFunctions.foreach { func => - IRFunctionRegistry.registerIR( - ctx, - func.name, - func.type_parameters, - func.value_parameter_names, - func.value_parameter_types, - func.return_type, - func.rendered_body, - ) - } - body() - } finally - IRFunctionRegistry.clearUserFunctions() - } - - def executeOneCommand(action: Int, payload: JValue): Unit = { - try { - val result = doAction(action, payload) - retryTransientErrors { - using(fs.createNoCompression(outputURL)) { outputStream => - val output = new HailSocketAPIOutputStream(outputStream) - output.writeBool(true) - output.writeBytes(result) - } - } - } catch { - case exc: HailWorkerException => - retryTransientErrors { - using(fs.createNoCompression(outputURL)) { outputStream => - val output = new HailSocketAPIOutputStream(outputStream) - output.writeBool(false) - output.writeString(exc.shortMessage) - output.writeString(exc.expandedMessage) - output.writeInt(exc.errorId) - } - } - log.error( - "A worker failed. The exception was written for Python but we will also throw an exception to fail this driver job." - ) - throw exc - case t: Throwable => - val (shortMessage, expandedMessage, errorId) = handleForPython(t) - retryTransientErrors { - using(fs.createNoCompression(outputURL)) { outputStream => - val output = new HailSocketAPIOutputStream(outputStream) - output.writeBool(false) - output.writeString(shortMessage) - output.writeString(expandedMessage) - output.writeInt(errorId) - } - } - log.error( - "An exception occurred in the driver. The exception was written for Python but we will re-throw to fail this driver job." - ) - throw t - } - } -} diff --git a/hail/hail/src/is/hail/backend/service/Worker.scala b/hail/hail/src/is/hail/backend/service/Worker.scala index 4f596de3b4a..0b3ce24cdc9 100644 --- a/hail/hail/src/is/hail/backend/service/Worker.scala +++ b/hail/hail/src/is/hail/backend/service/Worker.scala @@ -166,38 +166,11 @@ object Worker { timer.end("readInputs") timer.start("executeFunction") - if (HailContext.isInitialized) { - HailContext.get.backend = new ServiceBackend( - null, - null, - new HailClassLoader(getClass().getClassLoader()), - null, - null, - null, - null, - null, - null, - null, - scratchDir, - ) - } else { - HailContext( - // FIXME: workers should not have backends, but some things do need hail contexts - new ServiceBackend( - null, - null, - new HailClassLoader(getClass().getClassLoader()), - null, - null, - null, - null, - null, - null, - null, - scratchDir, - ) - ) - } + + // FIXME: workers should not have backends, but some things do need hail contexts + val backend = new ServiceBackend(null, null, null, null, null) + if (HailContext.isInitialized) HailContext.get.backend = backend + else HailContext(backend) val result = using(new ServiceTaskContext(i)) { htc => try diff --git a/hail/hail/src/is/hail/backend/spark/SparkBackend.scala b/hail/hail/src/is/hail/backend/spark/SparkBackend.scala index 728c8067705..8ea2b4a2b58 100644 --- a/hail/hail/src/is/hail/backend/spark/SparkBackend.scala +++ b/hail/hail/src/is/hail/backend/spark/SparkBackend.scala @@ -1,27 +1,23 @@ package is.hail.backend.spark -import is.hail.{HailContext, HailFeatureFlags} +import is.hail.HailContext import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ -import is.hail.backend.py4j.Py4JBackendExtensions import is.hail.expr.Validate import is.hail.expr.ir._ -import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs.MakeTuple import is.hail.expr.ir.lowering._ import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.fs._ -import is.hail.linalg.BlockMatrix import is.hail.rvd.RVD import is.hail.types._ import is.hail.types.physical.{PStruct, PTuple} import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual._ import is.hail.utils._ -import is.hail.variant.ReferenceGenome import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -33,7 +29,6 @@ import java.io.PrintWriter import com.fasterxml.jackson.core.StreamReadConstraints import org.apache.hadoop -import org.apache.hadoop.conf.Configuration import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -91,7 +86,7 @@ object SparkBackend { private var theSparkBackend: SparkBackend = _ - def sparkContext(op: String): SparkContext = HailContext.sparkBackend(op).sc + def sparkContext(implicit E: Enclosing): SparkContext = HailContext.sparkBackend.sc def checkSparkCompatibility(jarVersion: String, sparkVersion: String): Unit = { def majorMinor(version: String): String = version.split("\\.", 3).take(2).mkString(".") @@ -218,15 +213,11 @@ object SparkBackend { append: Boolean = false, skipLoggingConfiguration: Boolean = false, minBlockSize: Long = 1L, - tmpdir: String = "/tmp", - localTmpdir: String = "file:///tmp", - gcsRequesterPaysProject: String = null, - gcsRequesterPaysBuckets: String = null, ): SparkBackend = synchronized { if (theSparkBackend == null) return SparkBackend(sc, appName, master, local, logFile, quiet, append, skipLoggingConfiguration, - minBlockSize, tmpdir, localTmpdir, gcsRequesterPaysProject, gcsRequesterPaysBuckets) + minBlockSize) // there should be only one SparkContext assert(sc == null || (sc eq theSparkBackend.sc)) @@ -262,10 +253,6 @@ object SparkBackend { append: Boolean = false, skipLoggingConfiguration: Boolean = false, minBlockSize: Long = 1L, - tmpdir: String, - localTmpdir: String, - gcsRequesterPaysProject: String = null, - gcsRequesterPaysBuckets: String = null, ): SparkBackend = synchronized { require(theSparkBackend == null) @@ -280,32 +267,18 @@ object SparkBackend { checkSparkConfiguration(sc1) - if (!quiet) - ProgressBarBuilder.build(sc1) + if (!quiet) ProgressBarBuilder.build(sc1) sc1.uiWebUrl.foreach(ui => info(s"SparkUI: $ui")) - theSparkBackend = - new SparkBackend( - tmpdir, - localTmpdir, - sc1, - mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*), - gcsRequesterPaysProject, - gcsRequesterPaysBuckets, - ) + theSparkBackend = new SparkBackend(sc1) theSparkBackend } def stop(): Unit = synchronized { if (theSparkBackend != null) { - theSparkBackend.sc.stop() + theSparkBackend.close() theSparkBackend = null - // Hadoop does not honor the hadoop configuration as a component of the cache key for file - // systems, so we blow away the cache so that a new configuration can successfully take - // effect. - // https://github.com/hail-is/hail/pull/12133#issuecomment-1241322443 - hadoop.fs.FileSystem.closeAll() } } } @@ -316,105 +289,31 @@ class AnonymousDependency[T](val _rdd: RDD[T]) extends NarrowDependency[T](_rdd) override def getParents(partitionId: Int): Seq[Int] = Seq.empty } -class SparkBackend( - val tmpdir: String, - val localTmpdir: String, - val sc: SparkContext, - override val references: mutable.Map[String, ReferenceGenome], - gcsRequesterPaysProject: String, - gcsRequesterPaysBuckets: String, -) extends Backend with Py4JBackendExtensions { - - assert(gcsRequesterPaysProject != null || gcsRequesterPaysBuckets == null) - lazy val sparkSession: SparkSession = SparkSession.builder().config(sc.getConf).getOrCreate() +class SparkBackend(val sc: SparkContext) extends Backend { - private[this] val theHailClassLoader: HailClassLoader = - new HailClassLoader(getClass().getClassLoader()) + private case class Context( + maxStageParallelism: Int, + override val executionCache: ExecutionCache, + ) extends BackendContext - override def canExecuteParallelTasksOnDriver: Boolean = false - - val fs: HadoopFS = { - val conf = new Configuration(sc.hadoopConfiguration) - if (gcsRequesterPaysProject != null) { - if (gcsRequesterPaysBuckets == null) { - conf.set("fs.gs.requester.pays.mode", "AUTO") - conf.set("fs.gs.requester.pays.project.id", gcsRequesterPaysProject) - } else { - conf.set("fs.gs.requester.pays.mode", "CUSTOM") - conf.set("fs.gs.requester.pays.project.id", gcsRequesterPaysProject) - conf.set("fs.gs.requester.pays.buckets", gcsRequesterPaysBuckets) - } + val sparkSession: Lazy[SparkSession] = + lazily { + SparkSession.builder().config(sc.getConf).getOrCreate() } - new HadoopFS(new SerializableHadoopConfiguration(conf)) - } - - override def backend: Backend = this - override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() - - override val longLifeTempFileManager: TempFileManager = - new OwningTempFileManager(fs) - - private[this] val bmCache = mutable.Map.empty[String, BlockMatrix] - private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) - private[this] val persistedIr = mutable.Map.empty[Int, BaseIR] - private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32) - - def createExecuteContextForTests( - timer: ExecutionTimer, - region: Region, - selfContainedExecution: Boolean = true, - ): ExecuteContext = - new ExecuteContext( - tmpdir, - localTmpdir, - this, - references.toMap, - fs, - region, - timer, - if (selfContainedExecution) null else NonOwningTempFileManager(longLifeTempFileManager), - theHailClassLoader, - flags, - new BackendContext { - override val executionCache: ExecutionCache = - ExecutionCache.forTesting - }, - new IrMetadata(), - ImmutableMap.empty, - ImmutableMap.empty, - ImmutableMap.empty, - ImmutableMap.empty, - ) - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = - ExecutionTimer.logTime { timer => - ExecuteContext.scoped( - tmpdir, - localTmpdir, - this, - references.toMap, - fs, - timer, - null, - theHailClassLoader, - flags, - new BackendContext { - override val executionCache: ExecutionCache = - ExecutionCache.fromFlags(flags, fs, tmpdir) - }, - new IrMetadata(), - bmCache, - codeCache, - persistedIr, - coercerCache, - )(f) - } + override def canExecuteParallelTasksOnDriver: Boolean = false def broadcast[T: ClassTag](value: T): BroadcastValue[T] = new SparkBroadcastValue[T](sc.broadcast(value)) + override def backendContext(ctx: ExecuteContext): BackendContext = + Context( + ctx.flags.get(SparkBackend.Flags.MaxStageParallelism).toInt, + ExecutionCache.fromFlags(ctx.flags, ctx.fs, ctx.tmpdir), + ) + override def parallelizeAndComputeWithIndex( - backendContext: BackendContext, + ctx: BackendContext, fs: FS, contexts: IndexedSeq[Array[Byte]], stageIdentifier: String, @@ -444,13 +343,12 @@ class SparkBackend( } } - val chunkSize = flags.get(SparkBackend.Flags.MaxStageParallelism).toInt val partsToRun = partitions.getOrElse(contexts.indices) val buffer = new ArrayBuffer[(Array[Byte], Int)](partsToRun.length) var failure: Option[Throwable] = None try { - for (subparts <- partsToRun.grouped(chunkSize)) { + for (subparts <- partsToRun.grouped(ctx.asInstanceOf[Context].maxStageParallelism)) { sc.runJob( rdd, (_: TaskContext, it: Iterator[Array[Byte]]) => it.next(), @@ -459,8 +357,11 @@ class SparkBackend( ) } } catch { - case e: ExecutionException => failure = failure.orElse(Some(e.getCause)) case NonFatal(t) => failure = failure.orElse(Some(t)) + case e: ExecutionException => failure = failure.orElse(Some(e.getCause)) + case _: InterruptedException => + sc.cancelAllJobs() + Thread.currentThread().interrupt() } (failure, buffer.sortBy(_._2)) @@ -468,13 +369,18 @@ class SparkBackend( def defaultParallelism: Int = sc.defaultParallelism - override def asSpark(op: String): SparkBackend = this + override def asSpark(implicit E: Enclosing): SparkBackend = this - def close(): Unit = { - bmCache.values.foreach(_.unpersist()) - SparkBackend.stop() - longLifeTempFileManager.close() - } + def close(): Unit = + synchronized { + if (sparkSession.isEvaluated) sparkSession.close() + sc.stop() + // Hadoop does not honor the hadoop configuration as a component of the cache key for file + // systems, so we blow away the cache so that a new configuration can successfully take + // effect. + // https://github.com/hail-is/hail/pull/12133#issuecomment-1241322443 + hadoop.fs.FileSystem.closeAll() + } def startProgressBar(): Unit = ProgressBarBuilder.build(sc) @@ -507,8 +413,7 @@ class SparkBackend( case (false, true) => DArrayLowering.BMOnly case (false, false) => throw new LowererUnsupportedOperation("no lowering enabled") } - val ir = - LoweringPipeline.darrayLowerer(optimize)(typesToLower).apply(ctx, ir0).asInstanceOf[IR] + val ir = LoweringPipeline.darrayLowerer(optimize)(typesToLower)(ctx, ir0).asInstanceOf[IR] if (!Compilable(ir)) throw new LowererUnsupportedOperation(s"lowered to uncompilable IR: ${Pretty(ctx, ir)}") @@ -546,11 +451,11 @@ class SparkBackend( Validate(ir) ctx.irMetadata.semhash = SemanticHash(ctx)(ir) try { - val lowerTable = flags.get("lower") != null - val lowerBM = flags.get("lower_bm") != null + val lowerTable = ctx.flags.get("lower") != null + val lowerBM = ctx.flags.get("lower_bm") != null _jvmLowerAndExecute(ctx, ir, optimize = true, lowerTable, lowerBM) } catch { - case e: LowererUnsupportedOperation if flags.get("lower_only") != null => throw e + case e: LowererUnsupportedOperation if ctx.flags.get("lower_only") != null => throw e case _: LowererUnsupportedOperation => CompileAndEvaluate._apply(ctx, ir, optimize = true) } @@ -563,7 +468,7 @@ class SparkBackend( rt: RTable, nPartitions: Option[Int], ): TableReader = { - if (flags.get("use_new_shuffle") != null) + if (ctx.flags.get("use_new_shuffle") != null) return LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt) val (globals, rvd) = TableStageToRVD(ctx, stage) diff --git a/hail/hail/src/is/hail/expr/ir/Emit.scala b/hail/hail/src/is/hail/expr/ir/Emit.scala index 07cabb71695..3b472e48bb8 100644 --- a/hail/hail/src/is/hail/expr/ir/Emit.scala +++ b/hail/hail/src/is/hail/expr/ir/Emit.scala @@ -3381,7 +3381,7 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { Array[Array[Byte]], ]( "collectDArray", - mb.getObject(ctx.executeContext.backendContext), + mb.getObject(ctx.executeContext.backend.backendContext(ctx.executeContext)), mb.getHailClassLoader, mb.getFS, functionID, diff --git a/hail/hail/src/is/hail/expr/ir/GenericLines.scala b/hail/hail/src/is/hail/expr/ir/GenericLines.scala index 8adb4fe75eb..000348f272f 100644 --- a/hail/hail/src/is/hail/expr/ir/GenericLines.scala +++ b/hail/hail/src/is/hail/expr/ir/GenericLines.scala @@ -449,7 +449,7 @@ class GenericLinesRDDPartition(val index: Int, val context: Any) extends Partiti class GenericLinesRDD( @(transient @param) contexts: IndexedSeq[Any], body: (Any) => CloseableIterator[GenericLine], -) extends RDD[GenericLine](SparkBackend.sparkContext("GenericLinesRDD"), Seq()) { +) extends RDD[GenericLine](SparkBackend.sparkContext, Seq()) { protected def getPartitions: Array[Partition] = contexts.iterator.zipWithIndex.map { case (c, i) => diff --git a/hail/hail/src/is/hail/expr/ir/TableIR.scala b/hail/hail/src/is/hail/expr/ir/TableIR.scala index 43784087b1b..d548ecbf9f9 100644 --- a/hail/hail/src/is/hail/expr/ir/TableIR.scala +++ b/hail/hail/src/is/hail/expr/ir/TableIR.scala @@ -713,7 +713,7 @@ case class PartitionRVDReader(rvd: RVD, uidFieldName: String) extends PartitionR val upcastF = mb.genFieldThisRef[AsmFunction2RegionLongLong]("rvdreader_upcast") val broadcastRVD = - mb.getObject[BroadcastRVD](new BroadcastRVD(ctx.backend.asSpark("RVDReader"), rvd)) + mb.getObject[BroadcastRVD](new BroadcastRVD(ctx.backend.asSpark, rvd)) val producer = new StreamProducer { override def method: EmitMethodBuilder[_] = mb @@ -3198,7 +3198,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { fileStack += filesToMerge filesToMerge = - ContextRDD.weaken(SparkBackend.sparkContext("TableMapRows.execute").parallelize( + ContextRDD.weaken(SparkBackend.sparkContext.parallelize( 0 until nToMerge, nToMerge, )) diff --git a/hail/hail/src/is/hail/expr/ir/TableValue.scala b/hail/hail/src/is/hail/expr/ir/TableValue.scala index 8b5441db48a..a1dd38c72d1 100644 --- a/hail/hail/src/is/hail/expr/ir/TableValue.scala +++ b/hail/hail/src/is/hail/expr/ir/TableValue.scala @@ -188,7 +188,7 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow } def toDF(): DataFrame = - HailContext.sparkBackend("toDF").sparkSession.createDataFrame( + HailContext.sparkBackend.sparkSession.createDataFrame( rvd.toRows, typ.rowType.schema.asInstanceOf[StructType], ) diff --git a/hail/hail/src/is/hail/expr/ir/functions/Functions.scala b/hail/hail/src/is/hail/expr/ir/functions/Functions.scala index b621937d220..c47b795e4c9 100644 --- a/hail/hail/src/is/hail/expr/ir/functions/Functions.scala +++ b/hail/hail/src/is/hail/expr/ir/functions/Functions.scala @@ -22,7 +22,9 @@ import scala.reflect._ import org.apache.spark.sql.Row object IRFunctionRegistry { - private val userAddedFunctions: mutable.Set[(String, (Type, Seq[Type], Seq[Type]))] = + type UserDefinedFnKey = (String, (Type, Seq[Type], Seq[Type])) + + private[this] val userAddedFunctions: mutable.Set[UserDefinedFnKey] = mutable.HashSet.empty def clearUserFunctions(): Unit = { @@ -70,25 +72,41 @@ object IRFunctionRegistry { typeParamStrs: Array[String], argNameStrs: Array[String], argTypeStrs: Array[String], - returnType: String, + returnTypeStr: String, bodyStr: String, - ): Unit = { + ): UserDefinedFnKey = { requireJavaIdentifier(name) - val argNames = argNameStrs.map(Name) val typeParameters = typeParamStrs.map(IRParser.parseType).toFastSeq val valueParameterTypes = argTypeStrs.map(IRParser.parseType).toFastSeq - val refMap = BindingEnv.eval(argNames.zip(valueParameterTypes): _*) - val body = IRParser.parse_value_ir(ctx, bodyStr, refMap) + val argNames = argNameStrs.map(Name) + + val body = + IRParser.parse_value_ir(ctx, bodyStr, BindingEnv.eval(argNames.zip(valueParameterTypes): _*)) + val returnType = IRParser.parseType(returnTypeStr) + assert(body.typ == returnType) - userAddedFunctions += ((name, (body.typ, typeParameters, valueParameterTypes))) + val key: UserDefinedFnKey = (name, (returnType, typeParameters, valueParameterTypes)) + userAddedFunctions += key addIR( name, typeParameters, valueParameterTypes, - IRParser.parseType(returnType), + returnType, false, (_, args, _) => Subst(body, BindingEnv.eval(argNames.zip(args): _*)), ) + key + } + + def unregisterIr(key: UserDefinedFnKey): Unit = { + val (name, (returnType, typeParameterTypes, valueParameterTypes)) = key + if (userAddedFunctions.remove(key)) + removeIRFunction(name, returnType, typeParameterTypes, valueParameterTypes) + else { + throw new NoSuchElementException( + s"No user defined function registered matching: ${prettyFunctionSignature(name, returnType, typeParameterTypes, valueParameterTypes)}" + ) + } } def removeIRFunction( @@ -113,7 +131,9 @@ object IRFunctionRegistry { case Seq() => None case Seq(f) => Some(f) case _ => - fatal(s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(",")}).") + fatal( + s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}." + ) } def lookupFunctionOrFail( @@ -125,28 +145,34 @@ object IRFunctionRegistry { jvmRegistry.lift(name) match { case None => fatal( - s"no functions found with the signature $name(${valueParameterTypes.mkString(", ")}): $returnType" + s"no functions found with the signature ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}." ) case Some(functions) => functions.filter(t => t.unify(typeParameters, valueParameterTypes, returnType) ).toSeq match { case Seq() => - val prettyFunctionSignature = - s"$name[${typeParameters.mkString(", ")}](${valueParameterTypes.mkString(", ")}): $returnType" val prettyMismatchedFunctionSignatures = functions.map(x => s" $x").mkString("\n") fatal( - s"No function found with the signature $prettyFunctionSignature.\n" + + s"No function found with the signature ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}.\n" + s"However, there are other functions with that name:\n$prettyMismatchedFunctionSignatures" ) case Seq(f) => f case _ => fatal( - s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(", ")})." + s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)})." ) } } } + private[this] def prettyFunctionSignature( + name: String, + returnType: Type, + typeParameterTypes: Seq[Type], + valueParameterTypes: Seq[Type], + ): String = + s"$name[${typeParameterTypes.mkString(", ")}](${valueParameterTypes.mkString(", ")}): $returnType" + def lookupIR( name: String, returnType: Type, @@ -166,7 +192,9 @@ object IRFunctionRegistry { case Seq() => None case Seq(kv) => Some(kv) case _ => - fatal(s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(",")}).") + fatal( + s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}." + ) } } diff --git a/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala b/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala index 85d8673cc59..0dc13ec6598 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala @@ -126,9 +126,7 @@ object TableStageToRVD { partition = { ctx: Ref => _ts.partition(GetField(ctx, "context")) }, ) - val sparkContext = ctx.backend - .asSpark("TableStageToRVD") - .sc + val sparkContext = ctx.backend.asSpark.sc val globalsAndBroadcastVals = Let( diff --git a/hail/hail/src/is/hail/io/plink/LoadPlink.scala b/hail/hail/src/is/hail/io/plink/LoadPlink.scala index 4b7b30007b4..ca5572235fe 100644 --- a/hail/hail/src/is/hail/io/plink/LoadPlink.scala +++ b/hail/hail/src/is/hail/io/plink/LoadPlink.scala @@ -20,7 +20,6 @@ import is.hail.variant._ import org.apache.spark.TaskContext import org.apache.spark.sql.Row import org.json4s.{DefaultFormats, Formats, JValue} -import org.json4s.jackson.JsonMethods case class FamFileConfig( isQuantPheno: Boolean = false, @@ -83,14 +82,13 @@ object LoadPlink { isQuantPheno: Boolean, delimiter: String, missingValue: String, - ): String = { + ): JValue = { val ffConfig = FamFileConfig(isQuantPheno, delimiter, missingValue) val (data, ptyp) = LoadPlink.parseFam(fs, path, ffConfig) - val jv = JSONAnnotationImpex.exportAnnotation( + JSONAnnotationImpex.exportAnnotation( Row(ptyp.virtualType.toString, data), TStruct("type" -> TString, "data" -> TArray(ptyp.virtualType)), ) - JsonMethods.compact(jv) } def parseFam(fs: FS, filename: String, ffConfig: FamFileConfig) diff --git a/hail/hail/src/is/hail/io/vcf/LoadVCF.scala b/hail/hail/src/is/hail/io/vcf/LoadVCF.scala index 15ffe4a6ec2..f40ae888796 100644 --- a/hail/hail/src/is/hail/io/vcf/LoadVCF.scala +++ b/hail/hail/src/is/hail/io/vcf/LoadVCF.scala @@ -1685,7 +1685,7 @@ class PartitionedVCFRDD( file: String, @(transient @param) reverseContigMapping: Map[String, String], @(transient @param) _partitions: Array[Partition], -) extends RDD[WithContext[String]](SparkBackend.sparkContext("PartitionedVCFRDD"), Seq()) { +) extends RDD[WithContext[String]](SparkBackend.sparkContext, Seq()) { val contigRemappingBc = if (reverseContigMapping.size != 0) sparkContext.broadcast(reverseContigMapping) else null @@ -1806,7 +1806,7 @@ object MatrixVCFReader { val fsConfigBC = backend.broadcast(fs.getConfiguration()) val (failureOpt, _) = backend.parallelizeAndComputeWithIndex( - ctx.backendContext, + ctx.backend.backendContext(ctx), fs, files.tail.map(_.getBytes), "load_vcf_parse_header", diff --git a/hail/hail/src/is/hail/linalg/BlockMatrix.scala b/hail/hail/src/is/hail/linalg/BlockMatrix.scala index 7cda148db27..f5a6eb1e1b4 100644 --- a/hail/hail/src/is/hail/linalg/BlockMatrix.scala +++ b/hail/hail/src/is/hail/linalg/BlockMatrix.scala @@ -44,7 +44,7 @@ case class CollectMatricesRDDPartition( } class CollectMatricesRDD(@transient var bms: IndexedSeq[BlockMatrix]) - extends RDD[BDM[Double]](SparkBackend.sparkContext("CollectMatricesRDD"), Nil) { + extends RDD[BDM[Double]](SparkBackend.sparkContext, Nil) { private val nBlocks = bms.map(_.blocks.getNumPartitions) private val firstPartition = nBlocks.scan(0)(_ + _).init @@ -114,7 +114,7 @@ object BlockMatrix { def apply(gp: GridPartitioner, piBlock: (GridPartitioner, Int) => ((Int, Int), BDM[Double])) : BlockMatrix = new BlockMatrix( - new RDD[((Int, Int), BDM[Double])](SparkBackend.sparkContext("BlockMatrix.apply"), Nil) { + new RDD[((Int, Int), BDM[Double])](SparkBackend.sparkContext, Nil) { override val partitioner = Some(gp) protected def getPartitions: Array[Partition] = Array.tabulate(gp.numPartitions)(pi => @@ -2199,7 +2199,7 @@ class WriteBlocksRDD( parentPartStarts: Array[Long], entryField: String, gp: GridPartitioner, -) extends RDD[(Int, String)](SparkBackend.sparkContext("WriteBlocksRDD"), Nil) { +) extends RDD[(Int, String)](SparkBackend.sparkContext, Nil) { require(gp.nRows == parentPartStarts.last) @@ -2369,7 +2369,7 @@ class BlockMatrixReadRowBlockedRDD( metadata: BlockMatrixMetadata, maybeMaximumCacheMemoryInBytes: Option[Int], ) extends RDD[RVDContext => Iterator[Long]]( - SparkBackend.sparkContext("BlockMatrixReadRowBlockedRDD"), + SparkBackend.sparkContext, Nil, ) { import BlockMatrixReadRowBlockedRDD._ diff --git a/hail/hail/src/is/hail/linalg/RowMatrix.scala b/hail/hail/src/is/hail/linalg/RowMatrix.scala index 6adbaca171a..969ef2d4916 100644 --- a/hail/hail/src/is/hail/linalg/RowMatrix.scala +++ b/hail/hail/src/is/hail/linalg/RowMatrix.scala @@ -286,7 +286,7 @@ class ReadBlocksAsRowsRDD( partFiles: IndexedSeq[String], partitionCounts: Array[Long], gp: GridPartitioner, -) extends RDD[(Long, Array[Double])](SparkBackend.sparkContext("ReadBlocksAsRowsRDD"), Nil) { +) extends RDD[(Long, Array[Double])](SparkBackend.sparkContext, Nil) { private val partitionStarts = partitionCounts.scanLeft(0L)(_ + _) diff --git a/hail/hail/src/is/hail/methods/MatrixExportEntriesByCol.scala b/hail/hail/src/is/hail/methods/MatrixExportEntriesByCol.scala index 43ade7d1418..5db999bbbba 100644 --- a/hail/hail/src/is/hail/methods/MatrixExportEntriesByCol.scala +++ b/hail/hail/src/is/hail/methods/MatrixExportEntriesByCol.scala @@ -197,7 +197,7 @@ case class MatrixExportEntriesByCol( // clean up temporary files val temps = tempFolders.result() val fsBc = fs.broadcast - SparkBackend.sparkContext("MatrixExportEntriesByCol.execute").parallelize( + SparkBackend.sparkContext.parallelize( temps, (temps.length / 32).max(1), ).foreach(path => fsBc.value.delete(path, recursive = true)) diff --git a/hail/hail/src/is/hail/rvd/RVD.scala b/hail/hail/src/is/hail/rvd/RVD.scala index 9c463915722..2cd97c3df03 100644 --- a/hail/hail/src/is/hail/rvd/RVD.scala +++ b/hail/hail/src/is/hail/rvd/RVD.scala @@ -1377,7 +1377,7 @@ object RVD { ) } - val sc = SparkBackend.sparkContext("writeRowsSplitFiles") + val sc = SparkBackend.sparkContext val localTmpdir = execCtx.localTmpdir val fs = execCtx.fs val fsBc = fs.broadcast diff --git a/hail/hail/src/is/hail/sparkextras/ContextRDD.scala b/hail/hail/src/is/hail/sparkextras/ContextRDD.scala index fe9c4d4e4ac..9ecefc16e13 100644 --- a/hail/hail/src/is/hail/sparkextras/ContextRDD.scala +++ b/hail/hail/src/is/hail/sparkextras/ContextRDD.scala @@ -86,7 +86,7 @@ object ContextRDD { def empty[T: ClassTag](): ContextRDD[T] = new ContextRDD( - SparkBackend.sparkContext("ContextRDD.empty").emptyRDD[RVDContext => Iterator[T]] + SparkBackend.sparkContext.emptyRDD[RVDContext => Iterator[T]] ) def union[T: ClassTag]( @@ -117,7 +117,7 @@ object ContextRDD { filterAndReplace: TextInputFilterAndReplace, ): ContextRDD[WithContext[String]] = ContextRDD.weaken( - SparkBackend.sparkContext("ContxtRDD.textFilesLines").textFilesLines( + SparkBackend.sparkContext.textFilesLines( files, nPartitions, ) @@ -129,12 +129,12 @@ object ContextRDD { weaken(sc.parallelize(data, nPartitions.getOrElse(sc.defaultMinPartitions))).map(x => x) def parallelize[T: ClassTag](data: Seq[T], numSlices: Int): ContextRDD[T] = - weaken(SparkBackend.sparkContext("ContextRDD.parallelize").parallelize(data, numSlices)).map { + weaken(SparkBackend.sparkContext.parallelize(data, numSlices)).map { x => x } def parallelize[T: ClassTag](data: Seq[T]): ContextRDD[T] = - weaken(SparkBackend.sparkContext("ContextRDD.parallelize").parallelize(data)).map(x => x) + weaken(SparkBackend.sparkContext.parallelize(data)).map(x => x) type ElementType[T] = RVDContext => Iterator[T] diff --git a/hail/hail/src/is/hail/sparkextras/IndexReadRDD.scala b/hail/hail/src/is/hail/sparkextras/IndexReadRDD.scala index d7316f3d736..94e00942a2f 100644 --- a/hail/hail/src/is/hail/sparkextras/IndexReadRDD.scala +++ b/hail/hail/src/is/hail/sparkextras/IndexReadRDD.scala @@ -15,7 +15,7 @@ class IndexReadRDD[T: ClassTag]( @transient val partFiles: Array[String], @transient val intervalBounds: Option[Array[Interval]], f: (IndexedFilePartition, TaskContext) => T, -) extends RDD[T](SparkBackend.sparkContext("IndexReadRDD"), Nil) { +) extends RDD[T](SparkBackend.sparkContext, Nil) { def getPartitions: Array[Partition] = Array.tabulate(partFiles.length) { i => IndexedFilePartition(i, partFiles(i), intervalBounds.map(_(i))) diff --git a/hail/hail/src/is/hail/types/virtual/package.scala b/hail/hail/src/is/hail/types/virtual/package.scala new file mode 100644 index 00000000000..032c862bafb --- /dev/null +++ b/hail/hail/src/is/hail/types/virtual/package.scala @@ -0,0 +1,12 @@ +package is.hail.types + +package virtual { + sealed abstract class Kind[T <: VType] extends Product with Serializable + + object Kinds { + case object Value extends Kind[Type] + case object Table extends Kind[TableType] + case object Matrix extends Kind[MatrixType] + case object BlockMatrix extends Kind[BlockMatrixType] + } +} diff --git a/hail/hail/src/is/hail/utils/SpillingCollectIterator.scala b/hail/hail/src/is/hail/utils/SpillingCollectIterator.scala index a8bb6d8babb..cbf9833ed15 100644 --- a/hail/hail/src/is/hail/utils/SpillingCollectIterator.scala +++ b/hail/hail/src/is/hail/utils/SpillingCollectIterator.scala @@ -16,7 +16,7 @@ object SpillingCollectIterator { val nPartitions = rdd.partitions.length val x = new SpillingCollectIterator(localTmpdir, fs, nPartitions, sizeLimit) val ctc = classTag[T] - SparkBackend.sparkContext("SpillingCollectIterator.apply").runJob( + SparkBackend.sparkContext.runJob( rdd, (_, it: Iterator[T]) => it.toArray(ctc), 0 until nPartitions, diff --git a/hail/hail/src/is/hail/utils/package.scala b/hail/hail/src/is/hail/utils/package.scala index 8b6affd9801..5b5d2b33022 100644 --- a/hail/hail/src/is/hail/utils/package.scala +++ b/hail/hail/src/is/hail/utils/package.scala @@ -8,7 +8,7 @@ import scala.collection.{mutable, GenTraversableOnce, TraversableOnce} import scala.collection.generic.CanBuildFrom import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionException -import scala.language.higherKinds +import scala.language.{higherKinds, implicitConversions} import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -90,6 +90,23 @@ package utils { b.result() } } + + class Lazy[A] private[utils] (f: => A) { + private[this] var option: Option[A] = None + + def apply(): A = + synchronized { + option match { + case Some(a) => a + case None => val a = f; option = Some(a); a + } + } + + def isEvaluated: Boolean = + synchronized { + option.isDefined + } + } } package object utils @@ -1027,7 +1044,7 @@ package object utils val buffer = new mutable.ArrayBuffer[(A, Int)](tasks.length) val completer = new ExecutorCompletionService[(A, Int)](executor) - tasks.foreach { case (t, k) => completer.submit(() => t() -> k) } + val futures = tasks.map { case (t, k) => completer.submit(() => t() -> k) } tasks.foreach { _ => try buffer += completer.take().get() catch { @@ -1035,6 +1052,9 @@ package object utils err = accum(err, e.getCause) case NonFatal(ex) => err = accum(err, ex) + case _: InterruptedException => + futures.foreach(_.cancel(true)) + Thread.currentThread().interrupt() } } @@ -1044,6 +1064,12 @@ package object utils def runAllKeepFirstError[A](executor: ExecutorService) : IndexedSeq[(() => A, Int)] => (Option[Throwable], IndexedSeq[(A, Int)]) = runAll[Option, A](executor) { case (opt, e) => opt.orElse(Some(e)) }(None) + + def lazily[A](f: => A): Lazy[A] = + new Lazy(f) + + implicit def evalLazy[A](f: Lazy[A]): A = + f() } class CancellingExecutorService(delegate: ExecutorService) extends AbstractExecutorService { diff --git a/hail/hail/src/is/hail/variant/ReferenceGenome.scala b/hail/hail/src/is/hail/variant/ReferenceGenome.scala index 88e07e04073..78b82265ac7 100644 --- a/hail/hail/src/is/hail/variant/ReferenceGenome.scala +++ b/hail/hail/src/is/hail/variant/ReferenceGenome.scala @@ -650,7 +650,7 @@ object ReferenceGenome { def addFatalOnCollision( existing: mutable.Map[String, ReferenceGenome], - newReferences: IndexedSeq[ReferenceGenome], + newReferences: Iterable[ReferenceGenome], ): Unit = for (rg <- newReferences) { if (existing.get(rg.name).exists(_ != rg)) @@ -662,4 +662,14 @@ object ReferenceGenome { existing += (rg.name -> rg) } + + def addFatalOnCollision( + existing: Map[String, ReferenceGenome], + newReferences: Iterable[ReferenceGenome], + ): Map[String, ReferenceGenome] = { + val mut = mutable.Map(existing.toSeq: _*) + addFatalOnCollision(mut, newReferences) + mut.toMap + } + } diff --git a/hail/hail/test/src/is/hail/HailSuite.scala b/hail/hail/test/src/is/hail/HailSuite.scala index e986ee5dc6d..9b7ba1cbec1 100644 --- a/hail/hail/test/src/is/hail/HailSuite.scala +++ b/hail/hail/test/src/is/hail/HailSuite.scala @@ -1,9 +1,9 @@ package is.hail import is.hail.ExecStrategy.ExecStrategy -import is.hail.TestUtils._ import is.hail.annotations._ -import is.hail.backend.{BroadcastValue, ExecuteContext} +import is.hail.asm4s.HailClassLoader +import is.hail.backend.{Backend, ExecuteContext, OwningTempFileManager} import is.hail.backend.spark.SparkBackend import is.hail.expr.ir.{ bindIR, freshName, rangeIR, BindingEnv, BlockMatrixIR, Compilable, Env, Forall, IR, Interpret, @@ -13,22 +13,33 @@ import is.hail.expr.ir.defs.{ BlockMatrixCollect, Cast, MakeTuple, NDArrayRef, Ref, StreamMap, ToArray, } import is.hail.io.fs.FS +import is.hail.expr.ir.{TestUtils => _, _} +import is.hail.expr.ir.lowering.IrMetadata +import is.hail.io.fs.{FS, HadoopFS} import is.hail.types.virtual._ import is.hail.utils._ +import is.hail.variant.ReferenceGenome import java.io.{File, PrintWriter} import breeze.linalg.DenseMatrix +import org.apache.hadoop +import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext import org.apache.spark.sql.Row import org.scalatestplus.testng.TestNGSuite import org.testng.ITestContext -import org.testng.annotations.{AfterMethod, BeforeClass, BeforeMethod} +import org.testng.annotations.{AfterClass, AfterMethod, BeforeClass, BeforeMethod} object HailSuite { - val theHailClassLoader = TestUtils.theHailClassLoader - def withSparkBackend(): HailContext = { + val theHailClassLoader: HailClassLoader = + new HailClassLoader(getClass.getClassLoader) + + val flags: HailFeatureFlags = + HailFeatureFlags.fromEnv(sys.env + ("lower" -> "1")) + + lazy val hc: HailContext = { HailContext.configureLogging("/tmp/hail.log", quiet = false, append = false) val backend = SparkBackend( sc = new SparkContext( @@ -40,63 +51,71 @@ object HailSuite { ) .set("spark.unsafe.exceptionOnMemoryLeak", "true") ), - tmpdir = "/tmp", - localTmpdir = "file:///tmp", skipLoggingConfiguration = true, ) - HailContext(backend) - } - - lazy val hc: HailContext = { - val hc = withSparkBackend() - hc.sparkBackend("HailSuite.hc").flags.set("lower", "1") + val hc = HailContext(backend) hc.checkRVDKeys = true hc } } -class HailSuite extends TestNGSuite { - val theHailClassLoader = HailSuite.theHailClassLoader +class HailSuite extends TestNGSuite with TestUtils { def hc: HailContext = HailSuite.hc - @BeforeClass def ensureHailContextInitialized(): Unit = hc - - def backend: SparkBackend = hc.sparkBackend("HailSuite.backend") - - def sc: SparkContext = backend.sc - - def fs: FS = backend.fs - - def fsBc: BroadcastValue[FS] = fs.broadcast - - var timer: ExecutionTimer = _ + @BeforeClass + def initFs(): Unit = { + val conf = new Configuration(sc.hadoopConfiguration) + fs = new HadoopFS(new SerializableHadoopConfiguration(conf)) + } - var ctx: ExecuteContext = _ + @AfterClass + def closeFS(): Unit = + hadoop.fs.FileSystem.closeAll() + var fs: FS = _ var pool: RegionPool = _ + private[this] var ctx_ : ExecuteContext = _ + + def backend: Backend = hc.backend + def sc: SparkContext = hc.backend.asSpark.sc + def timer: ExecutionTimer = ctx.timer + def theHailClassLoader: HailClassLoader = ctx.theHailClassLoader + override def ctx: ExecuteContext = ctx_ def getTestResource(localPath: String): String = s"hail/test/resources/$localPath" @BeforeMethod def setupContext(context: ITestContext): Unit = { - assert(timer == null) - timer = new ExecutionTimer("HailSuite") - assert(ctx == null) pool = RegionPool() - ctx = backend.createExecuteContextForTests(timer, Region(pool = pool)) + ctx_ = new ExecuteContext( + tmpdir = "/tmp", + localTmpdir = "file:///tmp", + backend = hc.backend, + references = ReferenceGenome.builtinReferences(), + fs = fs, + r = Region(pool = pool), + timer = new ExecutionTimer(context.getName), + tempFileManager = new OwningTempFileManager(fs), + theHailClassLoader = HailSuite.theHailClassLoader, + flags = HailSuite.flags, + irMetadata = new IrMetadata(), + BlockMatrixCache = ImmutableMap.empty, + CodeCache = ImmutableMap.empty, + IrCache = ImmutableMap.empty, + CoercerCache = ImmutableMap.empty, + ) } @AfterMethod def tearDownContext(context: ITestContext): Unit = { - ctx.close() - ctx = null - timer.finish() - timer = null + ctx_.timer.finish() + ctx_.close() + ctx_ = null pool.close() - if (backend.sc.isStopped) - throw new RuntimeException(s"method stopped spark context!") + if (sc.isStopped) + throw new RuntimeException(s"'${context.getName}' stopped spark context!") } def assertEvalsTo( @@ -113,73 +132,71 @@ class HailSuite extends TestNGSuite { val t = x.typ assert(t == TVoid || t.typeCheck(expected), s"$t, $expected") - ExecuteContext.scoped { ctx => - val filteredExecStrats: Set[ExecStrategy] = - if (HailContext.backend.isInstanceOf[SparkBackend]) - execStrats - else { - info("skipping interpret and non-lowering compile steps on non-spark backend") - execStrats.intersect(ExecStrategy.backendOnly) - } + val filteredExecStrats: Set[ExecStrategy] = + if (HailContext.backend.isInstanceOf[SparkBackend]) + execStrats + else { + info("skipping interpret and non-lowering compile steps on non-spark backend") + execStrats.intersect(ExecStrategy.backendOnly) + } - filteredExecStrats.foreach { strat => - try { - val res = strat match { - case ExecStrategy.Interpret => - assert(agg.isEmpty) - Interpret[Any](ctx, x, env, args) - case ExecStrategy.InterpretUnoptimized => - assert(agg.isEmpty) - Interpret[Any](ctx, x, env, args, optimize = false) - case ExecStrategy.JvmCompile => - assert(Forall(x, node => Compilable(node))) - eval( - x, - env, - args, - agg, - bytecodePrinter = - Option(ctx.getFlag("jvm_bytecode_dump")) - .map { path => - val pw = new PrintWriter(new File(path)) - pw.print(s"/* JVM bytecode dump for IR:\n${Pretty(ctx, x)}\n */\n\n") - pw - }, - true, - ctx, - ) - case ExecStrategy.JvmCompileUnoptimized => - assert(Forall(x, node => Compilable(node))) - eval( - x, - env, - args, - agg, - bytecodePrinter = - Option(ctx.getFlag("jvm_bytecode_dump")) - .map { path => - val pw = new PrintWriter(new File(path)) - pw.print(s"/* JVM bytecode dump for IR:\n${Pretty(ctx, x)}\n */\n\n") - pw - }, - optimize = false, - ctx, - ) - case ExecStrategy.LoweredJVMCompile => - loweredExecute(ctx, x, env, args, agg) - } - if (t != TVoid) { - assert(t.typeCheck(res), s"\n t=$t\n result=$res\n strategy=$strat") - assert( - t.valuesSimilar(res, expected), - s"\n result=$res\n expect=$expected\n strategy=$strat)", + filteredExecStrats.foreach { strat => + try { + val res = strat match { + case ExecStrategy.Interpret => + assert(agg.isEmpty) + Interpret[Any](ctx, x, env, args) + case ExecStrategy.InterpretUnoptimized => + assert(agg.isEmpty) + Interpret[Any](ctx, x, env, args, optimize = false) + case ExecStrategy.JvmCompile => + assert(Forall(x, node => Compilable(node))) + eval( + x, + env, + args, + agg, + bytecodePrinter = + Option(ctx.getFlag("jvm_bytecode_dump")) + .map { path => + val pw = new PrintWriter(new File(path)) + pw.print(s"/* JVM bytecode dump for IR:\n${Pretty(ctx, x)}\n */\n\n") + pw + }, + true, + ctx, + ) + case ExecStrategy.JvmCompileUnoptimized => + assert(Forall(x, node => Compilable(node))) + eval( + x, + env, + args, + agg, + bytecodePrinter = + Option(ctx.getFlag("jvm_bytecode_dump")) + .map { path => + val pw = new PrintWriter(new File(path)) + pw.print(s"/* JVM bytecode dump for IR:\n${Pretty(ctx, x)}\n */\n\n") + pw + }, + optimize = false, + ctx, ) - } - } catch { - case e: Exception => - error(s"error from strategy $strat") - if (execStrats.contains(strat)) throw e + case ExecStrategy.LoweredJVMCompile => + loweredExecute(ctx, x, env, args, agg) + } + if (t != TVoid) { + assert(t.typeCheck(res), s"\n t=$t\n result=$res\n strategy=$strat") + assert( + t.valuesSimilar(res, expected), + s"\n result=$res\n expect=$expected\n strategy=$strat)", + ) } + } catch { + case e: Exception => + error(s"error from strategy $strat") + if (execStrats.contains(strat)) throw e } } } @@ -258,35 +275,33 @@ class HailSuite extends TestNGSuite { expected: DenseMatrix[Double], )(implicit execStrats: Set[ExecStrategy] ): Unit = { - ExecuteContext.scoped { ctx => - val filteredExecStrats: Set[ExecStrategy] = - if (HailContext.backend.isInstanceOf[SparkBackend]) execStrats - else { - info("skipping interpret and non-lowering compile steps on non-spark backend") - execStrats.intersect(ExecStrategy.backendOnly) - } - filteredExecStrats.filter(ExecStrategy.interpretOnly).foreach { strat => - try { - val res = strat match { - case ExecStrategy.Interpret => - Interpret(bm, ctx, optimize = true) - case ExecStrategy.InterpretUnoptimized => - Interpret(bm, ctx, optimize = false) - } - assert(res.toBreezeMatrix() == expected) - } catch { - case e: Exception => - error(s"error from strategy $strat") - if (execStrats.contains(strat)) throw e + val filteredExecStrats: Set[ExecStrategy] = + if (HailContext.backend.isInstanceOf[SparkBackend]) execStrats + else { + info("skipping interpret and non-lowering compile steps on non-spark backend") + execStrats.intersect(ExecStrategy.backendOnly) + } + filteredExecStrats.filter(ExecStrategy.interpretOnly).foreach { strat => + try { + val res = strat match { + case ExecStrategy.Interpret => + Interpret(bm, ctx, optimize = true) + case ExecStrategy.InterpretUnoptimized => + Interpret(bm, ctx, optimize = false) } + assert(res.toBreezeMatrix() == expected) + } catch { + case e: Exception => + error(s"error from strategy $strat") + if (execStrats.contains(strat)) throw e } - val expectedArray = Array.tabulate(expected.rows)(i => - Array.tabulate(expected.cols)(j => expected(i, j)).toFastSeq - ).toFastSeq - assertNDEvals(BlockMatrixCollect(bm), expectedArray)( - filteredExecStrats.filterNot(ExecStrategy.interpretOnly) - ) } + val expectedArray = Array.tabulate(expected.rows)(i => + Array.tabulate(expected.cols)(j => expected(i, j)).toFastSeq + ).toFastSeq + assertNDEvals(BlockMatrixCollect(bm), expectedArray)( + filteredExecStrats.filterNot(ExecStrategy.interpretOnly) + ) } def assertAllEvalTo( diff --git a/hail/hail/test/src/is/hail/TestUtils.scala b/hail/hail/test/src/is/hail/TestUtils.scala index da061cc8dc6..cd8965d514a 100644 --- a/hail/hail/test/src/is/hail/TestUtils.scala +++ b/hail/hail/test/src/is/hail/TestUtils.scala @@ -45,8 +45,9 @@ object ExecStrategy extends Enumeration { val allRelational: Set[ExecStrategy] = interpretOnly.union(lowering) } -object TestUtils { - val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) +trait TestUtils { + + def ctx: ExecuteContext = ??? import org.scalatest.Assertions._ @@ -99,7 +100,7 @@ object TestUtils { def removeConstantCols(A: DenseMatrix[Int]): DenseMatrix[Int] = { val data = (0 until A.cols).flatMap { j => val col = A(::, j) - if (TestUtils.isConstant(col)) + if (isConstant(col)) Array[Int]() else col.toArray @@ -123,23 +124,19 @@ object TestUtils { if (agg.isDefined || !env.isEmpty || !args.isEmpty) throw new LowererUnsupportedOperation("can't test with aggs or user defined args/env") - HailContext.sparkBackend("TestUtils.loweredExecute") - .jvmLowerAndExecute(ctx, x, optimize = false, lowerTable = true, lowerBM = true, - print = bytecodePrinter) - } - - def eval(x: IR): Any = ExecuteContext.scoped { ctx => - eval(x, Env.empty, FastSeq(), None, None, true, ctx) + HailContext.sparkBackend.jvmLowerAndExecute(ctx, x, optimize = false, lowerTable = true, + lowerBM = true, + print = bytecodePrinter) } def eval( x: IR, - env: Env[(Any, Type)], - args: IndexedSeq[(Any, Type)], - agg: Option[(IndexedSeq[Row], TStruct)], + env: Env[(Any, Type)] = Env.empty, + args: IndexedSeq[(Any, Type)] = FastSeq(), + agg: Option[(IndexedSeq[Row], TStruct)] = None, bytecodePrinter: Option[PrintWriter] = None, optimize: Boolean = true, - ctx: ExecuteContext, + ctx: ExecuteContext = ctx, ): Any = { val inputTypesB = new BoxedArrayBuilder[Type]() val inputsB = new mutable.ArrayBuffer[Any]() @@ -211,26 +208,29 @@ object TestUtils { assert(resultType2.virtualType == resultType) ctx.r.pool.scopedRegion { region => - val rvb = new RegionValueBuilder(ctx.stateManager, region) - rvb.start(argsPType) - rvb.startTuple() - var i = 0 - while (i < inputsB.length) { - rvb.addAnnotation(inputTypesB(i), inputsB(i)) - i += 1 + ctx.local(r = region) { ctx => + val rvb = new RegionValueBuilder(ctx.stateManager, ctx.r) + rvb.start(argsPType) + rvb.startTuple() + var i = 0 + while (i < inputsB.length) { + rvb.addAnnotation(inputTypesB(i), inputsB(i)) + i += 1 + } + rvb.endTuple() + val argsOff = rvb.end() + + rvb.start(aggArrayPType) + rvb.startArray(aggElements.length) + aggElements.foreach(r => rvb.addAnnotation(aggType, r)) + rvb.endArray() + val aggOff = rvb.end() + + ctx.scopedExecution { (hcl, fs, tc, r) => + val off = f(hcl, fs, tc, r)(r, argsOff, aggOff) + SafeRow(resultType2.asInstanceOf[PBaseStruct], off).get(0) + } } - rvb.endTuple() - val argsOff = rvb.end() - - rvb.start(aggArrayPType) - rvb.startArray(aggElements.length) - aggElements.foreach(r => rvb.addAnnotation(aggType, r)) - rvb.endArray() - val aggOff = rvb.end() - - val resultOff = - f(theHailClassLoader, ctx.fs, ctx.taskContext, region)(region, argsOff, aggOff) - SafeRow(resultType2.asInstanceOf[PBaseStruct], resultOff).get(0) } case None => @@ -250,19 +250,22 @@ object TestUtils { assert(resultType2.virtualType == resultType) ctx.r.pool.scopedRegion { region => - val rvb = new RegionValueBuilder(ctx.stateManager, region) - rvb.start(argsPType) - rvb.startTuple() - var i = 0 - while (i < inputsB.length) { - rvb.addAnnotation(inputTypesB(i), inputsB(i)) - i += 1 + ctx.local(r = region) { ctx => + val rvb = new RegionValueBuilder(ctx.stateManager, ctx.r) + rvb.start(argsPType) + rvb.startTuple() + var i = 0 + while (i < inputsB.length) { + rvb.addAnnotation(inputTypesB(i), inputsB(i)) + i += 1 + } + rvb.endTuple() + val argsOff = rvb.end() + ctx.scopedExecution { (hcl, fs, tc, r) => + val resultOff = f(hcl, fs, tc, r)(r, argsOff) + SafeRow(resultType2.asInstanceOf[PBaseStruct], resultOff).get(0) + } } - rvb.endTuple() - val argsOff = rvb.end() - - val resultOff = f(theHailClassLoader, ctx.fs, ctx.taskContext, region)(region, argsOff) - SafeRow(resultType2.asInstanceOf[PBaseStruct], resultOff).get(0) } } } @@ -276,7 +279,7 @@ object TestUtils { def assertEvalSame(x: IR, env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)]): Unit = { val t = x.typ - val (i, i2, c) = ExecuteContext.scoped { ctx => + val (i, i2, c) = { val i = Interpret[Any](ctx, x, env, args) val i2 = Interpret[Any](ctx, x, env, args, optimize = false) val c = eval(x, env, args, None, None, true, ctx) @@ -299,12 +302,11 @@ object TestUtils { env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)], regex: String, - ): Unit = - ExecuteContext.scoped { ctx => - interceptException[E](regex)(Interpret[Any](ctx, x, env, args)) - interceptException[E](regex)(Interpret[Any](ctx, x, env, args, optimize = false)) - interceptException[E](regex)(eval(x, env, args, None, None, true, ctx)) - } + ): Unit = { + interceptException[E](regex)(Interpret[Any](ctx, x, env, args)) + interceptException[E](regex)(Interpret[Any](ctx, x, env, args, optimize = false)) + interceptException[E](regex)(eval(x, env, args, None, None, true, ctx)) + } def assertFatal(x: IR, regex: String): Unit = assertThrows[HailException](x, regex) @@ -322,9 +324,7 @@ object TestUtils { args: IndexedSeq[(Any, Type)], regex: String, ): Unit = - ExecuteContext.scoped { ctx => - interceptException[E](regex)(eval(x, env, args, None, None, true, ctx)) - } + interceptException[E](regex)(eval(x, env, args, None, None, true, ctx)) def assertCompiledThrows[E <: Throwable: Manifest](x: IR, regex: String): Unit = assertCompiledThrows[E](x, Env.empty[(Any, Type)], FastSeq.empty[(Any, Type)], regex) diff --git a/hail/hail/test/src/is/hail/TestUtilsSuite.scala b/hail/hail/test/src/is/hail/TestUtilsSuite.scala index a6c91c5aef5..2c2b0c08406 100644 --- a/hail/hail/test/src/is/hail/TestUtilsSuite.scala +++ b/hail/hail/test/src/is/hail/TestUtilsSuite.scala @@ -11,21 +11,21 @@ class TestUtilsSuite extends HailSuite { val V = DenseVector(0d, 1d) val V1 = DenseVector(0d, 0.5d) - TestUtils.assertMatrixEqualityDouble(M, DenseMatrix.eye(2)) - TestUtils.assertMatrixEqualityDouble(M, M1, 0.001) - TestUtils.assertVectorEqualityDouble(V, 2d * V1) + assertMatrixEqualityDouble(M, DenseMatrix.eye(2)) + assertMatrixEqualityDouble(M, M1, 0.001) + assertVectorEqualityDouble(V, 2d * V1) - intercept[Exception](TestUtils.assertVectorEqualityDouble(V, V1)) - intercept[Exception](TestUtils.assertMatrixEqualityDouble(M, M1)) + intercept[Exception](assertVectorEqualityDouble(V, V1)) + intercept[Exception](assertMatrixEqualityDouble(M, M1)) } @Test def constantVectorTest(): Unit = { - assert(TestUtils.isConstant(DenseVector())) - assert(TestUtils.isConstant(DenseVector(0))) - assert(TestUtils.isConstant(DenseVector(0, 0))) - assert(TestUtils.isConstant(DenseVector(0, 0, 0))) - assert(!TestUtils.isConstant(DenseVector(0, 1))) - assert(!TestUtils.isConstant(DenseVector(0, 0, 1))) + assert(isConstant(DenseVector())) + assert(isConstant(DenseVector(0))) + assert(isConstant(DenseVector(0, 0))) + assert(isConstant(DenseVector(0, 0, 0))) + assert(!isConstant(DenseVector(0, 1))) + assert(!isConstant(DenseVector(0, 0, 1))) } @Test def removeConstantColsTest(): Unit = { @@ -33,6 +33,6 @@ class TestUtilsSuite extends HailSuite { val M1 = DenseMatrix((0, 1, 0), (1, 0, 1)) - assert(TestUtils.removeConstantCols(M) == M1) + assert(removeConstantCols(M) == M1) } } diff --git a/hail/hail/test/src/is/hail/annotations/UnsafeSuite.scala b/hail/hail/test/src/is/hail/annotations/UnsafeSuite.scala index dd5d5b4e1e6..8c14e2bb7e7 100644 --- a/hail/hail/test/src/is/hail/annotations/UnsafeSuite.scala +++ b/hail/hail/test/src/is/hail/annotations/UnsafeSuite.scala @@ -54,7 +54,7 @@ class UnsafeSuite extends HailSuite { @DataProvider(name = "codecs") def codecs(): Array[Array[Any]] = - ExecuteContext.scoped(ctx => codecs(ctx)) + codecs(ctx) def codecs(ctx: ExecuteContext): Array[Array[Any]] = (BufferSpec.specs ++ Array(TypedCodecSpec( diff --git a/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala b/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala index 461eed812f4..9b9c2a2c67e 100644 --- a/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala +++ b/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala @@ -1,15 +1,13 @@ package is.hail.backend -import is.hail.HailFeatureFlags -import is.hail.asm4s.HailClassLoader -import is.hail.backend.service.{ServiceBackend, ServiceBackendContext, ServiceBackendRPCPayload} -import is.hail.io.fs.{CloudStorageFSConfig, RouterFS} +import is.hail.HailSuite +import is.hail.backend.api.BatchJobConfig +import is.hail.backend.service.ServiceBackend import is.hail.services._ import is.hail.services.JobGroupStates.Success import is.hail.utils.{tokenUrlSafe, using} -import scala.collection.mutable -import scala.reflect.io.{Directory, Path} +import scala.reflect.io.Directory import scala.util.Random import java.io.Closeable @@ -19,137 +17,105 @@ import org.mockito.IdiomaticMockito import org.mockito.MockitoSugar.when import org.scalatest.OptionValues import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper -import org.scalatestplus.testng.TestNGSuite import org.testng.annotations.Test -class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionValues { +class ServiceBackendSuite extends HailSuite with IdiomaticMockito with OptionValues { @Test def testCreateJobPayload(): Unit = - withMockDriverContext { rpcConfig => + withObjectSpied[is.hail.utils.UtilsType] { + // not obvious how to pull out `tokenUrlSafe` and inject this directory + // using a spy is a hack and i don't particularly like it. + when(is.hail.utils.tokenUrlSafe) thenAnswer "TOKEN" + + val jobConfig = BatchJobConfig( + token = tokenUrlSafe, + billing_project = "fancy", + worker_cores = "128", + worker_memory = "a lot.", + storage = "a big ssd?", + cloudfuse_configs = Array(), + regions = Array("lunar1"), + ) + val batchClient = mock[BatchClient] - using(ServiceBackend(batchClient, rpcConfig)) { backend => - val contexts = Array.tabulate(1)(_.toString.getBytes) - - // verify that - // - the number of jobs matches the number of partitions, and - // - each job is created in the specified region, and - // - each job's resource configuration matches the rpc config - - when(batchClient.newJobGroup(any[JobGroupRequest])) thenAnswer { - jobGroup: JobGroupRequest => - jobGroup.batch_id shouldBe backend.batchConfig.batchId - jobGroup.absolute_parent_id shouldBe backend.batchConfig.jobGroupId - val jobs = jobGroup.jobs - jobs.length shouldEqual contexts.length - jobs.foreach { payload => - payload.regions.value shouldBe rpcConfig.regions - payload.resources.value shouldBe JobResources( - preemptible = true, - cpu = Some(rpcConfig.worker_cores), - memory = Some(rpcConfig.worker_memory), - storage = Some(rpcConfig.storage), + using(ServiceBackend(batchClient, jobConfig)) { backend => + using(LocalTmpFolder) { tmp => + val contexts = Array.tabulate(1)(_.toString.getBytes) + + // verify that + // - the number of jobs matches the number of partitions, and + // - each job is created in the specified region, and + // - each job's resource configuration matches the rpc config + + when(batchClient.newJobGroup(any[JobGroupRequest])) thenAnswer { + jobGroup: JobGroupRequest => + jobGroup.batch_id shouldBe backend.batchConfig.batchId + jobGroup.absolute_parent_id shouldBe backend.batchConfig.jobGroupId + val jobs = jobGroup.jobs + jobs.length shouldEqual contexts.length + jobs.foreach { payload => + payload.regions.value shouldBe jobConfig.regions + payload.resources.value shouldBe JobResources( + preemptible = true, + cpu = Some(jobConfig.worker_cores), + memory = Some(jobConfig.worker_memory), + storage = Some(jobConfig.storage), + ) + } + + backend.batchConfig.jobGroupId + 1 + } + + // the service backend expects that each job write its output to a well-known + // location when it finishes. + when(batchClient.waitForJobGroup(any[Int], any[Int])) thenAnswer { + (id: Int, jobGroupId: Int) => + id shouldEqual backend.batchConfig.batchId + jobGroupId shouldEqual backend.batchConfig.jobGroupId + 1 + + val resultsDir = tmp / "parallelizeAndComputeWithIndex" / tokenUrlSafe + + resultsDir.createDirectory() + for (i <- contexts.indices) (resultsDir / f"result.$i").toFile.writeAll("11") + JobGroupResponse( + batch_id = id, + job_group_id = jobGroupId, + state = Success, + complete = true, + n_jobs = contexts.length, + n_completed = contexts.length, + n_succeeded = contexts.length, + n_failed = 0, + n_cancelled = 0, ) - } - - backend.batchConfig.jobGroupId + 1 - } + } + + ctx.local(tmpdir = tmp.toString()) { ctx => + val (failure, _) = + backend.parallelizeAndComputeWithIndex( + backend.backendContext(ctx), + ctx.fs, + contexts, + "stage1", + )((bytes, _, _, _) => bytes) + + failure.foreach(throw _) + } + batchClient.newJobGroup(any) wasCalled once + batchClient.waitForJobGroup(any, any) wasCalled once - // the service backend expects that each job write its output to a well-known - // location when it finishes. - when(batchClient.waitForJobGroup(any[Int], any[Int])) thenAnswer { - (id: Int, jobGroupId: Int) => - id shouldEqual backend.batchConfig.batchId - jobGroupId shouldEqual backend.batchConfig.jobGroupId + 1 - - val resultsDir = - Path(backend.serviceBackendContext.remoteTmpDir) / - "parallelizeAndComputeWithIndex" / - tokenUrlSafe - - resultsDir.createDirectory() - for (i <- contexts.indices) (resultsDir / f"result.$i").toFile.writeAll("11") - JobGroupResponse( - batch_id = id, - job_group_id = jobGroupId, - state = Success, - complete = true, - n_jobs = contexts.length, - n_completed = contexts.length, - n_succeeded = contexts.length, - n_failed = 0, - n_cancelled = 0, - ) } - - val (failure, _) = - backend.parallelizeAndComputeWithIndex( - backend.serviceBackendContext, - backend.fs, - contexts, - "stage1", - )((bytes, _, _, _) => bytes) - - failure.foreach(throw _) - - batchClient.newJobGroup(any) wasCalled once - batchClient.waitForJobGroup(any, any) wasCalled once } } - def ServiceBackend(client: BatchClient, rpcConfig: ServiceBackendRPCPayload): ServiceBackend = { - val flags = HailFeatureFlags.fromEnv() - val fs = RouterFS.buildRoutes(CloudStorageFSConfig()) + def ServiceBackend(client: BatchClient, jobConfig: BatchJobConfig): ServiceBackend = new ServiceBackend( - jarSpec = GitRevision("123"), name = "name", - theHailClassLoader = new HailClassLoader(getClass.getClassLoader), - references = mutable.Map.empty, batchClient = client, + jarSpec = GitRevision("123"), batchConfig = BatchConfig(batchId = Random.nextInt(), jobGroupId = Random.nextInt()), - flags = flags, - tmpdir = rpcConfig.tmp_dir, - fs = fs, - serviceBackendContext = - new ServiceBackendContext( - rpcConfig.billing_project, - rpcConfig.remote_tmpdir, - rpcConfig.worker_cores, - rpcConfig.worker_memory, - rpcConfig.storage, - rpcConfig.regions, - rpcConfig.cloudfuse_configs, - profile = false, - ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir), - ), - scratchDir = rpcConfig.remote_tmpdir, + jobConfig = jobConfig, ) - } - - def withMockDriverContext(test: ServiceBackendRPCPayload => Any): Any = - using(LocalTmpFolder) { tmp => - withObjectSpied[is.hail.utils.UtilsType] { - // not obvious how to pull out `tokenUrlSafe` and inject this directory - // using a spy is a hack and i don't particularly like it. - when(is.hail.utils.tokenUrlSafe) thenAnswer "TOKEN" - - test { - ServiceBackendRPCPayload( - tmp_dir = tmp.path, - remote_tmpdir = tmp.path, - billing_project = "fancy", - worker_cores = "128", - worker_memory = "a lot.", - storage = "a big ssd?", - cloudfuse_configs = Array(), - regions = Array("lunar1"), - flags = Map(), - custom_references = Array(), - liftovers = Map(), - sequences = Map(), - ) - } - } - } def LocalTmpFolder: Directory with Closeable = new Directory(Directory.makeTemp("hail-testing-tmp").jfile) with Closeable { diff --git a/hail/hail/test/src/is/hail/check/Gen.scala b/hail/hail/test/src/is/hail/check/Gen.scala index be6a9f8eb94..78ab722918f 100644 --- a/hail/hail/test/src/is/hail/check/Gen.scala +++ b/hail/hail/test/src/is/hail/check/Gen.scala @@ -553,10 +553,10 @@ class Gen[+T](val apply: Parameters => T) extends AnyVal { def flatMap[U](f: T => Gen[U]): Gen[U] = Gen(p => f(apply(p))(p)) def resize(newSize: Int): Gen[T] = Gen((p: Parameters) => apply(p.copy(size = newSize))) - + def withFilter(f: T => Boolean): Gen[T] = Gen((p: Parameters) => Stream.continually(apply(p)).takeWhile(f).head) - def filter(f: T => Boolean): Gen[T] = + def filter(f: T => Boolean): Gen[T] = withFilter(f) } diff --git a/hail/hail/test/src/is/hail/expr/ir/ArrayFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/ArrayFunctionsSuite.scala index 4273f6e2dd6..e74fae67136 100644 --- a/hail/hail/test/src/is/hail/expr/ir/ArrayFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/ArrayFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.defs.{ArraySlice, F32, F64, I32, In, MakeArray, NA, Str} import is.hail.types.virtual._ diff --git a/hail/hail/test/src/is/hail/expr/ir/CallFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/CallFunctionsSuite.scala index e63415f6ddf..bd0dae68f6a 100644 --- a/hail/hail/test/src/is/hail/expr/ir/CallFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/CallFunctionsSuite.scala @@ -3,6 +3,7 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} import is.hail.expr.ir.TestUtils.IRCall import is.hail.expr.ir.defs.{False, I32, Str, True} +import is.hail.expr.ir.TestUtils.{IRArray, IRCall} import is.hail.types.virtual.{TArray, TBoolean, TCall, TInt32} import is.hail.variant._ @@ -61,7 +62,7 @@ class CallFunctionsSuite extends HailSuite { assertEvalsTo(invoke("Call", TCall, I32(1), False()), Call1(1, false)) assertEvalsTo(invoke("Call", TCall, I32(0), I32(0), False()), Call2(0, 0, false)) assertEvalsTo( - invoke("Call", TCall, TestUtils.IRArray(0, 1), False()), + invoke("Call", TCall, IRArray(0, 1), False()), CallN(Array(0, 1), false), ) assertEvalsTo(invoke("Call", TCall, Str("0|1")), Call2(0, 1, true)) diff --git a/hail/hail/test/src/is/hail/expr/ir/DictFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/DictFunctionsSuite.scala index b211701f942..b780c67cc7d 100644 --- a/hail/hail/test/src/is/hail/expr/ir/DictFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/DictFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.defs.{NA, ToSet, ToStream} import is.hail.types.virtual._ diff --git a/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala b/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala index 7ca4f22567f..acf32fc65ee 100644 --- a/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala @@ -1,10 +1,8 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ -import is.hail.annotations.{Region, SafeRow, ScalaToRegionValue} +import is.hail.annotations.{Region, RegionPool, SafeRow, ScalaToRegionValue} import is.hail.asm4s._ -import is.hail.backend.ExecuteContext import is.hail.expr.ir.agg.{CollectStateSig, PhysicalAggSig, TypedStateSig} import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs._ @@ -1064,14 +1062,14 @@ class EmitStreamSuite extends HailSuite { def assertMemoryDoesNotScaleWithStreamSize(lowSize: Int = 50, highSize: Int = 2500)(f: IR => IR) : Unit = { - val memUsed1 = ExecuteContext.scoped { ctx => - eval(f(lowSize), Env.empty, FastSeq(), None, None, false, ctx) - ctx.r.pool.getHighestTotalUsage + val memUsed1 = RegionPool.scoped { pool => + ctx.local(r = Region(pool = pool))(ctx => eval(f(lowSize), optimize = false, ctx = ctx)) + pool.getHighestTotalUsage } - val memUsed2 = ExecuteContext.scoped { ctx => - eval(f(highSize), Env.empty, FastSeq(), None, None, false, ctx) - ctx.r.pool.getHighestTotalUsage + val memUsed2 = RegionPool.scoped { pool => + ctx.local(r = Region(pool = pool))(ctx => eval(f(highSize), optimize = false, ctx = ctx)) + pool.getHighestTotalUsage } if (memUsed1 != memUsed2) diff --git a/hail/hail/test/src/is/hail/expr/ir/ForwardLetsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/ForwardLetsSuite.scala index b158fb2a506..5f39f4bd01f 100644 --- a/hail/hail/test/src/is/hail/expr/ir/ForwardLetsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/ForwardLetsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.HailSuite -import is.hail.TestUtils._ import is.hail.expr.Nat import is.hail.expr.ir.defs._ import is.hail.types.virtual._ diff --git a/hail/hail/test/src/is/hail/expr/ir/GenotypeFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/GenotypeFunctionsSuite.scala index 521922af0e7..f6ba5f123a3 100644 --- a/hail/hail/test/src/is/hail/expr/ir/GenotypeFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/GenotypeFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.types.virtual.TFloat64 import is.hail.utils.FastSeq diff --git a/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala b/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala index dbd4af08a34..6c636ecb386 100644 --- a/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala @@ -2,10 +2,10 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} import is.hail.ExecStrategy.ExecStrategy -import is.hail.TestUtils._ -import is.hail.annotations.{BroadcastRow, ExtendedOrdering, SafeNDArray} +import is.hail.annotations.{BroadcastRow, ExtendedOrdering, Region, RegionPool, SafeNDArray} import is.hail.backend.ExecuteContext import is.hail.expr.Nat +import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.agg._ import is.hail.expr.ir.defs._ import is.hail.expr.ir.defs.ArrayZipBehavior.ArrayZipBehavior @@ -1403,11 +1403,11 @@ class IRSuite extends HailSuite { assertEvalsTo(StreamLen(zip(ArrayZipBehavior.AssumeSameLength, range8, range8)), 8) // https://github.com/hail-is/hail/issues/8359 - is.hail.TestUtils.assertThrows[HailException]( + assertThrows[HailException]( zipToTuple(ArrayZipBehavior.AssertSameLength, range6, range8): IR, "zip: length mismatch": String, ) - is.hail.TestUtils.assertThrows[HailException]( + assertThrows[HailException]( zipToTuple(ArrayZipBehavior.AssertSameLength, range12, lit6): IR, "zip: length mismatch": String, ) @@ -1562,7 +1562,7 @@ class IRSuite extends HailSuite { val na = NA(TDict(TInt32, TString)) assertEvalsTo(LowerBoundOnOrderedCollection(na, I32(0), onKey = true), null) - val dwna = TestUtils.IRDict((1, 3), (3, null), (null, 5)) + val dwna = IRDict((1, 3), (3, null), (null, 5)) assertEvalsTo(LowerBoundOnOrderedCollection(dwna, I32(-1), onKey = true), 0) assertEvalsTo(LowerBoundOnOrderedCollection(dwna, I32(1), onKey = true), 0) assertEvalsTo(LowerBoundOnOrderedCollection(dwna, I32(2), onKey = true), 1) @@ -1570,7 +1570,7 @@ class IRSuite extends HailSuite { assertEvalsTo(LowerBoundOnOrderedCollection(dwna, I32(5), onKey = true), 2) assertEvalsTo(LowerBoundOnOrderedCollection(dwna, NA(TInt32), onKey = true), 2) - val dwoutna = TestUtils.IRDict((1, 3), (3, null)) + val dwoutna = IRDict((1, 3), (3, null)) assertEvalsTo(LowerBoundOnOrderedCollection(dwoutna, I32(-1), onKey = true), 0) assertEvalsTo(LowerBoundOnOrderedCollection(dwoutna, I32(4), onKey = true), 2) assertEvalsTo(LowerBoundOnOrderedCollection(dwoutna, NA(TInt32), onKey = true), 2) @@ -1801,18 +1801,18 @@ class IRSuite extends HailSuite { @Test def testStreamFold(): Unit = { assertEvalsTo(foldIR(StreamRange(1, 2, 1), NA(TBoolean))((accum, elt) => IsNA(accum)), true) - assertEvalsTo(foldIR(TestUtils.IRStream(1, 2, 3), 0)((accum, elt) => accum + elt), 6) + assertEvalsTo(foldIR(IRStream(1, 2, 3), 0)((accum, elt) => accum + elt), 6) assertEvalsTo( - foldIR(TestUtils.IRStream(1, 2, 3), NA(TInt32))((accum, elt) => accum + elt), + foldIR(IRStream(1, 2, 3), NA(TInt32))((accum, elt) => accum + elt), null, ) assertEvalsTo( - foldIR(TestUtils.IRStream(1, null, 3), NA(TInt32))((accum, elt) => accum + elt), + foldIR(IRStream(1, null, 3), NA(TInt32))((accum, elt) => accum + elt), null, ) - assertEvalsTo(foldIR(TestUtils.IRStream(1, null, 3), 0)((accum, elt) => accum + elt), null) + assertEvalsTo(foldIR(IRStream(1, null, 3), 0)((accum, elt) => accum + elt), null) assertEvalsTo( - foldIR(TestUtils.IRStream(1, null, 3), NA(TInt32))((accum, elt) => I32(5) + I32(5)), + foldIR(IRStream(1, null, 3), NA(TInt32))((accum, elt) => I32(5) + I32(5)), 10, ) } @@ -1846,15 +1846,15 @@ class IRSuite extends HailSuite { FastSeq(null, true, false, false), ) assertEvalsTo( - ToArray(streamScanIR(TestUtils.IRStream(1, 2, 3), 0)((accum, elt) => accum + elt)), + ToArray(streamScanIR(IRStream(1, 2, 3), 0)((accum, elt) => accum + elt)), FastSeq(0, 1, 3, 6), ) assertEvalsTo( - ToArray(streamScanIR(TestUtils.IRStream(1, 2, 3), NA(TInt32))((accum, elt) => accum + elt)), + ToArray(streamScanIR(IRStream(1, 2, 3), NA(TInt32))((accum, elt) => accum + elt)), FastSeq(null, null, null, null), ) assertEvalsTo( - ToArray(streamScanIR(TestUtils.IRStream(1, null, 3), NA(TInt32))((accum, elt) => + ToArray(streamScanIR(IRStream(1, null, 3), NA(TInt32))((accum, elt) => accum + elt )), FastSeq(null, null, null, null), @@ -3252,7 +3252,7 @@ class IRSuite extends HailSuite { @DataProvider(name = "valueIRs") def valueIRs(): Array[Array[Object]] = - ExecuteContext.scoped(ctx => valueIRs(ctx)) + valueIRs(ctx) def valueIRs(ctx: ExecuteContext): Array[Array[Object]] = { val fs = ctx.fs @@ -3314,8 +3314,7 @@ class IRSuite extends HailSuite { val table = TableRange(100, 10) val mt = MatrixIR.range(20, 2, Some(3)) - val vcf = is.hail.TestUtils.importVCF(ctx, getTestResource("sample.vcf")) - + val vcf = importVCF(ctx, getTestResource("sample.vcf")) val bgenReader = MatrixBGENReader( ctx, FastSeq(getTestResource("example.8bits.bgen")), @@ -3597,7 +3596,7 @@ class IRSuite extends HailSuite { @DataProvider(name = "tableIRs") def tableIRs(): Array[Array[TableIR]] = - ExecuteContext.scoped(ctx => tableIRs(ctx)) + tableIRs(ctx) def tableIRs(ctx: ExecuteContext): Array[Array[TableIR]] = { try { @@ -3706,7 +3705,7 @@ class IRSuite extends HailSuite { @DataProvider(name = "matrixIRs") def matrixIRs(): Array[Array[MatrixIR]] = - ExecuteContext.scoped(ctx => matrixIRs(ctx)) + matrixIRs(ctx) def matrixIRs(ctx: ExecuteContext): Array[Array[MatrixIR]] = { try { @@ -3730,7 +3729,7 @@ class IRSuite extends HailSuite { val read = MatrixIR.read(fs, getTestResource("backward_compatability/1.0.0/matrix_table/0.hmt")) val range = MatrixIR.range(3, 7, None) - val vcf = is.hail.TestUtils.importVCF(ctx, getTestResource("sample.vcf")) + val vcf = importVCF(ctx, getTestResource("sample.vcf")) val bgenReader = MatrixBGENReader( ctx, @@ -3908,39 +3907,32 @@ class IRSuite extends HailSuite { val cache = mutable.Map.empty[String, BlockMatrix] val bm = BlockMatrixRandom(0, gaussian = true, shape = Array(5L, 6L), blockSize = 3) try { - backend.withExecuteContext { ctx => - ctx.local(blockMatrixCache = cache) { ctx => - backend.execute(ctx, BlockMatrixWrite(bm, BlockMatrixPersistWriter("x", "MEMORY_ONLY"))) - } + ctx.local(blockMatrixCache = cache) { ctx => + backend.execute(ctx, BlockMatrixWrite(bm, BlockMatrixPersistWriter("x", "MEMORY_ONLY"))) } - backend.withExecuteContext { ctx => - ctx.local(blockMatrixCache = cache) { ctx => - val persist = BlockMatrixRead(BlockMatrixPersistReader("x", bm.typ)) - val s = Pretty.sexprStyle(persist, elideLiterals = false) - val x2 = IRParser.parse_blockmatrix_ir(ctx, s) - assert(x2 == persist) - } + ctx.local(blockMatrixCache = cache) { ctx => + val persist = BlockMatrixRead(BlockMatrixPersistReader("x", bm.typ)) + val s = Pretty.sexprStyle(persist, elideLiterals = false) + val x2 = IRParser.parse_blockmatrix_ir(ctx, s) + assert(x2 == persist) } - } finally + } finally { cache.values.foreach(_.unpersist()) + } } @Test def testCachedIR(): Unit = { val cached = Literal(TSet(TInt32), Set(1)) val s = s"(JavaIR 1)" - val x2 = ExecuteContext.scoped { ctx => - ctx.local(irCache = mutable.Map(1 -> cached))(ctx => IRParser.parse_value_ir(ctx, s)) - } + val x2 = ctx.local(irCache = mutable.Map(1 -> cached))(ctx => IRParser.parse_value_ir(ctx, s)) assert(x2 eq cached) } @Test def testCachedTableIR(): Unit = { val cached = TableRange(1, 1) val s = s"(JavaTable 1)" - val x2 = ExecuteContext.scoped { ctx => - ctx.local(irCache = mutable.Map(1 -> cached))(ctx => IRParser.parse_table_ir(ctx, s)) - } + val x2 = ctx.local(irCache = mutable.Map(1 -> cached))(ctx => IRParser.parse_table_ir(ctx, s)) assert(x2 eq cached) } @@ -4222,16 +4214,18 @@ class IRSuite extends HailSuite { val startingArg = SafeNDArray(IndexedSeq[Long](4L, 4L), (0 until 16).toFastSeq) - var memUsed = 0L - - ExecuteContext.scoped { ctx => - eval(ndSum, Env.empty, FastSeq(2 -> TInt32, startingArg -> ndType), None, None, true, ctx) - memUsed = ctx.r.pool.getHighestTotalUsage + val memUsed = RegionPool.scoped { pool => + ctx.local(r = Region(pool = pool)) { ctx => + eval(ndSum, Env.empty, FastSeq(2 -> TInt32, startingArg -> ndType), None, None, true, ctx) + pool.getHighestTotalUsage + } } - ExecuteContext.scoped { ctx => - eval(ndSum, Env.empty, FastSeq(100 -> TInt32, startingArg -> ndType), None, None, true, ctx) - assert(memUsed == ctx.r.pool.getHighestTotalUsage) + RegionPool.scoped { pool => + ctx.local(r = Region(pool = pool)) { ctx => + eval(ndSum, Env.empty, FastSeq(100 -> TInt32, startingArg -> ndType), None, None, true, ctx) + assert(memUsed == pool.getHighestTotalUsage) + } } } diff --git a/hail/hail/test/src/is/hail/expr/ir/IntervalSuite.scala b/hail/hail/test/src/is/hail/expr/ir/IntervalSuite.scala index 243c6a92a32..f868039fd29 100644 --- a/hail/hail/test/src/is/hail/expr/ir/IntervalSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/IntervalSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.defs.{ ErrorIDs, False, GetTupleElement, I32, If, Literal, MakeTuple, NA, True, } diff --git a/hail/hail/test/src/is/hail/expr/ir/MathFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/MathFunctionsSuite.scala index 6ec3d983de6..309ae863d59 100644 --- a/hail/hail/test/src/is/hail/expr/ir/MathFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/MathFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{stats, ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.defs.{ErrorIDs, F32, F64, False, I32, I64, Str, True} import is.hail.types.virtual._ import is.hail.utils._ diff --git a/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala b/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala index 5e6b03bb47c..9ae90fa1ebc 100644 --- a/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala @@ -2,7 +2,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} import is.hail.ExecStrategy.ExecStrategy -import is.hail.TestUtils._ import is.hail.annotations.BroadcastRow import is.hail.expr.JSONAnnotationImpex import is.hail.expr.ir.TestUtils._ @@ -381,7 +380,7 @@ class MatrixIRSuite extends HailSuite { } @Test def testMatrixMultiWriteDifferentTypesRaisesError(): Unit = { - val vcf = is.hail.TestUtils.importVCF(ctx, getTestResource("sample.vcf")) + val vcf = importVCF(ctx, getTestResource("sample.vcf")) val range = rangeMatrix(10, 2, None) val path1 = ctx.createTmpPath("test1") val path2 = ctx.createTmpPath("test2") diff --git a/hail/hail/test/src/is/hail/expr/ir/MemoryLeakSuite.scala b/hail/hail/test/src/is/hail/expr/ir/MemoryLeakSuite.scala index 90cb2d4713f..5fa628f8312 100644 --- a/hail/hail/test/src/is/hail/expr/ir/MemoryLeakSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/MemoryLeakSuite.scala @@ -1,9 +1,10 @@ package is.hail.expr.ir import is.hail.HailSuite -import is.hail.TestUtils.eval import is.hail.backend.ExecuteContext import is.hail.expr.ir.defs.{Literal, ToArray, ToStream} +import is.hail.annotations.{Region, RegionPool} +import is.hail.expr.ir import is.hail.types.virtual.{TArray, TBoolean, TSet, TString} import is.hail.utils._ @@ -17,19 +18,22 @@ class MemoryLeakSuite extends HailSuite { def run(size: Int): Long = { val lit = Literal(TSet(TString), (0 until litSize).map(_.toString).toSet) val queries = Literal(TArray(TString), (0 until size).map(_.toString).toFastSeq) - ExecuteContext.scoped { ctx => - eval( - ToArray( - mapIR(ToStream(queries))(r => invoke("contains", TBoolean, lit, r)) - ), - Env.empty, - FastSeq(), - None, - None, - false, - ctx, - ) - ctx.r.pool.getHighestTotalUsage + RegionPool.scoped { pool => + ctx.local(r = Region(pool = pool)) { ctx => + eval( + ToArray( + mapIR(ToStream(queries))(r => ir.invoke("contains", TBoolean, lit, r)) + ), + Env.empty, + FastSeq(), + None, + None, + false, + ctx, + ) + } + + pool.getHighestTotalUsage } } diff --git a/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala b/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala index 1f41d54260c..c7827fa214e 100644 --- a/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.annotations._ import is.hail.asm4s._ import is.hail.check.{Gen, Prop} diff --git a/hail/hail/test/src/is/hail/expr/ir/RequirednessSuite.scala b/hail/hail/test/src/is/hail/expr/ir/RequirednessSuite.scala index e62319bc968..a4e02032bb8 100644 --- a/hail/hail/test/src/is/hail/expr/ir/RequirednessSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/RequirednessSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.HailSuite -import is.hail.backend.ExecuteContext import is.hail.expr.Nat import is.hail.expr.ir.agg.CallStatsState import is.hail.expr.ir.defs._ @@ -104,7 +103,7 @@ class RequirednessSuite extends HailSuite { def pinterval(point: PType, r: Boolean): PInterval = PCanonicalInterval(point, r) @DataProvider(name = "valueIR") - def valueIR(): Array[Array[Any]] = ExecuteContext.scoped { ctx => + def valueIR(): Array[Array[Any]] = { val nodes = new BoxedArrayBuilder[Array[Any]](50) val allRequired = Array( diff --git a/hail/hail/test/src/is/hail/expr/ir/SetFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/SetFunctionsSuite.scala index dd3ec81340e..e1bfbbddc7c 100644 --- a/hail/hail/test/src/is/hail/expr/ir/SetFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/SetFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.defs.{I32, NA, ToSet, ToStream} import is.hail.types.virtual._ diff --git a/hail/hail/test/src/is/hail/expr/ir/SimplifySuite.scala b/hail/hail/test/src/is/hail/expr/ir/SimplifySuite.scala index a169e56dccc..8d9c78bd747 100644 --- a/hail/hail/test/src/is/hail/expr/ir/SimplifySuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/SimplifySuite.scala @@ -3,6 +3,7 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} import is.hail.expr.ir.TestUtils.IRAggCount import is.hail.expr.ir.defs._ +import is.hail.expr.ir.TestUtils._ import is.hail.types.virtual._ import is.hail.utils.{FastSeq, Interval} import is.hail.variant.Locus diff --git a/hail/hail/test/src/is/hail/expr/ir/StringFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/StringFunctionsSuite.scala index 09211f5cd83..c5ad27273c2 100644 --- a/hail/hail/test/src/is/hail/expr/ir/StringFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/StringFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.defs.{F32, I32, I64, MakeTuple, NA, Str} import is.hail.types.virtual._ diff --git a/hail/hail/test/src/is/hail/expr/ir/StringSliceSuite.scala b/hail/hail/test/src/is/hail/expr/ir/StringSliceSuite.scala index d532f9f9ed9..697747dfbf7 100644 --- a/hail/hail/test/src/is/hail/expr/ir/StringSliceSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/StringSliceSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.defs.{I32, In, NA, Str} import is.hail.types.virtual.TString import is.hail.utils._ diff --git a/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala b/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala index 62901f74263..39974a502a5 100644 --- a/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala @@ -2,7 +2,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} import is.hail.ExecStrategy.ExecStrategy -import is.hail.TestUtils._ import is.hail.annotations.SafeNDArray import is.hail.expr.Nat import is.hail.expr.ir.TestUtils._ diff --git a/hail/hail/test/src/is/hail/expr/ir/TrapNodeSuite.scala b/hail/hail/test/src/is/hail/expr/ir/TrapNodeSuite.scala index dc6c42b31cf..f08a632e4e1 100644 --- a/hail/hail/test/src/is/hail/expr/ir/TrapNodeSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/TrapNodeSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.defs.{ArrayRef, Die, GetTupleElement, I32, If, IsNA, Literal, Str, Trap} import is.hail.types.virtual._ import is.hail.utils._ diff --git a/hail/hail/test/src/is/hail/expr/ir/UtilFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/UtilFunctionsSuite.scala index 3c7eb8636ae..cc02aa909a4 100644 --- a/hail/hail/test/src/is/hail/expr/ir/UtilFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/UtilFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.defs.{Die, False, MakeStream, NA, Str, True} import is.hail.types.virtual.{TBoolean, TInt32, TStream} diff --git a/hail/hail/test/src/is/hail/expr/ir/analyses/SemanticHashSuite.scala b/hail/hail/test/src/is/hail/expr/ir/analyses/SemanticHashSuite.scala index 68ba43506f2..e31431c6e2f 100644 --- a/hail/hail/test/src/is/hail/expr/ir/analyses/SemanticHashSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/analyses/SemanticHashSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir.analyses import is.hail.{HAIL_PRETTY_VERSION, HailSuite} -import is.hail.backend.ExecuteContext import is.hail.expr.ir._ import is.hail.expr.ir.defs._ import is.hail.io.fs.{FS, FakeFS, FakeURL, FileListEntry} @@ -292,12 +291,14 @@ class SemanticHashSuite extends HailSuite { @Test(dataProvider = "isBaseIRSemanticallyEquivalent") def testSemanticEquivalence(a: BaseIR, b: BaseIR, isEqual: Boolean, comment: String): Unit = - assertResult( - isEqual, - s"expected semhash($a) ${if (isEqual) "==" else "!="} semhash($b), $comment", - )( - semhash(fakeFs)(a) == semhash(fakeFs)(b) - ) + ctx.local(fs = fakeFs) { ctx => + assertResult( + isEqual, + s"expected semhash($a) ${if (isEqual) "==" else "!="} semhash($b), $comment", + )( + SemanticHash(ctx)(a) == SemanticHash(ctx)(b) + ) + } @Test def testFileNotFoundExceptions(): Unit = { @@ -309,14 +310,13 @@ class SemanticHashSuite extends HailSuite { val ir = importMatrix("gs://fake-bucket/fake-matrix") - assertResult(None, "SemHash should be resilient to FileNotFoundExceptions.")( - semhash(fs)(ir) - ) + ctx.local(fs = fs) { ctx => + assertResult(None, "SemHash should be resilient to FileNotFoundExceptions.")( + SemanticHash(ctx)(ir) + ) + } } - def semhash(fs: FS)(ir: BaseIR): Option[SemanticHash.Type] = - ExecuteContext.scoped(_.local(fs = fs)(SemanticHash(_)(ir))) - val fakeFs: FS = new FakeFS { override def eTag(url: FakeURL): Option[String] = diff --git a/hail/hail/test/src/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala b/hail/hail/test/src/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala index 6bd31b6d4b4..72d5c29a580 100644 --- a/hail/hail/test/src/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala @@ -7,6 +7,7 @@ import is.hail.expr.ir.{ import is.hail.expr.ir.defs.{ Apply, ErrorIDs, GetField, I32, Literal, MakeStruct, Ref, SelectFields, ToArray, ToStream, } +import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.lowering.LowerDistributedSort.samplePartition import is.hail.types.RTable import is.hail.types.virtual.{TArray, TInt32, TStruct} @@ -15,7 +16,7 @@ import is.hail.utils.FastSeq import org.apache.spark.sql.Row import org.testng.annotations.Test -class LowerDistributedSortSuite extends HailSuite { +class LowerDistributedSortSuite extends HailSuite with TestUtils { implicit val execStrats = ExecStrategy.compileOnly @Test def testSamplePartition(): Unit = { @@ -71,14 +72,14 @@ class LowerDistributedSortSuite extends HailSuite { val sortedTs = LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt) .lower(ctx, myTable.typ.copy(key = FastSeq())) val res = - TestUtils.eval(sortedTs.mapCollect("test")(x => ToArray(x))).asInstanceOf[IndexedSeq[ + eval(sortedTs.mapCollect("test")(x => ToArray(x))).asInstanceOf[IndexedSeq[ IndexedSeq[Row] ]].flatten val rowFunc = myTable.typ.rowType.select(sortFields.map(_.field))._2 - val unsortedCollect = is.hail.expr.ir.TestUtils.collect(myTable) + val unsortedCollect = collect(myTable) val unsortedAnalyses = LoweringAnalyses.apply(unsortedCollect, ctx) - val unsorted = TestUtils.eval(LowerTableIR.apply( + val unsorted = eval(LowerTableIR.apply( unsortedCollect, DArrayLowering.All, ctx, diff --git a/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala b/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala index df686e1b7ed..25dca918bf2 100644 --- a/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala @@ -1,14 +1,13 @@ package is.hail.expr.ir.table import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils.loweredExecute -import is.hail.backend.ExecuteContext import is.hail.expr.ir._ import is.hail.expr.ir.TestUtils.IRAggCollect import is.hail.expr.ir.defs.{ ApplyBinaryPrimOp, ErrorIDs, GetField, MakeStream, MakeStruct, Ref, Str, StreamRange, TableAggregate, TableGetGlobals, } +import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.lowering.{DArrayLowering, LowerTableIR} import is.hail.rvd.RVDPartitioner import is.hail.types.virtual._ @@ -99,7 +98,7 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("lowering")) def testLowering(): Unit = { - val table = TestUtils.collect(mkTableGen()) + val table = collect(mkTableGen()) val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) assertEvalsTo(lowered, Row(FastSeq(0, 0).map(Row(_)), Row(0))) } @@ -107,13 +106,13 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("lowering")) def testNumberOfContextsMatchesPartitions(): Unit = { val errorId = 42 - val table = TestUtils.collect(mkTableGen( + val table = collect(mkTableGen( partitioner = Some(RVDPartitioner.unkeyed(ctx.stateManager, 0)), errorId = Some(errorId), )) val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) val ex = intercept[HailException] { - ExecuteContext.scoped(ctx => loweredExecute(ctx, lowered, Env.empty, FastSeq(), None)) + loweredExecute(ctx, lowered, Env.empty, FastSeq(), None) } ex.errorId shouldBe errorId ex.getMessage should include("partitioner contains 0 partitions, got 2 contexts.") @@ -122,7 +121,7 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("lowering")) def testRowsAreCorrectlyKeyed(): Unit = { val errorId = 56 - val table = TestUtils.collect(mkTableGen( + val table = collect(mkTableGen( partitioner = Some(new RVDPartitioner( ctx.stateManager, TStruct("a" -> TInt32), @@ -135,7 +134,7 @@ class TableGenSuite extends HailSuite { )) val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) val ex = intercept[SparkException] { - ExecuteContext.scoped(ctx => loweredExecute(ctx, lowered, Env.empty, FastSeq(), None)) + loweredExecute(ctx, lowered, Env.empty, FastSeq(), None) }.getCause.asInstanceOf[HailException] ex.errorId shouldBe errorId diff --git a/hail/hail/test/src/is/hail/io/compress/BGzipCodecSuite.scala b/hail/hail/test/src/is/hail/io/compress/BGzipCodecSuite.scala index 7f85df6d302..a310abdb5bc 100644 --- a/hail/hail/test/src/is/hail/io/compress/BGzipCodecSuite.scala +++ b/hail/hail/test/src/is/hail/io/compress/BGzipCodecSuite.scala @@ -1,7 +1,6 @@ package is.hail.io.compress import is.hail.HailSuite -import is.hail.TestUtils._ import is.hail.check.Gen import is.hail.check.Prop.forAll import is.hail.expr.ir.GenericLines diff --git a/hail/hail/test/src/is/hail/io/fs/FSSuite.scala b/hail/hail/test/src/is/hail/io/fs/FSSuite.scala index 3173b646b06..fa834fd5de4 100644 --- a/hail/hail/test/src/is/hail/io/fs/FSSuite.scala +++ b/hail/hail/test/src/is/hail/io/fs/FSSuite.scala @@ -12,13 +12,10 @@ import org.apache.hadoop.fs.FileAlreadyExistsException import org.scalatestplus.testng.TestNGSuite import org.testng.annotations.Test -trait FSSuite extends TestNGSuite { +trait FSSuite extends TestNGSuite with TestUtils { val root: String = System.getenv("HAIL_TEST_STORAGE_URI") - def fsResourcesRoot: String = System.getenv("HAIL_FS_TEST_CLOUD_RESOURCES_URI") - def tmpdir: String = System.getenv("HAIL_TEST_STORAGE_URI") - def fs: FS /* Structure of src/test/resources/fs: @@ -73,7 +70,7 @@ trait FSSuite extends TestNGSuite { @Test def testFileStatusOnDirIsFailure(): Unit = { val f = r("/dir") - TestUtils.interceptException[FileNotFoundException](f)( + interceptException[FileNotFoundException](f)( fs.fileStatus(f) ) } diff --git a/hail/hail/test/src/is/hail/linalg/BlockMatrixSuite.scala b/hail/hail/test/src/is/hail/linalg/BlockMatrixSuite.scala index 04a3a8387f4..da85781d8b2 100644 --- a/hail/hail/test/src/is/hail/linalg/BlockMatrixSuite.scala +++ b/hail/hail/test/src/is/hail/linalg/BlockMatrixSuite.scala @@ -1,6 +1,7 @@ package is.hail.linalg import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.check._ import is.hail.check.Arbitrary._ import is.hail.check.Gen._ @@ -926,9 +927,9 @@ class BlockMatrixSuite extends HailSuite { val bm = BlockMatrix.fromBreezeMatrix(lm, blockSize = 2) val expected = new BDM[Double](2, 3, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) - TestUtils.assertMatrixEqualityDouble(bm.pow(0.0).toBreezeMatrix(), BDM.fill(2, 3)(1.0)) - TestUtils.assertMatrixEqualityDouble(bm.pow(0.5).toBreezeMatrix(), expected) - TestUtils.assertMatrixEqualityDouble(bm.sqrt().toBreezeMatrix(), expected) + assertMatrixEqualityDouble(bm.pow(0.0).toBreezeMatrix(), BDM.fill(2, 3)(1.0)) + assertMatrixEqualityDouble(bm.pow(0.5).toBreezeMatrix(), expected) + assertMatrixEqualityDouble(bm.sqrt().toBreezeMatrix(), expected) } def filteredEquals(bm1: BlockMatrix, bm2: BlockMatrix): Boolean = @@ -1219,17 +1220,17 @@ class BlockMatrixSuite extends HailSuite { val v0 = Array(0.0, Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity) - TestUtils.interceptFatal(notSupported)(bm0 / bm0) - TestUtils.interceptFatal(notSupported)(bm0.reverseRowVectorDiv(v)) - TestUtils.interceptFatal(notSupported)(bm0.reverseColVectorDiv(v)) - TestUtils.interceptFatal(notSupported)(1 / bm0) + interceptFatal(notSupported)(bm0 / bm0) + interceptFatal(notSupported)(bm0.reverseRowVectorDiv(v)) + interceptFatal(notSupported)(bm0.reverseColVectorDiv(v)) + interceptFatal(notSupported)(1 / bm0) - TestUtils.interceptFatal(notSupported)(bm0.rowVectorDiv(v0)) - TestUtils.interceptFatal(notSupported)(bm0.colVectorDiv(v0)) - TestUtils.interceptFatal("multiplication by scalar NaN")(bm0 * Double.NaN) - TestUtils.interceptFatal("division by scalar 0.0")(bm0 / 0) + interceptFatal(notSupported)(bm0.rowVectorDiv(v0)) + interceptFatal(notSupported)(bm0.colVectorDiv(v0)) + interceptFatal("multiplication by scalar NaN")(bm0 * Double.NaN) + interceptFatal("division by scalar 0.0")(bm0 / 0) - TestUtils.interceptFatal(notSupported)(bm0.pow(-1)) + interceptFatal(notSupported)(bm0.pow(-1)) } @Test diff --git a/hail/hail/test/src/is/hail/methods/LocalLDPruneSuite.scala b/hail/hail/test/src/is/hail/methods/LocalLDPruneSuite.scala index 63ce2cd71c1..85d007d753f 100644 --- a/hail/hail/test/src/is/hail/methods/LocalLDPruneSuite.scala +++ b/hail/hail/test/src/is/hail/methods/LocalLDPruneSuite.scala @@ -1,6 +1,6 @@ package is.hail.methods -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.annotations.Annotation import is.hail.check.{Gen, Properties} import is.hail.expr.ir.{Interpret, MatrixValue, TableValue} @@ -110,7 +110,7 @@ class LocalLDPruneSuite extends HailSuite { val nCores = 4 lazy val mt = Interpret( - TestUtils.importVCF(ctx, getTestResource("sample.vcf.bgz"), nPartitions = Option(10)), + importVCF(ctx, getTestResource("sample.vcf.bgz"), nPartitions = Option(10)), ctx, false, ).toMatrixValue(Array("s")) diff --git a/hail/hail/test/src/is/hail/methods/SkatSuite.scala b/hail/hail/test/src/is/hail/methods/SkatSuite.scala index a9b5e723f21..59ca348b593 100644 --- a/hail/hail/test/src/is/hail/methods/SkatSuite.scala +++ b/hail/hail/test/src/is/hail/methods/SkatSuite.scala @@ -1,6 +1,6 @@ package is.hail.methods -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.expr.ir.DoubleArrayBuilder import is.hail.utils._ @@ -31,6 +31,6 @@ class SkatSuite extends HailSuite { val (qLarge, gramianLarge) = Skat.computeGramianLargeN(st) assert(D_==(qSmall, qLarge)) - TestUtils.assertMatrixEqualityDouble(gramianSmall, gramianLarge) + assertMatrixEqualityDouble(gramianSmall, gramianLarge) } } diff --git a/hail/hail/test/src/is/hail/stats/eigSymDSuite.scala b/hail/hail/test/src/is/hail/stats/eigSymDSuite.scala index 73bcb32e9e4..3a51cb32f4a 100644 --- a/hail/hail/test/src/is/hail/stats/eigSymDSuite.scala +++ b/hail/hail/test/src/is/hail/stats/eigSymDSuite.scala @@ -1,6 +1,6 @@ package is.hail.stats -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.utils._ import breeze.linalg.{eigSym, svd, DenseMatrix, DenseVector} @@ -109,7 +109,7 @@ class eigSymDSuite extends HailSuite { val x = DenseVector.fill[Double](n)(rand.nextGaussian()) - TestUtils.assertVectorEqualityDouble(x, TriSolve(A, A * x)) + assertVectorEqualityDouble(x, TriSolve(A, A * x)) } } } diff --git a/hail/hail/test/src/is/hail/utils/RichArraySuite.scala b/hail/hail/test/src/is/hail/utils/RichArraySuite.scala index df1a5809f51..b7793450052 100644 --- a/hail/hail/test/src/is/hail/utils/RichArraySuite.scala +++ b/hail/hail/test/src/is/hail/utils/RichArraySuite.scala @@ -1,6 +1,6 @@ package is.hail.utils -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.utils.richUtils.RichArray import org.testng.annotations.Test @@ -15,7 +15,7 @@ class RichArraySuite extends HailSuite { RichArray.importFromDoubles(fs, file, a2, bufSize = 16) assert(a === a2) - TestUtils.interceptFatal("Premature") { + interceptFatal("Premature") { RichArray.importFromDoubles(fs, file, new Array[Double](101), bufSize = 64) } } diff --git a/hail/hail/test/src/is/hail/utils/RichDenseMatrixDoubleSuite.scala b/hail/hail/test/src/is/hail/utils/RichDenseMatrixDoubleSuite.scala index a8d05321bf4..282e71d77e0 100644 --- a/hail/hail/test/src/is/hail/utils/RichDenseMatrixDoubleSuite.scala +++ b/hail/hail/test/src/is/hail/utils/RichDenseMatrixDoubleSuite.scala @@ -1,6 +1,6 @@ package is.hail.utils -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.linalg.BlockMatrix import is.hail.utils.richUtils.RichDenseMatrixDouble @@ -33,7 +33,7 @@ class RichDenseMatrixDoubleSuite extends HailSuite { val lmT2 = RichDenseMatrixDouble.importFromDoubles(fs, fileT, 100, 50, rowMajor = true) assert(mT === lmT2) - TestUtils.interceptFatal("Premature") { + interceptFatal("Premature") { RichDenseMatrixDouble.importFromDoubles(fs, fileT, 100, 100, rowMajor = true) } } diff --git a/hail/hail/test/src/is/hail/variant/GenotypeSuite.scala b/hail/hail/test/src/is/hail/variant/GenotypeSuite.scala index a2a057b30fc..1f7374d180d 100644 --- a/hail/hail/test/src/is/hail/variant/GenotypeSuite.scala +++ b/hail/hail/test/src/is/hail/variant/GenotypeSuite.scala @@ -8,8 +8,7 @@ import is.hail.utils._ import org.scalatestplus.testng.TestNGSuite import org.testng.annotations.Test -class GenotypeSuite extends TestNGSuite { - +class GenotypeSuite extends TestNGSuite with TestUtils { val v = Variant("1", 1, "A", "T") @Test def gtPairGtIndexIsId(): Unit = @@ -118,6 +117,6 @@ class GenotypeSuite extends TestNGSuite { assert(Call.parse("0|1") == Call2(0, 1, phased = true)) intercept[UnsupportedOperationException](Call.parse("1/1/1")) intercept[UnsupportedOperationException](Call.parse("1|1|1")) - TestUtils.interceptFatal("invalid call expression:")(Call.parse("0/")) + interceptFatal("invalid call expression:")(Call.parse("0/")) } } diff --git a/hail/hail/test/src/is/hail/variant/LocusIntervalSuite.scala b/hail/hail/test/src/is/hail/variant/LocusIntervalSuite.scala index 0816b4afa89..7371b2d7939 100644 --- a/hail/hail/test/src/is/hail/variant/LocusIntervalSuite.scala +++ b/hail/hail/test/src/is/hail/variant/LocusIntervalSuite.scala @@ -1,6 +1,6 @@ package is.hail.variant -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.utils._ import org.testng.annotations.Test @@ -161,11 +161,11 @@ class LocusIntervalSuite extends HailSuite { true, true, )) - TestUtils.interceptFatal("Start 'X:0' is not within the range")(Locus.parseInterval( + interceptFatal("Start 'X:0' is not within the range")(Locus.parseInterval( "[X:0-5)", rg, )) - TestUtils.interceptFatal(s"End 'X:${xMax + 1}' is not within the range")(Locus.parseInterval( + interceptFatal(s"End 'X:${xMax + 1}' is not within the range")(Locus.parseInterval( s"[X:1-${xMax + 1}]", rg, )) @@ -208,19 +208,19 @@ class LocusIntervalSuite extends HailSuite { false, )) - TestUtils.interceptFatal("invalid interval expression") { + interceptFatal("invalid interval expression") { Locus.parseInterval("4::start-5:end", rg) } - TestUtils.interceptFatal("invalid interval expression") { + interceptFatal("invalid interval expression") { Locus.parseInterval("4:start-", rg) } - TestUtils.interceptFatal("invalid interval expression") { + interceptFatal("invalid interval expression") { Locus.parseInterval("1:1.1111K-2k", rg) } - TestUtils.interceptFatal("invalid interval expression") { + interceptFatal("invalid interval expression") { Locus.parseInterval("1:1.1111111M-2M", rg) } diff --git a/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala b/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala index 05422e43c51..823fc3c8a1a 100644 --- a/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala +++ b/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala @@ -1,7 +1,8 @@ package is.hail.variant -import is.hail.{HailSuite, TestUtils} -import is.hail.backend.{ExecuteContext, HailStateManager} +import is.hail.HailSuite +import is.hail.backend.HailStateManager +import is.hail.check.Prop._ import is.hail.check.Properties import is.hail.expr.ir.EmitFunctionBuilder import is.hail.io.reference.{FASTAReader, FASTAReaderConfig, LiftOver} @@ -13,13 +14,9 @@ import org.testng.annotations.Test class ReferenceGenomeSuite extends HailSuite { - def hasReference(name: String) = ctx.stateManager.referenceGenomes.contains(name) - - def getReference(name: String) = ctx.references(name) - @Test def testGRCh37(): Unit = { - assert(hasReference(ReferenceGenome.GRCh37)) - val grch37 = getReference(ReferenceGenome.GRCh37) + assert(ctx.references.contains(ReferenceGenome.GRCh37)) + val grch37 = ctx.references(ReferenceGenome.GRCh37) assert(grch37.inX("X") && grch37.inY("Y") && grch37.isMitochondrial("MT")) assert(grch37.contigLength("1") == 249250621) @@ -34,8 +31,8 @@ class ReferenceGenomeSuite extends HailSuite { } @Test def testGRCh38(): Unit = { - assert(hasReference(ReferenceGenome.GRCh38)) - val grch38 = getReference(ReferenceGenome.GRCh38) + assert(ctx.references.contains(ReferenceGenome.GRCh38)) + val grch38 = ctx.references(ReferenceGenome.GRCh38) assert(grch38.inX("chrX") && grch38.inY("chrY") && grch38.isMitochondrial("chrM")) assert(grch38.contigLength("chr1") == 248956422) @@ -50,18 +47,18 @@ class ReferenceGenomeSuite extends HailSuite { } @Test def testAssertions(): Unit = { - TestUtils.interceptFatal("Must have at least one contig in the reference genome.")( + interceptFatal("Must have at least one contig in the reference genome.")( ReferenceGenome("test", Array.empty[String], Map.empty[String, Int]) ) - TestUtils.interceptFatal("No lengths given for the following contigs:")(ReferenceGenome( + interceptFatal("No lengths given for the following contigs:")(ReferenceGenome( "test", Array("1", "2", "3"), Map("1" -> 5), )) - TestUtils.interceptFatal("Contigs found in 'lengths' that are not present in 'contigs'")( + interceptFatal("Contigs found in 'lengths' that are not present in 'contigs'")( ReferenceGenome("test", Array("1", "2", "3"), Map("1" -> 5, "2" -> 5, "3" -> 5, "4" -> 100)) ) - TestUtils.interceptFatal("The following X contig names are absent from the reference:")( + interceptFatal("The following X contig names are absent from the reference:")( ReferenceGenome( "test", Array("1", "2", "3"), @@ -69,7 +66,7 @@ class ReferenceGenomeSuite extends HailSuite { xContigs = Set("X"), ) ) - TestUtils.interceptFatal("The following Y contig names are absent from the reference:")( + interceptFatal("The following Y contig names are absent from the reference:")( ReferenceGenome( "test", Array("1", "2", "3"), @@ -77,7 +74,7 @@ class ReferenceGenomeSuite extends HailSuite { yContigs = Set("Y"), ) ) - TestUtils.interceptFatal( + interceptFatal( "The following mitochondrial contig names are absent from the reference:" )(ReferenceGenome( "test", @@ -85,13 +82,13 @@ class ReferenceGenomeSuite extends HailSuite { Map("1" -> 5, "2" -> 5, "3" -> 5), mtContigs = Set("MT"), )) - TestUtils.interceptFatal("The contig name for PAR interval")(ReferenceGenome( + interceptFatal("The contig name for PAR interval")(ReferenceGenome( "test", Array("1", "2", "3"), Map("1" -> 5, "2" -> 5, "3" -> 5), parInput = Array((Locus("X", 1), Locus("X", 5))), )) - TestUtils.interceptFatal("in both X and Y contigs.")(ReferenceGenome( + interceptFatal("in both X and Y contigs.")(ReferenceGenome( "test", Array("1", "2", "3"), Map("1" -> 5, "2" -> 5, "3" -> 5), @@ -102,13 +99,13 @@ class ReferenceGenomeSuite extends HailSuite { @Test def testContigRemap(): Unit = { val mapping = Map("23" -> "foo") - TestUtils.interceptFatal("have remapped contigs in reference genome")( - getReference(ReferenceGenome.GRCh37).validateContigRemap(mapping) + interceptFatal("have remapped contigs in reference genome")( + ctx.references(ReferenceGenome.GRCh37).validateContigRemap(mapping) ) } @Test def testComparisonOps(): Unit = { - val rg = getReference(ReferenceGenome.GRCh37) + val rg = ctx.references(ReferenceGenome.GRCh37) // Test contigs assert(rg.compare("3", "18") < 0) @@ -127,7 +124,7 @@ class ReferenceGenomeSuite extends HailSuite { @Test def testWriteToFile(): Unit = { val tmpFile = ctx.createTmpPath("grWrite", "json") - val rg = getReference(ReferenceGenome.GRCh37) + val rg = ctx.references(ReferenceGenome.GRCh37) rg.copy(name = "GRCh37_2").write(fs, tmpFile) val gr2 = ReferenceGenome.fromFile(fs, tmpFile) @@ -221,19 +218,18 @@ class ReferenceGenomeSuite extends HailSuite { } @Test def testSerializeOnFB(): Unit = { - ExecuteContext.scoped { ctx => - val grch38 = ctx.references(ReferenceGenome.GRCh38) - val fb = EmitFunctionBuilder[String, Boolean](ctx, "serialize_rg") - val rgfield = fb.getReferenceGenome(grch38.name) - fb.emit(rgfield.invoke[String, Boolean]("isValidContig", fb.getCodeParam[String](1))) - - val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r) + val grch38 = ctx.references(ReferenceGenome.GRCh38) + val fb = EmitFunctionBuilder[String, Boolean](ctx, "serialize_rg") + val rgfield = fb.getReferenceGenome(grch38.name) + fb.emit(rgfield.invoke[String, Boolean]("isValidContig", fb.getCodeParam[String](1))) + ctx.scopedExecution { (cl, fs, tc, r) => + val f = fb.resultWithIndex()(cl, fs, tc, r) assert(f("X") == grch38.isValidContig("X")) } } - @Test def testSerializeWithLiftoverOnFB(): Unit = { - ExecuteContext.scoped { ctx => + @Test def testSerializeWithLiftoverOnFB(): Unit = + ctx.local(references = ctx.references.mapValues(_.copy())) { ctx => val grch37 = ctx.references(ReferenceGenome.GRCh37) val liftoverFile = getTestResource("grch37_to_grch38_chr20.over.chain.gz") @@ -249,13 +245,13 @@ class ReferenceGenomeSuite extends HailSuite { fb.getCodeParam[Double](3), )) - val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r) - assert(f("GRCh38", Locus("20", 60001), 0.95) == grch37.liftoverLocus( - "GRCh38", - Locus("20", 60001), - 0.95, - )) - grch37.removeLiftover("GRCh38") + ctx.scopedExecution { (cl, fs, tc, r) => + val f = fb.resultWithIndex()(cl, fs, tc, r) + assert(f("GRCh38", Locus("20", 60001), 0.95) == grch37.liftoverLocus( + "GRCh38", + Locus("20", 60001), + 0.95, + )) + } } - } } diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index 995d0b668c4..e7e14bbe90a 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -3,7 +3,7 @@ import zipfile from dataclasses import dataclass from enum import Enum -from typing import AbstractSet, Any, ClassVar, Dict, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import AbstractSet, Any, ClassVar, Dict, List, Mapping, Optional, Set, Tuple, TypeVar, Union import orjson @@ -68,15 +68,45 @@ def local_jar_information() -> LocalJarInformation: raise ValueError(f'Hail requires either {hail_jar} or {hail_all_spark_jar}.') +class IRFunction: + def __init__( + self, + name: str, + type_parameters: Union[Tuple[HailType, ...], List[HailType]], + value_parameter_names: Union[Tuple[str, ...], List[str]], + value_parameter_types: Union[Tuple[HailType, ...], List[HailType]], + return_type: HailType, + body: Expression, + ): + assert len(value_parameter_names) == len(value_parameter_types) + render = CSERenderer() + self._name = name + self._type_parameters = type_parameters + self._value_parameter_names = value_parameter_names + self._value_parameter_types = value_parameter_types + self._return_type = return_type + self._rendered_body = render(finalize_randomness(body._ir)) + + def to_dataclass(self): + return SerializedIRFunction( + name=self._name, + type_parameters=[tp._parsable_string() for tp in self._type_parameters], + value_parameter_names=list(self._value_parameter_names), + value_parameter_types=[vpt._parsable_string() for vpt in self._value_parameter_types], + return_type=self._return_type._parsable_string(), + rendered_body=self._rendered_body, + ) + + class ActionTag(Enum): - LOAD_REFERENCES_FROM_DATASET = 1 - VALUE_TYPE = 2 - TABLE_TYPE = 3 - MATRIX_TABLE_TYPE = 4 - BLOCK_MATRIX_TYPE = 5 - EXECUTE = 6 - PARSE_VCF_METADATA = 7 - IMPORT_FAM = 8 + VALUE_TYPE = 1 + TABLE_TYPE = 2 + MATRIX_TABLE_TYPE = 3 + BLOCK_MATRIX_TYPE = 4 + EXECUTE = 5 + PARSE_VCF_METADATA = 6 + IMPORT_FAM = 7 + LOAD_REFERENCES_FROM_DATASET = 8 FROM_FASTA_FILE = 9 @@ -90,11 +120,21 @@ class IRTypePayload(ActionPayload): ir: str +@dataclass +class SerializedIRFunction: + name: str + type_parameters: List[str] + value_parameter_names: List[str] + value_parameter_types: List[str] + return_type: str + rendered_body: str + + @dataclass class ExecutePayload(ActionPayload): ir: str + fns: List[SerializedIRFunction] stream_codec: str - timed: bool @dataclass @@ -164,6 +204,8 @@ def _valid_flags(self) -> AbstractSet[str]: def __init__(self): self._persisted_locations = dict() self._references = {} + self.functions: List[IRFunction] = [] + self._registered_ir_function_names: Set[str] = set() @abc.abstractmethod def validate_file(self, uri: str): @@ -171,10 +213,15 @@ def validate_file(self, uri: str): @abc.abstractmethod def stop(self): - pass + self.functions = [] + self._registered_ir_function_names = set() def execute(self, ir: BaseIR, timed: bool = False) -> Any: - payload = ExecutePayload(self._render_ir(ir), '{"name":"StreamBufferSpec"}', timed) + payload = ExecutePayload( + self._render_ir(ir), + fns=[fn.to_dataclass() for fn in self.functions], + stream_codec='{"name":"StreamBufferSpec"}', + ) try: result, timings = self._rpc(ActionTag.EXECUTE, payload) except FatalError as e: @@ -300,7 +347,6 @@ def unpersist(self, dataset: Dataset) -> Dataset: tempfile.__exit__(None, None, None) return unpersisted - @abc.abstractmethod def register_ir_function( self, name: str, @@ -310,11 +356,13 @@ def register_ir_function( return_type: HailType, body: Expression, ): - pass + self._registered_ir_function_names.add(name) + self.functions.append( + IRFunction(name, type_parameters, value_parameter_names, value_parameter_types, return_type, body) + ) - @abc.abstractmethod def _is_registered_ir_function_name(self, name: str) -> bool: - pass + return name in self._registered_ir_function_names @abc.abstractmethod def persist_expression(self, expr: Expression) -> Expression: diff --git a/hail/python/hail/backend/local_backend.py b/hail/python/hail/backend/local_backend.py index 708eacaef2d..311a5c7e9bb 100644 --- a/hail/python/hail/backend/local_backend.py +++ b/hail/python/hail/backend/local_backend.py @@ -74,7 +74,6 @@ def __init__( hail_package = getattr(self._gateway.jvm, 'is').hail jbackend = hail_package.backend.local.LocalBackend.apply( - tmpdir, log, True, append, @@ -82,23 +81,14 @@ def __init__( ) jhc = hail_package.HailContext.apply(jbackend, branching_factor, optimizer_iterations) - super(LocalBackend, self).__init__(self._gateway.jvm, jbackend, jhc) + super().__init__(self._gateway.jvm, jbackend, jhc, tmpdir, tmpdir) + self.gcs_requester_pays_configuration = gcs_requester_pays_configuration self._fs = self._exit_stack.enter_context( RouterFS(gcs_kwargs={'gcs_requester_pays_configuration': gcs_requester_pays_configuration}) ) self._logger = None - - flags = {} - if gcs_requester_pays_configuration is not None: - if isinstance(gcs_requester_pays_configuration, str): - flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration - else: - assert isinstance(gcs_requester_pays_configuration, tuple) - flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration[0] - flags['gcs_requester_pays_buckets'] = ','.join(gcs_requester_pays_configuration[1]) - - self._initialize_flags(flags) + self._initialize_flags({}) def validate_file(self, uri: str) -> None: async_to_blocking(validate_file(uri, self._fs.afs)) @@ -119,7 +109,7 @@ def register_ir_function( ) def stop(self): - super().stop() + super(Py4JBackend, self).stop() self._exit_stack.close() uninstall_exception_handler() diff --git a/hail/python/hail/backend/py4j_backend.py b/hail/python/hail/backend/py4j_backend.py index 03bc9be5dfe..e3d33e97c17 100644 --- a/hail/python/hail/backend/py4j_backend.py +++ b/hail/python/hail/backend/py4j_backend.py @@ -4,7 +4,7 @@ import socketserver import sys from threading import Thread -from typing import Mapping, Optional, Set, Tuple +from typing import Mapping, Optional, Tuple import orjson import py4j @@ -13,9 +13,11 @@ import hail from hail.expr import construct_expr +from hail.fs.hadoop_fs import HadoopFS from hail.ir import JavaIR from hail.utils.java import Env, FatalError, scala_package_object from hail.version import __version__ +from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration from ..hail_logging import Logger from .backend import ActionTag, Backend, fatal_error_from_java_error_triplet @@ -171,8 +173,15 @@ def parse(node): class Py4JBackend(Backend): @abc.abstractmethod - def __init__(self, jvm: JVMView, jbackend: JavaObject, jhc: JavaObject): - super(Py4JBackend, self).__init__() + def __init__( + self, + jvm: JVMView, + jbackend: JavaObject, + jhc: JavaObject, + tmpdir: str, + remote_tmpdir: str, + ): + super().__init__() import base64 def decode_bytearray(encoded): @@ -185,15 +194,18 @@ def decode_bytearray(encoded): self._jvm = jvm self._hail_package = getattr(self._jvm, 'is').hail self._utils_package_object = scala_package_object(self._hail_package.utils) - self._jbackend = jbackend self._jhc = jhc - self._backend_server = self._hail_package.backend.BackendServer(self._jbackend) - self._backend_server_port: int = self._backend_server.port() - self._backend_server.start() + self._jbackend = self._hail_package.backend.api.Py4JBackendApi(jbackend) + self._jbackend.pySetLocalTmp(tmpdir) + self._jbackend.pySetRemoteTmp(remote_tmpdir) + + self._jhttp_server = self._jbackend.pyHttpServer() + self._backend_server_port: int = self._jhttp_server.port() self._requests_session = requests.Session() - self._registered_ir_function_names: Set[str] = set() + self._gcs_requester_pays_config = None + self._fs = None # This has to go after creating the SparkSession. Unclear why. # Maybe it does its own patch? @@ -216,6 +228,23 @@ def hail_package(self): def utils_package_object(self): return self._utils_package_object + @property + def gcs_requester_pays_configuration(self) -> Optional[GCSRequesterPaysConfiguration]: + return self._gcs_requester_pays_config + + @gcs_requester_pays_configuration.setter + def gcs_requester_pays_configuration(self, config: Optional[GCSRequesterPaysConfiguration]): + self._gcs_requester_pays_config = config + project, buckets = (None, None) if config is None else (config, None) if isinstance(config, str) else config + self._jbackend.pySetGcsRequesterPaysConfig(project, buckets) + self._fs = None # stale + + @property + def fs(self): + if self._fs is None: + self._fs = HadoopFS(self._utils_package_object, self._jbackend.pyFs()) + return self._fs + @property def logger(self): if self._logger is None: @@ -238,9 +267,6 @@ def persist_expression(self, expr): t = expr.dtype return construct_expr(JavaIR(t, self._jbackend.pyExecuteLiteral(self._render_ir(expr._ir))), t) - def _is_registered_ir_function_name(self, name: str) -> bool: - return name in self._registered_ir_function_names - def set_flags(self, **flags: Mapping[str, str]): available = self._jbackend.pyAvailableFlags() invalid = [] @@ -275,12 +301,6 @@ def add_liftover(self, name, chain_file, dest_reference_genome): def remove_liftover(self, name, dest_reference_genome): self._jbackend.pyRemoveLiftover(name, dest_reference_genome) - def _parse_value_ir(self, code, ref_map={}): - return self._jbackend.parse_value_ir( - code, - {k: t._parsable_string() for k, t in ref_map.items()}, - ) - def _register_ir_function(self, name, type_parameters, argument_names, argument_types, return_type, code): self._registered_ir_function_names.add(name) self._jbackend.pyRegisterIR( @@ -292,12 +312,6 @@ def _register_ir_function(self, name, type_parameters, argument_names, argument_ code, ) - def _parse_table_ir(self, code): - return self._jbackend.parse_table_ir(code) - - def _parse_matrix_ir(self, code): - return self._jbackend.parse_matrix_ir(code) - def _parse_blockmatrix_ir(self, code): return self._jbackend.parse_blockmatrix_ir(code) @@ -305,9 +319,9 @@ def _to_java_blockmatrix_ir(self, ir): return self._parse_blockmatrix_ir(self._render_ir(ir)) def stop(self): - self._backend_server.close() + self._jhttp_server.close() self._jbackend.close() self._jhc.stop() self._jhc = None - self._registered_ir_function_names = set() uninstall_exception_handler() + super().stop() diff --git a/hail/python/hail/backend/service_backend.py b/hail/python/hail/backend/service_backend.py index 953ce749b05..595811d67e5 100644 --- a/hail/python/hail/backend/service_backend.py +++ b/hail/python/hail/backend/service_backend.py @@ -12,10 +12,6 @@ import hailtop.aiotools.fs as afs from hail.context import TemporaryDirectory, TemporaryFilename, tmp_dir from hail.experimental import read_expression, write_expression -from hail.expr.expressions.base_expression import Expression -from hail.expr.types import HailType -from hail.ir import finalize_randomness -from hail.ir.renderer import CSERenderer from hail.utils import FatalError from hail.version import __revision__, __version__ from hailtop import yamlx @@ -33,7 +29,7 @@ from ..builtin_references import BUILTIN_REFERENCES from ..utils import ANY_REGION -from .backend import ActionPayload, ActionTag, Backend, ExecutePayload, fatal_error_from_java_error_triplet +from .backend import ActionPayload, ActionTag, Backend, fatal_error_from_java_error_triplet ReferenceGenomeConfig = Dict[str, Any] @@ -67,53 +63,6 @@ async def read_str(strm: afs.ReadableStream) -> str: return b.decode('utf-8') -@dataclass -class SerializedIRFunction: - name: str - type_parameters: List[str] - value_parameter_names: List[str] - value_parameter_types: List[str] - return_type: str - rendered_body: str - - -class IRFunction: - def __init__( - self, - name: str, - type_parameters: Union[Tuple[HailType, ...], List[HailType]], - value_parameter_names: Union[Tuple[str, ...], List[str]], - value_parameter_types: Union[Tuple[HailType, ...], List[HailType]], - return_type: HailType, - body: Expression, - ): - assert len(value_parameter_names) == len(value_parameter_types) - render = CSERenderer() - self._name = name - self._type_parameters = type_parameters - self._value_parameter_names = value_parameter_names - self._value_parameter_types = value_parameter_types - self._return_type = return_type - self._rendered_body = render(finalize_randomness(body._ir)) - - def to_dataclass(self): - return SerializedIRFunction( - name=self._name, - type_parameters=[tp._parsable_string() for tp in self._type_parameters], - value_parameter_names=list(self._value_parameter_names), - value_parameter_types=[vpt._parsable_string() for vpt in self._value_parameter_types], - return_type=self._return_type._parsable_string(), - rendered_body=self._rendered_body, - ) - - -@dataclass -class ServiceBackendExecutePayload(ActionPayload): - functions: List[SerializedIRFunction] - idempotency_token: str - payload: ExecutePayload - - @dataclass class CloudfuseConfig: bucket: str @@ -130,17 +79,21 @@ class SequenceConfig: @dataclass class ServiceBackendRPCConfig: tmp_dir: str - remote_tmpdir: str + flags: Dict[str, str] + custom_references: List[str] + liftovers: Dict[str, Dict[str, str]] + sequences: Dict[str, SequenceConfig] + + +@dataclass +class BatchJobConfig: + token: str billing_project: str worker_cores: str worker_memory: str storage: str cloudfuse_configs: List[CloudfuseConfig] regions: List[str] - flags: Dict[str, str] - custom_references: List[str] - liftovers: Dict[str, Dict[str, str]] - sequences: Dict[str, SequenceConfig] class ServiceBackend(Backend): @@ -149,14 +102,14 @@ class ServiceBackend(Backend): DRIVER = "driver" # is.hail.backend.service.ServiceBackendSocketAPI2 protocol - LOAD_REFERENCES_FROM_DATASET = 1 - VALUE_TYPE = 2 - TABLE_TYPE = 3 - MATRIX_TABLE_TYPE = 4 - BLOCK_MATRIX_TYPE = 5 - EXECUTE = 6 - PARSE_VCF_METADATA = 7 - IMPORT_FAM = 8 + VALUE_TYPE = 1 + TABLE_TYPE = 2 + MATRIX_TABLE_TYPE = 3 + BLOCK_MATRIX_TYPE = 4 + EXECUTE = 5 + PARSE_VCF_METADATA = 6 + IMPORT_FAM = 7 + LOAD_REFERENCES_FROM_DATASET = 8 FROM_FASTA_FILE = 9 @staticmethod @@ -289,7 +242,6 @@ def __init__( self.batch_attributes = batch_attributes self.remote_tmpdir = remote_tmpdir self.flags: Dict[str, str] = {} - self.functions: List[IRFunction] = [] self._registered_ir_function_names: Set[str] = set() self.driver_cores = driver_cores self.driver_memory = driver_memory @@ -334,16 +286,16 @@ def logger(self): def stop(self): hail_event_loop().run_until_complete(self._stop()) + super().stop() async def _stop(self): await self._async_exit_stack.aclose() - self.functions = [] - self._registered_ir_function_names = set() async def _run_on_batch( self, name: str, service_backend_config: ServiceBackendRPCConfig, + job_config: BatchJobConfig, action: ActionTag, payload: ActionPayload, *, @@ -357,7 +309,8 @@ async def _run_on_batch( async with await self._async_fs.create(iodir + '/in') as infile: await infile.write( orjson.dumps({ - 'config': service_backend_config, + 'rpc_config': service_backend_config, + 'job_config': job_config, 'action': action.value, 'payload': payload, }) @@ -375,8 +328,8 @@ async def _run_on_batch( elif self.driver_memory is not None: resources['memory'] = str(self.driver_memory) - if service_backend_config.storage != '0Gi': - resources['storage'] = service_backend_config.storage + if job_config.storage != '0Gi': + resources['storage'] = job_config.storage j = self._batch.create_jvm_job( jar_spec=self.jar_spec, @@ -390,7 +343,7 @@ async def _run_on_batch( resources=resources, attributes={'name': name + '_driver'}, regions=self.regions, - cloudfuse=[(c.bucket, c.mount_path, c.read_only) for c in service_backend_config.cloudfuse_configs], + cloudfuse=[(c.bucket, c.mount_path, c.read_only) for c in job_config.cloudfuse_configs], profile=self.flags['profile'] is not None, ) await self._batch.submit(disable_progress_bar=True) @@ -467,11 +420,6 @@ def _rpc(self, action: ActionTag, payload: ActionPayload) -> Tuple[bytes, Option return self._cancel_on_ctrl_c(self._async_rpc(action, payload)) async def _async_rpc(self, action: ActionTag, payload: ActionPayload): - if isinstance(payload, ExecutePayload): - payload = ServiceBackendExecutePayload( - [f.to_dataclass() for f in self.functions], self._batch.token, payload - ) - storage_requirement_bytes = 0 readonly_fuse_buckets: Set[str] = set() @@ -486,31 +434,37 @@ async def _async_rpc(self, action: ActionTag, payload: ActionPayload): readonly_fuse_buckets.add(bucket) storage_requirement_bytes += await (await self._async_fs.statfile(blob)).size() sequence_file_mounts[rg_name] = SequenceConfig( - f'/cloudfuse/{fasta_bucket}/{fasta_path}', f'/cloudfuse/{index_bucket}/{index_path}' + f'/cloudfuse/{fasta_bucket}/{fasta_path}', + f'/cloudfuse/{index_bucket}/{index_path}', ) - storage_gib_str = f'{math.ceil(storage_requirement_bytes / 1024 / 1024 / 1024)}Gi' - qob_config = ServiceBackendRPCConfig( - tmp_dir=tmp_dir(), - remote_tmpdir=self.remote_tmpdir, - billing_project=self.billing_project, - worker_cores=str(self.worker_cores), - worker_memory=str(self.worker_memory), - storage=storage_gib_str, - cloudfuse_configs=[ - CloudfuseConfig(bucket, f'/cloudfuse/{bucket}', True) for bucket in readonly_fuse_buckets - ], - regions=self.regions, - flags=self.flags, - custom_references=[ - orjson.dumps(rg._config).decode('utf-8') - for rg in self._references.values() - if rg.name not in BUILTIN_REFERENCES - ], - liftovers={rg.name: rg._liftovers for rg in self._references.values() if len(rg._liftovers) > 0}, - sequences=sequence_file_mounts, + return await self._run_on_batch( + name=f'{action.name.lower()}(...)', + service_backend_config=ServiceBackendRPCConfig( + tmp_dir=self.remote_tmpdir, + flags=self.flags, + custom_references=[ + orjson.dumps(rg._config).decode('utf-8') + for rg in self._references.values() + if rg.name not in BUILTIN_REFERENCES + ], + liftovers={rg.name: rg._liftovers for rg in self._references.values() if len(rg._liftovers) > 0}, + sequences=sequence_file_mounts, + ), + job_config=BatchJobConfig( + token=self._batch.token, + billing_project=self.billing_project, + worker_cores=str(self.worker_cores), + worker_memory=str(self.worker_memory), + storage=f'{math.ceil(storage_requirement_bytes / 1024 / 1024 / 1024)}Gi', + cloudfuse_configs=[ + CloudfuseConfig(bucket, f'/cloudfuse/{bucket}', True) for bucket in readonly_fuse_buckets + ], + regions=self.regions, + ), + action=action, + payload=payload, ) - return await self._run_on_batch(f'{action.name.lower()}(...)', qob_config, action, payload) # Sequence and liftover information is stored on the ReferenceGenome # and there is no persistent backend to keep in sync. @@ -533,23 +487,6 @@ def add_liftover(self, name: str, chain_file: str, dest_reference_genome: str): def remove_liftover(self, name, dest_reference_genome): # pylint: disable=unused-argument pass - def register_ir_function( - self, - name: str, - type_parameters: Union[Tuple[HailType, ...], List[HailType]], - value_parameter_names: Union[Tuple[str, ...], List[str]], - value_parameter_types: Union[Tuple[HailType, ...], List[HailType]], - return_type: HailType, - body: Expression, - ): - self._registered_ir_function_names.add(name) - self.functions.append( - IRFunction(name, type_parameters, value_parameter_names, value_parameter_types, return_type, body) - ) - - def _is_registered_ir_function_name(self, name: str) -> bool: - return name in self._registered_ir_function_names - def persist_expression(self, expr): # FIXME: should use context manager to clean up persisted resources fname = TemporaryFilename(prefix='persist_expression').name diff --git a/hail/python/hail/backend/spark_backend.py b/hail/python/hail/backend/spark_backend.py index 69f53292443..192ca57eb75 100644 --- a/hail/python/hail/backend/spark_backend.py +++ b/hail/python/hail/backend/spark_backend.py @@ -7,11 +7,11 @@ import pyspark.sql from hail.expr.table_type import ttable -from hail.fs.hadoop_fs import HadoopFS from hail.ir import BaseIR from hail.ir.renderer import CSERenderer from hail.table import Table from hail.utils import copy_log +from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration from hailtop.aiotools.router_fs import RouterAsyncFS from hailtop.aiotools.validators import validate_file from hailtop.utils import async_to_blocking @@ -47,12 +47,9 @@ def __init__( skip_logging_configuration, optimizer_iterations, *, - gcs_requester_pays_project: Optional[str] = None, - gcs_requester_pays_buckets: Optional[str] = None, + gcs_requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None, copy_log_on_error: bool = False, ): - assert gcs_requester_pays_project is not None or gcs_requester_pays_buckets is None - try: local_jar_info = local_jar_information() except ValueError: @@ -120,10 +117,6 @@ def __init__( append, skip_logging_configuration, min_block_size, - tmpdir, - local_tmpdir, - gcs_requester_pays_project, - gcs_requester_pays_buckets, ) jhc = hail_package.HailContext.getOrCreate(jbackend, branching_factor, optimizer_iterations) else: @@ -137,10 +130,6 @@ def __init__( append, skip_logging_configuration, min_block_size, - tmpdir, - local_tmpdir, - gcs_requester_pays_project, - gcs_requester_pays_buckets, ) jhc = hail_package.HailContext.apply(jbackend, branching_factor, optimizer_iterations) @@ -149,12 +138,12 @@ def __init__( self.sc = sc else: self.sc = pyspark.SparkContext(gateway=self._gateway, jsc=jvm.JavaSparkContext(self._jsc)) - self._jspark_session = jbackend.sparkSession() + self._jspark_session = jbackend.sparkSession().apply() self._spark_session = pyspark.sql.SparkSession(self.sc, self._jspark_session) - super(SparkBackend, self).__init__(jvm, jbackend, jhc) + super().__init__(jvm, jbackend, jhc, local_tmpdir, tmpdir) + self.gcs_requester_pays_configuration = gcs_requester_pays_config - self._fs = None self._logger = None if not quiet: @@ -167,7 +156,7 @@ def __init__( self._initialize_flags({}) self._router_async_fs = RouterAsyncFS( - gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_project} + gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_config} ) self._tmpdir = tmpdir @@ -181,12 +170,6 @@ def stop(self): self.sc.stop() self.sc = None - @property - def fs(self): - if self._fs is None: - self._fs = HadoopFS(self._utils_package_object, self._jbackend.fs()) - return self._fs - def from_spark(self, df, key): result_tuple = self._jbackend.pyFromDF(df._jdf, key) tir_id, type_json = result_tuple._1(), result_tuple._2() diff --git a/hail/python/hail/context.py b/hail/python/hail/context.py index 6c96599a24b..b3d72e2cef4 100644 --- a/hail/python/hail/context.py +++ b/hail/python/hail/context.py @@ -472,14 +472,10 @@ def init_spark( optimizer_iterations = get_env_or_default(_optimizer_iterations, 'HAIL_OPTIMIZER_ITERATIONS', 3) app_name = app_name or 'Hail' - ( - gcs_requester_pays_project, - gcs_requester_pays_buckets, - ) = convert_gcs_requester_pays_configuration_to_hadoop_conf_style( - get_gcs_requester_pays_configuration( - gcs_requester_pays_configuration=gcs_requester_pays_configuration, - ) + gcs_requester_pays_configuration = get_gcs_requester_pays_configuration( + gcs_requester_pays_configuration=gcs_requester_pays_configuration, ) + backend = SparkBackend( idempotent, sc, @@ -496,8 +492,7 @@ def init_spark( local_tmpdir, skip_logging_configuration, optimizer_iterations, - gcs_requester_pays_project=gcs_requester_pays_project, - gcs_requester_pays_buckets=gcs_requester_pays_buckets, + gcs_requester_pays_config=gcs_requester_pays_configuration, copy_log_on_error=copy_log_on_error, ) if not backend.fs.exists(tmpdir): diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index 81f5fb3b855..cdb1c874c36 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -2,7 +2,7 @@ import random import re import unittest -from test.hail.helpers import resource, skip_unless_spark_backend, skip_when_service_backend +from test.hail.helpers import resource, skip_unless_spark_backend import numpy as np import pytest @@ -189,12 +189,6 @@ def value_ir(value_irs, request): return value_irs[request.param] -@skip_when_service_backend() -def test_ir_parses(value_ir): - env = value_irs_env() - Env.backend()._parse_value_ir(str(value_ir), env) - - def test_ir_value_type(value_ir): env = value_irs_env() typ = Env.backend().value_type( @@ -313,11 +307,6 @@ def table_ir(table_irs, request): return table_irs[request.param] -@skip_when_service_backend() -def test_table_ir_parses(table_ir): - Env.backend()._parse_table_ir(str(table_ir)) - - def test_table_ir_table_type(table_ir): typ = Env.backend().table_type(table_ir) assert table_ir.typ == typ @@ -420,11 +409,6 @@ def matrix_ir(matrix_irs, request): return matrix_irs[request.param] -@skip_when_service_backend() -def test_matrix_ir_parses(matrix_ir): - Env.backend()._parse_matrix_ir(str(matrix_ir)) - - def test_matrix_ir_matrix_type(matrix_ir): typ = Env.backend().matrix_type(matrix_ir) assert typ == matrix_ir.typ