Skip to content

Commit

Permalink
[query] Expose references via ExecuteContext
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Jan 29, 2025
1 parent 5539028 commit bd12bb0
Show file tree
Hide file tree
Showing 26 changed files with 343 additions and 378 deletions.
3 changes: 3 additions & 0 deletions hail/hail/src/is/hail/HailFeatureFlags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class HailFeatureFlags private (
flags.update(flag, value)
}

def +(feature: (String, String)): HailFeatureFlags =
new HailFeatureFlags(flags + (feature._1 -> feature._2))

def get(flag: String): String = flags(flag)

def lookup(flag: String): Option[String] =
Expand Down
8 changes: 2 additions & 6 deletions hail/hail/src/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ abstract class Backend extends Closeable {
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T]

def references: mutable.Map[String, ReferenceGenome]

def lowerDistributedSort(
ctx: ExecuteContext,
stage: TableStage,
Expand Down Expand Up @@ -181,10 +179,8 @@ abstract class Backend extends Closeable {
): Array[Byte] =
withExecuteContext { ctx =>
jsonToBytes {
Extraction.decompose {
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput).toJSON
}(defaultJSONFormats)
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput).toJSON
}
}

Expand Down
9 changes: 6 additions & 3 deletions hail/hail/src/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ object ExecuteContext {
tmpdir: String,
localTmpdir: String,
backend: Backend,
references: Map[String, ReferenceGenome],
fs: FS,
timer: ExecutionTimer,
tempFileManager: TempFileManager,
Expand All @@ -79,6 +80,7 @@ object ExecuteContext {
tmpdir,
localTmpdir,
backend,
references,
fs,
region,
timer,
Expand Down Expand Up @@ -107,6 +109,7 @@ class ExecuteContext(
val tmpdir: String,
val localTmpdir: String,
val backend: Backend,
val references: Map[String, ReferenceGenome],
val fs: FS,
val r: Region,
val timer: ExecutionTimer,
Expand All @@ -128,7 +131,7 @@ class ExecuteContext(
)
}

def stateManager = HailStateManager(backend.references.toMap)
val stateManager = HailStateManager(references)

val tempFileManager: TempFileManager =
if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs)
Expand All @@ -154,8 +157,6 @@ class ExecuteContext(

def getFlag(name: String): String = flags.get(name)

def getReference(name: String): ReferenceGenome = backend.references(name)

def shouldWriteIRFiles(): Boolean = getFlag("write_ir_files") != null

def shouldNotLogIR(): Boolean = flags.get("no_ir_logging") != null
Expand All @@ -174,6 +175,7 @@ class ExecuteContext(
tmpdir: String = this.tmpdir,
localTmpdir: String = this.localTmpdir,
backend: Backend = this.backend,
references: Map[String, ReferenceGenome] = this.references,
fs: FS = this.fs,
r: Region = this.r,
timer: ExecutionTimer = this.timer,
Expand All @@ -189,6 +191,7 @@ class ExecuteContext(
tmpdir,
localTmpdir,
backend,
references,
fs,
r,
timer,
Expand Down
2 changes: 2 additions & 0 deletions hail/hail/src/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class LocalBackend(
override val references: mutable.Map[String, ReferenceGenome],
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {

override def backend: Backend = this
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
override def longLifeTempFileManager: TempFileManager = null

Expand All @@ -102,6 +103,7 @@ class LocalBackend(
tmpdir,
tmpdir,
this,
references.toMap,
fs,
timer,
null,
Expand Down
38 changes: 25 additions & 13 deletions hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ import is.hail.expr.ir.{
import is.hail.expr.ir.IRParser.parseType
import is.hail.expr.ir.defs.{EncodedLiteral, GetFieldByIdx}
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
import is.hail.linalg.RowMatrix
import is.hail.types.physical.PStruct
import is.hail.types.virtual.{TArray, TInterval}
import is.hail.utils.{defaultJSONFormats, log, toRichIterable, FastSeq, HailException, Interval}
import is.hail.variant.ReferenceGenome

import scala.collection.mutable
import scala.jdk.CollectionConverters.{
asScalaBufferConverter, mapAsScalaMapConverter, seqAsJavaListConverter,
}
Expand All @@ -29,7 +31,10 @@ import org.json4s.Formats
import org.json4s.jackson.{JsonMethods, Serialization}
import sourcecode.Enclosing

trait Py4JBackendExtensions { this: Backend =>
trait Py4JBackendExtensions {
def backend: Backend
def references: mutable.Map[String, ReferenceGenome]
def persistedIR: mutable.Map[Int, BaseIR]
def flags: HailFeatureFlags
def longLifeTempFileManager: TempFileManager

Expand Down Expand Up @@ -59,7 +64,9 @@ trait Py4JBackendExtensions { this: Backend =>
persistedIR.remove(id)

def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit =
withExecuteContext(ctx => references(name).addSequence(ctx, fastaFile, indexFile))
backend.withExecuteContext { ctx =>
references(name).addSequence(IndexedFastaSequenceFile(ctx.fs, fastaFile, indexFile))
}

def pyRemoveSequence(name: String): Unit =
references(name).removeSequence()
Expand All @@ -74,7 +81,7 @@ trait Py4JBackendExtensions { this: Backend =>
partitionSize: java.lang.Integer,
entries: String,
): Unit =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
val rm = RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize)
entries match {
case "full" =>
Expand Down Expand Up @@ -112,7 +119,7 @@ trait Py4JBackendExtensions { this: Backend =>
returnType: String,
bodyStr: String,
): Unit = {
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
IRFunctionRegistry.registerIR(
ctx,
name,
Expand All @@ -126,10 +133,10 @@ trait Py4JBackendExtensions { this: Backend =>
}

def pyExecuteLiteral(irStr: String): Int =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap))
assert(ir.typ.isRealizable)
execute(ctx, ir) match {
backend.execute(ctx, ir) match {
case Left(_) => throw new HailException("Can't create literal")
case Right((pt, addr)) =>
val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0)
Expand Down Expand Up @@ -158,13 +165,13 @@ trait Py4JBackendExtensions { this: Backend =>
}

def pyToDF(s: String): DataFrame =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
val tir = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
Interpret(tir, ctx).toDF()
}

def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
log.info("pyReadMultipleMatrixTables: got query")
val kvs = JsonMethods.parse(jsonQuery) match {
case json4s.JObject(values) => values.toMap
Expand Down Expand Up @@ -193,19 +200,24 @@ trait Py4JBackendExtensions { this: Backend =>
addReference(ReferenceGenome.fromJSON(jsonConfig))

def pyRemoveReference(name: String): Unit =
references.remove(name)
removeReference(name)

def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit =
withExecuteContext(ctx => references(name).addLiftover(ctx, chainFile, destRGName))
backend.withExecuteContext { ctx =>
references(name).addLiftover(references(destRGName), LiftOver(ctx.fs, chainFile))
}

def pyRemoveLiftover(name: String, destRGName: String): Unit =
references(name).removeLiftover(destRGName)

private[this] def addReference(rg: ReferenceGenome): Unit =
ReferenceGenome.addFatalOnCollision(references, FastSeq(rg))

private[this] def removeReference(name: String): Unit =
references -= name

def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
IRParser.parse_value_ir(
s,
IRParserEnvironment(ctx, irMap = persistedIR.toMap),
Expand All @@ -231,7 +243,7 @@ trait Py4JBackendExtensions { this: Backend =>
}

def loadReferencesFromDataset(path: String): Array[Byte] =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path)
ReferenceGenome.addFatalOnCollision(references, rgs)

Expand All @@ -245,7 +257,7 @@ trait Py4JBackendExtensions { this: Backend =>
f: ExecuteContext => T
)(implicit E: Enclosing
): T =
withExecuteContext { ctx =>
backend.withExecuteContext { ctx =>
val tempFileManager = longLifeTempFileManager
if (selfContainedExecution && tempFileManager != null) f(ctx)
else ctx.local(tempFileManager = NonOwningTempFileManager(tempFileManager))(f)
Expand Down
28 changes: 13 additions & 15 deletions hail/hail/src/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import is.hail.expr.ir.defs.MakeTuple
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.ir.lowering._
import is.hail.io.fs._
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
import is.hail.linalg.BlockMatrix
import is.hail.services.{BatchClient, JobGroupRequest, _}
import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success}
Expand Down Expand Up @@ -93,7 +94,16 @@ object ServiceBackend {
rpcConfig.custom_references.map(ReferenceGenome.fromJSON),
)

val backend = new ServiceBackend(
rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) =>
liftoversForSource.foreach { case (destGenome, chainFile) =>
references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile))
}
}
rpcConfig.sequences.foreach { case (rg, seq) =>
references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index))
}

new ServiceBackend(
JarUrl(jarLocation),
name,
theHailClassLoader,
Expand All @@ -106,27 +116,14 @@ object ServiceBackend {
backendContext,
scratchDir,
)

backend.withExecuteContext { ctx =>
rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) =>
liftoversForSource.foreach { case (destGenome, chainFile) =>
references(sourceGenome).addLiftover(ctx, chainFile, destGenome)
}
}
rpcConfig.sequences.foreach { case (rg, seq) =>
references(rg).addSequence(ctx, seq.fasta, seq.index)
}
}

backend
}
}

class ServiceBackend(
val jarSpec: JarSpec,
var name: String,
val theHailClassLoader: HailClassLoader,
override val references: mutable.Map[String, ReferenceGenome],
val references: mutable.Map[String, ReferenceGenome],
val batchClient: BatchClient,
val batchConfig: BatchConfig,
val flags: HailFeatureFlags,
Expand Down Expand Up @@ -397,6 +394,7 @@ class ServiceBackend(
tmpdir,
"file:///tmp",
this,
references.toMap,
fs,
timer,
null,
Expand Down
3 changes: 3 additions & 0 deletions hail/hail/src/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ class SparkBackend(
new HadoopFS(new SerializableHadoopConfiguration(conf))
}

override def backend: Backend = this
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()

override val longLifeTempFileManager: TempFileManager =
Expand Down Expand Up @@ -375,6 +376,7 @@ class SparkBackend(
tmpdir,
localTmpdir,
this,
references.toMap,
fs,
region,
timer,
Expand All @@ -394,6 +396,7 @@ class SparkBackend(
tmpdir,
localTmpdir,
this,
references.toMap,
fs,
timer,
null,
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/expr/ir/EmitClassBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class EmitModuleBuilder(val ctx: ExecuteContext, val modb: ModuleBuilder) {
}

def referenceGenomes(): IndexedSeq[ReferenceGenome] =
rgContainers.keys.map(ctx.getReference(_)).toIndexedSeq.sortBy(_.name)
rgContainers.keys.map(ctx.references(_)).toIndexedSeq.sortBy(_.name)

def referenceGenomeFields(): IndexedSeq[StaticField[ReferenceGenome]] =
rgContainers.toFastSeq.sortBy(_._1).map(_._2)
Expand Down
4 changes: 2 additions & 2 deletions hail/hail/src/is/hail/expr/ir/ExtractIntervalFilters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) {
BoolValue.fromComparison(l, op).restrict(keySet)
case Contig(rgStr) =>
// locus contig equality comparison
val b = getIntervalFromContig(l.asInstanceOf[String], ctx.getReference(rgStr)) match {
val b = getIntervalFromContig(l.asInstanceOf[String], ctx.references(rgStr)) match {
case Some(i) =>
val b = BoolValue(
KeySet(i),
Expand All @@ -671,7 +671,7 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) {
case Position(rgStr) =>
// locus position comparison
val posBoolValue = BoolValue.fromComparison(l, op)
val rg = ctx.getReference(rgStr)
val rg = ctx.references(rgStr)
val b = BoolValue(
KeySet(liftPosIntervalsToLocus(posBoolValue.trueBound, rg, ctx)),
KeySet(liftPosIntervalsToLocus(posBoolValue.falseBound, rg, ctx)),
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/expr/ir/MatrixValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ case class MatrixValue(
ReferenceGenome.exportReferences(
fs,
refPath,
ReferenceGenome.getReferences(t).map(ctx.getReference(_)),
ReferenceGenome.getReferences(t).map(ctx.references(_)),
)
}

Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/io/plink/LoadPlink.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ object MatrixPLINKReader {
implicit val formats: Formats = DefaultFormats
val params = jv.extract[MatrixPLINKReaderParameters]

val referenceGenome = params.rg.map(ctx.getReference)
val referenceGenome = params.rg.map(ctx.references)
referenceGenome.foreach(_.validateContigRemap(params.contigRecoding))

val locusType = TLocus.schemaFromRG(params.rg)
Expand Down
Loading

0 comments on commit bd12bb0

Please sign in to comment.