Skip to content

Commit

Permalink
[query] Lowering + Optimisation with implict timing context
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Oct 16, 2024
1 parent 2f05ecd commit fb23e65
Show file tree
Hide file tree
Showing 22 changed files with 713 additions and 714 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache
Validate(ir)
val queryID = Backend.nextID()
log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}")
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
val res = _jvmLowerAndExecute(ctx, ir)
log.info(s"finished execution of query $queryID")
res
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ class ServiceBackend(
Validate(ir)
val queryID = Backend.nextID()
log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}")
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
val res = _jvmLowerAndExecute(ctx, ir)
log.info(s"finished execution of query $queryID")
res
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ class SparkBackend(
ctx.time {
TypeCheck(ctx, ir)
Validate(ir)
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
try {
val lowerTable = getFlag("lower") != null
val lowerBM = getFlag("lower_bm") != null
Expand Down
10 changes: 5 additions & 5 deletions hail/src/main/scala/is/hail/expr/ir/BaseIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ abstract class BaseIR {
// New sentinel values can be obtained by `nextFlag` on `IRMetadata`.
var mark: Int = 0

def isAlphaEquiv(ctx: ExecuteContext, other: BaseIR): Boolean =
/* FIXME: rewrite to not rebuild the irs, by maintaining an env mapping left names to right
* names */
NormalizeNames(ctx, this, allowFreeVariables = true) ==
NormalizeNames(ctx, other, allowFreeVariables = true)
def isAlphaEquiv(ctx: ExecuteContext, other: BaseIR): Boolean = {
// FIXME: rewrite to not rebuild the irs by maintaining an env mapping left to right names
val normalize: (ExecuteContext, BaseIR) => BaseIR = NormalizeNames(allowFreeVariables = true)
normalize(ctx, this) == normalize(ctx, other)
}

def mapChildrenWithIndex(f: (BaseIR, Int) => BaseIR): BaseIR = {
val newChildren = childrenSeq.view.zipWithIndex.map(f.tupled).toArray
Expand Down
5 changes: 2 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/Compile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ object Compile {
print: Option[PrintWriter] = None,
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F) =
ctx.time {
val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true)
val normalizedBody = NormalizeNames(allowFreeVariables = true)(ctx, body)
val k =
CodeCacheKey(FastSeq[AggStateSig](), params.map { case (n, pt) => (n, pt) }, normalizedBody)
(ctx.backend.lookupOrCompileCachedFunction[F](k) {
Expand Down Expand Up @@ -108,8 +108,7 @@ object CompileWithAggregators {
(HailClassLoader, FS, HailTaskContext, Region) => (F with FunctionWithAggRegion),
) =
ctx.time {
val normalizedBody =
NormalizeNames(ctx, body, allowFreeVariables = true)
val normalizedBody = NormalizeNames(allowFreeVariables = true)(ctx, body)
val k = CodeCacheKey(aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody)
(ctx.backend.lookupOrCompileCachedFunction[F with FunctionWithAggRegion](k) {

Expand Down
73 changes: 37 additions & 36 deletions hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,52 +24,53 @@ object ExtractIntervalFilters {

val MAX_LITERAL_SIZE = 4096

def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR = {
MapIR.mapBaseIR(
ir0,
(ir: BaseIR) => {
(
ir match {
case TableFilter(child, pred) =>
extractPartitionFilters(
ctx,
pred,
Ref(TableIR.rowName, child.typ.rowType),
child.typ.key,
)
.map { case (newCond, intervals) =>
def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR =
ctx.time {
MapIR.mapBaseIR(
ir0,
ir =>
(
ir match {
case TableFilter(child, pred) =>
extractPartitionFilters(
ctx,
pred,
Ref(TableIR.rowName, child.typ.rowType),
child.typ.key,
)
.map { case (newCond, intervals) =>
log.info(
s"generated TableFilterIntervals node with ${intervals.length} intervals:\n " +
s"Intervals: ${intervals.mkString(", ")}\n " +
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
)
TableFilter(TableFilterIntervals(child, intervals, keep = true), newCond)
}
case MatrixFilterRows(child, pred) =>
extractPartitionFilters(
ctx,
pred,
Ref(MatrixIR.rowName, child.typ.rowType),
child.typ.rowKey,
).map { case (newCond, intervals) =>
log.info(
s"generated TableFilterIntervals node with ${intervals.length} intervals:\n " +
s"generated MatrixFilterIntervals node with ${intervals.length} intervals:\n " +
s"Intervals: ${intervals.mkString(", ")}\n " +
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
)
TableFilter(TableFilterIntervals(child, intervals, keep = true), newCond)
MatrixFilterRows(MatrixFilterIntervals(child, intervals, keep = true), newCond)
}
case MatrixFilterRows(child, pred) => extractPartitionFilters(
ctx,
pred,
Ref(MatrixIR.rowName, child.typ.rowType),
child.typ.rowKey,
).map { case (newCond, intervals) =>
log.info(
s"generated MatrixFilterIntervals node with ${intervals.length} intervals:\n " +
s"Intervals: ${intervals.mkString(", ")}\n " +
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
)
MatrixFilterRows(MatrixFilterIntervals(child, intervals, keep = true), newCond)
}

case _ => None
}
).getOrElse(ir)
},
)
}
case _ => None
}
).getOrElse(ir),
)
}

def extractPartitionFilters(ctx: ExecuteContext, cond: IR, ref: Ref, key: IndexedSeq[String])
: Option[(IR, IndexedSeq[Interval])] = {
if (key.isEmpty) None
else {
else ctx.time {
val extract =
new ExtractIntervalFilters(ctx, ref.typ.asInstanceOf[TStruct].typeAfterSelectNames(key))
val trueSet = extract.analyze(cond, ref.name)
Expand Down
4 changes: 3 additions & 1 deletion hail/src/main/scala/is/hail/expr/ir/FoldConstants.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import is.hail.utils.HailException

object FoldConstants {
def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR =
ctx.r.pool.scopedRegion(region => ctx.local(r = region)(foldConstants(_, ir)))
ctx.time {
ctx.r.pool.scopedRegion(r => ctx.local(r = r)(foldConstants(_, ir)))
}

private def foldConstants(ctx: ExecuteContext, ir: BaseIR): BaseIR =
RewriteBottomUp(
Expand Down
95 changes: 53 additions & 42 deletions hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@ package is.hail.expr.ir

import is.hail.backend.ExecuteContext
import is.hail.types.virtual.TVoid
import is.hail.utils.BoxedArrayBuilder
import is.hail.utils.{fatal, BoxedArrayBuilder}

import scala.collection.Set
import scala.util.control.NonFatal

object ForwardLets {
def apply[T <: BaseIR](ctx: ExecuteContext)(ir0: T): T = {
val ir1 = NormalizeNames(ctx, ir0, allowFreeVariables = true)
val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false)
val nestingDepth = NestingDepth(ir1)

def rewrite(ir: BaseIR, env: BindingEnv[IR]): BaseIR = {
def apply[T <: BaseIR](ctx: ExecuteContext, ir0: T): T =
ctx.time {
val ir1 = NormalizeNames(allowFreeVariables = true)(ctx, ir0)
val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false)
val nestingDepth = NestingDepth(ctx, ir1)

def shouldForward(value: IR, refs: Set[RefEquality[BaseRef]], base: Block, scope: Int)
: Boolean = {
: Boolean =
IsPure(value) && (
value.isInstanceOf[Ref] ||
value.isInstanceOf[In] ||
Expand All @@ -27,45 +27,56 @@ object ForwardLets {
!ContainsAgg(value)) &&
!ContainsAggIntermediate(value)
)
}

ir match {
case l: Block =>
val keep = new BoxedArrayBuilder[Binding]
val refs = uses(l)
val newEnv = l.bindings.foldLeft(env) {
case (env, Binding(name, value, scope)) =>
val rewriteValue = rewrite(value, env.promoteScope(scope)).asInstanceOf[IR]
if (
rewriteValue.typ != TVoid
&& shouldForward(rewriteValue, refs.filter(_.t.name == name), l, scope)
) {
env.bindInScope(name, rewriteValue, scope)
} else {
keep += Binding(name, rewriteValue, scope)
env
def rewrite(ir: BaseIR, env: BindingEnv[IR]): BaseIR =
ir match {
case l: Block =>
val keep = new BoxedArrayBuilder[Binding]
val refs = uses(l)
val newEnv = l.bindings.foldLeft(env) {
case (env, Binding(name, value, scope)) =>
val rewriteValue = rewrite(value, env.promoteScope(scope)).asInstanceOf[IR]
if (
rewriteValue.typ != TVoid
&& shouldForward(rewriteValue, refs.filter(_.t.name == name), l, scope)
) {
env.bindInScope(name, rewriteValue, scope)
} else {
keep += Binding(name, rewriteValue, scope)
env
}
}

val newBody = rewrite(l.body, newEnv).asInstanceOf[IR]
if (keep.isEmpty) newBody
else Block(keep.result(), newBody)

case x @ Ref(name, _) =>
env.eval
.lookupOption(name)
.map { forwarded =>
if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy()
else forwarded
}
}
.getOrElse(x)
case _ =>
ir.mapChildrenWithIndex((ir1, i) =>
rewrite(ir1, env.extend(Bindings.get(ir, i).dropBindings))
)
}

val newBody = rewrite(l.body, newEnv).asInstanceOf[IR]
if (keep.isEmpty) newBody
else Block(keep.result(), newBody)
val ir = rewrite(ir1, BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty)))

case x @ Ref(name, _) =>
env.eval
.lookupOption(name)
.map { forwarded =>
if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy()
else forwarded
}
.getOrElse(x)
case _ =>
ir.mapChildrenWithIndex((ir1, i) =>
rewrite(ir1, env.extend(Bindings.get(ir, i).dropBindings))
try
TypeCheck(ctx, ir)
catch {
case NonFatal(e) =>
fatal(
s"bad ir from ForwardLets, started as\n${Pretty(ctx, ir0, preserveNames = true)}",
e,
)
}
}

rewrite(ir1, BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty))).asInstanceOf[T]
}
ir.asInstanceOf[T]
}
}
Loading

0 comments on commit fb23e65

Please sign in to comment.