Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Lowering + Optimisation with implict timing context #14731

Open
wants to merge 1 commit into
base: ehigham/stateless-backend
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ object LocalBackend extends Backend {
Validate(ir)
val queryID = Backend.nextID()
log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}")
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
val res = _jvmLowerAndExecute(ctx, ir)
log.info(s"finished execution of query $queryID")
res
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ class ServiceBackend(
Validate(ir)
val queryID = Backend.nextID()
log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}")
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
val res = _jvmLowerAndExecute(ctx, ir)
log.info(s"finished execution of query $queryID")
res
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ class SparkBackend(val sc: SparkContext) extends Backend {
ctx.time {
TypeCheck(ctx, ir)
Validate(ir)
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
try {
val lowerTable = ctx.flags.get("lower") != null
val lowerBM = ctx.flags.get("lower_bm") != null
Expand Down
10 changes: 5 additions & 5 deletions hail/hail/src/is/hail/expr/ir/BaseIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ abstract class BaseIR {
// New sentinel values can be obtained by `nextFlag` on `IRMetadata`.
var mark: Int = 0

def isAlphaEquiv(ctx: ExecuteContext, other: BaseIR): Boolean =
/* FIXME: rewrite to not rebuild the irs, by maintaining an env mapping left names to right
* names */
NormalizeNames(ctx, this, allowFreeVariables = true) ==
NormalizeNames(ctx, other, allowFreeVariables = true)
def isAlphaEquiv(ctx: ExecuteContext, other: BaseIR): Boolean = {
// FIXME: rewrite to not rebuild the irs by maintaining an env mapping left to right names
val normalize: (ExecuteContext, BaseIR) => BaseIR = NormalizeNames(allowFreeVariables = true)
normalize(ctx, this) == normalize(ctx, other)
}

def mapChildrenWithIndex(f: (BaseIR, Int) => BaseIR): BaseIR = {
val newChildren = childrenSeq.view.zipWithIndex.map(f.tupled).toArray
Expand Down
2 changes: 0 additions & 2 deletions hail/hail/src/is/hail/expr/ir/Compilable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ object InterpretableButNotCompilable {
case _: MatrixToValueApply => true
case _: BlockMatrixToValueApply => true
case _: BlockMatrixCollect => true
case _: BlockMatrixToTableApply => true
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fruitless test

case _ => false
}
}
Expand All @@ -44,7 +43,6 @@ object Compilable {
case _: TableToValueApply => false
case _: MatrixToValueApply => false
case _: BlockMatrixToValueApply => false
case _: BlockMatrixToTableApply => false
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fruitless test

case _: RelationalRef => false
case _: RelationalLet => false
case _ => true
Expand Down
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/expr/ir/Compile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ object compile {
N: sourcecode.Name,
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F with Mixin) =
ctx.time {
val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true)
val normalizedBody = NormalizeNames(allowFreeVariables = true)(ctx, body)
ctx.CodeCache.getOrElseUpdate(
CodeCacheKey(aggSigs.getOrElse(Array.empty).toFastSeq, params, normalizedBody), {
var ir = Subst(
Expand Down
73 changes: 37 additions & 36 deletions hail/hail/src/is/hail/expr/ir/ExtractIntervalFilters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,52 +25,53 @@ object ExtractIntervalFilters {

val MAX_LITERAL_SIZE = 4096

def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR = {
MapIR.mapBaseIR(
ir0,
(ir: BaseIR) => {
(
ir match {
case TableFilter(child, pred) =>
extractPartitionFilters(
ctx,
pred,
Ref(TableIR.rowName, child.typ.rowType),
child.typ.key,
)
.map { case (newCond, intervals) =>
def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR =
ctx.time {
MapIR.mapBaseIR(
ir0,
ir =>
(
ir match {
case TableFilter(child, pred) =>
extractPartitionFilters(
ctx,
pred,
Ref(TableIR.rowName, child.typ.rowType),
child.typ.key,
)
.map { case (newCond, intervals) =>
log.info(
s"generated TableFilterIntervals node with ${intervals.length} intervals:\n " +
s"Intervals: ${intervals.mkString(", ")}\n " +
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
)
TableFilter(TableFilterIntervals(child, intervals, keep = true), newCond)
}
case MatrixFilterRows(child, pred) =>
extractPartitionFilters(
ctx,
pred,
Ref(MatrixIR.rowName, child.typ.rowType),
child.typ.rowKey,
).map { case (newCond, intervals) =>
log.info(
s"generated TableFilterIntervals node with ${intervals.length} intervals:\n " +
s"generated MatrixFilterIntervals node with ${intervals.length} intervals:\n " +
s"Intervals: ${intervals.mkString(", ")}\n " +
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
)
TableFilter(TableFilterIntervals(child, intervals, keep = true), newCond)
MatrixFilterRows(MatrixFilterIntervals(child, intervals, keep = true), newCond)
}
case MatrixFilterRows(child, pred) => extractPartitionFilters(
ctx,
pred,
Ref(MatrixIR.rowName, child.typ.rowType),
child.typ.rowKey,
).map { case (newCond, intervals) =>
log.info(
s"generated MatrixFilterIntervals node with ${intervals.length} intervals:\n " +
s"Intervals: ${intervals.mkString(", ")}\n " +
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
)
MatrixFilterRows(MatrixFilterIntervals(child, intervals, keep = true), newCond)
}

case _ => None
}
).getOrElse(ir)
},
)
}
case _ => None
}
).getOrElse(ir),
)
}

def extractPartitionFilters(ctx: ExecuteContext, cond: IR, ref: Ref, key: IndexedSeq[String])
: Option[(IR, IndexedSeq[Interval])] = {
if (key.isEmpty) None
else {
else ctx.time {
val extract =
new ExtractIntervalFilters(ctx, ref.typ.asInstanceOf[TStruct].typeAfterSelectNames(key))
val trueSet = extract.analyze(cond, ref.name)
Expand Down
4 changes: 3 additions & 1 deletion hail/hail/src/is/hail/expr/ir/FoldConstants.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import is.hail.utils.HailException

object FoldConstants {
def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR =
ctx.r.pool.scopedRegion(region => ctx.local(r = region)(foldConstants(_, ir)))
ctx.time {
ctx.r.pool.scopedRegion(r => ctx.local(r = r)(foldConstants(_, ir)))
}

private def foldConstants(ctx: ExecuteContext, ir: BaseIR): BaseIR =
RewriteBottomUp(
Expand Down
95 changes: 53 additions & 42 deletions hail/hail/src/is/hail/expr/ir/ForwardLets.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@ package is.hail.expr.ir
import is.hail.backend.ExecuteContext
import is.hail.expr.ir.defs.{BaseRef, Binding, Block, In, Ref, Str}
import is.hail.types.virtual.TVoid
import is.hail.utils.BoxedArrayBuilder
import is.hail.utils.{fatal, BoxedArrayBuilder}

import scala.collection.Set
import scala.util.control.NonFatal

object ForwardLets {
def apply[T <: BaseIR](ctx: ExecuteContext)(ir0: T): T = {
val ir1 = NormalizeNames(ctx, ir0, allowFreeVariables = true)
val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false)
val nestingDepth = NestingDepth(ir1)

def rewrite(ir: BaseIR, env: BindingEnv[IR]): BaseIR = {
def apply[T <: BaseIR](ctx: ExecuteContext, ir0: T): T =
ctx.time {
val ir1 = NormalizeNames(allowFreeVariables = true)(ctx, ir0)
val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false)
val nestingDepth = NestingDepth(ctx, ir1)

def shouldForward(value: IR, refs: Set[RefEquality[BaseRef]], base: Block, scope: Int)
: Boolean = {
: Boolean =
IsPure(value) && (
value.isInstanceOf[Ref] ||
value.isInstanceOf[In] ||
Expand All @@ -28,45 +28,56 @@ object ForwardLets {
!ContainsAgg(value)) &&
!ContainsAggIntermediate(value)
)
}

ir match {
case l: Block =>
val keep = new BoxedArrayBuilder[Binding]
val refs = uses(l)
val newEnv = l.bindings.foldLeft(env) {
case (env, Binding(name, value, scope)) =>
val rewriteValue = rewrite(value, env.promoteScope(scope)).asInstanceOf[IR]
if (
rewriteValue.typ != TVoid
&& shouldForward(rewriteValue, refs.filter(_.t.name == name), l, scope)
) {
env.bindInScope(name, rewriteValue, scope)
} else {
keep += Binding(name, rewriteValue, scope)
env
def rewrite(ir: BaseIR, env: BindingEnv[IR]): BaseIR =
ir match {
case l: Block =>
val keep = new BoxedArrayBuilder[Binding]
val refs = uses(l)
val newEnv = l.bindings.foldLeft(env) {
case (env, Binding(name, value, scope)) =>
val rewriteValue = rewrite(value, env.promoteScope(scope)).asInstanceOf[IR]
if (
rewriteValue.typ != TVoid
&& shouldForward(rewriteValue, refs.filter(_.t.name == name), l, scope)
) {
env.bindInScope(name, rewriteValue, scope)
} else {
keep += Binding(name, rewriteValue, scope)
env
}
}

val newBody = rewrite(l.body, newEnv).asInstanceOf[IR]
if (keep.isEmpty) newBody
else Block(keep.result(), newBody)

case x @ Ref(name, _) =>
env.eval
.lookupOption(name)
.map { forwarded =>
if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy()
else forwarded
}
}
.getOrElse(x)
case _ =>
ir.mapChildrenWithIndex((ir1, i) =>
rewrite(ir1, env.extend(Bindings.get(ir, i).dropBindings))
)
}

val newBody = rewrite(l.body, newEnv).asInstanceOf[IR]
if (keep.isEmpty) newBody
else Block(keep.result(), newBody)
val ir = rewrite(ir1, BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty)))

case x @ Ref(name, _) =>
env.eval
.lookupOption(name)
.map { forwarded =>
if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy()
else forwarded
}
.getOrElse(x)
case _ =>
ir.mapChildrenWithIndex((ir1, i) =>
rewrite(ir1, env.extend(Bindings.get(ir, i).dropBindings))
try
TypeCheck(ctx, ir)
catch {
case NonFatal(e) =>
fatal(
s"bad ir from ForwardLets, started as\n${Pretty(ctx, ir0, preserveNames = true)}",
e,
)
}
}

rewrite(ir1, BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty))).asInstanceOf[T]
}
ir.asInstanceOf[T]
}
}
Loading