Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Move LoweredTableReaderCoercer into ExecuteContext #14696

Open
wants to merge 1 commit into
base: ehigham/ctx-persisted-ir
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions hail/hail/src/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ abstract class Backend extends Closeable {
def asSpark(op: String): SparkBackend =
fatal(s"${getClass.getSimpleName}: $op requires SparkBackend")

def shouldCacheQueryInfo: Boolean = true

def lowerDistributedSort(
ctx: ExecuteContext,
stage: TableStage,
Expand Down
6 changes: 6 additions & 0 deletions hail/hail/src/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import is.hail.annotations.{Region, RegionPool}
import is.hail.asm4s.HailClassLoader
import is.hail.backend.local.LocalTaskContext
import is.hail.expr.ir.{BaseIR, CodeCacheKey, CompiledFunction}
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
import is.hail.expr.ir.lowering.IrMetadata
import is.hail.io.fs.FS
import is.hail.linalg.BlockMatrix
Expand Down Expand Up @@ -76,6 +77,7 @@ object ExecuteContext {
blockMatrixCache: mutable.Map[String, BlockMatrix],
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
irCache: mutable.Map[Int, BaseIR],
coercerCache: mutable.Map[Any, LoweredTableReaderCoercer],
)(
f: ExecuteContext => T
): T = {
Expand All @@ -97,6 +99,7 @@ object ExecuteContext {
blockMatrixCache,
codeCache,
irCache,
coercerCache,
))(f(_))
}
}
Expand Down Expand Up @@ -129,6 +132,7 @@ class ExecuteContext(
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
val IrCache: mutable.Map[Int, BaseIR],
val CoercerCache: mutable.Map[Any, LoweredTableReaderCoercer],
) extends Closeable {

val rngNonce: Long =
Expand Down Expand Up @@ -198,6 +202,7 @@ class ExecuteContext(
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache,
irCache: mutable.Map[Int, BaseIR] = this.IrCache,
coercerCache: mutable.Map[Any, LoweredTableReaderCoercer] = this.CoercerCache,
)(
f: ExecuteContext => A
): A =
Expand All @@ -217,5 +222,6 @@ class ExecuteContext(
blockMatrixCache,
codeCache,
irCache,
coercerCache,
))(f)
}
3 changes: 3 additions & 0 deletions hail/hail/src/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import is.hail.backend._
import is.hail.backend.py4j.Py4JBackendExtensions
import is.hail.expr.Validate
import is.hail.expr.ir._
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.defs.MakeTuple
Expand Down Expand Up @@ -94,6 +95,7 @@ class LocalBackend(
private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
private[this] val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()
private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32)

// flags can be set after construction from python
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
Expand All @@ -119,6 +121,7 @@ class LocalBackend(
ImmutableMap.empty,
codeCache,
persistedIR,
coercerCache,
)(f)
}

Expand Down
5 changes: 2 additions & 3 deletions hail/hail/src/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ class ServiceBackend(
private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000
private[this] val executor = Executors.newFixedThreadPool(MAX_AVAILABLE_GCS_CONNECTIONS)

override def shouldCacheQueryInfo: Boolean = false

def defaultParallelism: Int = 4

def broadcast[T: ClassTag](_value: T): BroadcastValue[T] = {
Expand Down Expand Up @@ -392,7 +390,8 @@ class ServiceBackend(
serviceBackendContext,
new IrMetadata(),
ImmutableMap.empty,
mutable.Map.empty,
ImmutableMap.empty,
ImmutableMap.empty,
ImmutableMap.empty,
)(f)
}
Expand Down
4 changes: 4 additions & 0 deletions hail/hail/src/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import is.hail.backend._
import is.hail.backend.py4j.Py4JBackendExtensions
import is.hail.expr.Validate
import is.hail.expr.ir._
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.defs.MakeTuple
Expand Down Expand Up @@ -356,6 +357,7 @@ class SparkBackend(
private[this] val bmCache = mutable.Map.empty[String, BlockMatrix]
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
private[this] val persistedIr = mutable.Map.empty[Int, BaseIR]
private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32)

def createExecuteContextForTests(
timer: ExecutionTimer,
Expand All @@ -381,6 +383,7 @@ class SparkBackend(
ImmutableMap.empty,
ImmutableMap.empty,
ImmutableMap.empty,
ImmutableMap.empty,
)

override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T =
Expand All @@ -403,6 +406,7 @@ class SparkBackend(
bmCache,
codeCache,
persistedIr,
coercerCache,
)(f)
}

Expand Down
35 changes: 12 additions & 23 deletions hail/hail/src/is/hail/expr/ir/GenericTableValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package is.hail.expr.ir
import is.hail.annotations.Region
import is.hail.asm4s._
import is.hail.backend.ExecuteContext
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
import is.hail.expr.ir.defs.{Literal, PartitionReader, ReadPartition, ToStream}
import is.hail.expr.ir.functions.UtilFunctions
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
Expand Down Expand Up @@ -144,16 +145,6 @@ class PartitionIteratorLongReader(
)
}

abstract class LoweredTableReaderCoercer {
def coerce(
ctx: ExecuteContext,
globals: IR,
contextType: Type,
contexts: IndexedSeq[Any],
body: IR => IR,
): TableStage
}

class GenericTableValue(
val fullTableType: TableType,
val uidFieldName: String,
Expand All @@ -169,12 +160,11 @@ class GenericTableValue(
assert(contextType.hasField("partitionIndex"))
assert(contextType.fieldType("partitionIndex") == TInt32)

private var ltrCoercer: LoweredTableReaderCoercer = _

private def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any)
: LoweredTableReaderCoercer = {
if (ltrCoercer == null) {
ltrCoercer = LoweredTableReader.makeCoercer(
: LoweredTableReaderCoercer =
ctx.CoercerCache.getOrElseUpdate(
(1, contextType, fullTableType.key, cacheKey),
LoweredTableReader.makeCoercer(
ctx,
fullTableType.key,
1,
Expand All @@ -185,11 +175,8 @@ class GenericTableValue(
bodyPType,
body,
context,
cacheKey,
)
}
ltrCoercer
}
),
)

def toTableStage(ctx: ExecuteContext, requestedType: TableType, context: String, cacheKey: Any)
: TableStage = {
Expand Down Expand Up @@ -218,11 +205,13 @@ class GenericTableValue(
val contextsIR = ToStream(Literal(TArray(contextType), contexts))
TableStage(globalsIR, p, TableStageDependency.none, contextsIR, requestedBody)
} else {
getLTVCoercer(ctx, context, cacheKey).coerce(
getLTVCoercer(ctx, context, cacheKey)(
ctx,
globalsIR,
contextType, contexts,
requestedBody)
contextType,
contexts,
requestedBody,
)
}
}
}
Loading