Skip to content

Commit 9397c47

Browse files
committed
[query] Remove persistedIr from Backend interface
1 parent d56d118 commit 9397c47

File tree

12 files changed

+463
-496
lines changed

12 files changed

+463
-496
lines changed

hail/hail/src/is/hail/backend/Backend.scala

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@ package is.hail.backend
33
import is.hail.asm4s._
44
import is.hail.backend.Backend.jsonToBytes
55
import is.hail.backend.spark.SparkBackend
6-
import is.hail.expr.ir.{
7-
BaseIR, IR, IRParser, IRParserEnvironment, LoweringAnalyses, SortField, TableIR, TableReader,
8-
}
6+
import is.hail.expr.ir.{IR, IRParser, LoweringAnalyses, SortField, TableIR, TableReader}
97
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
108
import is.hail.io.{BufferSpec, TypedCodecSpec}
119
import is.hail.io.fs._
@@ -18,7 +16,6 @@ import is.hail.types.virtual.TFloat64
1816
import is.hail.utils._
1917
import is.hail.variant.ReferenceGenome
2018

21-
import scala.collection.mutable
2219
import scala.reflect.ClassTag
2320

2421
import java.io._
@@ -67,7 +64,6 @@ trait BackendContext {
6764
}
6865

6966
abstract class Backend extends Closeable {
70-
val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()
7167

7268
def defaultParallelism: Int
7369

@@ -125,30 +121,30 @@ abstract class Backend extends Closeable {
125121
def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T
126122

127123
final def valueType(s: String): Array[Byte] =
128-
jsonToBytes {
129-
withExecuteContext { ctx =>
130-
IRParser.parse_value_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
124+
withExecuteContext { ctx =>
125+
jsonToBytes {
126+
IRParser.parse_value_ir(ctx, s).typ.toJSON
131127
}
132128
}
133129

134130
final def tableType(s: String): Array[Byte] =
135-
jsonToBytes {
136-
withExecuteContext { ctx =>
137-
IRParser.parse_table_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
131+
withExecuteContext { ctx =>
132+
jsonToBytes {
133+
IRParser.parse_table_ir(ctx, s).typ.toJSON
138134
}
139135
}
140136

141137
final def matrixTableType(s: String): Array[Byte] =
142-
jsonToBytes {
143-
withExecuteContext { ctx =>
144-
IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
138+
withExecuteContext { ctx =>
139+
jsonToBytes {
140+
IRParser.parse_matrix_ir(ctx, s).typ.toJSON
145141
}
146142
}
147143

148144
final def blockMatrixType(s: String): Array[Byte] =
149-
jsonToBytes {
150-
withExecuteContext { ctx =>
151-
IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
145+
withExecuteContext { ctx =>
146+
jsonToBytes {
147+
IRParser.parse_blockmatrix_ir(ctx, s).typ.toJSON
152148
}
153149
}
154150

hail/hail/src/is/hail/backend/BackendServer.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package is.hail.backend
22

3-
import is.hail.expr.ir.{IRParser, IRParserEnvironment}
3+
import is.hail.expr.ir.IRParser
44
import is.hail.utils._
55

66
import scala.util.control.NonFatal
@@ -89,10 +89,7 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler {
8989
backend.withExecuteContext { ctx =>
9090
val (res, timings) = ExecutionTimer.time { timer =>
9191
ctx.local(timer = timer) { ctx =>
92-
val irData = IRParser.parse_value_ir(
93-
irStr,
94-
IRParserEnvironment(ctx, irMap = backend.persistedIR.toMap),
95-
)
92+
val irData = IRParser.parse_value_ir(ctx, irStr)
9693
backend.execute(ctx, irData)
9794
}
9895
}

hail/hail/src/is/hail/backend/ExecuteContext.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import is.hail.{HailContext, HailFeatureFlags}
44
import is.hail.annotations.{Region, RegionPool}
55
import is.hail.asm4s.HailClassLoader
66
import is.hail.backend.local.LocalTaskContext
7-
import is.hail.expr.ir.{CodeCacheKey, CompiledFunction}
7+
import is.hail.expr.ir.{BaseIR, CodeCacheKey, CompiledFunction}
88
import is.hail.expr.ir.lowering.IrMetadata
99
import is.hail.io.fs.FS
1010
import is.hail.linalg.BlockMatrix
@@ -75,6 +75,7 @@ object ExecuteContext {
7575
irMetadata: IrMetadata,
7676
blockMatrixCache: mutable.Map[String, BlockMatrix],
7777
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
78+
irCache: mutable.Map[Int, BaseIR],
7879
)(
7980
f: ExecuteContext => T
8081
): T = {
@@ -95,6 +96,7 @@ object ExecuteContext {
9596
irMetadata,
9697
blockMatrixCache,
9798
codeCache,
99+
irCache,
98100
))(f(_))
99101
}
100102
}
@@ -126,6 +128,7 @@ class ExecuteContext(
126128
val irMetadata: IrMetadata,
127129
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
128130
val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
131+
val IrCache: mutable.Map[Int, BaseIR],
129132
) extends Closeable {
130133

131134
val rngNonce: Long =
@@ -194,6 +197,7 @@ class ExecuteContext(
194197
irMetadata: IrMetadata = this.irMetadata,
195198
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
196199
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache,
200+
irCache: mutable.Map[Int, BaseIR] = this.IrCache,
197201
)(
198202
f: ExecuteContext => A
199203
): A =
@@ -212,5 +216,6 @@ class ExecuteContext(
212216
irMetadata,
213217
blockMatrixCache,
214218
codeCache,
219+
irCache,
215220
))(f)
216221
}

hail/hail/src/is/hail/backend/local/LocalBackend.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class LocalBackend(
9393

9494
private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
9595
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
96+
private[this] val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()
9697

9798
// flags can be set after construction from python
9899
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
@@ -117,6 +118,7 @@ class LocalBackend(
117118
new IrMetadata(),
118119
ImmutableMap.empty,
119120
codeCache,
121+
persistedIR,
120122
)(f)
121123
}
122124

hail/hail/src/is/hail/backend/py4j/Py4JBackendExtensions.scala

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ import is.hail.HailFeatureFlags
44
import is.hail.backend.{Backend, ExecuteContext, NonOwningTempFileManager, TempFileManager}
55
import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex}
66
import is.hail.expr.ir.{
7-
BaseIR, BindingEnv, BlockMatrixIR, IR, IRParser, IRParserEnvironment, Interpret, MatrixIR,
8-
MatrixNativeReader, MatrixRead, Name, NativeReaderOptions, TableIR, TableLiteral, TableValue,
7+
BaseIR, BindingEnv, BlockMatrixIR, IR, IRParser, Interpret, MatrixIR, MatrixNativeReader,
8+
MatrixRead, Name, NativeReaderOptions, TableIR, TableLiteral, TableValue,
99
}
1010
import is.hail.expr.ir.IRParser.parseType
1111
import is.hail.expr.ir.defs.{EncodedLiteral, GetFieldByIdx}
@@ -34,7 +34,6 @@ import sourcecode.Enclosing
3434
trait Py4JBackendExtensions {
3535
def backend: Backend
3636
def references: mutable.Map[String, ReferenceGenome]
37-
def persistedIR: mutable.Map[Int, BaseIR]
3837
def flags: HailFeatureFlags
3938
def longLifeTempFileManager: TempFileManager
4039

@@ -54,14 +53,14 @@ trait Py4JBackendExtensions {
5453
irID
5554
}
5655

57-
private[this] def addJavaIR(ir: BaseIR): Int = {
56+
private[this] def addJavaIR(ctx: ExecuteContext, ir: BaseIR): Int = {
5857
val id = nextIRID()
59-
persistedIR += (id -> ir)
58+
ctx.IrCache += (id -> ir)
6059
id
6160
}
6261

6362
def pyRemoveJavaIR(id: Int): Unit =
64-
persistedIR.remove(id)
63+
backend.withExecuteContext(_.IrCache.remove(id))
6564

6665
def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit =
6766
backend.withExecuteContext { ctx =>
@@ -118,7 +117,7 @@ trait Py4JBackendExtensions {
118117
argTypeStrs: java.util.ArrayList[String],
119118
returnType: String,
120119
bodyStr: String,
121-
): Unit = {
120+
): Unit =
122121
backend.withExecuteContext { ctx =>
123122
IRFunctionRegistry.registerIR(
124123
ctx,
@@ -130,17 +129,16 @@ trait Py4JBackendExtensions {
130129
bodyStr,
131130
)
132131
}
133-
}
134132

135133
def pyExecuteLiteral(irStr: String): Int =
136134
backend.withExecuteContext { ctx =>
137-
val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap))
135+
val ir = IRParser.parse_value_ir(ctx, irStr)
138136
assert(ir.typ.isRealizable)
139137
backend.execute(ctx, ir) match {
140138
case Left(_) => throw new HailException("Can't create literal")
141139
case Right((pt, addr)) =>
142140
val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0)
143-
addJavaIR(field)
141+
addJavaIR(ctx, field)
144142
}
145143
}
146144

@@ -159,14 +157,14 @@ trait Py4JBackendExtensions {
159157
),
160158
ctx.theHailClassLoader,
161159
)
162-
val id = addJavaIR(tir)
160+
val id = addJavaIR(ctx, tir)
163161
(id, JsonMethods.compact(tir.typ.toJSON))
164162
}
165163
}
166164

167165
def pyToDF(s: String): DataFrame =
168166
backend.withExecuteContext { ctx =>
169-
val tir = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
167+
val tir = IRParser.parse_table_ir(ctx, s)
170168
Interpret(tir, ctx).toDF()
171169
}
172170

@@ -219,27 +217,23 @@ trait Py4JBackendExtensions {
219217
def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR =
220218
backend.withExecuteContext { ctx =>
221219
IRParser.parse_value_ir(
220+
ctx,
222221
s,
223-
IRParserEnvironment(ctx, irMap = persistedIR.toMap),
224222
BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) =>
225223
Name(n) -> IRParser.parseType(t)
226224
}.toSeq: _*),
227225
)
228226
}
229227

230228
def parse_table_ir(s: String): TableIR =
231-
withExecuteContext(selfContainedExecution = false) { ctx =>
232-
IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
233-
}
229+
withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_table_ir(ctx, s))
234230

235231
def parse_matrix_ir(s: String): MatrixIR =
236-
withExecuteContext(selfContainedExecution = false) { ctx =>
237-
IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
238-
}
232+
withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_matrix_ir(ctx, s))
239233

240234
def parse_blockmatrix_ir(s: String): BlockMatrixIR =
241235
withExecuteContext(selfContainedExecution = false) { ctx =>
242-
IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
236+
IRParser.parse_blockmatrix_ir(ctx, s)
243237
}
244238

245239
def loadReferencesFromDataset(path: String): Array[Byte] =

hail/hail/src/is/hail/backend/service/ServiceBackend.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ class ServiceBackend(
399399
new IrMetadata(),
400400
ImmutableMap.empty,
401401
mutable.Map.empty,
402+
ImmutableMap.empty,
402403
)(f)
403404
}
404405

