diff --git a/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala b/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala index bfce7175d481..3e695dac9181 100644 --- a/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala +++ b/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala @@ -6,10 +6,7 @@ import is.hail.backend._ import is.hail.backend.caching.BlockMatrixCache import is.hail.backend.spark.SparkBackend import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex} -import is.hail.expr.ir.{ - BaseIR, BlockMatrixIR, CodeCacheKey, CompiledFunction, EncodedLiteral, GetFieldByIdx, IRParser, - Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue, -} +import is.hail.expr.ir.{BaseIR, BlockMatrixIR, CodeCacheKey, CompiledFunction, EncodedLiteral, GetFieldByIdx, IRParser, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue} import is.hail.expr.ir.IRParser.parseType import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer import is.hail.expr.ir.functions.IRFunctionRegistry @@ -26,12 +23,10 @@ import is.hail.variant.ReferenceGenome import scala.collection.mutable import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConverter} - import java.io.Closeable import java.net.InetSocketAddress import java.util import java.util.concurrent._ - import com.google.api.client.http.HttpStatusCodes import com.sun.net.httpserver.{HttpExchange, HttpServer} import org.apache.hadoop @@ -42,41 +37,40 @@ import org.json4s._ import org.json4s.jackson.{JsonMethods, Serialization} import sourcecode.Enclosing +import javax.annotation.Nullable + final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandling { - private[this] val tmpdir: String = ??? - private[this] val localTmpdir: String = ??? - private[this] val longLifeTempFileManager = null: TempFileManager private[this] val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() - private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader) + private[this] val hcl = new HailClassLoader(getClass.getClassLoader) private[this] val references = mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*) private[this] val bmCache = new BlockMatrixCache() private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) private[this] val persistedIr = mutable.Map[Int, BaseIR]() private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32) - private[this] def cloudfsConfig = CloudStorageFSConfig.fromFlagsAndEnv(None, flags) - - def fs: FS = backend match { - case s: SparkBackend => - val conf = new Configuration(s.sc.hadoopConfiguration) - cloudfsConfig.google.flatMap(_.requester_pays_config).foreach { - case RequesterPaysConfig(prj, bkts) => - bkts - .map { buckets => - conf.set("fs.gs.requester.pays.mode", "CUSTOM") - conf.set("fs.gs.requester.pays.project.id", prj) - conf.set("fs.gs.requester.pays.buckets", buckets.mkString(",")) - } - .getOrElse { - conf.set("fs.gs.requester.pays.mode", "AUTO") - conf.set("fs.gs.requester.pays.project.id", prj) - } - } - new HadoopFS(new SerializableHadoopConfiguration(conf)) + private[this] var irID: Int = 0 + private[this] var tmpdir: String = _ + private[this] var localTmpdir: String = _ + + private[this] object tmpFileManager extends TempFileManager { + private[this] var fs = newFs(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) + private[this] var manager = new OwningTempFileManager(fs) + + def setFs(fs: FS): Unit = { + close() + this.fs = fs + manager = new OwningTempFileManager(fs) + } + + def getFs: FS = + fs + + override def newTmpPath(tmpdir: String, prefix: String, extension: String): String = + manager.newTmpPath(tmpdir, prefix, extension) - case _ => - RouterFS.buildRoutes(cloudfsConfig) + override def close(): Unit = + manager.close() } def pyGetFlag(name: String): String = @@ -88,24 +82,40 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin def pyAvailableFlags: java.util.ArrayList[String] = flags.available - private[this] var irID: Int = 0 + def pySetTmpdir(tmp: String): Unit = + tmpdir = tmp - private[this] def nextIRID(): Int = { - irID += 1 - irID - } + def pySetLocalTmp(tmp: String): Unit = + localTmpdir = tmp - private[this] def addJavaIR(ctx: ExecuteContext, ir: BaseIR): Int = { - val id = nextIRID() - ctx.IrCache += (id -> ir) - id + def pySetRequesterPays(@Nullable project: String, @Nullable buckets: util.List[String]): Unit = { + val cloudfsConf = CloudStorageFSConfig.fromFlagsAndEnv(None, flags) + + val rpConfig: Option[RequesterPaysConfig] = + (Option(project).filter(_.nonEmpty), Option(buckets)) match { + case (Some(project), buckets) => Some(RequesterPaysConfig(project, buckets.map(_.asScala.toSet))) + case (None, Some(_)) => fatal("A non-empty, non-null requester pays google project is required to configure requester pays buckets.") + case (None, None) => None + } + + val fs = newFs( + cloudfsConf.copy( + google = (cloudfsConf.google, rpConfig) match { + case (Some(gconf), _) => Some(gconf.copy(requester_pays_config = rpConfig)) + case (None, Some(_)) => Some(GoogleStorageFSConfig(None, rpConfig)) + case _ => None + } + ) + ) + + tmpFileManager.setFs(fs) } def pyRemoveJavaIR(id: Int): Unit = persistedIr.remove(id) def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = - references(name).addSequence(IndexedFastaSequenceFile(fs, fastaFile, indexFile)) + references(name).addSequence(IndexedFastaSequenceFile(tmpFileManager.getFs, fastaFile, indexFile)) def pyRemoveSequence(name: String): Unit = references(name).removeSequence() @@ -239,7 +249,7 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin removeReference(name) def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = - references(name).addLiftover(references(destRGName), LiftOver(fs, chainFile)) + references(name).addLiftover(references(destRGName), LiftOver(tmpFileManager.getFs, chainFile)) def pyRemoveLiftover(name: String, destRGName: String): Unit = references(name).removeLiftover(destRGName) @@ -267,7 +277,7 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin private[this] def removeReference(name: String): Unit = references -= name - private def withExecuteContext[T]( + private[this] def withExecuteContext[T]( selfContainedExecution: Boolean = true )( f: ExecuteContext => T @@ -278,12 +288,12 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin tmpdir = tmpdir, localTmpdir = localTmpdir, backend = backend, - fs = fs, + fs = tmpFileManager.getFs, timer = timer, tempFileManager = if (selfContainedExecution) null - else NonOwningTempFileManager(longLifeTempFileManager), - theHailClassLoader = theHailClassLoader, + else NonOwningTempFileManager(tmpFileManager), + theHailClassLoader = hcl, flags = flags, irMetadata = IrMetadata(None), references = references, @@ -294,6 +304,40 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin )(f) } + private[this] def newFs(cloudfsConfig: CloudStorageFSConfig): FS = + backend match { + case s: SparkBackend => + val conf = new Configuration(s.sc.hadoopConfiguration) + cloudfsConfig.google.flatMap(_.requester_pays_config).foreach { + case RequesterPaysConfig(prj, bkts) => + bkts + .map { buckets => + conf.set("fs.gs.requester.pays.mode", "CUSTOM") + conf.set("fs.gs.requester.pays.project.id", prj) + conf.set("fs.gs.requester.pays.buckets", buckets.mkString(",")) + } + .getOrElse { + conf.set("fs.gs.requester.pays.mode", "AUTO") + conf.set("fs.gs.requester.pays.project.id", prj) + } + } + new HadoopFS(new SerializableHadoopConfiguration(conf)) + + case _ => + RouterFS.buildRoutes(cloudfsConfig) + } + + private[this] def nextIRID(): Int = { + irID += 1 + irID + } + + private[this] def addJavaIR(ctx: ExecuteContext, ir: BaseIR): Int = { + val id = nextIRID() + ctx.IrCache += (id -> ir) + id + } + override def close(): Unit = synchronized { bmCache.close()