Skip to content

Commit a33bb7f

Browse files
committed
[query] Expose references via ExecuteContext
1 parent b8b22b8 commit a33bb7f

19 files changed

+315
-350
lines changed

hail/src/main/scala/is/hail/HailFeatureFlags.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ class HailFeatureFlags private (
6868
flags.update(flag, value)
6969
}
7070

71+
def +(feature: (String, String)): HailFeatureFlags =
72+
new HailFeatureFlags(flags + (feature._1 -> feature._2))
73+
7174
def get(flag: String): String = flags(flag)
7275

7376
def lookup(flag: String): Option[String] =

hail/src/main/scala/is/hail/backend/Backend.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,6 @@ abstract class Backend extends Closeable {
118118
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
119119
: CompiledFunction[T]
120120

121-
def references: mutable.Map[String, ReferenceGenome]
122-
123121
def lowerDistributedSort(
124122
ctx: ExecuteContext,
125123
stage: TableStage,
@@ -194,10 +192,8 @@ abstract class Backend extends Closeable {
194192
): Array[Byte] =
195193
withExecuteContext { ctx =>
196194
jsonToBytes {
197-
Extraction.decompose {
198-
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
199-
xContigs, yContigs, mtContigs, parInput).toJSON
200-
}(defaultJSONFormats)
195+
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
196+
xContigs, yContigs, mtContigs, parInput).toJSON
201197
}
202198
}
203199