hail/hail/src/is/hail/backend/spark/SparkBackend.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ class SparkBackend(
355355

356356
private[this] val bmCache = mutable.Map.empty[String, BlockMatrix]
357357
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
358+
private[this] val persistedIr = mutable.Map.empty[Int, BaseIR]
358359

359360
def createExecuteContextForTests(
360361
timer: ExecutionTimer,
@@ -379,6 +380,7 @@ class SparkBackend(
379380
new IrMetadata(),
380381
ImmutableMap.empty,
381382
ImmutableMap.empty,
383+
ImmutableMap.empty,
382384
)
383385

384386
override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T =
@@ -400,6 +402,7 @@ class SparkBackend(
400402
new IrMetadata(),
401403
bmCache,
402404
codeCache,
405+
persistedIr,
403406
)(f)
404407
}
405408

hail/hail/src/is/hail/expr/ir/MatrixIR.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,14 @@ case class MatrixLiteral(typ: MatrixType, tl: TableLiteral) extends MatrixIR {
114114
}
115115

116116
object MatrixReader {
117-
def fromJson(env: IRParserEnvironment, jv: JValue): MatrixReader = {
117+
def fromJson(ctx: ExecuteContext, jv: JValue): MatrixReader = {
118118
implicit val formats: Formats = DefaultFormats
119119
(jv \ "name").extract[String] match {
120-
case "MatrixRangeReader" => MatrixRangeReader.fromJValue(env.ctx, jv)
121-
case "MatrixNativeReader" => MatrixNativeReader.fromJValue(env.ctx.fs, jv)
122-
case "MatrixBGENReader" => MatrixBGENReader.fromJValue(env, jv)
123-
case "MatrixPLINKReader" => MatrixPLINKReader.fromJValue(env.ctx, jv)
124-
case "MatrixVCFReader" => MatrixVCFReader.fromJValue(env.ctx, jv)
120+
case "MatrixRangeReader" => MatrixRangeReader.fromJValue(ctx, jv)
121+
case "MatrixNativeReader" => MatrixNativeReader.fromJValue(ctx.fs, jv)
122+
case "MatrixBGENReader" => MatrixBGENReader.fromJValue(ctx, jv)
123+
case "MatrixPLINKReader" => MatrixPLINKReader.fromJValue(ctx, jv)
124+
case "MatrixVCFReader" => MatrixVCFReader.fromJValue(ctx, jv)
125125
}
126126
}
127127

0 commit comments

Comments
 (0)