Skip to content

Commit

Permalink
[query] Move LoweredTableReaderCoercer into ExecuteContext
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Jan 29, 2025
1 parent 0b07e0a commit 006475e
Show file tree
Hide file tree
Showing 7 changed files with 359 additions and 376 deletions.
2 changes: 0 additions & 2 deletions hail/hail/src/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ abstract class Backend extends Closeable {
def asSpark(op: String): SparkBackend =
fatal(s"${getClass.getSimpleName}: $op requires SparkBackend")

def shouldCacheQueryInfo: Boolean = true

def lowerDistributedSort(
ctx: ExecuteContext,
stage: TableStage,
Expand Down
6 changes: 6 additions & 0 deletions hail/hail/src/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import is.hail.annotations.{Region, RegionPool}
import is.hail.asm4s.HailClassLoader
import is.hail.backend.local.LocalTaskContext
import is.hail.expr.ir.{BaseIR, CodeCacheKey, CompiledFunction}
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
import is.hail.expr.ir.lowering.IrMetadata
import is.hail.io.fs.FS
import is.hail.linalg.BlockMatrix
Expand Down Expand Up @@ -76,6 +77,7 @@ object ExecuteContext {
blockMatrixCache: mutable.Map[String, BlockMatrix],
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
irCache: mutable.Map[Int, BaseIR],
coercerCache: mutable.Map[Any, LoweredTableReaderCoercer],
)(
f: ExecuteContext => T
): T = {
Expand All @@ -97,6 +99,7 @@ object ExecuteContext {
blockMatrixCache,
codeCache,
irCache,
coercerCache,
))(f(_))
}
}
Expand Down Expand Up @@ -129,6 +132,7 @@ class ExecuteContext(
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
val IrCache: mutable.Map[Int, BaseIR],
val CoercerCache: mutable.Map[Any, LoweredTableReaderCoercer],
) extends Closeable {

val rngNonce: Long =
Expand Down Expand Up @@ -198,6 +202,7 @@ class ExecuteContext(
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache,
irCache: mutable.Map[Int, BaseIR] = this.IrCache,
coercerCache: mutable.Map[Any, LoweredTableReaderCoercer] = this.CoercerCache,
)(
f: ExecuteContext => A
): A =
Expand All @@ -217,5 +222,6 @@ class ExecuteContext(
blockMatrixCache,
codeCache,
irCache,
coercerCache,
))(f)
}
3 changes: 3 additions & 0 deletions hail/hail/src/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import is.hail.backend._
import is.hail.backend.py4j.Py4JBackendExtensions
import is.hail.expr.Validate
import is.hail.expr.ir._
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.defs.MakeTuple
Expand Down Expand Up @@ -94,6 +95,7 @@ class LocalBackend(
private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
private[this] val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()
private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32)

// flags can be set after construction from python
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
Expand All @@ -119,6 +121,7 @@ class LocalBackend(
ImmutableMap.empty,
codeCache,
persistedIR,
coercerCache,
)(f)
}

Expand Down
5 changes: 2 additions & 3 deletions hail/hail/src/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ class ServiceBackend(
private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000
private[this] val executor = Executors.newFixedThreadPool(MAX_AVAILABLE_GCS_CONNECTIONS)

override def shouldCacheQueryInfo: Boolean = false

def defaultParallelism: Int = 4

def broadcast[T: ClassTag](_value: T): BroadcastValue[T] = {
Expand Down Expand Up @@ -392,7 +390,8 @@ class ServiceBackend(
serviceBackendContext,
new IrMetadata(),
ImmutableMap.empty,
mutable.Map.empty,
ImmutableMap.empty,
ImmutableMap.empty,
ImmutableMap.empty,
)(f)
}
Expand Down
6 changes: 5 additions & 1 deletion hail/hail/src/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import is.hail.backend._
import is.hail.backend.py4j.Py4JBackendExtensions
import is.hail.expr.Validate
import is.hail.expr.ir._
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.defs.MakeTuple
Expand Down Expand Up @@ -356,6 +357,7 @@ class SparkBackend(
private[this] val bmCache = mutable.Map.empty[String, BlockMatrix]
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
private[this] val persistedIr = mutable.Map.empty[Int, BaseIR]
private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32)

def createExecuteContextForTests(
timer: ExecutionTimer,
Expand All @@ -379,7 +381,8 @@ class SparkBackend(
},
new IrMetadata(),
ImmutableMap.empty,
mutable.Map.empty,
ImmutableMap.empty,
ImmutableMap.empty,
ImmutableMap.empty,
)

Expand All @@ -403,6 +406,7 @@ class SparkBackend(
bmCache,
codeCache,
persistedIr,
coercerCache,
)(f)
}

Expand Down
35 changes: 12 additions & 23 deletions hail/hail/src/is/hail/expr/ir/GenericTableValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package is.hail.expr.ir
import is.hail.annotations.Region
import is.hail.asm4s._
import is.hail.backend.ExecuteContext
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
import is.hail.expr.ir.defs.{Literal, PartitionReader, ReadPartition, ToStream}
import is.hail.expr.ir.functions.UtilFunctions
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
Expand Down Expand Up @@ -144,16 +145,6 @@ class PartitionIteratorLongReader(
)
}

abstract class LoweredTableReaderCoercer {
def coerce(
ctx: ExecuteContext,
globals: IR,
contextType: Type,
contexts: IndexedSeq[Any],
body: IR => IR,
): TableStage
}

class GenericTableValue(
val fullTableType: TableType,
val uidFieldName: String,
Expand All @@ -169,12 +160,11 @@ class GenericTableValue(
assert(contextType.hasField("partitionIndex"))
assert(contextType.fieldType("partitionIndex") == TInt32)

private var ltrCoercer: LoweredTableReaderCoercer = _

private def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any)
: LoweredTableReaderCoercer = {
if (ltrCoercer == null) {
ltrCoercer = LoweredTableReader.makeCoercer(
: LoweredTableReaderCoercer =
ctx.CoercerCache.getOrElseUpdate(
(1, contextType, fullTableType.key, cacheKey),
LoweredTableReader.makeCoercer(
ctx,
fullTableType.key,
1,
Expand All @@ -185,11 +175,8 @@ class GenericTableValue(
bodyPType,
body,
context,
cacheKey,
)
}
ltrCoercer
}
),
)

def toTableStage(ctx: ExecuteContext, requestedType: TableType, context: String, cacheKey: Any)
: TableStage = {
Expand Down Expand Up @@ -218,11 +205,13 @@ class GenericTableValue(
val contextsIR = ToStream(Literal(TArray(contextType), contexts))
TableStage(globalsIR, p, TableStageDependency.none, contextsIR, requestedBody)
} else {
getLTVCoercer(ctx, context, cacheKey).coerce(
getLTVCoercer(ctx, context, cacheKey)(
ctx,
globalsIR,
contextType, contexts,
requestedBody)
contextType,
contexts,
requestedBody,
)
}
}
}
Loading

0 comments on commit 006475e

Please sign in to comment.