Skip to content

Commit

Permalink
[query] Expose references via ExecuteContext
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Dec 16, 2024
1 parent b8b22b8 commit a33bb7f
Show file tree
Hide file tree
Showing 19 changed files with 315 additions and 350 deletions.
3 changes: 3 additions & 0 deletions hail/src/main/scala/is/hail/HailFeatureFlags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class HailFeatureFlags private (
flags.update(flag, value)
}

def +(feature: (String, String)): HailFeatureFlags =
new HailFeatureFlags(flags + (feature._1 -> feature._2))

def get(flag: String): String = flags(flag)

def lookup(flag: String): Option[String] =
Expand Down
8 changes: 2 additions & 6 deletions hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ abstract class Backend extends Closeable {
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T]

def references: mutable.Map[String, ReferenceGenome]

def lowerDistributedSort(
ctx: ExecuteContext,
stage: TableStage,
Expand Down Expand Up @@ -194,10 +192,8 @@ abstract class Backend extends Closeable {
): Array[Byte] =
withExecuteContext { ctx =>
jsonToBytes {
Extraction.decompose {
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput).toJSON
}(defaultJSONFormats)
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput).toJSON
}
}

Expand Down
9 changes: 7 additions & 2 deletions hail/src/main/scala/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ object ExecuteContext {
tmpdir: String,
localTmpdir: String,
backend: Backend,
references: Map[String, ReferenceGenome],
fs: FS,
timer: ExecutionTimer,
tempFileManager: TempFileManager,
Expand All @@ -79,6 +80,7 @@ object ExecuteContext {
tmpdir,
localTmpdir,
backend,
references,
fs,
region,
timer,
Expand Down Expand Up @@ -107,6 +109,7 @@ class ExecuteContext(
val tmpdir: String,
val localTmpdir: String,
val backend: Backend,
val references: Map[String, ReferenceGenome],
val fs: FS,
val r: Region,
val timer: ExecutionTimer,
Expand All @@ -128,7 +131,7 @@ class ExecuteContext(
)
}

def stateManager = HailStateManager(backend.references.toMap)
val stateManager = HailStateManager(references)

val tempFileManager: TempFileManager =
if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs)
Expand All @@ -154,7 +157,7 @@ class ExecuteContext(

def getFlag(name: String): String = flags.get(name)

def getReference(name: String): ReferenceGenome = backend.references(name)
def getReference(name: String): ReferenceGenome = references(name)

def shouldWriteIRFiles(): Boolean = getFlag("write_ir_files") != null

Expand All @@ -174,6 +177,7 @@ class ExecuteContext(
tmpdir: String = this.tmpdir,
localTmpdir: String = this.localTmpdir,
backend: Backend = this.backend,
references: Map[String, ReferenceGenome] = this.references,
fs: FS = this.fs,
r: Region = this.r,
timer: ExecutionTimer = this.timer,
Expand All @@ -189,6 +193,7 @@ class ExecuteContext(
tmpdir,
localTmpdir,
backend,
references,
fs,
r,
timer,
Expand Down
2 changes: 2 additions & 0 deletions hail/src/main/scala/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class LocalBackend(
override val references: mutable.Map[String, ReferenceGenome],
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {

override def backend: Backend = this
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
override def longLifeTempFileManager: TempFileManager = null

Expand All @@ -88,6 +89,7 @@ class LocalBackend(
tmpdir,
tmpdir,
this,
references.toMap,
fs,
timer,
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import is.hail.expr.ir.{
}
import is.hail.expr.ir.IRParser.parseType
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
import is.hail.linalg.RowMatrix
import is.hail.types.physical.PStruct
import is.hail.types.virtual.{TArray, TInterval}
Expand All @@ -30,7 +31,9 @@ import org.json4s.Formats
import org.json4s.jackson.{JsonMethods, Serialization}
import sourcecode.Enclosing

trait Py4JBackendExtensions { this: Backend =>
trait Py4JBackendExtensions {
def backend: Backend
def references: mutable.Map[String, ReferenceGenome]
def persistedIR: mutable.Map[Int, BaseIR]
def flags: HailFeatureFlags
def longLifeTempFileManager: TempFileManager
Expand Down Expand Up @@ -61,7 +64,9 @@ trait Py4JBackendExtensions { this: Backend =>
persistedIR.remove(id)

def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit =
withExecuteContext(ctx => references(name).addSequence(ctx, fastaFile, indexFile))
backend.withExecuteContext { ctx =>
references(name).addSequence(IndexedFastaSequenceFile(ctx.fs, fastaFile, indexFile))
}

def pyRemoveSequence(name: String): Unit =
references(name).removeSequence()
Expand All @@ -76,7 +81,7 @@ trait Py4JBackendExtensions { this: Backend =>
partitionSize: java.lang.Integer,
entries: String,
): Unit =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
val rm = RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize)
entries match {
case "full" =>
Expand Down Expand Up @@ -114,7 +119,7 @@ trait Py4JBackendExtensions { this: Backend =>
returnType: String,
bodyStr: String,
): Unit = {
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
IRFunctionRegistry.registerIR(
ctx,
name,
Expand All @@ -128,10 +133,10 @@ trait Py4JBackendExtensions { this: Backend =>
}

def pyExecuteLiteral(irStr: String): Int =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap))
assert(ir.typ.isRealizable)
execute(ctx, ir) match {
backend.execute(ctx, ir) match {
case Left(_) => throw new HailException("Can't create literal")
case Right((pt, addr)) =>
val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0)
Expand Down Expand Up @@ -160,13 +165,13 @@ trait Py4JBackendExtensions { this: Backend =>
}

def pyToDF(s: String): DataFrame =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
val tir = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
Interpret(tir, ctx).toDF()
}

def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
log.info("pyReadMultipleMatrixTables: got query")
val kvs = JsonMethods.parse(jsonQuery) match {
case json4s.JObject(values) => values.toMap
Expand Down Expand Up @@ -195,10 +200,12 @@ trait Py4JBackendExtensions { this: Backend =>
addReference(ReferenceGenome.fromJSON(jsonConfig))

def pyRemoveReference(name: String): Unit =
references.remove(name)
removeReference(name)

def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit =
withExecuteContext(ctx => references(name).addLiftover(ctx, chainFile, destRGName))
backend.withExecuteContext { ctx =>
references(name).addLiftover(references(destRGName), LiftOver(ctx.fs, chainFile))
}

def pyRemoveLiftover(name: String, destRGName: String): Unit =
references(name).removeLiftover(destRGName)
Expand All @@ -218,8 +225,11 @@ trait Py4JBackendExtensions { this: Backend =>
}
}

private[this] def removeReference(name: String): Unit =
references -= name

def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
IRParser.parse_value_ir(
s,
IRParserEnvironment(ctx, irMap = persistedIR.toMap),
Expand All @@ -245,7 +255,7 @@ trait Py4JBackendExtensions { this: Backend =>
}

def loadReferencesFromDataset(path: String): Array[Byte] =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path)
rgs.foreach(addReference)

Expand All @@ -259,7 +269,7 @@ trait Py4JBackendExtensions { this: Backend =>
f: ExecuteContext => T
)(implicit E: Enclosing
): T =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
val tempFileManager = longLifeTempFileManager
if (selfContainedExecution && tempFileManager != null) f(ctx)
else ctx.local(tempFileManager = NonOwningTempFileManager(tempFileManager))(f)
Expand Down
30 changes: 14 additions & 16 deletions hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.ir.lowering._
import is.hail.io.fs._
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
import is.hail.linalg.BlockMatrix
import is.hail.services.{BatchClient, JobGroupRequest, _}
import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success}
Expand Down Expand Up @@ -92,11 +93,20 @@ object ServiceBackend {
references += (r.name -> r)
}

val backend = new ServiceBackend(
rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) =>
liftoversForSource.foreach { case (destGenome, chainFile) =>
references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile))
}
}
rpcConfig.sequences.foreach { case (rg, seq) =>
references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index))
}

new ServiceBackend(
JarUrl(jarLocation),
name,
theHailClassLoader,
references,
references.toMap,
batchClient,
batchConfig,
flags,
Expand All @@ -105,27 +115,14 @@ object ServiceBackend {
backendContext,
scratchDir,
)

backend.withExecuteContext { ctx =>
rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) =>
liftoversForSource.foreach { case (destGenome, chainFile) =>
references(sourceGenome).addLiftover(ctx, chainFile, destGenome)
}
}
rpcConfig.sequences.foreach { case (rg, seq) =>
references(rg).addSequence(ctx, seq.fasta, seq.index)
}
}

backend
}
}

class ServiceBackend(
val jarSpec: JarSpec,
var name: String,
val theHailClassLoader: HailClassLoader,
override val references: mutable.Map[String, ReferenceGenome],
val references: Map[String, ReferenceGenome],
val batchClient: BatchClient,
val batchConfig: BatchConfig,
val flags: HailFeatureFlags,
Expand Down Expand Up @@ -396,6 +393,7 @@ class ServiceBackend(
tmpdir,
"file:///tmp",
this,
references,
fs,
timer,
null,
Expand Down
3 changes: 3 additions & 0 deletions hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ class SparkBackend(
new HadoopFS(new SerializableHadoopConfiguration(conf))
}

override def backend: Backend = this
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()

override val longLifeTempFileManager: TempFileManager =
Expand Down Expand Up @@ -361,6 +362,7 @@ class SparkBackend(
tmpdir,
localTmpdir,
this,
references.toMap,
fs,
region,
timer,
Expand All @@ -380,6 +382,7 @@ class SparkBackend(
tmpdir,
localTmpdir,
this,
references.toMap,
fs,
timer,
null,
Expand Down
69 changes: 0 additions & 69 deletions hail/src/main/scala/is/hail/io/reference/LiftOver.scala

This file was deleted.

Loading

0 comments on commit a33bb7f

Please sign in to comment.