Skip to content

Commit d970a8b

Browse files
committed
[query] Remove lookupOrCompileCachedFunction from Backend interface
1 parent 5db76b7 commit d970a8b

17 files changed

+145
-161
lines changed

hail/hail/src/is/hail/backend/Backend.scala

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ import is.hail.asm4s._
44
import is.hail.backend.Backend.jsonToBytes
55
import is.hail.backend.spark.SparkBackend
66
import is.hail.expr.ir.{
7-
BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses,
8-
SortField, TableIR, TableReader,
7+
BaseIR, IR, IRParser, IRParserEnvironment, LoweringAnalyses, SortField, TableIR, TableReader,
98
}
109
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
1110
import is.hail.io.{BufferSpec, TypedCodecSpec}
@@ -92,9 +91,6 @@ abstract class Backend extends Closeable {
9291

9392
def shouldCacheQueryInfo: Boolean = true
9493

95-
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
96-
: CompiledFunction[T]
97-
9894
def lowerDistributedSort(
9995
ctx: ExecuteContext,
10096
stage: TableStage,
@@ -193,23 +189,3 @@ abstract class Backend extends Closeable {
193189

194190
def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]
195191
}
196-
197-
trait BackendWithCodeCache {
198-
private[this] val codeCache: Cache[CodeCacheKey, CompiledFunction[_]] = new Cache(50)
199-
200-
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
201-
: CompiledFunction[T] = {
202-
codeCache.get(k) match {
203-
case Some(v) => v.asInstanceOf[CompiledFunction[T]]
204-
case None =>
205-
val compiledFunction = f
206-
codeCache += ((k, compiledFunction))
207-
compiledFunction
208-
}
209-
}
210-
}
211-
212-
trait BackendWithNoCodeCache {
213-
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
214-
: CompiledFunction[T] = f
215-
}

hail/hail/src/is/hail/backend/ExecuteContext.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import is.hail.{HailContext, HailFeatureFlags}
44
import is.hail.annotations.{Region, RegionPool}
55
import is.hail.asm4s.HailClassLoader
66
import is.hail.backend.local.LocalTaskContext
7+
import is.hail.expr.ir.{CodeCacheKey, CompiledFunction}
78
import is.hail.expr.ir.lowering.IrMetadata
89
import is.hail.io.fs.FS
910
import is.hail.linalg.BlockMatrix
@@ -73,6 +74,7 @@ object ExecuteContext {
7374
backendContext: BackendContext,
7475
irMetadata: IrMetadata,
7576
blockMatrixCache: mutable.Map[String, BlockMatrix],
77+
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
7678
)(
7779
f: ExecuteContext => T
7880
): T = {
@@ -92,6 +94,7 @@ object ExecuteContext {
9294
backendContext,
9395
irMetadata,
9496
blockMatrixCache,
97+
codeCache,
9598
))(f(_))
9699
}
97100
}
@@ -122,6 +125,7 @@ class ExecuteContext(
122125
val backendContext: BackendContext,
123126
val irMetadata: IrMetadata,
124127
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
128+
val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
125129
) extends Closeable {
126130

127131
val rngNonce: Long =
@@ -189,6 +193,7 @@ class ExecuteContext(
189193
backendContext: BackendContext = this.backendContext,
190194
irMetadata: IrMetadata = this.irMetadata,
191195
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
196+
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache,
192197
)(
193198
f: ExecuteContext => A
194199
): A =
@@ -206,5 +211,6 @@ class ExecuteContext(
206211
backendContext,
207212
irMetadata,
208213
blockMatrixCache,
214+
codeCache,
209215
))(f)
210216
}

