From 001f93a6432a8b2072c3ccd398aded29e87888ab Mon Sep 17 00:00:00 2001 From: Patrick Schultz Date: Wed, 15 Nov 2023 15:14:39 -0500 Subject: [PATCH] [compiler] allow relational IR in ExtractIntervals (#14013) Fix bug where relational IR inside the condition of a TableFilter or MatrixFilter causes a ClassCastException. This can happen if, for example, there's a TableAggregate inside the condition. --- .../is/hail/expr/ir/ExtractIntervalFilters.scala | 5 ++++- .../hail/expr/ir/ExtractIntervalFiltersSuite.scala | 13 +++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala b/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala index 4619fe17287..ce317d8618b 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala @@ -866,7 +866,10 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { } res = if (res == null) { - val children = x.children.map(child => recur(child.asInstanceOf[IR])).toFastSeq + val children = x.children.map { + case child: IR => recur(child) + case _ => AbstractLattice.top + }.toFastSeq val keyOrConstVal = computeKeyOrConst(x, children) if (x.typ == TBoolean) { if (keyOrConstVal == AbstractLattice.top) diff --git a/hail/src/test/scala/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala b/hail/src/test/scala/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala index 81b01856209..9eb1c93bad7 100644 --- a/hail/src/test/scala/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala @@ -757,6 +757,19 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => ) } + @Test def testRelationalChildren(): Unit = { + val testRows = FastSeq( + Row(0, 0, true), + Row(0, 10, true), + Row(0, 20, true), + Row(0, null, true)) + + val count = TableAggregate(TableRange(10, 1), ApplyAggOp(FastSeq(), FastSeq(), AggSignature(Count(), FastSeq(), FastSeq()))) + print(count.typ) + val filter = gt(count, Cast(k1, TInt64)) + check(filter, ref1, k1Full, testRows, filter, FastSeq(Interval(Row(), Row(), true, true))) + } + @Test def testIntegration() { hc // force initialization val tab1 = TableRange(10, 5)