From 904e13b71af22313d35ab2fb86519290e6c466b4 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Wed, 18 Sep 2024 16:46:37 -0400 Subject: [PATCH] [query] unify backend rpc --- hail/hail/src/is/hail/HailContext.scala | 9 +- hail/hail/src/is/hail/backend/Backend.scala | 98 +--- .../hail/src/is/hail/backend/BackendRpc.scala | 256 ++++++++++ .../src/is/hail/backend/BackendServer.scala | 155 +++--- .../src/is/hail/backend/ExecuteContext.scala | 9 +- .../is/hail/backend/local/LocalBackend.scala | 5 +- .../backend/py4j/Py4JBackendExtensions.scala | 48 +- .../hail/backend/service/ServiceBackend.scala | 473 +++++++----------- .../src/is/hail/backend/service/Worker.scala | 49 +- .../is/hail/backend/spark/SparkBackend.scala | 12 +- .../src/is/hail/expr/ir/GenericLines.scala | 2 +- hail/hail/src/is/hail/expr/ir/TableIR.scala | 4 +- .../hail/src/is/hail/expr/ir/TableValue.scala | 2 +- .../is/hail/expr/ir/functions/Functions.scala | 58 ++- .../expr/ir/lowering/RVDToTableStage.scala | 4 +- .../hail/src/is/hail/io/plink/LoadPlink.scala | 6 +- hail/hail/src/is/hail/io/vcf/LoadVCF.scala | 2 +- .../hail/src/is/hail/linalg/BlockMatrix.scala | 8 +- hail/hail/src/is/hail/linalg/RowMatrix.scala | 2 +- .../methods/MatrixExportEntriesByCol.scala | 2 +- hail/hail/src/is/hail/rvd/RVD.scala | 2 +- .../src/is/hail/sparkextras/ContextRDD.scala | 8 +- .../is/hail/sparkextras/IndexReadRDD.scala | 2 +- .../src/is/hail/types/virtual/package.scala | 12 + .../hail/utils/SpillingCollectIterator.scala | 2 +- hail/hail/test/src/is/hail/HailSuite.scala | 4 +- hail/hail/test/src/is/hail/TestUtils.scala | 6 +- .../is/hail/backend/ServiceBackendSuite.scala | 72 +-- hail/hail/test/src/is/hail/check/Gen.scala | 4 +- .../hail/variant/ReferenceGenomeSuite.scala | 53 +- hail/python/hail/backend/backend.py | 80 ++- hail/python/hail/backend/local_backend.py | 2 +- hail/python/hail/backend/py4j_backend.py | 21 +- hail/python/hail/backend/service_backend.py | 169 ++----- hail/python/test/hail/test_ir.py | 18 +- 35 files changed, 810 insertions(+), 849 deletions(-) create mode 100644 hail/hail/src/is/hail/backend/BackendRpc.scala create mode 100644 hail/hail/src/is/hail/types/virtual/package.scala diff --git a/hail/hail/src/is/hail/HailContext.scala b/hail/hail/src/is/hail/HailContext.scala index 3de89ce13cd..f52efe60ee7 100644 --- a/hail/hail/src/is/hail/HailContext.scala +++ b/hail/hail/src/is/hail/HailContext.scala @@ -19,6 +19,7 @@ import org.apache.spark.executor.InputMetrics import org.apache.spark.rdd.RDD import org.json4s.Extraction import org.json4s.jackson.JsonMethods +import sourcecode.Enclosing case class FilePartition(index: Int, file: String) extends Partition @@ -41,7 +42,7 @@ object HailContext { def backend: Backend = get.backend - def sparkBackend(op: String): SparkBackend = get.sparkBackend(op) + def sparkBackend(implicit E: Enclosing): SparkBackend = get.backend.asSpark def configureLogging(logFile: String, quiet: Boolean, append: Boolean): Unit = { org.apache.log4j.helpers.LogLog.setInternalDebugging(true) @@ -152,7 +153,7 @@ object HailContext { val fsBc = fs.broadcast - new RDD[T](SparkBackend.sparkContext("readPartition"), Nil) { + new RDD[T](SparkBackend.sparkContext, Nil) { def getPartitions: Array[Partition] = Array.tabulate(nPartitions)(i => FilePartition(i, partFiles(i))) @@ -175,8 +176,6 @@ class HailContext private ( ) { def stop(): Unit = HailContext.stop() - def sparkBackend(op: String): SparkBackend = backend.asSpark(op) - var checkRVDKeys: Boolean = false def version: String = is.hail.HAIL_PRETTY_VERSION @@ -188,7 +187,7 @@ class HailContext private ( maxLines: Int, ): Map[String, Array[WithContext[String]]] = { val regexp = regex.r - SparkBackend.sparkContext("fileAndLineCounts").textFilesLines(fs.globAll(files).map(_.getPath)) + SparkBackend.sparkContext.textFilesLines(fs.globAll(files).map(_.getPath)) .filter(line => regexp.findFirstIn(line.value).isDefined) .take(maxLines) .groupBy(_.source.file) diff --git a/hail/hail/src/is/hail/backend/Backend.scala b/hail/hail/src/is/hail/backend/Backend.scala index 666b3c0d2b6..34a85bfde93 100644 --- a/hail/hail/src/is/hail/backend/Backend.scala +++ b/hail/hail/src/is/hail/backend/Backend.scala @@ -1,28 +1,21 @@ package is.hail.backend -import is.hail.asm4s._ -import is.hail.backend.Backend.jsonToBytes +import is.hail.asm4s.HailClassLoader import is.hail.backend.spark.SparkBackend -import is.hail.expr.ir.{IR, IRParser, LoweringAnalyses, SortField, TableIR, TableReader} +import is.hail.expr.ir.{IR, LoweringAnalyses, SortField, TableIR, TableReader} import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} import is.hail.io.{BufferSpec, TypedCodecSpec} -import is.hail.io.fs._ -import is.hail.io.plink.LoadPlink -import is.hail.io.vcf.LoadVCF -import is.hail.types._ +import is.hail.io.fs.FS +import is.hail.types.RTable import is.hail.types.encoded.EType import is.hail.types.physical.PTuple -import is.hail.types.virtual.TFloat64 -import is.hail.utils._ -import is.hail.variant.ReferenceGenome +import is.hail.utils.ExecutionTimer.Timings +import is.hail.utils.fatal import scala.reflect.ClassTag -import java.io._ -import java.nio.charset.StandardCharsets +import java.io.{Closeable, OutputStream} -import org.json4s._ -import org.json4s.jackson.JsonMethods import sourcecode.Enclosing object Backend { @@ -38,23 +31,19 @@ object Backend { ctx: ExecuteContext, t: PTuple, off: Long, - bufferSpecString: String, + bufferSpec: BufferSpec, os: OutputStream, ): Unit = { - val bs = BufferSpec.parseOrDefault(bufferSpecString) assert(t.size == 1) val elementType = t.fields(0).typ val codec = TypedCodecSpec( EType.fromPythonTypeEncoding(elementType.virtualType), elementType.virtualType, - bs, + bufferSpec, ) assert(t.isFieldDefined(off, 0)) codec.encode(ctx, elementType, t.loadField(off, 0), os) } - - def jsonToBytes(f: => JValue): Array[Byte] = - JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8) } abstract class BroadcastValue[T] { def value: T } @@ -82,8 +71,8 @@ abstract class Backend extends Closeable { f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] ): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) - def asSpark(op: String): SparkBackend = - fatal(s"${getClass.getSimpleName}: $op requires SparkBackend") + def asSpark(implicit E: Enclosing): SparkBackend = + fatal(s"${getClass.getSimpleName}: ${E.value} requires SparkBackend") def lowerDistributedSort( ctx: ExecuteContext, @@ -116,70 +105,7 @@ abstract class Backend extends Closeable { def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses) : TableStage - def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T - - final def valueType(s: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - IRParser.parse_value_ir(ctx, s).typ.toJSON - } - } - - final def tableType(s: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - IRParser.parse_table_ir(ctx, s).typ.toJSON - } - } - - final def matrixTableType(s: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - IRParser.parse_matrix_ir(ctx, s).typ.toJSON - } - } - - final def blockMatrixType(s: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - IRParser.parse_blockmatrix_ir(ctx, s).typ.toJSON - } - } - - def loadReferencesFromDataset(path: String): Array[Byte] - - def fromFASTAFile( - name: String, - fastaFile: String, - indexFile: String, - xContigs: Array[String], - yContigs: Array[String], - mtContigs: Array[String], - parInput: Array[String], - ): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, - xContigs, yContigs, mtContigs, parInput).toJSON - } - } - - def parseVCFMetadata(path: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - Extraction.decompose { - LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path) - }(defaultJSONFormats) - } - } - - def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String) - : Array[Byte] = - withExecuteContext { ctx => - LoadPlink.importFamJSON(ctx.fs, path, isQuantPheno, delimiter, missingValue).getBytes( - StandardCharsets.UTF_8 - ) - } + def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)] } diff --git a/hail/hail/src/is/hail/backend/BackendRpc.scala b/hail/hail/src/is/hail/backend/BackendRpc.scala new file mode 100644 index 00000000000..bb2df117ad1 --- /dev/null +++ b/hail/hail/src/is/hail/backend/BackendRpc.scala @@ -0,0 +1,256 @@ +package is.hail.backend + +import is.hail.expr.ir.IRParser +import is.hail.expr.ir.functions.IRFunctionRegistry +import is.hail.expr.ir.functions.IRFunctionRegistry.UserDefinedFnKey +import is.hail.io.BufferSpec +import is.hail.io.plink.LoadPlink +import is.hail.io.vcf.LoadVCF +import is.hail.services.retryTransientErrors +import is.hail.types.virtual.{Kind, TFloat64, VType} +import is.hail.types.virtual.Kinds._ +import is.hail.utils.{using, BoxedArrayBuilder, ExecutionTimer, FastSeq} +import is.hail.utils.ExecutionTimer.Timings +import is.hail.variant.ReferenceGenome + +import scala.util.control.NonFatal + +import java.io.ByteArrayOutputStream +import java.nio.charset.StandardCharsets + +import org.json4s.{DefaultFormats, Extraction, Formats, JArray, JValue} +import org.json4s.jackson.JsonMethods + +case class IRTypePayload(ir: String) +case class LoadReferencesFromDatasetPayload(path: String) + +case class FromFASTAFilePayload( + name: String, + fasta_file: String, + index_file: String, + x_contigs: Array[String], + y_contigs: Array[String], + mt_contigs: Array[String], + par: Array[String], +) + +case class ParseVCFMetadataPayload(path: String) +case class ImportFamPayload(path: String, quant_pheno: Boolean, delimiter: String, missing: String) + +case class ExecutePayload( + ir: String, + fns: Array[SerializedIRFunction], + stream_codec: String, +) + +case class SerializedIRFunction( + name: String, + type_parameters: Array[String], + value_parameter_names: Array[String], + value_parameter_types: Array[String], + return_type: String, + rendered_body: String, +) + +trait BackendRpc { + + sealed trait Command extends Product with Serializable + + object Commands { + case class TypeOf(k: Kind[_ <: VType], ir: String) extends Command + case class Execute(ir: String, fs: Array[SerializedIRFunction], codec: String) extends Command + case class ParseVcfMetadata(path: String) extends Command + + case class ImportFam(path: String, isQuantPheno: Boolean, delimiter: String, missing: String) + extends Command + + case class LoadReferencesFromDataset(path: String) extends Command + + case class LoadReferencesFromFASTA( + name: String, + fasta_file: String, + index_file: String, + x_contigs: Array[String], + y_contigs: Array[String], + mt_contigs: Array[String], + par: Array[String], + ) extends Command + } + + type Env + + trait Reader { + def command(env: Env): Command + } + + trait Context { + def scoped[A](env: Env)(f: ExecuteContext => A): (A, Timings) + } + + trait Writer { + def timings(env: Env, t: ExecutionTimer.Timings): Unit + def result(env: Env, r: Array[Byte]): Unit + def error(env: Env, t: Throwable): Unit + } + + implicit val fmts: Formats = DefaultFormats + + import Commands._ + + final def runRpc(env: Env)(implicit Reader: Reader, Context: Context, Write: Writer): Unit = + try { + val command = Reader.command(env) + val (result, timings) = retryTransientErrors { + Context.scoped(env) { ctx => + command match { + case TypeOf(kind, s) => + jsonToBytes { + (kind match { + case Value => IRParser.parse_value_ir(ctx, s) + case Table => IRParser.parse_table_ir(ctx, s) + case Matrix => IRParser.parse_matrix_ir(ctx, s) + case BlockMatrix => IRParser.parse_blockmatrix_ir(ctx, s) + }).typ.toJSON + } + + case Execute(s, fns, codec) => + val bufferSpec = BufferSpec.parseOrDefault(codec) + withIRFunctionsReadFromInput(ctx, fns) { + val ir = IRParser.parse_value_ir(ctx, s) + val res = ctx.backend.execute(ctx, ir) + res match { + case Left(_) => Array.empty[Byte] + case Right((pt, off)) => + using(new ByteArrayOutputStream()) { os => + Backend.encodeToOutputStream(ctx, pt, off, bufferSpec, os) + os.toByteArray + } + } + } + + case ParseVcfMetadata(path) => + jsonToBytes { + val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path) + Extraction.decompose(metadata) + } + + case ImportFam(path, isQuantPheno, delimiter, missing) => + jsonToBytes { + LoadPlink.importFamJSON(ctx.fs, path, isQuantPheno, delimiter, missing) + } + + case LoadReferencesFromDataset(path) => + jsonToBytes { + val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) + ReferenceGenome.addFatalOnCollision(ctx.references, rgs) + JArray(rgs.map(_.toJSON).toList) + } + + case LoadReferencesFromFASTA(name, fasta, index, xContigs, yContigs, mtContigs, par) => + jsonToBytes { + val rg = ReferenceGenome.fromFASTAFile( + ctx, + name, + fasta, + index, + xContigs, + yContigs, + mtContigs, + par, + ) + ReferenceGenome.addFatalOnCollision(ctx.references, FastSeq(rg)) + rg.toJSON + } + } + } + } + + Write.result(env, result) + Write.timings(env, timings) + } catch { + case NonFatal(error) => Write.error(env, error) + } + + def jsonToBytes(v: JValue): Array[Byte] = + JsonMethods.compact(v).getBytes(StandardCharsets.UTF_8) + + private[this] def withIRFunctionsReadFromInput[A]( + ctx: ExecuteContext, + serializedFunctions: Array[SerializedIRFunction], + )( + body: => A + ): A = { + val fns = new BoxedArrayBuilder[UserDefinedFnKey](serializedFunctions.length) + try { + for (func <- serializedFunctions) { + fns += IRFunctionRegistry.registerIR( + ctx, + func.name, + func.type_parameters, + func.value_parameter_names, + func.value_parameter_types, + func.return_type, + func.rendered_body, + ) + } + + body + } finally + for (i <- 0 until fns.length) + IRFunctionRegistry.unregisterIr(fns(i)) + } +} + +trait HttpLikeBackendRpc extends BackendRpc { + + import Commands._ + + trait Routing extends Reader { + + sealed trait Route extends Product with Serializable + + object Routes { + case class TypeOf(kind: Kind[_ <: VType]) extends Route + case object Execute extends Route + case object ParseVcfMetadata extends Route + case object ImportFam extends Route + case object LoadReferencesFromDataset extends Route + case object LoadReferencesFromFASTA extends Route + } + + def route(a: Env): Route + def payload(a: Env): JValue + + final override def command(a: Env): Command = + route(a) match { + case Routes.TypeOf(k) => + TypeOf(k, payload(a).extract[IRTypePayload].ir) + case Routes.Execute => + val ExecutePayload(ir, fns, codec) = payload(a).extract[ExecutePayload] + Execute(ir, fns, codec) + case Routes.ParseVcfMetadata => + ParseVcfMetadata(payload(a).extract[ParseVCFMetadataPayload].path) + case Routes.ImportFam => + val config = payload(a).extract[ImportFamPayload] + ImportFam(config.path, config.quant_pheno, config.delimiter, config.missing) + case Routes.LoadReferencesFromDataset => + val path = payload(a).extract[LoadReferencesFromDatasetPayload].path + LoadReferencesFromDataset(path) + case Routes.LoadReferencesFromFASTA => + val config = payload(a).extract[FromFASTAFilePayload] + LoadReferencesFromFASTA( + config.name, + config.fasta_file, + config.index_file, + config.x_contigs, + config.y_contigs, + config.mt_contigs, + config.par, + ) + } + } + + implicit protected def Reader: Routing + implicit protected def Writer: Writer + implicit protected def Context: Context +} diff --git a/hail/hail/src/is/hail/backend/BackendServer.scala b/hail/hail/src/is/hail/backend/BackendServer.scala index 24e7f7b9875..984d24f4709 100644 --- a/hail/hail/src/is/hail/backend/BackendServer.scala +++ b/hail/hail/src/is/hail/backend/BackendServer.scala @@ -1,41 +1,21 @@ package is.hail.backend -import is.hail.expr.ir.IRParser +import is.hail.types.virtual.Kinds.{BlockMatrix, Matrix, Table, Value} import is.hail.utils._ - -import scala.util.control.NonFatal +import is.hail.utils.ExecutionTimer.Timings import java.io.Closeable import java.net.InetSocketAddress -import java.nio.charset.StandardCharsets import java.util.concurrent._ +import com.google.api.client.http.HttpStatusCodes import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer} import org.json4s._ -import org.json4s.jackson.JsonMethods -import org.json4s.jackson.JsonMethods.compact - -case class IRTypePayload(ir: String) -case class LoadReferencesFromDatasetPayload(path: String) - -case class FromFASTAFilePayload( - name: String, - fasta_file: String, - index_file: String, - x_contigs: Array[String], - y_contigs: Array[String], - mt_contigs: Array[String], - par: Array[String], -) - -case class ParseVCFMetadataPayload(path: String) -case class ImportFamPayload(path: String, quant_pheno: Boolean, delimiter: String, missing: String) -case class ExecutePayload(ir: String, stream_codec: String, timed: Boolean) +import org.json4s.jackson.{JsonMethods, Serialization} class BackendServer(backend: Backend) extends Closeable { // 0 => let the OS pick an available port private[this] val httpServer = HttpServer.create(new InetSocketAddress(0), 10) - private[this] val handler = new BackendHttpHandler(backend) private[this] val thread = { // This HTTP server *must not* start non-daemon threads because such threads keep the JVM @@ -59,7 +39,7 @@ class BackendServer(backend: Backend) extends Closeable { /* Source: * https://docs.oracle.com/en/java/javase/11/docs/api/jdk.httpserver/com/sun/net/httpserver/HttpServer.html#setExecutor(java.util.concurrent.Executor) */ // - httpServer.createContext("/", handler) + httpServer.createContext("/", Handler) httpServer.setExecutor(null) val t = Executors.defaultThreadFactory().newThread(new Runnable() { def run(): Unit = @@ -69,85 +49,76 @@ class BackendServer(backend: Backend) extends Closeable { t } - def port = httpServer.getAddress.getPort + def port: Int = httpServer.getAddress.getPort def start(): Unit = thread.start() override def close(): Unit = httpServer.stop(10) -} -class BackendHttpHandler(backend: Backend) extends HttpHandler { - def handle(exchange: HttpExchange): Unit = { - implicit val formats: Formats = DefaultFormats - - try { - val body = using(exchange.getRequestBody)(JsonMethods.parse(_)) - if (exchange.getRequestURI.getPath == "/execute") { - val ExecutePayload(irStr, streamCodec, timed) = body.extract[ExecutePayload] - backend.withExecuteContext { ctx => - val (res, timings) = ExecutionTimer.time { timer => - ctx.local(timer = timer) { ctx => - val irData = IRParser.parse_value_ir(ctx, irStr) - backend.execute(ctx, irData) - } - } - - if (timed) { - exchange.getResponseHeaders.add("X-Hail-Timings", compact(timings.toJSON)) - } - - res match { - case Left(_) => exchange.sendResponseHeaders(200, -1L) - case Right((t, off)) => - exchange.sendResponseHeaders(200, 0L) // 0 => an arbitrarily long response body - using(exchange.getResponseBody) { os => - Backend.encodeToOutputStream(ctx, t, off, streamCodec, os) - } - } + private[this] object Handler extends HttpHandler with HttpLikeBackendRpc { + + case class Env(exchange: HttpExchange, payload: JValue) + + override def handle(exchange: HttpExchange): Unit = { + val payload = using(exchange.getRequestBody)(JsonMethods.parse(_)) + runRpc(Env(exchange, payload)) + } + + implicit override object Reader extends Routing { + + import Routes._ + + override def route(env: Env): Route = + env.exchange.getRequestURI.getPath match { + case "/value/type" => TypeOf(Value) + case "/table/type" => TypeOf(Table) + case "/matrixtable/type" => TypeOf(Matrix) + case "/blockmatrix/type" => TypeOf(BlockMatrix) + case "/execute" => Execute + case "/vcf/metadata/parse" => ParseVcfMetadata + case "/fam/import" => ImportFam + case "/references/load" => LoadReferencesFromDataset + case "/references/from_fasta" => LoadReferencesFromFASTA } - return - } - val response: Array[Byte] = exchange.getRequestURI.getPath match { - case "/value/type" => backend.valueType(body.extract[IRTypePayload].ir) - case "/table/type" => backend.tableType(body.extract[IRTypePayload].ir) - case "/matrixtable/type" => backend.matrixTableType(body.extract[IRTypePayload].ir) - case "/blockmatrix/type" => backend.blockMatrixType(body.extract[IRTypePayload].ir) - case "/references/load" => - backend.loadReferencesFromDataset(body.extract[LoadReferencesFromDatasetPayload].path) - case "/references/from_fasta" => - val config = body.extract[FromFASTAFilePayload] - backend.fromFASTAFile( - config.name, - config.fasta_file, - config.index_file, - config.x_contigs, - config.y_contigs, - config.mt_contigs, - config.par, - ) - case "/vcf/metadata/parse" => - backend.parseVCFMetadata(body.extract[ParseVCFMetadataPayload].path) - case "/fam/import" => - val config = body.extract[ImportFamPayload] - backend.importFam(config.path, config.quant_pheno, config.delimiter, config.missing) + override def payload(env: Env): JValue = env.payload + } + + implicit override object Writer extends Writer with ErrorHandling { + + override def timings(env: Env, t: Timings): Unit = { + val ts = Serialization.write(Map("timings" -> t)) + env.exchange.getResponseHeaders.add("X-Hail-Timings", ts) } - exchange.sendResponseHeaders(200, response.length) - using(exchange.getResponseBody())(_.write(response)) - } catch { - case NonFatal(t) => - val (shortMessage, expandedMessage, errorId) = handleForPython(t) - val errorJson = JObject( - "short" -> JString(shortMessage), - "expanded" -> JString(expandedMessage), - "error_id" -> JInt(errorId), + override def result(env: Env, result: Array[Byte]): Unit = + respond(env, HttpStatusCodes.STATUS_CODE_OK, result) + + override def error(env: Env, t: Throwable): Unit = + respond( + env, + HttpStatusCodes.STATUS_CODE_SERVER_ERROR, + jsonToBytes { + val (shortMessage, expandedMessage, errorId) = handleForPython(t) + JObject( + "short" -> JString(shortMessage), + "expanded" -> JString(expandedMessage), + "error_id" -> JInt(errorId), + ) + }, ) - val errorBytes = JsonMethods.compact(errorJson).getBytes(StandardCharsets.UTF_8) - exchange.sendResponseHeaders(500, errorBytes.length) - using(exchange.getResponseBody())(_.write(errorBytes)) + + private[this] def respond(env: Env, code: Int, payload: Array[Byte]): Unit = { + env.exchange.sendResponseHeaders(code, payload.length) + using(env.exchange.getResponseBody)(_.write(payload)) + } + } + + implicit override protected object Context extends Context { + override def scoped[A](env: Env)(f: ExecuteContext => A): (A, Timings) = + backend.withExecuteContext(f) } } } diff --git a/hail/hail/src/is/hail/backend/ExecuteContext.scala b/hail/hail/src/is/hail/backend/ExecuteContext.scala index a9217359af1..ca37f5752e4 100644 --- a/hail/hail/src/is/hail/backend/ExecuteContext.scala +++ b/hail/hail/src/is/hail/backend/ExecuteContext.scala @@ -55,12 +55,10 @@ object NonOwningTempFileManager { } object ExecuteContext { - def scoped[T](f: ExecuteContext => T)(implicit E: Enclosing): T = { - val result = HailContext.sparkBackend("ExecuteContext.scoped").withExecuteContext( + def scoped[T](f: ExecuteContext => T)(implicit E: Enclosing): T = + HailContext.sparkBackend.withExecuteContext( selfContainedExecution = false )(f) - result - } def scoped[T]( tmpdir: String, @@ -146,7 +144,8 @@ class ExecuteContext( ) } - val stateManager = HailStateManager(references) + def stateManager: HailStateManager = + HailStateManager(references) val tempFileManager: TempFileManager = if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs) diff --git a/hail/hail/src/is/hail/backend/local/LocalBackend.scala b/hail/hail/src/is/hail/backend/local/LocalBackend.scala index 89847e5af85..67bdfe350ee 100644 --- a/hail/hail/src/is/hail/backend/local/LocalBackend.scala +++ b/hail/hail/src/is/hail/backend/local/LocalBackend.scala @@ -18,6 +18,7 @@ import is.hail.types.physical.PTuple import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual.TVoid import is.hail.utils._ +import is.hail.utils.ExecutionTimer.Timings import is.hail.variant.ReferenceGenome import scala.collection.mutable @@ -100,8 +101,8 @@ class LocalBackend( // flags can be set after construction from python def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = - ExecutionTimer.logTime { timer => + override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) = + ExecutionTimer.time { timer => val fs = this.fs ExecuteContext.scoped( tmpdir, diff --git a/hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala b/hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala index aa084dcf46f..16b570cb9df 100644 --- a/hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala +++ b/hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala @@ -4,8 +4,8 @@ import is.hail.HailFeatureFlags import is.hail.backend.{Backend, ExecuteContext, NonOwningTempFileManager, TempFileManager} import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex} import is.hail.expr.ir.{ - BaseIR, BindingEnv, BlockMatrixIR, IR, IRParser, Interpret, MatrixIR, MatrixNativeReader, - MatrixRead, Name, NativeReaderOptions, TableIR, TableLiteral, TableValue, + BaseIR, BlockMatrixIR, IRParser, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, + NativeReaderOptions, TableLiteral, TableValue, } import is.hail.expr.ir.IRParser.parseType import is.hail.expr.ir.defs.{EncodedLiteral, GetFieldByIdx} @@ -14,21 +14,17 @@ import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} import is.hail.linalg.RowMatrix import is.hail.types.physical.PStruct import is.hail.types.virtual.{TArray, TInterval} -import is.hail.utils.{defaultJSONFormats, log, toRichIterable, FastSeq, HailException, Interval} +import is.hail.utils.{log, toRichIterable, FastSeq, HailException, Interval} import is.hail.variant.ReferenceGenome import scala.collection.mutable -import scala.jdk.CollectionConverters.{ - asScalaBufferConverter, mapAsScalaMapConverter, seqAsJavaListConverter, -} +import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConverter} -import java.nio.charset.StandardCharsets import java.util import org.apache.spark.sql.DataFrame import org.json4s -import org.json4s.Formats -import org.json4s.jackson.{JsonMethods, Serialization} +import org.json4s.jackson.JsonMethods import sourcecode.Enclosing trait Py4JBackendExtensions { @@ -140,7 +136,7 @@ trait Py4JBackendExtensions { val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) addJavaIR(ctx, field) } - } + }._1 def pyFromDF(df: DataFrame, jKey: java.util.List[String]): (Int, String) = { val key = jKey.asScala.toArray.toFastSeq @@ -166,7 +162,7 @@ trait Py4JBackendExtensions { backend.withExecuteContext { ctx => val tir = IRParser.parse_table_ir(ctx, s) Interpret(tir, ctx).toDF() - } + }._1 def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] = backend.withExecuteContext { ctx => @@ -192,7 +188,7 @@ trait Py4JBackendExtensions { } log.info("pyReadMultipleMatrixTables: returning N matrix tables") matrixReaders.asJava - } + }._1 def pyAddReference(jsonConfig: String): Unit = addReference(ReferenceGenome.fromJSON(jsonConfig)) @@ -214,37 +210,11 @@ trait Py4JBackendExtensions { private[this] def removeReference(name: String): Unit = references -= name - def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = - backend.withExecuteContext { ctx => - IRParser.parse_value_ir( - ctx, - s, - BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) => - Name(n) -> IRParser.parseType(t) - }.toSeq: _*), - ) - } - - def parse_table_ir(s: String): TableIR = - withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_table_ir(ctx, s)) - - def parse_matrix_ir(s: String): MatrixIR = - withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_matrix_ir(ctx, s)) - def parse_blockmatrix_ir(s: String): BlockMatrixIR = withExecuteContext(selfContainedExecution = false) { ctx => IRParser.parse_blockmatrix_ir(ctx, s) } - def loadReferencesFromDataset(path: String): Array[Byte] = - backend.withExecuteContext { ctx => - val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) - ReferenceGenome.addFatalOnCollision(references, rgs) - - implicit val formats: Formats = defaultJSONFormats - Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8) - } - def withExecuteContext[T]( selfContainedExecution: Boolean = true )( @@ -255,5 +225,5 @@ trait Py4JBackendExtensions { val tempFileManager = longLifeTempFileManager if (selfContainedExecution && tempFileManager != null) f(ctx) else ctx.local(tempFileManager = NonOwningTempFileManager(tempFileManager))(f) - } + }._1 } diff --git a/hail/hail/src/is/hail/backend/service/ServiceBackend.scala b/hail/hail/src/is/hail/backend/service/ServiceBackend.scala index 6e4309056b2..a1283c1fc6c 100644 --- a/hail/hail/src/is/hail/backend/service/ServiceBackend.scala +++ b/hail/hail/src/is/hail/backend/service/ServiceBackend.scala @@ -4,14 +4,12 @@ import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ +import is.hail.backend.service.ServiceBackend.MaxAvailableGcsConnections import is.hail.expr.Validate -import is.hail.expr.ir.{ - IR, IRParser, IRSize, LoweringAnalyses, SortField, TableIR, TableReader, TypeCheck, -} +import is.hail.expr.ir.{IR, IRSize, LoweringAnalyses, SortField, TableIR, TableReader, TypeCheck} import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.defs.MakeTuple -import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering._ import is.hail.io.fs._ import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} @@ -20,8 +18,9 @@ import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success} import is.hail.types._ import is.hail.types.physical._ import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType -import is.hail.types.virtual._ +import is.hail.types.virtual.{Kinds, TVoid} import is.hail.utils._ +import is.hail.utils.ExecutionTimer.Timings import is.hail.variant.ReferenceGenome import scala.annotation.switch @@ -33,108 +32,36 @@ import java.nio.charset.StandardCharsets import java.nio.file.Path import java.util.concurrent._ -import org.apache.log4j.Logger import org.json4s.{DefaultFormats, Formats} import org.json4s.JsonAST._ -import org.json4s.jackson.{JsonMethods, Serialization} +import org.json4s.jackson.JsonMethods import sourcecode.Enclosing -class ServiceBackendContext( - val billingProject: String, - val remoteTmpDir: String, - val workerCores: String, - val workerMemory: String, - val storageRequirement: String, - val regions: Array[String], - val cloudfuseConfig: Array[CloudfuseConfig], - val profile: Boolean, - val executionCache: ExecutionCache, -) extends BackendContext with Serializable {} +case class ServiceBackendContext( + remoteTmpDir: String, + jobConfig: BatchJobConfig, + override val executionCache: ExecutionCache, +) extends BackendContext with Serializable object ServiceBackend { - - def apply( - jarLocation: String, - name: String, - theHailClassLoader: HailClassLoader, - batchClient: BatchClient, - batchConfig: BatchConfig, - scratchDir: String = sys.env.getOrElse("HAIL_WORKER_SCRATCH_DIR", ""), - rpcConfig: ServiceBackendRPCPayload, - env: Map[String, String], - ): ServiceBackend = { - - val flags = HailFeatureFlags.fromEnv(rpcConfig.flags) - val shouldProfile = flags.get("profile") != null - val fs = RouterFS.buildRoutes( - CloudStorageFSConfig.fromFlagsAndEnv( - Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), - flags, - env, - ) - ) - - val backendContext = new ServiceBackendContext( - rpcConfig.billing_project, - rpcConfig.remote_tmpdir, - rpcConfig.worker_cores, - rpcConfig.worker_memory, - rpcConfig.storage, - rpcConfig.regions, - rpcConfig.cloudfuse_configs, - shouldProfile, - ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir), - ) - - val references = mutable.Map.empty[String, ReferenceGenome] - references ++= ReferenceGenome.builtinReferences() - ReferenceGenome.addFatalOnCollision( - references, - rpcConfig.custom_references.map(ReferenceGenome.fromJSON), - ) - - rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => - liftoversForSource.foreach { case (destGenome, chainFile) => - references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile)) - } - } - rpcConfig.sequences.foreach { case (rg, seq) => - references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index)) - } - - new ServiceBackend( - JarUrl(jarLocation), - name, - theHailClassLoader, - references, - batchClient, - batchConfig, - flags, - rpcConfig.tmp_dir, - fs, - backendContext, - scratchDir, - ) - } + val MaxAvailableGcsConnections = 1000 } class ServiceBackend( - val jarSpec: JarSpec, - var name: String, - val theHailClassLoader: HailClassLoader, - val references: mutable.Map[String, ReferenceGenome], - val batchClient: BatchClient, + val name: String, + batchClient: BatchClient, + jarSpec: JarSpec, + theHailClassLoader: HailClassLoader, val batchConfig: BatchConfig, - val flags: HailFeatureFlags, - val tmpdir: String, + rpcConfig: ServiceBackendRPCPayload, + jobConfig: BatchJobConfig, + flags: HailFeatureFlags, val fs: FS, - val serviceBackendContext: ServiceBackendContext, - val scratchDir: String, + references: mutable.Map[String, ReferenceGenome], ) extends Backend with Logging { private[this] var stageCount = 0 - private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000 - private[this] val executor = Executors.newFixedThreadPool(MAX_AVAILABLE_GCS_CONNECTIONS) + private[this] val executor = Executors.newFixedThreadPool(MaxAvailableGcsConnections) def defaultParallelism: Int = 4 @@ -160,7 +87,6 @@ class ServiceBackend( } private[this] def submitJobGroupAndWait( - backendContext: ServiceBackendContext, collection: IndexedSeq[Array[Byte]], token: String, root: String, @@ -180,13 +106,13 @@ class ServiceBackend( resources = Some( JobResources( preemptible = true, - cpu = Some(backendContext.workerCores).filter(_ != "None"), - memory = Some(backendContext.workerMemory).filter(_ != "None"), - storage = Some(backendContext.storageRequirement).filter(_ != "0Gi"), + cpu = Some(jobConfig.worker_cores).filter(_ != "None"), + memory = Some(jobConfig.worker_memory).filter(_ != "None"), + storage = Some(jobConfig.storage).filter(_ != "0Gi"), ) ), - regions = Some(backendContext.regions).filter(_.nonEmpty), - cloudfuse = Some(backendContext.cloudfuseConfig).filter(_.nonEmpty), + regions = Some(jobConfig.regions).filter(_.nonEmpty), + cloudfuse = Some(jobConfig.cloudfuse_configs).filter(_.nonEmpty), ) val jobs = @@ -280,7 +206,7 @@ class ServiceBackend( uploadFunction.get() uploadContexts.get() - val jobGroup = submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier) + val jobGroup = submitJobGroupAndWait(parts, token, root, stageIdentifier) log.info(s"parallelizeAndComputeWithIndex: $token: reading results") val startTime = System.nanoTime() @@ -375,11 +301,11 @@ class ServiceBackend( : TableStage = LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses) - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = - ExecutionTimer.logTime { timer => + override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) = + ExecutionTimer.time { timer => ExecuteContext.scoped( - tmpdir, - "file:///tmp", + rpcConfig.tmp_dir, + rpcConfig.tmp_dir, this, references.toMap, fs, @@ -387,7 +313,11 @@ class ServiceBackend( null, theHailClassLoader, flags, - serviceBackendContext, + ServiceBackendContext( + rpcConfig.tmp_dir, + jobConfig, + ExecutionCache.fromFlags(flags, fs, rpcConfig.tmp_dir), + ), new IrMetadata(), ImmutableMap.empty, ImmutableMap.empty, @@ -395,21 +325,12 @@ class ServiceBackend( ImmutableMap.empty, )(f) } - - override def loadReferencesFromDataset(path: String): Array[Byte] = - withExecuteContext { ctx => - val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) - ReferenceGenome.addFatalOnCollision(references, rgs) - implicit val formats: Formats = defaultJSONFormats - Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8) - } } class EndOfInputException extends RuntimeException class HailBatchFailure(message: String) extends RuntimeException(message) -object ServiceBackendAPI { - private[this] val log = Logger.getLogger(getClass.getName()) +object ServiceBackendAPI extends HttpLikeBackendRpc with Logging { def main(argv: Array[String]): Unit = { assert(argv.length == 7, argv.toFastSeq) @@ -423,39 +344,70 @@ object ServiceBackendAPI { val inputURL = argv(5) val outputURL = argv(6) - implicit val formats: Formats = DefaultFormats + implicit val fmts: Formats = DefaultFormats - val fs = RouterFS.buildRoutes( + val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json") + DeployConfig.set(deployConfig) + sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir) + + var fs = RouterFS.buildRoutes( CloudStorageFSConfig.fromFlagsAndEnv( Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), HailFeatureFlags.fromEnv(), ) ) - val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json") - DeployConfig.set(deployConfig) - sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir) - val batchClient = BatchClient(deployConfig, Path.of(scratchDir, "secrets/gsa-key/key.json")) - log.info("BatchClient allocated.") + val (rpcConfig, jobConfig, action, payload) = + using(fs.openNoCompression(inputURL)) { is => + val input = JsonMethods.parse(is) + ( + (input \ "rpc_config").extract[ServiceBackendRPCPayload], + (input \ "job_config").extract[BatchJobConfig], + (input \ "action").extract[Int], + input \ "payload", + ) + } + + // requester pays config is conveyed in feature flags currently + val featureFlags = HailFeatureFlags.fromEnv(rpcConfig.flags) + fs = RouterFS.buildRoutes( + CloudStorageFSConfig.fromFlagsAndEnv( + Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), + featureFlags, + ) + ) + + val references = mutable.Map[String, ReferenceGenome]() + references ++= ReferenceGenome.builtinReferences() + ReferenceGenome.addFatalOnCollision( + references, + rpcConfig.custom_references.map(ReferenceGenome.fromJSON), + ) - val batchConfig = - BatchConfig.fromConfigFile(Path.of(scratchDir, "batch-config/batch-config.json")) - log.info("BatchConfig parsed.") + rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => + liftoversForSource.foreach { case (destGenome, chainFile) => + references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile)) + } + } - val input = using(fs.openNoCompression(inputURL))(JsonMethods.parse(_)) - val rpcConfig = (input \ "config").extract[ServiceBackendRPCPayload] + rpcConfig.sequences.foreach { case (rg, seq) => + references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index)) + } // FIXME: when can the classloader be shared? (optimizer benefits!) - val backend = ServiceBackend( - jarLocation, + val backend = new ServiceBackend( name, - new HailClassLoader(getClass().getClassLoader()), - batchClient, - batchConfig, - scratchDir, + BatchClient(deployConfig, Path.of(scratchDir, "secrets/gsa-key/key.json")), + JarUrl(jarLocation), + new HailClassLoader(getClass.getClassLoader), + BatchConfig.fromConfigFile(Path.of(scratchDir, "batch-config/batch-config.json")), rpcConfig, - sys.env, + jobConfig, + featureFlags, + fs, + references, ) + log.info("ServiceBackend allocated.") if (HailContext.isInitialized) { HailContext.get.backend = backend @@ -465,9 +417,84 @@ object ServiceBackendAPI { log.info("HailContexet initialized.") } - val action = (input \ "action").extract[Int] - val payload = (input \ "payload") - new ServiceBackendAPI(backend, fs, outputURL).executeOneCommand(action, payload) + runRpc(Env(backend, fs, outputURL, action, payload)) + } + + case class Env( + backend: ServiceBackend, + fs: FS, + outputUrl: String, + action: Int, + payload: JValue, + ) + + implicit override object Reader extends Routing { + import Routes._ + + override def route(a: Env): Route = + (a.action: @switch) match { + case 1 => TypeOf(Kinds.Value) + case 2 => TypeOf(Kinds.Table) + case 3 => TypeOf(Kinds.Matrix) + case 4 => TypeOf(Kinds.BlockMatrix) + case 5 => Execute + case 6 => ParseVcfMetadata + case 7 => ImportFam + case 8 => LoadReferencesFromDataset + case 9 => LoadReferencesFromFASTA + } + + override def payload(a: Env): JValue = a.payload + } + + implicit override object Writer extends Writer { + + // service backend doesn't support sending timings back to the python client + override def timings(env: Env, t: Timings): Unit = + () + + override def result(env: Env, result: Array[Byte]): Unit = + retryTransientErrors { + using(env.fs.createNoCompression(env.outputUrl)) { outputStream => + val output = new HailSocketAPIOutputStream(outputStream) + output.writeBool(true) + output.writeBytes(result) + } + } + + override def error(env: Env, t: Throwable): Unit = + retryTransientErrors { + val (shortMessage, expandedMessage, errorId) = + t match { + case t: HailWorkerException => + log.error( + "A worker failed. The exception was written for Python but we will also throw an exception to fail this driver job.", + t, + ) + (t.shortMessage, t.expandedMessage, t.errorId) + case _ => + log.error( + "An exception occurred in the driver. The exception was written for Python but we will re-throw to fail this driver job.", + t, + ) + handleForPython(t) + } + + using(env.fs.createNoCompression(env.outputUrl)) { outputStream => + val output = new HailSocketAPIOutputStream(outputStream) + output.writeBool(false) + output.writeString(shortMessage) + output.writeString(expandedMessage) + output.writeInt(errorId) + } + + throw t + } + } + + implicit override object Context extends Context { + override def scoped[A](env: Env)(f: ExecuteContext => A): (A, Timings) = + env.backend.withExecuteContext(f) } } @@ -508,174 +535,18 @@ case class SequenceConfig(fasta: String, index: String) case class ServiceBackendRPCPayload( tmp_dir: String, - remote_tmpdir: String, - billing_project: String, - worker_cores: String, - worker_memory: String, - storage: String, - cloudfuse_configs: Array[CloudfuseConfig], - regions: Array[String], flags: Map[String, String], custom_references: Array[String], liftovers: Map[String, Map[String, String]], sequences: Map[String, SequenceConfig], ) -case class ServiceBackendExecutePayload( - functions: Array[SerializedIRFunction], - idempotency_token: String, - payload: ExecutePayload, -) - -case class SerializedIRFunction( - name: String, - type_parameters: Array[String], - value_parameter_names: Array[String], - value_parameter_types: Array[String], - return_type: String, - rendered_body: String, +case class BatchJobConfig( + token: String, + billing_project: String, + worker_cores: String, + worker_memory: String, + storage: String, + cloudfuse_configs: Array[CloudfuseConfig], + regions: Array[String], ) - -class ServiceBackendAPI( - private[this] val backend: ServiceBackend, - private[this] val fs: FS, - private[this] val outputURL: String, -) extends Thread { - private[this] val LOAD_REFERENCES_FROM_DATASET = 1 - private[this] val VALUE_TYPE = 2 - private[this] val TABLE_TYPE = 3 - private[this] val MATRIX_TABLE_TYPE = 4 - private[this] val BLOCK_MATRIX_TYPE = 5 - private[this] val EXECUTE = 6 - private[this] val PARSE_VCF_METADATA = 7 - private[this] val IMPORT_FAM = 8 - private[this] val FROM_FASTA_FILE = 9 - - private[this] val log = Logger.getLogger(getClass.getName()) - - private[this] def doAction(action: Int, payload: JValue): Array[Byte] = retryTransientErrors { - implicit val formats: Formats = DefaultFormats - (action: @switch) match { - case LOAD_REFERENCES_FROM_DATASET => - val path = payload.extract[LoadReferencesFromDatasetPayload].path - backend.loadReferencesFromDataset(path) - case VALUE_TYPE => - val ir = payload.extract[IRTypePayload].ir - backend.valueType(ir) - case TABLE_TYPE => - val ir = payload.extract[IRTypePayload].ir - backend.tableType(ir) - case MATRIX_TABLE_TYPE => - val ir = payload.extract[IRTypePayload].ir - backend.matrixTableType(ir) - case BLOCK_MATRIX_TYPE => - val ir = payload.extract[IRTypePayload].ir - backend.blockMatrixType(ir) - case EXECUTE => - val qobExecutePayload = payload.extract[ServiceBackendExecutePayload] - val bufferSpecString = qobExecutePayload.payload.stream_codec - val code = qobExecutePayload.payload.ir - backend.withExecuteContext { ctx => - withIRFunctionsReadFromInput(qobExecutePayload.functions, ctx) { () => - val ir = IRParser.parse_value_ir(ctx, code) - backend.execute(ctx, ir) match { - case Left(()) => - Array() - case Right((pt, off)) => - using(new ByteArrayOutputStream()) { os => - Backend.encodeToOutputStream(ctx, pt, off, bufferSpecString, os) - os.toByteArray - } - } - } - } - case PARSE_VCF_METADATA => - val path = payload.extract[ParseVCFMetadataPayload].path - backend.parseVCFMetadata(path) - case IMPORT_FAM => - val famPayload = payload.extract[ImportFamPayload] - val path = famPayload.path - val quantPheno = famPayload.quant_pheno - val delimiter = famPayload.delimiter - val missing = famPayload.missing - backend.importFam(path, quantPheno, delimiter, missing) - case FROM_FASTA_FILE => - val fastaPayload = payload.extract[FromFASTAFilePayload] - backend.fromFASTAFile( - fastaPayload.name, - fastaPayload.fasta_file, - fastaPayload.index_file, - fastaPayload.x_contigs, - fastaPayload.y_contigs, - fastaPayload.mt_contigs, - fastaPayload.par, - ) - } - } - - private[this] def withIRFunctionsReadFromInput( - serializedFunctions: Array[SerializedIRFunction], - ctx: ExecuteContext, - )( - body: () => Array[Byte] - ): Array[Byte] = { - try { - serializedFunctions.foreach { func => - IRFunctionRegistry.registerIR( - ctx, - func.name, - func.type_parameters, - func.value_parameter_names, - func.value_parameter_types, - func.return_type, - func.rendered_body, - ) - } - body() - } finally - IRFunctionRegistry.clearUserFunctions() - } - - def executeOneCommand(action: Int, payload: JValue): Unit = { - try { - val result = doAction(action, payload) - retryTransientErrors { - using(fs.createNoCompression(outputURL)) { outputStream => - val output = new HailSocketAPIOutputStream(outputStream) - output.writeBool(true) - output.writeBytes(result) - } - } - } catch { - case exc: HailWorkerException => - retryTransientErrors { - using(fs.createNoCompression(outputURL)) { outputStream => - val output = new HailSocketAPIOutputStream(outputStream) - output.writeBool(false) - output.writeString(exc.shortMessage) - output.writeString(exc.expandedMessage) - output.writeInt(exc.errorId) - } - } - log.error( - "A worker failed. The exception was written for Python but we will also throw an exception to fail this driver job." - ) - throw exc - case t: Throwable => - val (shortMessage, expandedMessage, errorId) = handleForPython(t) - retryTransientErrors { - using(fs.createNoCompression(outputURL)) { outputStream => - val output = new HailSocketAPIOutputStream(outputStream) - output.writeBool(false) - output.writeString(shortMessage) - output.writeString(expandedMessage) - output.writeInt(errorId) - } - } - log.error( - "An exception occurred in the driver. The exception was written for Python but we will re-throw to fail this driver job." - ) - throw t - } - } -} diff --git a/hail/hail/src/is/hail/backend/service/Worker.scala b/hail/hail/src/is/hail/backend/service/Worker.scala index 4f596de3b4a..9c4a459591a 100644 --- a/hail/hail/src/is/hail/backend/service/Worker.scala +++ b/hail/hail/src/is/hail/backend/service/Worker.scala @@ -166,38 +166,23 @@ object Worker { timer.end("readInputs") timer.start("executeFunction") - if (HailContext.isInitialized) { - HailContext.get.backend = new ServiceBackend( - null, - null, - new HailClassLoader(getClass().getClassLoader()), - null, - null, - null, - null, - null, - null, - null, - scratchDir, - ) - } else { - HailContext( - // FIXME: workers should not have backends, but some things do need hail contexts - new ServiceBackend( - null, - null, - new HailClassLoader(getClass().getClassLoader()), - null, - null, - null, - null, - null, - null, - null, - scratchDir, - ) - ) - } + + // FIXME: workers should not have backends, but some things do need hail contexts + val backend = new ServiceBackend( + null, + null, + null, + new HailClassLoader(getClass().getClassLoader()), + null, + null, + null, + null, + null, + null, + ) + + if (HailContext.isInitialized) HailContext.get.backend = backend + else HailContext(backend) val result = using(new ServiceTaskContext(i)) { htc => try diff --git a/hail/hail/src/is/hail/backend/spark/SparkBackend.scala b/hail/hail/src/is/hail/backend/spark/SparkBackend.scala index 728c8067705..305a7809640 100644 --- a/hail/hail/src/is/hail/backend/spark/SparkBackend.scala +++ b/hail/hail/src/is/hail/backend/spark/SparkBackend.scala @@ -21,6 +21,7 @@ import is.hail.types.physical.{PStruct, PTuple} import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual._ import is.hail.utils._ +import is.hail.utils.ExecutionTimer.Timings import is.hail.variant.ReferenceGenome import scala.collection.mutable @@ -91,7 +92,7 @@ object SparkBackend { private var theSparkBackend: SparkBackend = _ - def sparkContext(op: String): SparkContext = HailContext.sparkBackend(op).sc + def sparkContext(implicit E: Enclosing): SparkContext = HailContext.sparkBackend.sc def checkSparkCompatibility(jarVersion: String, sparkVersion: String): Unit = { def majorMinor(version: String): String = version.split("\\.", 3).take(2).mkString(".") @@ -362,7 +363,6 @@ class SparkBackend( def createExecuteContextForTests( timer: ExecutionTimer, region: Region, - selfContainedExecution: Boolean = true, ): ExecuteContext = new ExecuteContext( tmpdir, @@ -372,7 +372,7 @@ class SparkBackend( fs, region, timer, - if (selfContainedExecution) null else NonOwningTempFileManager(longLifeTempFileManager), + null, theHailClassLoader, flags, new BackendContext { @@ -386,8 +386,8 @@ class SparkBackend( ImmutableMap.empty, ) - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = - ExecutionTimer.logTime { timer => + override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) = + ExecutionTimer.time { timer => ExecuteContext.scoped( tmpdir, localTmpdir, @@ -468,7 +468,7 @@ class SparkBackend( def defaultParallelism: Int = sc.defaultParallelism - override def asSpark(op: String): SparkBackend = this + override def asSpark(implicit E: Enclosing): SparkBackend = this def close(): Unit = { bmCache.values.foreach(_.unpersist()) diff --git a/hail/hail/src/is/hail/expr/ir/GenericLines.scala b/hail/hail/src/is/hail/expr/ir/GenericLines.scala index 8adb4fe75eb..000348f272f 100644 --- a/hail/hail/src/is/hail/expr/ir/GenericLines.scala +++ b/hail/hail/src/is/hail/expr/ir/GenericLines.scala @@ -449,7 +449,7 @@ class GenericLinesRDDPartition(val index: Int, val context: Any) extends Partiti class GenericLinesRDD( @(transient @param) contexts: IndexedSeq[Any], body: (Any) => CloseableIterator[GenericLine], -) extends RDD[GenericLine](SparkBackend.sparkContext("GenericLinesRDD"), Seq()) { +) extends RDD[GenericLine](SparkBackend.sparkContext, Seq()) { protected def getPartitions: Array[Partition] = contexts.iterator.zipWithIndex.map { case (c, i) => diff --git a/hail/hail/src/is/hail/expr/ir/TableIR.scala b/hail/hail/src/is/hail/expr/ir/TableIR.scala index 43784087b1b..d548ecbf9f9 100644 --- a/hail/hail/src/is/hail/expr/ir/TableIR.scala +++ b/hail/hail/src/is/hail/expr/ir/TableIR.scala @@ -713,7 +713,7 @@ case class PartitionRVDReader(rvd: RVD, uidFieldName: String) extends PartitionR val upcastF = mb.genFieldThisRef[AsmFunction2RegionLongLong]("rvdreader_upcast") val broadcastRVD = - mb.getObject[BroadcastRVD](new BroadcastRVD(ctx.backend.asSpark("RVDReader"), rvd)) + mb.getObject[BroadcastRVD](new BroadcastRVD(ctx.backend.asSpark, rvd)) val producer = new StreamProducer { override def method: EmitMethodBuilder[_] = mb @@ -3198,7 +3198,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { fileStack += filesToMerge filesToMerge = - ContextRDD.weaken(SparkBackend.sparkContext("TableMapRows.execute").parallelize( + ContextRDD.weaken(SparkBackend.sparkContext.parallelize( 0 until nToMerge, nToMerge, )) diff --git a/hail/hail/src/is/hail/expr/ir/TableValue.scala b/hail/hail/src/is/hail/expr/ir/TableValue.scala index 8b5441db48a..a1dd38c72d1 100644 --- a/hail/hail/src/is/hail/expr/ir/TableValue.scala +++ b/hail/hail/src/is/hail/expr/ir/TableValue.scala @@ -188,7 +188,7 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow } def toDF(): DataFrame = - HailContext.sparkBackend("toDF").sparkSession.createDataFrame( + HailContext.sparkBackend.sparkSession.createDataFrame( rvd.toRows, typ.rowType.schema.asInstanceOf[StructType], ) diff --git a/hail/hail/src/is/hail/expr/ir/functions/Functions.scala b/hail/hail/src/is/hail/expr/ir/functions/Functions.scala index b621937d220..c47b795e4c9 100644 --- a/hail/hail/src/is/hail/expr/ir/functions/Functions.scala +++ b/hail/hail/src/is/hail/expr/ir/functions/Functions.scala @@ -22,7 +22,9 @@ import scala.reflect._ import org.apache.spark.sql.Row object IRFunctionRegistry { - private val userAddedFunctions: mutable.Set[(String, (Type, Seq[Type], Seq[Type]))] = + type UserDefinedFnKey = (String, (Type, Seq[Type], Seq[Type])) + + private[this] val userAddedFunctions: mutable.Set[UserDefinedFnKey] = mutable.HashSet.empty def clearUserFunctions(): Unit = { @@ -70,25 +72,41 @@ object IRFunctionRegistry { typeParamStrs: Array[String], argNameStrs: Array[String], argTypeStrs: Array[String], - returnType: String, + returnTypeStr: String, bodyStr: String, - ): Unit = { + ): UserDefinedFnKey = { requireJavaIdentifier(name) - val argNames = argNameStrs.map(Name) val typeParameters = typeParamStrs.map(IRParser.parseType).toFastSeq val valueParameterTypes = argTypeStrs.map(IRParser.parseType).toFastSeq - val refMap = BindingEnv.eval(argNames.zip(valueParameterTypes): _*) - val body = IRParser.parse_value_ir(ctx, bodyStr, refMap) + val argNames = argNameStrs.map(Name) + + val body = + IRParser.parse_value_ir(ctx, bodyStr, BindingEnv.eval(argNames.zip(valueParameterTypes): _*)) + val returnType = IRParser.parseType(returnTypeStr) + assert(body.typ == returnType) - userAddedFunctions += ((name, (body.typ, typeParameters, valueParameterTypes))) + val key: UserDefinedFnKey = (name, (returnType, typeParameters, valueParameterTypes)) + userAddedFunctions += key addIR( name, typeParameters, valueParameterTypes, - IRParser.parseType(returnType), + returnType, false, (_, args, _) => Subst(body, BindingEnv.eval(argNames.zip(args): _*)), ) + key + } + + def unregisterIr(key: UserDefinedFnKey): Unit = { + val (name, (returnType, typeParameterTypes, valueParameterTypes)) = key + if (userAddedFunctions.remove(key)) + removeIRFunction(name, returnType, typeParameterTypes, valueParameterTypes) + else { + throw new NoSuchElementException( + s"No user defined function registered matching: ${prettyFunctionSignature(name, returnType, typeParameterTypes, valueParameterTypes)}" + ) + } } def removeIRFunction( @@ -113,7 +131,9 @@ object IRFunctionRegistry { case Seq() => None case Seq(f) => Some(f) case _ => - fatal(s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(",")}).") + fatal( + s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}." + ) } def lookupFunctionOrFail( @@ -125,28 +145,34 @@ object IRFunctionRegistry { jvmRegistry.lift(name) match { case None => fatal( - s"no functions found with the signature $name(${valueParameterTypes.mkString(", ")}): $returnType" + s"no functions found with the signature ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}." ) case Some(functions) => functions.filter(t => t.unify(typeParameters, valueParameterTypes, returnType) ).toSeq match { case Seq() => - val prettyFunctionSignature = - s"$name[${typeParameters.mkString(", ")}](${valueParameterTypes.mkString(", ")}): $returnType" val prettyMismatchedFunctionSignatures = functions.map(x => s" $x").mkString("\n") fatal( - s"No function found with the signature $prettyFunctionSignature.\n" + + s"No function found with the signature ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}.\n" + s"However, there are other functions with that name:\n$prettyMismatchedFunctionSignatures" ) case Seq(f) => f case _ => fatal( - s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(", ")})." + s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)})." ) } } } + private[this] def prettyFunctionSignature( + name: String, + returnType: Type, + typeParameterTypes: Seq[Type], + valueParameterTypes: Seq[Type], + ): String = + s"$name[${typeParameterTypes.mkString(", ")}](${valueParameterTypes.mkString(", ")}): $returnType" + def lookupIR( name: String, returnType: Type, @@ -166,7 +192,9 @@ object IRFunctionRegistry { case Seq() => None case Seq(kv) => Some(kv) case _ => - fatal(s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(",")}).") + fatal( + s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}." + ) } } diff --git a/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala b/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala index 85d8673cc59..0dc13ec6598 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/RVDToTableStage.scala @@ -126,9 +126,7 @@ object TableStageToRVD { partition = { ctx: Ref => _ts.partition(GetField(ctx, "context")) }, ) - val sparkContext = ctx.backend - .asSpark("TableStageToRVD") - .sc + val sparkContext = ctx.backend.asSpark.sc val globalsAndBroadcastVals = Let( diff --git a/hail/hail/src/is/hail/io/plink/LoadPlink.scala b/hail/hail/src/is/hail/io/plink/LoadPlink.scala index 4b7b30007b4..ca5572235fe 100644 --- a/hail/hail/src/is/hail/io/plink/LoadPlink.scala +++ b/hail/hail/src/is/hail/io/plink/LoadPlink.scala @@ -20,7 +20,6 @@ import is.hail.variant._ import org.apache.spark.TaskContext import org.apache.spark.sql.Row import org.json4s.{DefaultFormats, Formats, JValue} -import org.json4s.jackson.JsonMethods case class FamFileConfig( isQuantPheno: Boolean = false, @@ -83,14 +82,13 @@ object LoadPlink { isQuantPheno: Boolean, delimiter: String, missingValue: String, - ): String = { + ): JValue = { val ffConfig = FamFileConfig(isQuantPheno, delimiter, missingValue) val (data, ptyp) = LoadPlink.parseFam(fs, path, ffConfig) - val jv = JSONAnnotationImpex.exportAnnotation( + JSONAnnotationImpex.exportAnnotation( Row(ptyp.virtualType.toString, data), TStruct("type" -> TString, "data" -> TArray(ptyp.virtualType)), ) - JsonMethods.compact(jv) } def parseFam(fs: FS, filename: String, ffConfig: FamFileConfig) diff --git a/hail/hail/src/is/hail/io/vcf/LoadVCF.scala b/hail/hail/src/is/hail/io/vcf/LoadVCF.scala index 15ffe4a6ec2..e516052ce58 100644 --- a/hail/hail/src/is/hail/io/vcf/LoadVCF.scala +++ b/hail/hail/src/is/hail/io/vcf/LoadVCF.scala @@ -1685,7 +1685,7 @@ class PartitionedVCFRDD( file: String, @(transient @param) reverseContigMapping: Map[String, String], @(transient @param) _partitions: Array[Partition], -) extends RDD[WithContext[String]](SparkBackend.sparkContext("PartitionedVCFRDD"), Seq()) { +) extends RDD[WithContext[String]](SparkBackend.sparkContext, Seq()) { val contigRemappingBc = if (reverseContigMapping.size != 0) sparkContext.broadcast(reverseContigMapping) else null diff --git a/hail/hail/src/is/hail/linalg/BlockMatrix.scala b/hail/hail/src/is/hail/linalg/BlockMatrix.scala index 7cda148db27..f5a6eb1e1b4 100644 --- a/hail/hail/src/is/hail/linalg/BlockMatrix.scala +++ b/hail/hail/src/is/hail/linalg/BlockMatrix.scala @@ -44,7 +44,7 @@ case class CollectMatricesRDDPartition( } class CollectMatricesRDD(@transient var bms: IndexedSeq[BlockMatrix]) - extends RDD[BDM[Double]](SparkBackend.sparkContext("CollectMatricesRDD"), Nil) { + extends RDD[BDM[Double]](SparkBackend.sparkContext, Nil) { private val nBlocks = bms.map(_.blocks.getNumPartitions) private val firstPartition = nBlocks.scan(0)(_ + _).init @@ -114,7 +114,7 @@ object BlockMatrix { def apply(gp: GridPartitioner, piBlock: (GridPartitioner, Int) => ((Int, Int), BDM[Double])) : BlockMatrix = new BlockMatrix( - new RDD[((Int, Int), BDM[Double])](SparkBackend.sparkContext("BlockMatrix.apply"), Nil) { + new RDD[((Int, Int), BDM[Double])](SparkBackend.sparkContext, Nil) { override val partitioner = Some(gp) protected def getPartitions: Array[Partition] = Array.tabulate(gp.numPartitions)(pi => @@ -2199,7 +2199,7 @@ class WriteBlocksRDD( parentPartStarts: Array[Long], entryField: String, gp: GridPartitioner, -) extends RDD[(Int, String)](SparkBackend.sparkContext("WriteBlocksRDD"), Nil) { +) extends RDD[(Int, String)](SparkBackend.sparkContext, Nil) { require(gp.nRows == parentPartStarts.last) @@ -2369,7 +2369,7 @@ class BlockMatrixReadRowBlockedRDD( metadata: BlockMatrixMetadata, maybeMaximumCacheMemoryInBytes: Option[Int], ) extends RDD[RVDContext => Iterator[Long]]( - SparkBackend.sparkContext("BlockMatrixReadRowBlockedRDD"), + SparkBackend.sparkContext, Nil, ) { import BlockMatrixReadRowBlockedRDD._ diff --git a/hail/hail/src/is/hail/linalg/RowMatrix.scala b/hail/hail/src/is/hail/linalg/RowMatrix.scala index 6adbaca171a..969ef2d4916 100644 --- a/hail/hail/src/is/hail/linalg/RowMatrix.scala +++ b/hail/hail/src/is/hail/linalg/RowMatrix.scala @@ -286,7 +286,7 @@ class ReadBlocksAsRowsRDD( partFiles: IndexedSeq[String], partitionCounts: Array[Long], gp: GridPartitioner, -) extends RDD[(Long, Array[Double])](SparkBackend.sparkContext("ReadBlocksAsRowsRDD"), Nil) { +) extends RDD[(Long, Array[Double])](SparkBackend.sparkContext, Nil) { private val partitionStarts = partitionCounts.scanLeft(0L)(_ + _) diff --git a/hail/hail/src/is/hail/methods/MatrixExportEntriesByCol.scala b/hail/hail/src/is/hail/methods/MatrixExportEntriesByCol.scala index 43ade7d1418..5db999bbbba 100644 --- a/hail/hail/src/is/hail/methods/MatrixExportEntriesByCol.scala +++ b/hail/hail/src/is/hail/methods/MatrixExportEntriesByCol.scala @@ -197,7 +197,7 @@ case class MatrixExportEntriesByCol( // clean up temporary files val temps = tempFolders.result() val fsBc = fs.broadcast - SparkBackend.sparkContext("MatrixExportEntriesByCol.execute").parallelize( + SparkBackend.sparkContext.parallelize( temps, (temps.length / 32).max(1), ).foreach(path => fsBc.value.delete(path, recursive = true)) diff --git a/hail/hail/src/is/hail/rvd/RVD.scala b/hail/hail/src/is/hail/rvd/RVD.scala index 9c463915722..2cd97c3df03 100644 --- a/hail/hail/src/is/hail/rvd/RVD.scala +++ b/hail/hail/src/is/hail/rvd/RVD.scala @@ -1377,7 +1377,7 @@ object RVD { ) } - val sc = SparkBackend.sparkContext("writeRowsSplitFiles") + val sc = SparkBackend.sparkContext val localTmpdir = execCtx.localTmpdir val fs = execCtx.fs val fsBc = fs.broadcast diff --git a/hail/hail/src/is/hail/sparkextras/ContextRDD.scala b/hail/hail/src/is/hail/sparkextras/ContextRDD.scala index fe9c4d4e4ac..9ecefc16e13 100644 --- a/hail/hail/src/is/hail/sparkextras/ContextRDD.scala +++ b/hail/hail/src/is/hail/sparkextras/ContextRDD.scala @@ -86,7 +86,7 @@ object ContextRDD { def empty[T: ClassTag](): ContextRDD[T] = new ContextRDD( - SparkBackend.sparkContext("ContextRDD.empty").emptyRDD[RVDContext => Iterator[T]] + SparkBackend.sparkContext.emptyRDD[RVDContext => Iterator[T]] ) def union[T: ClassTag]( @@ -117,7 +117,7 @@ object ContextRDD { filterAndReplace: TextInputFilterAndReplace, ): ContextRDD[WithContext[String]] = ContextRDD.weaken( - SparkBackend.sparkContext("ContxtRDD.textFilesLines").textFilesLines( + SparkBackend.sparkContext.textFilesLines( files, nPartitions, ) @@ -129,12 +129,12 @@ object ContextRDD { weaken(sc.parallelize(data, nPartitions.getOrElse(sc.defaultMinPartitions))).map(x => x) def parallelize[T: ClassTag](data: Seq[T], numSlices: Int): ContextRDD[T] = - weaken(SparkBackend.sparkContext("ContextRDD.parallelize").parallelize(data, numSlices)).map { + weaken(SparkBackend.sparkContext.parallelize(data, numSlices)).map { x => x } def parallelize[T: ClassTag](data: Seq[T]): ContextRDD[T] = - weaken(SparkBackend.sparkContext("ContextRDD.parallelize").parallelize(data)).map(x => x) + weaken(SparkBackend.sparkContext.parallelize(data)).map(x => x) type ElementType[T] = RVDContext => Iterator[T] diff --git a/hail/hail/src/is/hail/sparkextras/IndexReadRDD.scala b/hail/hail/src/is/hail/sparkextras/IndexReadRDD.scala index d7316f3d736..94e00942a2f 100644 --- a/hail/hail/src/is/hail/sparkextras/IndexReadRDD.scala +++ b/hail/hail/src/is/hail/sparkextras/IndexReadRDD.scala @@ -15,7 +15,7 @@ class IndexReadRDD[T: ClassTag]( @transient val partFiles: Array[String], @transient val intervalBounds: Option[Array[Interval]], f: (IndexedFilePartition, TaskContext) => T, -) extends RDD[T](SparkBackend.sparkContext("IndexReadRDD"), Nil) { +) extends RDD[T](SparkBackend.sparkContext, Nil) { def getPartitions: Array[Partition] = Array.tabulate(partFiles.length) { i => IndexedFilePartition(i, partFiles(i), intervalBounds.map(_(i))) diff --git a/hail/hail/src/is/hail/types/virtual/package.scala b/hail/hail/src/is/hail/types/virtual/package.scala new file mode 100644 index 00000000000..032c862bafb --- /dev/null +++ b/hail/hail/src/is/hail/types/virtual/package.scala @@ -0,0 +1,12 @@ +package is.hail.types + +package virtual { + sealed abstract class Kind[T <: VType] extends Product with Serializable + + object Kinds { + case object Value extends Kind[Type] + case object Table extends Kind[TableType] + case object Matrix extends Kind[MatrixType] + case object BlockMatrix extends Kind[BlockMatrixType] + } +} diff --git a/hail/hail/src/is/hail/utils/SpillingCollectIterator.scala b/hail/hail/src/is/hail/utils/SpillingCollectIterator.scala index a8bb6d8babb..cbf9833ed15 100644 --- a/hail/hail/src/is/hail/utils/SpillingCollectIterator.scala +++ b/hail/hail/src/is/hail/utils/SpillingCollectIterator.scala @@ -16,7 +16,7 @@ object SpillingCollectIterator { val nPartitions = rdd.partitions.length val x = new SpillingCollectIterator(localTmpdir, fs, nPartitions, sizeLimit) val ctc = classTag[T] - SparkBackend.sparkContext("SpillingCollectIterator.apply").runJob( + SparkBackend.sparkContext.runJob( rdd, (_, it: Iterator[T]) => it.toArray(ctc), 0 until nPartitions, diff --git a/hail/hail/test/src/is/hail/HailSuite.scala b/hail/hail/test/src/is/hail/HailSuite.scala index e986ee5dc6d..bff392dcd6a 100644 --- a/hail/hail/test/src/is/hail/HailSuite.scala +++ b/hail/hail/test/src/is/hail/HailSuite.scala @@ -49,7 +49,7 @@ object HailSuite { lazy val hc: HailContext = { val hc = withSparkBackend() - hc.sparkBackend("HailSuite.hc").flags.set("lower", "1") + hc.backend.asSpark.flags.set("lower", "1") hc.checkRVDKeys = true hc } @@ -62,7 +62,7 @@ class HailSuite extends TestNGSuite { @BeforeClass def ensureHailContextInitialized(): Unit = hc - def backend: SparkBackend = hc.sparkBackend("HailSuite.backend") + def backend: SparkBackend = hc.backend.asSpark def sc: SparkContext = backend.sc diff --git a/hail/hail/test/src/is/hail/TestUtils.scala b/hail/hail/test/src/is/hail/TestUtils.scala index da061cc8dc6..550be49398a 100644 --- a/hail/hail/test/src/is/hail/TestUtils.scala +++ b/hail/hail/test/src/is/hail/TestUtils.scala @@ -123,9 +123,9 @@ object TestUtils { if (agg.isDefined || !env.isEmpty || !args.isEmpty) throw new LowererUnsupportedOperation("can't test with aggs or user defined args/env") - HailContext.sparkBackend("TestUtils.loweredExecute") - .jvmLowerAndExecute(ctx, x, optimize = false, lowerTable = true, lowerBM = true, - print = bytecodePrinter) + HailContext.sparkBackend.jvmLowerAndExecute(ctx, x, optimize = false, lowerTable = true, + lowerBM = true, + print = bytecodePrinter) } def eval(x: IR): Any = ExecuteContext.scoped { ctx => diff --git a/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala b/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala index 461eed812f4..3139849a54f 100644 --- a/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala +++ b/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala @@ -2,7 +2,9 @@ package is.hail.backend import is.hail.HailFeatureFlags import is.hail.asm4s.HailClassLoader -import is.hail.backend.service.{ServiceBackend, ServiceBackendContext, ServiceBackendRPCPayload} +import is.hail.backend.service.{ + BatchJobConfig, ServiceBackend, ServiceBackendContext, ServiceBackendRPCPayload, +} import is.hail.io.fs.{CloudStorageFSConfig, RouterFS} import is.hail.services._ import is.hail.services.JobGroupStates.Success @@ -25,9 +27,9 @@ import org.testng.annotations.Test class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionValues { @Test def testCreateJobPayload(): Unit = - withMockDriverContext { rpcConfig => + withMockDriverContext { case (rpcConfig, jobConfig) => val batchClient = mock[BatchClient] - using(ServiceBackend(batchClient, rpcConfig)) { backend => + using(ServiceBackend(batchClient, rpcConfig, jobConfig)) { backend => val contexts = Array.tabulate(1)(_.toString.getBytes) // verify that @@ -42,12 +44,12 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV val jobs = jobGroup.jobs jobs.length shouldEqual contexts.length jobs.foreach { payload => - payload.regions.value shouldBe rpcConfig.regions + payload.regions.value shouldBe jobConfig.regions payload.resources.value shouldBe JobResources( preemptible = true, - cpu = Some(rpcConfig.worker_cores), - memory = Some(rpcConfig.worker_memory), - storage = Some(rpcConfig.storage), + cpu = Some(jobConfig.worker_cores), + memory = Some(jobConfig.worker_memory), + storage = Some(jobConfig.storage), ) } @@ -62,7 +64,7 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV jobGroupId shouldEqual backend.batchConfig.jobGroupId + 1 val resultsDir = - Path(backend.serviceBackendContext.remoteTmpDir) / + Path(rpcConfig.tmp_dir) / "parallelizeAndComputeWithIndex" / tokenUrlSafe @@ -83,7 +85,11 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV val (failure, _) = backend.parallelizeAndComputeWithIndex( - backend.serviceBackendContext, + ServiceBackendContext( + remoteTmpDir = rpcConfig.tmp_dir, + jobConfig = jobConfig, + executionCache = ExecutionCache.noCache, + ), backend.fs, contexts, "stage1", @@ -96,58 +102,52 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV } } - def ServiceBackend(client: BatchClient, rpcConfig: ServiceBackendRPCPayload): ServiceBackend = { + def ServiceBackend( + client: BatchClient, + rpcConfig: ServiceBackendRPCPayload, + jobConfig: BatchJobConfig, + ): ServiceBackend = { val flags = HailFeatureFlags.fromEnv() val fs = RouterFS.buildRoutes(CloudStorageFSConfig()) new ServiceBackend( - jarSpec = GitRevision("123"), name = "name", - theHailClassLoader = new HailClassLoader(getClass.getClassLoader), - references = mutable.Map.empty, batchClient = client, + jarSpec = GitRevision("123"), + theHailClassLoader = new HailClassLoader(getClass.getClassLoader), batchConfig = BatchConfig(batchId = Random.nextInt(), jobGroupId = Random.nextInt()), + rpcConfig = rpcConfig, + jobConfig = jobConfig, flags = flags, - tmpdir = rpcConfig.tmp_dir, fs = fs, - serviceBackendContext = - new ServiceBackendContext( - rpcConfig.billing_project, - rpcConfig.remote_tmpdir, - rpcConfig.worker_cores, - rpcConfig.worker_memory, - rpcConfig.storage, - rpcConfig.regions, - rpcConfig.cloudfuse_configs, - profile = false, - ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir), - ), - scratchDir = rpcConfig.remote_tmpdir, + references = mutable.Map.empty, ) } - def withMockDriverContext(test: ServiceBackendRPCPayload => Any): Any = + def withMockDriverContext(test: (ServiceBackendRPCPayload, BatchJobConfig) => Any): Any = using(LocalTmpFolder) { tmp => withObjectSpied[is.hail.utils.UtilsType] { // not obvious how to pull out `tokenUrlSafe` and inject this directory // using a spy is a hack and i don't particularly like it. when(is.hail.utils.tokenUrlSafe) thenAnswer "TOKEN" - test { + test( ServiceBackendRPCPayload( tmp_dir = tmp.path, - remote_tmpdir = tmp.path, + flags = Map(), + custom_references = Array(), + liftovers = Map(), + sequences = Map(), + ), + BatchJobConfig( + token = tokenUrlSafe, billing_project = "fancy", worker_cores = "128", worker_memory = "a lot.", storage = "a big ssd?", cloudfuse_configs = Array(), regions = Array("lunar1"), - flags = Map(), - custom_references = Array(), - liftovers = Map(), - sequences = Map(), - ) - } + ), + ) } } diff --git a/hail/hail/test/src/is/hail/check/Gen.scala b/hail/hail/test/src/is/hail/check/Gen.scala index be6a9f8eb94..78ab722918f 100644 --- a/hail/hail/test/src/is/hail/check/Gen.scala +++ b/hail/hail/test/src/is/hail/check/Gen.scala @@ -553,10 +553,10 @@ class Gen[+T](val apply: Parameters => T) extends AnyVal { def flatMap[U](f: T => Gen[U]): Gen[U] = Gen(p => f(apply(p))(p)) def resize(newSize: Int): Gen[T] = Gen((p: Parameters) => apply(p.copy(size = newSize))) - + def withFilter(f: T => Boolean): Gen[T] = Gen((p: Parameters) => Stream.continually(apply(p)).takeWhile(f).head) - def filter(f: T => Boolean): Gen[T] = + def filter(f: T => Boolean): Gen[T] = withFilter(f) } diff --git a/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala b/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala index 05422e43c51..c579894a642 100644 --- a/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala +++ b/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala @@ -1,7 +1,7 @@ package is.hail.variant import is.hail.{HailSuite, TestUtils} -import is.hail.backend.{ExecuteContext, HailStateManager} +import is.hail.check.Prop._ import is.hail.check.Properties import is.hail.expr.ir.EmitFunctionBuilder import is.hail.io.reference.{FASTAReader, FASTAReaderConfig, LiftOver} @@ -13,13 +13,9 @@ import org.testng.annotations.Test class ReferenceGenomeSuite extends HailSuite { - def hasReference(name: String) = ctx.stateManager.referenceGenomes.contains(name) - - def getReference(name: String) = ctx.references(name) - @Test def testGRCh37(): Unit = { - assert(hasReference(ReferenceGenome.GRCh37)) - val grch37 = getReference(ReferenceGenome.GRCh37) + assert(ctx.references.contains(ReferenceGenome.GRCh37)) + val grch37 = ctx.references(ReferenceGenome.GRCh37) assert(grch37.inX("X") && grch37.inY("Y") && grch37.isMitochondrial("MT")) assert(grch37.contigLength("1") == 249250621) @@ -34,8 +30,8 @@ class ReferenceGenomeSuite extends HailSuite { } @Test def testGRCh38(): Unit = { - assert(hasReference(ReferenceGenome.GRCh38)) - val grch38 = getReference(ReferenceGenome.GRCh38) + assert(ctx.references.contains(ReferenceGenome.GRCh38)) + val grch38 = ctx.references(ReferenceGenome.GRCh38) assert(grch38.inX("chrX") && grch38.inY("chrY") && grch38.isMitochondrial("chrM")) assert(grch38.contigLength("chr1") == 248956422) @@ -103,12 +99,12 @@ class ReferenceGenomeSuite extends HailSuite { @Test def testContigRemap(): Unit = { val mapping = Map("23" -> "foo") TestUtils.interceptFatal("have remapped contigs in reference genome")( - getReference(ReferenceGenome.GRCh37).validateContigRemap(mapping) + ctx.references(ReferenceGenome.GRCh37).validateContigRemap(mapping) ) } @Test def testComparisonOps(): Unit = { - val rg = getReference(ReferenceGenome.GRCh37) + val rg = ctx.references(ReferenceGenome.GRCh37) // Test contigs assert(rg.compare("3", "18") < 0) @@ -127,7 +123,7 @@ class ReferenceGenomeSuite extends HailSuite { @Test def testWriteToFile(): Unit = { val tmpFile = ctx.createTmpPath("grWrite", "json") - val rg = getReference(ReferenceGenome.GRCh37) + val rg = ctx.references(ReferenceGenome.GRCh37) rg.copy(name = "GRCh37_2").write(fs, tmpFile) val gr2 = ReferenceGenome.fromFile(fs, tmpFile) @@ -221,19 +217,18 @@ class ReferenceGenomeSuite extends HailSuite { } @Test def testSerializeOnFB(): Unit = { - ExecuteContext.scoped { ctx => - val grch38 = ctx.references(ReferenceGenome.GRCh38) - val fb = EmitFunctionBuilder[String, Boolean](ctx, "serialize_rg") - val rgfield = fb.getReferenceGenome(grch38.name) - fb.emit(rgfield.invoke[String, Boolean]("isValidContig", fb.getCodeParam[String](1))) - - val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r) + val grch38 = ctx.references(ReferenceGenome.GRCh38) + val fb = EmitFunctionBuilder[String, Boolean](ctx, "serialize_rg") + val rgfield = fb.getReferenceGenome(grch38.name) + fb.emit(rgfield.invoke[String, Boolean]("isValidContig", fb.getCodeParam[String](1))) + ctx.scopedExecution { (cl, fs, tc, r) => + val f = fb.resultWithIndex()(cl, fs, tc, r) assert(f("X") == grch38.isValidContig("X")) } } - @Test def testSerializeWithLiftoverOnFB(): Unit = { - ExecuteContext.scoped { ctx => + @Test def testSerializeWithLiftoverOnFB(): Unit = + ctx.local(references = ctx.references.mapValues(_.copy())) { ctx => val grch37 = ctx.references(ReferenceGenome.GRCh37) val liftoverFile = getTestResource("grch37_to_grch38_chr20.over.chain.gz") @@ -249,13 +244,13 @@ class ReferenceGenomeSuite extends HailSuite { fb.getCodeParam[Double](3), )) - val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r) - assert(f("GRCh38", Locus("20", 60001), 0.95) == grch37.liftoverLocus( - "GRCh38", - Locus("20", 60001), - 0.95, - )) - grch37.removeLiftover("GRCh38") + ctx.scopedExecution { (cl, fs, tc, r) => + val f = fb.resultWithIndex()(cl, fs, tc, r) + assert(f("GRCh38", Locus("20", 60001), 0.95) == grch37.liftoverLocus( + "GRCh38", + Locus("20", 60001), + 0.95, + )) + } } - } } diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index 995d0b668c4..e7e14bbe90a 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -3,7 +3,7 @@ import zipfile from dataclasses import dataclass from enum import Enum -from typing import AbstractSet, Any, ClassVar, Dict, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import AbstractSet, Any, ClassVar, Dict, List, Mapping, Optional, Set, Tuple, TypeVar, Union import orjson @@ -68,15 +68,45 @@ def local_jar_information() -> LocalJarInformation: raise ValueError(f'Hail requires either {hail_jar} or {hail_all_spark_jar}.') +class IRFunction: + def __init__( + self, + name: str, + type_parameters: Union[Tuple[HailType, ...], List[HailType]], + value_parameter_names: Union[Tuple[str, ...], List[str]], + value_parameter_types: Union[Tuple[HailType, ...], List[HailType]], + return_type: HailType, + body: Expression, + ): + assert len(value_parameter_names) == len(value_parameter_types) + render = CSERenderer() + self._name = name + self._type_parameters = type_parameters + self._value_parameter_names = value_parameter_names + self._value_parameter_types = value_parameter_types + self._return_type = return_type + self._rendered_body = render(finalize_randomness(body._ir)) + + def to_dataclass(self): + return SerializedIRFunction( + name=self._name, + type_parameters=[tp._parsable_string() for tp in self._type_parameters], + value_parameter_names=list(self._value_parameter_names), + value_parameter_types=[vpt._parsable_string() for vpt in self._value_parameter_types], + return_type=self._return_type._parsable_string(), + rendered_body=self._rendered_body, + ) + + class ActionTag(Enum): - LOAD_REFERENCES_FROM_DATASET = 1 - VALUE_TYPE = 2 - TABLE_TYPE = 3 - MATRIX_TABLE_TYPE = 4 - BLOCK_MATRIX_TYPE = 5 - EXECUTE = 6 - PARSE_VCF_METADATA = 7 - IMPORT_FAM = 8 + VALUE_TYPE = 1 + TABLE_TYPE = 2 + MATRIX_TABLE_TYPE = 3 + BLOCK_MATRIX_TYPE = 4 + EXECUTE = 5 + PARSE_VCF_METADATA = 6 + IMPORT_FAM = 7 + LOAD_REFERENCES_FROM_DATASET = 8 FROM_FASTA_FILE = 9 @@ -90,11 +120,21 @@ class IRTypePayload(ActionPayload): ir: str +@dataclass +class SerializedIRFunction: + name: str + type_parameters: List[str] + value_parameter_names: List[str] + value_parameter_types: List[str] + return_type: str + rendered_body: str + + @dataclass class ExecutePayload(ActionPayload): ir: str + fns: List[SerializedIRFunction] stream_codec: str - timed: bool @dataclass @@ -164,6 +204,8 @@ def _valid_flags(self) -> AbstractSet[str]: def __init__(self): self._persisted_locations = dict() self._references = {} + self.functions: List[IRFunction] = [] + self._registered_ir_function_names: Set[str] = set() @abc.abstractmethod def validate_file(self, uri: str): @@ -171,10 +213,15 @@ def validate_file(self, uri: str): @abc.abstractmethod def stop(self): - pass + self.functions = [] + self._registered_ir_function_names = set() def execute(self, ir: BaseIR, timed: bool = False) -> Any: - payload = ExecutePayload(self._render_ir(ir), '{"name":"StreamBufferSpec"}', timed) + payload = ExecutePayload( + self._render_ir(ir), + fns=[fn.to_dataclass() for fn in self.functions], + stream_codec='{"name":"StreamBufferSpec"}', + ) try: result, timings = self._rpc(ActionTag.EXECUTE, payload) except FatalError as e: @@ -300,7 +347,6 @@ def unpersist(self, dataset: Dataset) -> Dataset: tempfile.__exit__(None, None, None) return unpersisted - @abc.abstractmethod def register_ir_function( self, name: str, @@ -310,11 +356,13 @@ def register_ir_function( return_type: HailType, body: Expression, ): - pass + self._registered_ir_function_names.add(name) + self.functions.append( + IRFunction(name, type_parameters, value_parameter_names, value_parameter_types, return_type, body) + ) - @abc.abstractmethod def _is_registered_ir_function_name(self, name: str) -> bool: - pass + return name in self._registered_ir_function_names @abc.abstractmethod def persist_expression(self, expr: Expression) -> Expression: diff --git a/hail/python/hail/backend/local_backend.py b/hail/python/hail/backend/local_backend.py index 708eacaef2d..7bcb1145259 100644 --- a/hail/python/hail/backend/local_backend.py +++ b/hail/python/hail/backend/local_backend.py @@ -119,7 +119,7 @@ def register_ir_function( ) def stop(self): - super().stop() + super(Py4JBackend, self).stop() self._exit_stack.close() uninstall_exception_handler() diff --git a/hail/python/hail/backend/py4j_backend.py b/hail/python/hail/backend/py4j_backend.py index 03bc9be5dfe..2d787a0f339 100644 --- a/hail/python/hail/backend/py4j_backend.py +++ b/hail/python/hail/backend/py4j_backend.py @@ -4,7 +4,7 @@ import socketserver import sys from threading import Thread -from typing import Mapping, Optional, Set, Tuple +from typing import Mapping, Optional, Tuple import orjson import py4j @@ -193,8 +193,6 @@ def decode_bytearray(encoded): self._backend_server.start() self._requests_session = requests.Session() - self._registered_ir_function_names: Set[str] = set() - # This has to go after creating the SparkSession. Unclear why. # Maybe it does its own patch? install_exception_handler() @@ -238,9 +236,6 @@ def persist_expression(self, expr): t = expr.dtype return construct_expr(JavaIR(t, self._jbackend.pyExecuteLiteral(self._render_ir(expr._ir))), t) - def _is_registered_ir_function_name(self, name: str) -> bool: - return name in self._registered_ir_function_names - def set_flags(self, **flags: Mapping[str, str]): available = self._jbackend.pyAvailableFlags() invalid = [] @@ -275,12 +270,6 @@ def add_liftover(self, name, chain_file, dest_reference_genome): def remove_liftover(self, name, dest_reference_genome): self._jbackend.pyRemoveLiftover(name, dest_reference_genome) - def _parse_value_ir(self, code, ref_map={}): - return self._jbackend.parse_value_ir( - code, - {k: t._parsable_string() for k, t in ref_map.items()}, - ) - def _register_ir_function(self, name, type_parameters, argument_names, argument_types, return_type, code): self._registered_ir_function_names.add(name) self._jbackend.pyRegisterIR( @@ -292,12 +281,6 @@ def _register_ir_function(self, name, type_parameters, argument_names, argument_ code, ) - def _parse_table_ir(self, code): - return self._jbackend.parse_table_ir(code) - - def _parse_matrix_ir(self, code): - return self._jbackend.parse_matrix_ir(code) - def _parse_blockmatrix_ir(self, code): return self._jbackend.parse_blockmatrix_ir(code) @@ -309,5 +292,5 @@ def stop(self): self._jbackend.close() self._jhc.stop() self._jhc = None - self._registered_ir_function_names = set() uninstall_exception_handler() + super().stop() diff --git a/hail/python/hail/backend/service_backend.py b/hail/python/hail/backend/service_backend.py index 953ce749b05..595811d67e5 100644 --- a/hail/python/hail/backend/service_backend.py +++ b/hail/python/hail/backend/service_backend.py @@ -12,10 +12,6 @@ import hailtop.aiotools.fs as afs from hail.context import TemporaryDirectory, TemporaryFilename, tmp_dir from hail.experimental import read_expression, write_expression -from hail.expr.expressions.base_expression import Expression -from hail.expr.types import HailType -from hail.ir import finalize_randomness -from hail.ir.renderer import CSERenderer from hail.utils import FatalError from hail.version import __revision__, __version__ from hailtop import yamlx @@ -33,7 +29,7 @@ from ..builtin_references import BUILTIN_REFERENCES from ..utils import ANY_REGION -from .backend import ActionPayload, ActionTag, Backend, ExecutePayload, fatal_error_from_java_error_triplet +from .backend import ActionPayload, ActionTag, Backend, fatal_error_from_java_error_triplet ReferenceGenomeConfig = Dict[str, Any] @@ -67,53 +63,6 @@ async def read_str(strm: afs.ReadableStream) -> str: return b.decode('utf-8') -@dataclass -class SerializedIRFunction: - name: str - type_parameters: List[str] - value_parameter_names: List[str] - value_parameter_types: List[str] - return_type: str - rendered_body: str - - -class IRFunction: - def __init__( - self, - name: str, - type_parameters: Union[Tuple[HailType, ...], List[HailType]], - value_parameter_names: Union[Tuple[str, ...], List[str]], - value_parameter_types: Union[Tuple[HailType, ...], List[HailType]], - return_type: HailType, - body: Expression, - ): - assert len(value_parameter_names) == len(value_parameter_types) - render = CSERenderer() - self._name = name - self._type_parameters = type_parameters - self._value_parameter_names = value_parameter_names - self._value_parameter_types = value_parameter_types - self._return_type = return_type - self._rendered_body = render(finalize_randomness(body._ir)) - - def to_dataclass(self): - return SerializedIRFunction( - name=self._name, - type_parameters=[tp._parsable_string() for tp in self._type_parameters], - value_parameter_names=list(self._value_parameter_names), - value_parameter_types=[vpt._parsable_string() for vpt in self._value_parameter_types], - return_type=self._return_type._parsable_string(), - rendered_body=self._rendered_body, - ) - - -@dataclass -class ServiceBackendExecutePayload(ActionPayload): - functions: List[SerializedIRFunction] - idempotency_token: str - payload: ExecutePayload - - @dataclass class CloudfuseConfig: bucket: str @@ -130,17 +79,21 @@ class SequenceConfig: @dataclass class ServiceBackendRPCConfig: tmp_dir: str - remote_tmpdir: str + flags: Dict[str, str] + custom_references: List[str] + liftovers: Dict[str, Dict[str, str]] + sequences: Dict[str, SequenceConfig] + + +@dataclass +class BatchJobConfig: + token: str billing_project: str worker_cores: str worker_memory: str storage: str cloudfuse_configs: List[CloudfuseConfig] regions: List[str] - flags: Dict[str, str] - custom_references: List[str] - liftovers: Dict[str, Dict[str, str]] - sequences: Dict[str, SequenceConfig] class ServiceBackend(Backend): @@ -149,14 +102,14 @@ class ServiceBackend(Backend): DRIVER = "driver" # is.hail.backend.service.ServiceBackendSocketAPI2 protocol - LOAD_REFERENCES_FROM_DATASET = 1 - VALUE_TYPE = 2 - TABLE_TYPE = 3 - MATRIX_TABLE_TYPE = 4 - BLOCK_MATRIX_TYPE = 5 - EXECUTE = 6 - PARSE_VCF_METADATA = 7 - IMPORT_FAM = 8 + VALUE_TYPE = 1 + TABLE_TYPE = 2 + MATRIX_TABLE_TYPE = 3 + BLOCK_MATRIX_TYPE = 4 + EXECUTE = 5 + PARSE_VCF_METADATA = 6 + IMPORT_FAM = 7 + LOAD_REFERENCES_FROM_DATASET = 8 FROM_FASTA_FILE = 9 @staticmethod @@ -289,7 +242,6 @@ def __init__( self.batch_attributes = batch_attributes self.remote_tmpdir = remote_tmpdir self.flags: Dict[str, str] = {} - self.functions: List[IRFunction] = [] self._registered_ir_function_names: Set[str] = set() self.driver_cores = driver_cores self.driver_memory = driver_memory @@ -334,16 +286,16 @@ def logger(self): def stop(self): hail_event_loop().run_until_complete(self._stop()) + super().stop() async def _stop(self): await self._async_exit_stack.aclose() - self.functions = [] - self._registered_ir_function_names = set() async def _run_on_batch( self, name: str, service_backend_config: ServiceBackendRPCConfig, + job_config: BatchJobConfig, action: ActionTag, payload: ActionPayload, *, @@ -357,7 +309,8 @@ async def _run_on_batch( async with await self._async_fs.create(iodir + '/in') as infile: await infile.write( orjson.dumps({ - 'config': service_backend_config, + 'rpc_config': service_backend_config, + 'job_config': job_config, 'action': action.value, 'payload': payload, }) @@ -375,8 +328,8 @@ async def _run_on_batch( elif self.driver_memory is not None: resources['memory'] = str(self.driver_memory) - if service_backend_config.storage != '0Gi': - resources['storage'] = service_backend_config.storage + if job_config.storage != '0Gi': + resources['storage'] = job_config.storage j = self._batch.create_jvm_job( jar_spec=self.jar_spec, @@ -390,7 +343,7 @@ async def _run_on_batch( resources=resources, attributes={'name': name + '_driver'}, regions=self.regions, - cloudfuse=[(c.bucket, c.mount_path, c.read_only) for c in service_backend_config.cloudfuse_configs], + cloudfuse=[(c.bucket, c.mount_path, c.read_only) for c in job_config.cloudfuse_configs], profile=self.flags['profile'] is not None, ) await self._batch.submit(disable_progress_bar=True) @@ -467,11 +420,6 @@ def _rpc(self, action: ActionTag, payload: ActionPayload) -> Tuple[bytes, Option return self._cancel_on_ctrl_c(self._async_rpc(action, payload)) async def _async_rpc(self, action: ActionTag, payload: ActionPayload): - if isinstance(payload, ExecutePayload): - payload = ServiceBackendExecutePayload( - [f.to_dataclass() for f in self.functions], self._batch.token, payload - ) - storage_requirement_bytes = 0 readonly_fuse_buckets: Set[str] = set() @@ -486,31 +434,37 @@ async def _async_rpc(self, action: ActionTag, payload: ActionPayload): readonly_fuse_buckets.add(bucket) storage_requirement_bytes += await (await self._async_fs.statfile(blob)).size() sequence_file_mounts[rg_name] = SequenceConfig( - f'/cloudfuse/{fasta_bucket}/{fasta_path}', f'/cloudfuse/{index_bucket}/{index_path}' + f'/cloudfuse/{fasta_bucket}/{fasta_path}', + f'/cloudfuse/{index_bucket}/{index_path}', ) - storage_gib_str = f'{math.ceil(storage_requirement_bytes / 1024 / 1024 / 1024)}Gi' - qob_config = ServiceBackendRPCConfig( - tmp_dir=tmp_dir(), - remote_tmpdir=self.remote_tmpdir, - billing_project=self.billing_project, - worker_cores=str(self.worker_cores), - worker_memory=str(self.worker_memory), - storage=storage_gib_str, - cloudfuse_configs=[ - CloudfuseConfig(bucket, f'/cloudfuse/{bucket}', True) for bucket in readonly_fuse_buckets - ], - regions=self.regions, - flags=self.flags, - custom_references=[ - orjson.dumps(rg._config).decode('utf-8') - for rg in self._references.values() - if rg.name not in BUILTIN_REFERENCES - ], - liftovers={rg.name: rg._liftovers for rg in self._references.values() if len(rg._liftovers) > 0}, - sequences=sequence_file_mounts, + return await self._run_on_batch( + name=f'{action.name.lower()}(...)', + service_backend_config=ServiceBackendRPCConfig( + tmp_dir=self.remote_tmpdir, + flags=self.flags, + custom_references=[ + orjson.dumps(rg._config).decode('utf-8') + for rg in self._references.values() + if rg.name not in BUILTIN_REFERENCES + ], + liftovers={rg.name: rg._liftovers for rg in self._references.values() if len(rg._liftovers) > 0}, + sequences=sequence_file_mounts, + ), + job_config=BatchJobConfig( + token=self._batch.token, + billing_project=self.billing_project, + worker_cores=str(self.worker_cores), + worker_memory=str(self.worker_memory), + storage=f'{math.ceil(storage_requirement_bytes / 1024 / 1024 / 1024)}Gi', + cloudfuse_configs=[ + CloudfuseConfig(bucket, f'/cloudfuse/{bucket}', True) for bucket in readonly_fuse_buckets + ], + regions=self.regions, + ), + action=action, + payload=payload, ) - return await self._run_on_batch(f'{action.name.lower()}(...)', qob_config, action, payload) # Sequence and liftover information is stored on the ReferenceGenome # and there is no persistent backend to keep in sync. @@ -533,23 +487,6 @@ def add_liftover(self, name: str, chain_file: str, dest_reference_genome: str): def remove_liftover(self, name, dest_reference_genome): # pylint: disable=unused-argument pass - def register_ir_function( - self, - name: str, - type_parameters: Union[Tuple[HailType, ...], List[HailType]], - value_parameter_names: Union[Tuple[str, ...], List[str]], - value_parameter_types: Union[Tuple[HailType, ...], List[HailType]], - return_type: HailType, - body: Expression, - ): - self._registered_ir_function_names.add(name) - self.functions.append( - IRFunction(name, type_parameters, value_parameter_names, value_parameter_types, return_type, body) - ) - - def _is_registered_ir_function_name(self, name: str) -> bool: - return name in self._registered_ir_function_names - def persist_expression(self, expr): # FIXME: should use context manager to clean up persisted resources fname = TemporaryFilename(prefix='persist_expression').name diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index 81f5fb3b855..cdb1c874c36 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -2,7 +2,7 @@ import random import re import unittest -from test.hail.helpers import resource, skip_unless_spark_backend, skip_when_service_backend +from test.hail.helpers import resource, skip_unless_spark_backend import numpy as np import pytest @@ -189,12 +189,6 @@ def value_ir(value_irs, request): return value_irs[request.param] -@skip_when_service_backend() -def test_ir_parses(value_ir): - env = value_irs_env() - Env.backend()._parse_value_ir(str(value_ir), env) - - def test_ir_value_type(value_ir): env = value_irs_env() typ = Env.backend().value_type( @@ -313,11 +307,6 @@ def table_ir(table_irs, request): return table_irs[request.param] -@skip_when_service_backend() -def test_table_ir_parses(table_ir): - Env.backend()._parse_table_ir(str(table_ir)) - - def test_table_ir_table_type(table_ir): typ = Env.backend().table_type(table_ir) assert table_ir.typ == typ @@ -420,11 +409,6 @@ def matrix_ir(matrix_irs, request): return matrix_irs[request.param] -@skip_when_service_backend() -def test_matrix_ir_parses(matrix_ir): - Env.backend()._parse_matrix_ir(str(matrix_ir)) - - def test_matrix_ir_matrix_type(matrix_ir): typ = Env.backend().matrix_type(matrix_ir) assert typ == matrix_ir.typ