Skip to content

Commit

Permalink
[query] Remove lookupOrCompileCachedFunction from Backend interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Jan 29, 2025
1 parent 512cbb9 commit b9e29ee
Show file tree
Hide file tree
Showing 18 changed files with 156 additions and 167 deletions.
26 changes: 1 addition & 25 deletions hail/hail/src/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ import is.hail.asm4s._
import is.hail.backend.Backend.jsonToBytes
import is.hail.backend.spark.SparkBackend
import is.hail.expr.ir.{
BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses,
SortField, TableIR, TableReader,
BaseIR, IR, IRParser, IRParserEnvironment, LoweringAnalyses, SortField, TableIR, TableReader,
}
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
import is.hail.io.{BufferSpec, TypedCodecSpec}
Expand Down Expand Up @@ -92,9 +91,6 @@ abstract class Backend extends Closeable {

def shouldCacheQueryInfo: Boolean = true

def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T]

def lowerDistributedSort(
ctx: ExecuteContext,
stage: TableStage,
Expand Down Expand Up @@ -193,23 +189,3 @@ abstract class Backend extends Closeable {

def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]
}

trait BackendWithCodeCache {
private[this] val codeCache: Cache[CodeCacheKey, CompiledFunction[_]] = new Cache(50)

def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T] = {
codeCache.get(k) match {
case Some(v) => v.asInstanceOf[CompiledFunction[T]]
case None =>
val compiledFunction = f
codeCache += ((k, compiledFunction))
compiledFunction
}
}
}

trait BackendWithNoCodeCache {
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T] = f
}
6 changes: 6 additions & 0 deletions hail/hail/src/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import is.hail.{HailContext, HailFeatureFlags}
import is.hail.annotations.{Region, RegionPool}
import is.hail.asm4s.HailClassLoader
import is.hail.backend.local.LocalTaskContext
import is.hail.expr.ir.{CodeCacheKey, CompiledFunction}
import is.hail.expr.ir.lowering.IrMetadata
import is.hail.io.fs.FS
import is.hail.linalg.BlockMatrix
Expand Down Expand Up @@ -73,6 +74,7 @@ object ExecuteContext {
backendContext: BackendContext,
irMetadata: IrMetadata,
blockMatrixCache: mutable.Map[String, BlockMatrix],
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
)(
f: ExecuteContext => T
): T = {
Expand All @@ -92,6 +94,7 @@ object ExecuteContext {
backendContext,
irMetadata,
blockMatrixCache,
codeCache,
))(f(_))
}
}
Expand Down Expand Up @@ -122,6 +125,7 @@ class ExecuteContext(
val backendContext: BackendContext,
val irMetadata: IrMetadata,
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
) extends Closeable {

val rngNonce: Long =
Expand Down Expand Up @@ -189,6 +193,7 @@ class ExecuteContext(
backendContext: BackendContext = this.backendContext,
irMetadata: IrMetadata = this.irMetadata,
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache,
)(
f: ExecuteContext => A
): A =
Expand All @@ -206,5 +211,6 @@ class ExecuteContext(
backendContext,
irMetadata,
blockMatrixCache,
codeCache,
))(f)
}
7 changes: 5 additions & 2 deletions hail/hail/src/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import is.hail.backend.py4j.Py4JBackendExtensions
import is.hail.expr.Validate
import is.hail.expr.ir._
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.defs.MakeTuple
import is.hail.expr.ir.lowering._
import is.hail.io.fs._
Expand Down Expand Up @@ -84,13 +85,14 @@ object LocalBackend {
class LocalBackend(
val tmpdir: String,
override val references: mutable.Map[String, ReferenceGenome],
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {
) extends Backend with Py4JBackendExtensions {

override def backend: Backend = this
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
override def longLifeTempFileManager: TempFileManager = null

private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader())
private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)

// flags can be set after construction from python
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
Expand All @@ -114,6 +116,7 @@ class LocalBackend(
},
new IrMetadata(),
ImmutableMap.empty,
codeCache,
)(f)
}

Expand Down
8 changes: 4 additions & 4 deletions hail/hail/src/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import is.hail.asm4s._
import is.hail.backend._
import is.hail.expr.Validate
import is.hail.expr.ir.{
Compile, IR, IRParser, IRSize, LoweringAnalyses, SortField, TableIR, TableReader, TypeCheck,
IR, IRParser, IRSize, LoweringAnalyses, SortField, TableIR, TableReader, TypeCheck,
}
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.defs.MakeTuple
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.ir.lowering._
Expand Down Expand Up @@ -51,7 +52,6 @@ class ServiceBackendContext(
) extends BackendContext with Serializable {}

object ServiceBackend {
private val log = Logger.getLogger(getClass.getName())

def apply(
jarLocation: String,
Expand Down Expand Up @@ -130,8 +130,7 @@ class ServiceBackend(
val fs: FS,
val serviceBackendContext: ServiceBackendContext,
val scratchDir: String,
) extends Backend with BackendWithNoCodeCache {
import ServiceBackend.log
) extends Backend with Logging {

private[this] var stageCount = 0
private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000
Expand Down Expand Up @@ -393,6 +392,7 @@ class ServiceBackend(
serviceBackendContext,
new IrMetadata(),
ImmutableMap.empty,
mutable.Map.empty,
)(f)
}

Expand Down
11 changes: 8 additions & 3 deletions hail/hail/src/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@ import is.hail.{HailContext, HailFeatureFlags}
import is.hail.annotations._
import is.hail.asm4s._
import is.hail.backend._
import is.hail.backend.caching.BlockMatrixCache
import is.hail.backend.py4j.Py4JBackendExtensions
import is.hail.expr.Validate
import is.hail.expr.ir._
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.defs.MakeTuple
import is.hail.expr.ir.lowering._
import is.hail.io.{BufferSpec, TypedCodecSpec}
import is.hail.io.fs._
import is.hail.linalg.BlockMatrix
import is.hail.rvd.RVD
import is.hail.types._
import is.hail.types.physical.{PStruct, PTuple}
Expand All @@ -26,9 +27,10 @@ import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionException
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import java.io.PrintWriter

import com.fasterxml.jackson.core.StreamReadConstraints
import is.hail.linalg.BlockMatrix
import org.apache.hadoop
import org.apache.hadoop.conf.Configuration
import org.apache.spark._
Expand Down Expand Up @@ -320,7 +322,7 @@ class SparkBackend(
override val references: mutable.Map[String, ReferenceGenome],
gcsRequesterPaysProject: String,
gcsRequesterPaysBuckets: String,
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {
) extends Backend with Py4JBackendExtensions {

assert(gcsRequesterPaysProject != null || gcsRequesterPaysBuckets == null)
lazy val sparkSession: SparkSession = SparkSession.builder().config(sc.getConf).getOrCreate()
Expand Down Expand Up @@ -352,6 +354,7 @@ class SparkBackend(
new OwningTempFileManager(fs)

private[this] val bmCache = mutable.Map.empty[String, BlockMatrix]
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)

def createExecuteContextForTests(
timer: ExecutionTimer,
Expand All @@ -375,6 +378,7 @@ class SparkBackend(
},
new IrMetadata(),
ImmutableMap.empty,
mutable.Map.empty,
)

override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T =
Expand All @@ -395,6 +399,7 @@ class SparkBackend(
},
new IrMetadata(),
bmCache,
codeCache,
)(f)
}

Expand Down
Loading

0 comments on commit b9e29ee

Please sign in to comment.