Skip to content

Commit

Permalink
no need for another tmpfilemanager impl
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Oct 30, 2024
1 parent 27771b4 commit de0b114
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 41 deletions.
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ trait TempFileManager extends AutoCloseable {
def newTmpPath(tmpdir: String, prefix: String, extension: String = null): String
}

class OwningTempFileManager(fs: FS) extends TempFileManager {
class OwningTempFileManager(val fs: FS) extends TempFileManager {
private[this] val tmpPaths = mutable.ArrayBuffer[String]()

override def newTmpPath(tmpdir: String, prefix: String, extension: String): String = {
Expand Down
63 changes: 23 additions & 40 deletions hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ import com.sun.net.httpserver.{HttpExchange, HttpServer}
import org.apache.hadoop
import org.apache.hadoop.conf.Configuration
import org.apache.spark.sql.DataFrame
import org.json4s
import org.json4s._
import org.json4s.jackson.{JsonMethods, Serialization}
import sourcecode.Enclosing
Expand All @@ -52,6 +51,7 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin

private[this] val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
private[this] val hcl = new HailClassLoader(getClass.getClassLoader)

private[this] val references = mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*)
private[this] val bmCache = new BlockMatrixCache()
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
Expand All @@ -61,29 +61,12 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
private[this] var irID: Int = 0
private[this] var tmpdir: String = _
private[this] var localTmpdir: String = _

private[this] object tmpFileManager extends TempFileManager {
private[this] var fs = newFs(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
private[this] var manager = new OwningTempFileManager(fs)

def setFs(fs: FS): Unit = {
close()
this.fs = fs
manager = new OwningTempFileManager(fs)
}

def getFs: FS =
fs

override def newTmpPath(tmpdir: String, prefix: String, extension: String): String =
manager.newTmpPath(tmpdir, prefix, extension)

override def close(): Unit =
manager.close()
}
private[this] var tmpFileManager = new OwningTempFileManager(
newFs(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
)

def pyFs: FS =
using(weak.acquire())(_ => tmpFileManager.getFs)
using(weak.acquire())(_ => tmpFileManager.fs)

def pyGetFlag(name: String): String =
using(weak.acquire())(_ => flags.get(name))
Expand All @@ -102,6 +85,8 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin

def pySetGcsRequesterPaysConfig(project: String, buckets: util.List[String]): Unit =
using(weak.acquire()) { _ =>
tmpFileManager.close()

val cloudfsConf = CloudStorageFSConfig.fromFlagsAndEnv(None, flags)

val rpConfig: Option[RequesterPaysConfig] =
Expand All @@ -126,15 +111,15 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
)
)

tmpFileManager.setFs(fs)
tmpFileManager = new OwningTempFileManager(fs)
}

def pyRemoveJavaIR(id: Int): Unit =
using(weak.acquire())(_ => persistedIr.remove(id))

def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit =
using(weak.acquire()) { _ =>
val seq = IndexedFastaSequenceFile(tmpFileManager.getFs, fastaFile, indexFile)
val seq = IndexedFastaSequenceFile(tmpFileManager.fs, fastaFile, indexFile)
references(name).addSequence(seq)
}

Expand Down Expand Up @@ -240,28 +225,26 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin

def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] =
withExecuteContext() { ctx =>
implicit val fmts: Formats = DefaultFormats
log.info("pyReadMultipleMatrixTables: got query")
val kvs = JsonMethods.parse(jsonQuery) match {
case json4s.JObject(values) => values.toMap
}

val paths = kvs("paths").asInstanceOf[json4s.JArray].arr.toArray.map {
case json4s.JString(s) => s
}

val intervalPointType = parseType(kvs("intervalPointType").asInstanceOf[json4s.JString].s)
val kvs = JsonMethods.parse(jsonQuery).extract[Map[String, JValue]]
val paths = kvs("paths").extract[IndexedSeq[String]]
val intervalPointType = parseType(kvs("intervalPointType").extract[String])
val intervalObjects =
JSONAnnotationImpex.importAnnotation(kvs("intervals"), TArray(TInterval(intervalPointType)))
.asInstanceOf[IndexedSeq[Interval]]

val opts = NativeReaderOptions(intervalObjects, intervalPointType)
val matrixReaders: IndexedSeq[MatrixIR] = paths.map { p =>
log.info(s"creating MatrixRead node for $p")
val mnr = MatrixNativeReader(ctx.fs, p, Some(opts))
MatrixRead(mnr.fullMatrixTypeWithoutUIDs, false, false, mnr): MatrixIR
}
val matrixReaders: util.List[MatrixIR] =
paths.map { p =>
log.info(s"creating MatrixRead node for $p")
val mnr = MatrixNativeReader(ctx.fs, p, Some(opts))
MatrixRead(mnr.fullMatrixTypeWithoutUIDs, false, false, mnr): MatrixIR
}.asJava

log.info("pyReadMultipleMatrixTables: returning N matrix tables")
matrixReaders.asJava
matrixReaders
}._1

def pyAddReference(jsonConfig: String): Unit =
Expand All @@ -274,7 +257,7 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
using(weak.acquire()) { _ =>
references(name).addLiftover(
references(destRGName),
LiftOver(tmpFileManager.getFs, chainFile),
LiftOver(tmpFileManager.fs, chainFile),
)
}

Expand Down Expand Up @@ -316,7 +299,7 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
tmpdir = tmpdir,
localTmpdir = localTmpdir,
backend = backend,
fs = tmpFileManager.getFs,
fs = tmpFileManager.fs,
timer = timer,
tempFileManager =
if (selfContainedExecution) null
Expand Down

0 comments on commit de0b114

Please sign in to comment.