Skip to content

Commit fc4df24

Browse files
committed
[query] Remove persistedIr from Backend interface
1 parent 9ecc3df commit fc4df24

File tree

12 files changed

+466
-499
lines changed

12 files changed

+466
-499
lines changed

hail/src/main/scala/is/hail/backend/Backend.scala

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@ package is.hail.backend
33
import is.hail.asm4s._
44
import is.hail.backend.Backend.jsonToBytes
55
import is.hail.backend.spark.SparkBackend
6-
import is.hail.expr.ir.{
7-
BaseIR, IR, IRParser, IRParserEnvironment, LoweringAnalyses, SortField, TableIR, TableReader,
8-
}
6+
import is.hail.expr.ir.{IR, IRParser, LoweringAnalyses, SortField, TableIR, TableReader}
97
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
108
import is.hail.io.{BufferSpec, TypedCodecSpec}
119
import is.hail.io.fs._
@@ -18,7 +16,6 @@ import is.hail.types.virtual.TFloat64
1816
import is.hail.utils._
1917
import is.hail.variant.ReferenceGenome
2018

21-
import scala.collection.mutable
2219
import scala.reflect.ClassTag
2320

2421
import java.io._
@@ -67,7 +64,6 @@ trait BackendContext {
6764
}
6865

6966
abstract class Backend extends Closeable {
70-
val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()
7167

7268
def defaultParallelism: Int
7369

@@ -125,30 +121,30 @@ abstract class Backend extends Closeable {
125121
def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T
126122

127123
final def valueType(s: String): Array[Byte] =
128-
jsonToBytes {
129-
withExecuteContext { ctx =>
130-
IRParser.parse_value_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
124+
withExecuteContext { ctx =>
125+
jsonToBytes {
126+
IRParser.parse_value_ir(ctx, s).typ.toJSON
131127
}
132128
}
133129

134130
final def tableType(s: String): Array[Byte] =
135-
jsonToBytes {
136-
withExecuteContext { ctx =>
137-
IRParser.parse_table_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
131+
withExecuteContext { ctx =>
132+
jsonToBytes {
133+
IRParser.parse_table_ir(ctx, s).typ.toJSON
138134
}
139135
}
140136

141137
final def matrixTableType(s: String): Array[Byte] =
142-
jsonToBytes {
143-
withExecuteContext { ctx =>
144-
IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
138+
withExecuteContext { ctx =>
139+
jsonToBytes {
140+
IRParser.parse_matrix_ir(ctx, s).typ.toJSON
145141
}
146142
}
147143

148144
final def blockMatrixType(s: String): Array[Byte] =
149-
jsonToBytes {
150-
withExecuteContext { ctx =>
151-
IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
145+
withExecuteContext { ctx =>
146+
jsonToBytes {
147+
IRParser.parse_blockmatrix_ir(ctx, s).typ.toJSON
152148
}
153149
}
154150

hail/src/main/scala/is/hail/backend/BackendServer.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package is.hail.backend
22

3-
import is.hail.expr.ir.{IRParser, IRParserEnvironment}
3+
import is.hail.expr.ir.IRParser
44
import is.hail.utils._
55

66
import scala.util.control.NonFatal
@@ -89,10 +89,7 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler {
8989
backend.withExecuteContext { ctx =>
9090
val (res, timings) = ExecutionTimer.time { timer =>
9191
ctx.local(timer = timer) { ctx =>
92-
val irData = IRParser.parse_value_ir(
93-
irStr,
94-
IRParserEnvironment(ctx, irMap = backend.persistedIR.toMap),
95-
)
92+
val irData = IRParser.parse_value_ir(ctx, irStr)
9693
backend.execute(ctx, irData)
9794
}
9895
}

hail/src/main/scala/is/hail/backend/ExecuteContext.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import is.hail.{HailContext, HailFeatureFlags}
44
import is.hail.annotations.{Region, RegionPool}
55
import is.hail.asm4s.HailClassLoader
66
import is.hail.backend.local.LocalTaskContext
7-
import is.hail.expr.ir.{CodeCacheKey, CompiledFunction}
7+
import is.hail.expr.ir.{BaseIR, CodeCacheKey, CompiledFunction}
88
import is.hail.expr.ir.lowering.IrMetadata
99
import is.hail.io.fs.FS
1010
import is.hail.linalg.BlockMatrix
@@ -75,6 +75,7 @@ object ExecuteContext {
7575
irMetadata: IrMetadata,
7676
blockMatrixCache: mutable.Map[String, BlockMatrix],
7777
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
78+
irCache: mutable.Map[Int, BaseIR],
7879
)(
7980
f: ExecuteContext => T
8081
): T = {
@@ -95,6 +96,7 @@ object ExecuteContext {
9596
irMetadata,
9697
blockMatrixCache,
9798
codeCache,
99+
irCache,
98100
))(f(_))
99101
}
100102
}
@@ -126,6 +128,7 @@ class ExecuteContext(
126128
val irMetadata: IrMetadata,
127129
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
128130
val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
131+
val IrCache: mutable.Map[Int, BaseIR],
129132
) extends Closeable {
130133

131134
val rngNonce: Long =
@@ -194,6 +197,7 @@ class ExecuteContext(
194197
irMetadata: IrMetadata = this.irMetadata,
195198
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
196199
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache,
200+
irCache: mutable.Map[Int, BaseIR] = this.IrCache,
197201
)(
198202
f: ExecuteContext => A
199203
): A =
@@ -212,5 +216,6 @@ class ExecuteContext(
212216
irMetadata,
213217
blockMatrixCache,
214218
codeCache,
219+
irCache,
215220
))(f)
216221
}

hail/src/main/scala/is/hail/backend/local/LocalBackend.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class LocalBackend(
9292

9393
private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
9494
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
95+
private[this] val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()
9596

9697
// flags can be set after construction from python
9798
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
@@ -116,6 +117,7 @@ class LocalBackend(
116117
new IrMetadata(),
117118
ImmutableMap.empty,
118119
codeCache,
120+
persistedIR,
119121
)(f)
120122
}
121123

hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ import is.hail.HailFeatureFlags
44
import is.hail.backend.{Backend, ExecuteContext, NonOwningTempFileManager, TempFileManager}
55
import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex}
66
import is.hail.expr.ir.{
7-
BaseIR, BindingEnv, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IR, IRParser,
8-
IRParserEnvironment, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, Name,
9-
NativeReaderOptions, TableIR, TableLiteral, TableValue,
7+
BaseIR, BindingEnv, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IR, IRParser, Interpret,
8+
MatrixIR, MatrixNativeReader, MatrixRead, Name, NativeReaderOptions, TableIR, TableLiteral,
9+
TableValue,
1010
}
1111
import is.hail.expr.ir.IRParser.parseType
1212
import is.hail.expr.ir.functions.IRFunctionRegistry
@@ -34,7 +34,6 @@ import sourcecode.Enclosing
3434
trait Py4JBackendExtensions {
3535
def backend: Backend
3636
def references: mutable.Map[String, ReferenceGenome]
37-
def persistedIR: mutable.Map[Int, BaseIR]
3837
def flags: HailFeatureFlags
3938
def longLifeTempFileManager: TempFileManager
4039

@@ -54,14 +53,14 @@ trait Py4JBackendExtensions {
5453
irID
5554
}
5655

57-
private[this] def addJavaIR(ir: BaseIR): Int = {
56+
private[this] def addJavaIR(ctx: ExecuteContext, ir: BaseIR): Int = {
5857
val id = nextIRID()
59-
persistedIR += (id -> ir)
58+
ctx.IrCache += (id -> ir)
6059
id
6160
}
6261

6362
def pyRemoveJavaIR(id: Int): Unit =
64-
persistedIR.remove(id)
63+
backend.withExecuteContext(_.IrCache.remove(id))
6564

6665
def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit =
6766
backend.withExecuteContext { ctx =>
@@ -118,7 +117,7 @@ trait Py4JBackendExtensions {
118117
argTypeStrs: java.util.ArrayList[String],
119118
returnType: String,
120119
bodyStr: String,
121-
): Unit = {
120+
): Unit =
122121
backend.withExecuteContext { ctx =>
123122
IRFunctionRegistry.registerIR(
124123
ctx,
@@ -130,17 +129,16 @@ trait Py4JBackendExtensions {
130129
bodyStr,
131130
)
132131
}
133-
}
134132

135133
def pyExecuteLiteral(irStr: String): Int =
136134
backend.withExecuteContext { ctx =>
137-
val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap))
135+
val ir = IRParser.parse_value_ir(ctx, irStr)
138136
assert(ir.typ.isRealizable)
139137
backend.execute(ctx, ir) match {
140138
case Left(_) => throw new HailException("Can't create literal")
141139
case Right((pt, addr)) =>
142140
val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0)
143-
addJavaIR(field)
141+
addJavaIR(ctx, field)
144142
}
145143
}
146144

@@ -159,14 +157,14 @@ trait Py4JBackendExtensions {
159157
),
160158
ctx.theHailClassLoader,
161159
)
162-
val id = addJavaIR(tir)
160+
val id = addJavaIR(ctx, tir)
163161
(id, JsonMethods.compact(tir.typ.toJSON))
164162
}
165163
}
166164

167165
def pyToDF(s: String): DataFrame =
168166
backend.withExecuteContext { ctx =>
169-
val tir = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
167+
val tir = IRParser.parse_table_ir(ctx, s)
170168
Interpret(tir, ctx).toDF()
171169
}
172170

@@ -219,27 +217,23 @@ trait Py4JBackendExtensions {
219217
def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR =
220218
backend.withExecuteContext { ctx =>
221219
IRParser.parse_value_ir(
220+
ctx,
222221
s,
223-
IRParserEnvironment(ctx, irMap = persistedIR.toMap),
224222
BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) =>
225223
Name(n) -> IRParser.parseType(t)
226224
}.toSeq: _*),
227225
)
228226
}
229227

230228
def parse_table_ir(s: String): TableIR =
231-
withExecuteContext(selfContainedExecution = false) { ctx =>
232-
IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
233-
}
229+
withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_table_ir(ctx, s))
234230

235231
def parse_matrix_ir(s: String): MatrixIR =
236-
withExecuteContext(selfContainedExecution = false) { ctx =>
237-
IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
238-
}
232+
withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_matrix_ir(ctx, s))
239233

240234
def parse_blockmatrix_ir(s: String): BlockMatrixIR =
241235
withExecuteContext(selfContainedExecution = false) { ctx =>
242-
IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
236+
IRParser.parse_blockmatrix_ir(ctx, s)
243237
}
244238

245239
def loadReferencesFromDataset(path: String): Array[Byte] =

hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ class ServiceBackend(
392392
new IrMetadata(),
393393
ImmutableMap.empty,
394394
mutable.Map.empty,
395+
ImmutableMap.empty,
395396
)(f)
396397
}
397398

hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ class SparkBackend(
354354

355355
private[this] val bmCache = new BlockMatrixCache()
356356
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
357+
private[this] val persistedIr = mutable.Map.empty[Int, BaseIR]
357358

358359
def createExecuteContextForTests(
359360
timer: ExecutionTimer,
@@ -378,6 +379,7 @@ class SparkBackend(
378379
new IrMetadata(),
379380
ImmutableMap.empty,
380381
mutable.Map.empty,
382+
ImmutableMap.empty,
381383
)
382384

383385
override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T =
@@ -399,6 +401,7 @@ class SparkBackend(
399401
new IrMetadata(),
400402
bmCache,
401403
codeCache,
404+
persistedIr,
402405
)(f)
403406
}
404407

hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,14 @@ case class MatrixLiteral(typ: MatrixType, tl: TableLiteral) extends MatrixIR {
113113
}
114114

115115
object MatrixReader {
116-
def fromJson(env: IRParserEnvironment, jv: JValue): MatrixReader = {
116+
def fromJson(ctx: ExecuteContext, jv: JValue): MatrixReader = {
117117
implicit val formats: Formats = DefaultFormats
118118
(jv \ "name").extract[String] match {
119-
case "MatrixRangeReader" => MatrixRangeReader.fromJValue(env.ctx, jv)
120-
case "MatrixNativeReader" => MatrixNativeReader.fromJValue(env.ctx.fs, jv)
121-
case "MatrixBGENReader" => MatrixBGENReader.fromJValue(env, jv)
122-
case "MatrixPLINKReader" => MatrixPLINKReader.fromJValue(env.ctx, jv)
123-
case "MatrixVCFReader" => MatrixVCFReader.fromJValue(env.ctx, jv)
119+
case "MatrixRangeReader" => MatrixRangeReader.fromJValue(ctx, jv)
120+
case "MatrixNativeReader" => MatrixNativeReader.fromJValue(ctx.fs, jv)
121+
case "MatrixBGENReader" => MatrixBGENReader.fromJValue(ctx, jv)
122+
case "MatrixPLINKReader" => MatrixPLINKReader.fromJValue(ctx, jv)
123+
case "MatrixVCFReader" => MatrixVCFReader.fromJValue(ctx, jv)
124124
}
125125
}
126126

0 commit comments

Comments
 (0)