hail/src/main/scala/is/hail/backend/ExecuteContext.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ object ExecuteContext {
6363
tmpdir: String,
6464
localTmpdir: String,
6565
backend: Backend,
66+
references: Map[String, ReferenceGenome],
6667
fs: FS,
6768
timer: ExecutionTimer,
6869
tempFileManager: TempFileManager,
@@ -79,6 +80,7 @@ object ExecuteContext {
7980
tmpdir,
8081
localTmpdir,
8182
backend,
83+
references,
8284
fs,
8385
region,
8486
timer,
@@ -107,6 +109,7 @@ class ExecuteContext(
107109
val tmpdir: String,
108110
val localTmpdir: String,
109111
val backend: Backend,
112+
val references: Map[String, ReferenceGenome],
110113
val fs: FS,
111114
val r: Region,
112115
val timer: ExecutionTimer,
@@ -128,7 +131,7 @@ class ExecuteContext(
128131
)
129132
}
130133

131-
def stateManager = HailStateManager(backend.references.toMap)
134+
val stateManager = HailStateManager(references)
132135

133136
val tempFileManager: TempFileManager =
134137
if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs)
@@ -154,7 +157,7 @@ class ExecuteContext(
154157

155158
def getFlag(name: String): String = flags.get(name)
156159

157-
def getReference(name: String): ReferenceGenome = backend.references(name)
160+
def getReference(name: String): ReferenceGenome = references(name)
158161

159162
def shouldWriteIRFiles(): Boolean = getFlag("write_ir_files") != null
160163

@@ -174,6 +177,7 @@ class ExecuteContext(
174177
tmpdir: String = this.tmpdir,
175178
localTmpdir: String = this.localTmpdir,
176179
backend: Backend = this.backend,
180+
references: Map[String, ReferenceGenome] = this.references,
177181
fs: FS = this.fs,
178182
r: Region = this.r,
179183
timer: ExecutionTimer = this.timer,
@@ -189,6 +193,7 @@ class ExecuteContext(
189193
tmpdir,
190194
localTmpdir,
191195
backend,
196+
references,
192197
fs,
193198
r,
194199
timer,

hail/src/main/scala/is/hail/backend/local/LocalBackend.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class LocalBackend(
7373
override val references: mutable.Map[String, ReferenceGenome],
7474
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {
7575

76+
override def backend: Backend = this
7677
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
7778
override def longLifeTempFileManager: TempFileManager = null
7879

@@ -88,6 +89,7 @@ class LocalBackend(
8889
tmpdir,
8990
tmpdir,
9091
this,
92+
references.toMap,
9193
fs,
9294
timer,
9395
null,

hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import is.hail.expr.ir.{
1010
}
1111
import is.hail.expr.ir.IRParser.parseType
1212
import is.hail.expr.ir.functions.IRFunctionRegistry
13+
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
1314
import is.hail.linalg.RowMatrix
1415
import is.hail.types.physical.PStruct
1516
import is.hail.types.virtual.{TArray, TInterval}
@@ -30,7 +31,9 @@ import org.json4s.Formats
3031
import org.json4s.jackson.{JsonMethods, Serialization}
3132
import sourcecode.Enclosing
3233

33-
trait Py4JBackendExtensions { this: Backend =>
34+
trait Py4JBackendExtensions {
35+
def backend: Backend
36+
def references: mutable.Map[String, ReferenceGenome]
3437
def persistedIR: mutable.Map[Int, BaseIR]
3538
def flags: HailFeatureFlags
3639
def longLifeTempFileManager: TempFileManager
@@ -61,7 +64,9 @@ trait Py4JBackendExtensions { this: Backend =>
6164
persistedIR.remove(id)
6265

6366
def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit =
64-
withExecuteContext(ctx => references(name).addSequence(ctx, fastaFile, indexFile))
67+
backend.withExecuteContext { ctx =>
68+
references(name).addSequence(IndexedFastaSequenceFile(ctx.fs, fastaFile, indexFile))
69+
}
6570

6671
def pyRemoveSequence(name: String): Unit =
6772
references(name).removeSequence()
@@ -76,7 +81,7 @@ trait Py4JBackendExtensions { this: Backend =>
7681
partitionSize: java.lang.Integer,
7782
entries: String,
7883
): Unit =
79-
withExecuteContext { ctx =>
84+
backend.withExecuteContext { ctx =>
8085
val rm = RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize)
8186
entries match {
8287
case "full" =>
@@ -114,7 +119,7 @@ trait Py4JBackendExtensions { this: Backend =>
114119
returnType: String,
115120
bodyStr: String,
116121
): Unit = {
117-
withExecuteContext { ctx =>
122+
backend.withExecuteContext { ctx =>
118123
IRFunctionRegistry.registerIR(
119124
ctx,
120125
name,
@@ -128,10 +133,10 @@ trait Py4JBackendExtensions { this: Backend =>
128133
}
129134

130135
def pyExecuteLiteral(irStr: String): Int =
131-
withExecuteContext { ctx =>
136+
backend.withExecuteContext { ctx =>
132137
val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap))
133138
assert(ir.typ.isRealizable)
134-
execute(ctx, ir) match {
139+
backend.execute(ctx, ir) match {
135140
case Left(_) => throw new HailException("Can't create literal")
136141
case Right((pt, addr)) =>
137142
val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0)
@@ -160,13 +165,13 @@ trait Py4JBackendExtensions { this: Backend =>
160165
}
161166

162167
def pyToDF(s: String): DataFrame =
163-
withExecuteContext { ctx =>
168+
backend.withExecuteContext { ctx =>
164169
val tir = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
165170
Interpret(tir, ctx).toDF()
166171
}
167172

168173
def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] =
169-
withExecuteContext { ctx =>
174+
backend.withExecuteContext { ctx =>
170175
log.info("pyReadMultipleMatrixTables: got query")
171176
val kvs = JsonMethods.parse(jsonQuery) match {
172177
case json4s.JObject(values) => values.toMap
@@ -195,10 +200,12 @@ trait Py4JBackendExtensions { this: Backend =>
195200
addReference(ReferenceGenome.fromJSON(jsonConfig))
196201

197202
def pyRemoveReference(name: String): Unit =
198-
references.remove(name)
203+
removeReference(name)
199204

200205
def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit =
201-
withExecuteContext(ctx => references(name).addLiftover(ctx, chainFile, destRGName))
206+
backend.withExecuteContext { ctx =>
207+
references(name).addLiftover(references(destRGName), LiftOver(ctx.fs, chainFile))
208+
}
202209

203210
def pyRemoveLiftover(name: String, destRGName: String): Unit =
204211
references(name).removeLiftover(destRGName)
@@ -218,8 +225,11 @@ trait Py4JBackendExtensions { this: Backend =>
218225
}
219226
}
220227

228+
private[this] def removeReference(name: String): Unit =
229+
references -= name
230+
221231
def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR =
222-
withExecuteContext { ctx =>
232+
backend.withExecuteContext { ctx =>
223233
IRParser.parse_value_ir(
224234
s,
225235
IRParserEnvironment(ctx, irMap = persistedIR.toMap),
@@ -245,7 +255,7 @@ trait Py4JBackendExtensions { this: Backend =>
245255
}
246256

247257
def loadReferencesFromDataset(path: String): Array[Byte] =
248-
withExecuteContext { ctx =>
258+
backend.withExecuteContext { ctx =>
249259
val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path)
250260
rgs.foreach(addReference)
251261

@@ -259,7 +269,7 @@ trait Py4JBackendExtensions { this: Backend =>
259269
f: ExecuteContext => T
260270
)(implicit E: Enclosing
261271
): T =
262-
withExecuteContext { ctx =>
272+
backend.withExecuteContext { ctx =>
263273
val tempFileManager = longLifeTempFileManager
264274
if (selfContainedExecution && tempFileManager != null) f(ctx)
265275
else ctx.local(tempFileManager = NonOwningTempFileManager(tempFileManager))(f)

hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import is.hail.expr.ir.analyses.SemanticHash
1313
import is.hail.expr.ir.functions.IRFunctionRegistry
1414
import is.hail.expr.ir.lowering._
1515
import is.hail.io.fs._
16+
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
1617
import is.hail.linalg.BlockMatrix
1718
import is.hail.services.{BatchClient, JobGroupRequest, _}
1819
import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success}
@@ -92,11 +93,20 @@ object ServiceBackend {
9293
references += (r.name -> r)
9394
}
9495

95-
val backend = new ServiceBackend(
96+
rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) =>
97+
liftoversForSource.foreach { case (destGenome, chainFile) =>
98+
references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile))
99+
}
100+
}
101+
rpcConfig.sequences.foreach { case (rg, seq) =>
102+
references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index))
103+
}
104+
105+
new ServiceBackend(
96106
JarUrl(jarLocation),
97107
name,
98108
theHailClassLoader,
99-
references,
109+
references.toMap,
100110
batchClient,
101111
batchConfig,
102112
flags,
@@ -105,27 +115,14 @@ object ServiceBackend {
105115
backendContext,
106116
scratchDir,
107117
)
108-
109-
backend.withExecuteContext { ctx =>
110-
rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) =>
111-
liftoversForSource.foreach { case (destGenome, chainFile) =>
112-
references(sourceGenome).addLiftover(ctx, chainFile, destGenome)
113-
}
114-
}
115-
rpcConfig.sequences.foreach { case (rg, seq) =>
116-
references(rg).addSequence(ctx, seq.fasta, seq.index)
117-
}
118-
}
119-
120-
backend
121118
}
122119
}
123120

124121
class ServiceBackend(
125122
val jarSpec: JarSpec,
126123
var name: String,
127124
val theHailClassLoader: HailClassLoader,
128-
override val references: mutable.Map[String, ReferenceGenome],
125+
val references: Map[String, ReferenceGenome],
129126
val batchClient: BatchClient,
130127
val batchConfig: BatchConfig,
131128
val flags: HailFeatureFlags,
@@ -396,6 +393,7 @@ class ServiceBackend(
396393
tmpdir,
397394
"file:///tmp",
398395
this,
396+
references,
399397
fs,
400398
timer,
401399
null,

hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ class SparkBackend(
332332
new HadoopFS(new SerializableHadoopConfiguration(conf))
333333
}
334334

335+
override def backend: Backend = this
335336
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
336337

337338
override val longLifeTempFileManager: TempFileManager =
@@ -361,6 +362,7 @@ class SparkBackend(
361362
tmpdir,
362363
localTmpdir,
363364
this,
365+
references.toMap,
364366
fs,
365367
region,
366368
timer,
@@ -380,6 +382,7 @@ class SparkBackend(
380382
tmpdir,
381383
localTmpdir,
382384
this,
385+
references.toMap,
383386
fs,
384387
timer,
385388
null,

hail/src/main/scala/is/hail/io/reference/LiftOver.scala

Lines changed: 0 additions & 69 deletions
This file was deleted.

0 commit comments

Comments
 (0)