hail/hail/src/is/hail/backend/local/LocalBackend.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import is.hail.backend.py4j.Py4JBackendExtensions
88
import is.hail.expr.Validate
99
import is.hail.expr.ir._
1010
import is.hail.expr.ir.analyses.SemanticHash
11+
import is.hail.expr.ir.compile.Compile
1112
import is.hail.expr.ir.defs.MakeTuple
1213
import is.hail.expr.ir.lowering._
1314
import is.hail.io.fs._
@@ -84,13 +85,14 @@ object LocalBackend {
8485
class LocalBackend(
8586
val tmpdir: String,
8687
override val references: mutable.Map[String, ReferenceGenome],
87-
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {
88+
) extends Backend with Py4JBackendExtensions {
8889

8990
override def backend: Backend = this
9091
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
9192
override def longLifeTempFileManager: TempFileManager = null
9293

93-
private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader())
94+
private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
95+
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
9496

9597
// flags can be set after construction from python
9698
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
@@ -114,6 +116,7 @@ class LocalBackend(
114116
},
115117
new IrMetadata(),
116118
ImmutableMap.empty,
119+
codeCache,
117120
)(f)
118121
}
119122

hail/hail/src/is/hail/backend/service/ServiceBackend.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ import is.hail.asm4s._
66
import is.hail.backend._
77
import is.hail.expr.Validate
88
import is.hail.expr.ir.{
9-
Compile, IR, IRParser, IRSize, LoweringAnalyses, SortField, TableIR, TableReader, TypeCheck,
9+
IR, IRParser, IRSize, LoweringAnalyses, SortField, TableIR, TableReader, TypeCheck,
1010
}
1111
import is.hail.expr.ir.analyses.SemanticHash
12+
import is.hail.expr.ir.compile.Compile
1213
import is.hail.expr.ir.defs.MakeTuple
1314
import is.hail.expr.ir.functions.IRFunctionRegistry
1415
import is.hail.expr.ir.lowering._
@@ -51,7 +52,6 @@ class ServiceBackendContext(
5152
) extends BackendContext with Serializable {}
5253

5354
object ServiceBackend {
54-
private val log = Logger.getLogger(getClass.getName())
5555

5656
def apply(
5757
jarLocation: String,
@@ -130,8 +130,7 @@ class ServiceBackend(
130130
val fs: FS,
131131
val serviceBackendContext: ServiceBackendContext,
132132
val scratchDir: String,
133-
) extends Backend with BackendWithNoCodeCache {
134-
import ServiceBackend.log
133+
) extends Backend with Logging {
135134

136135
private[this] var stageCount = 0
137136
private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000
@@ -393,6 +392,7 @@ class ServiceBackend(
393392
serviceBackendContext,
394393
new IrMetadata(),
395394
ImmutableMap.empty,
395+
mutable.Map.empty,
396396
)(f)
397397
}
398398

hail/hail/src/is/hail/backend/spark/SparkBackend.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import is.hail.backend.py4j.Py4JBackendExtensions
88
import is.hail.expr.Validate
99
import is.hail.expr.ir._
1010
import is.hail.expr.ir.analyses.SemanticHash
11+
import is.hail.expr.ir.compile.Compile
1112
import is.hail.expr.ir.defs.MakeTuple
1213
import is.hail.expr.ir.lowering._
1314
import is.hail.io.{BufferSpec, TypedCodecSpec}
@@ -321,7 +322,7 @@ class SparkBackend(
321322
override val references: mutable.Map[String, ReferenceGenome],
322323
gcsRequesterPaysProject: String,
323324
gcsRequesterPaysBuckets: String,
324-
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {
325+
) extends Backend with Py4JBackendExtensions {
325326

326327
assert(gcsRequesterPaysProject != null || gcsRequesterPaysBuckets == null)
327328
lazy val sparkSession: SparkSession = SparkSession.builder().config(sc.getConf).getOrCreate()
@@ -353,6 +354,7 @@ class SparkBackend(
353354
new OwningTempFileManager(fs)
354355

355356
private[this] val bmCache = mutable.Map.empty[String, BlockMatrix]
357+
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
356358

357359
def createExecuteContextForTests(
358360
timer: ExecutionTimer,
@@ -376,6 +378,7 @@ class SparkBackend(
376378
},
377379
new IrMetadata(),
378380
ImmutableMap.empty,
381+
ImmutableMap.empty,
379382
)
380383

381384
override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T =
@@ -396,6 +399,7 @@ class SparkBackend(
396399
},
397400
new IrMetadata(),
398401
bmCache,
402+
codeCache,
399403
)(f)
400404
}
401405

0 commit comments

Comments
 (0)