Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,150 @@ class DataSourceV2EnhancedPartitionFilterSuite
}
}

test("extract partition filter from translated OR with mixed partition and data references") {
withTable(partFilterTableName) {
sql(s"CREATE TABLE $partFilterTableName (part_col string, data string) USING $v2Source " +
"PARTITIONED BY (part_col)")
sql(s"INSERT INTO $partFilterTableName VALUES " +
"('a', 'x'), ('a', 'other'), ('b', 'y'), ('c', 'z')")

val df = sql(s"SELECT * FROM $partFilterTableName WHERE " +
"(part_col = 'a' AND data = 'x') OR (part_col = 'b' AND data = 'y')")
checkAnswer(df, Seq(Row("a", "x"), Row("b", "y")))
assertPushedPartitionPredicates(df, 1)
assertScanReturnsPartitionKeys(df, Set("a", "b"))
assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part_col"))
}
}

test("extract partition filter from untranslatable OR with mixed partition and data references") {
withTable(partFilterTableName) {
sql(s"CREATE TABLE $partFilterTableName (part_col string, data string) USING $v2Source " +
"PARTITIONED BY (part_col)")
sql(s"INSERT INTO $partFilterTableName VALUES " +
"('a', 'x'), ('b', 'y'), ('c', 'z')")

spark.udf.register("my_upper_extract", (s: String) =>
if (s == null) null else s.toUpperCase(Locale.ROOT))

val df = sql(s"SELECT * FROM $partFilterTableName WHERE " +
"(my_upper_extract(part_col) = 'A' AND data = 'x') OR " +
"(my_upper_extract(part_col) = 'B' AND data = 'y')")
checkAnswer(df, Seq(Row("a", "x"), Row("b", "y")))
assertPushedPartitionPredicates(df, 1)
assertScanReturnsPartitionKeys(df, Set("a", "b"))
assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part_col"))
}
}

test("extract partition filter from OR with one partition-only and one mixed filter") {
withTable(partFilterTableName) {
sql(s"CREATE TABLE $partFilterTableName (part_col string, data string) USING $v2Source " +
"PARTITIONED BY (part_col)")
sql(s"INSERT INTO $partFilterTableName VALUES " +
"('a', 'x'), ('a', 'other'), ('b', 'y'), ('b', 'other'), ('c', 'z')")

val df = sql(s"SELECT * FROM $partFilterTableName WHERE " +
"part_col = 'a' OR (part_col = 'b' AND data = 'y')")
checkAnswer(df, Seq(Row("a", "x"), Row("a", "other"), Row("b", "y")))
assertPushedPartitionPredicates(df, 1)
assertScanReturnsPartitionKeys(df, Set("a", "b"))
assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part_col"))
}
}

test("extract multi-column partition filter from OR with mixed partition and data references") {
withTable(partFilterTableName) {
sql(s"CREATE TABLE $partFilterTableName (p1 string, p2 string, data string) " +
s"USING $v2Source PARTITIONED BY (p1, p2)")
sql(s"INSERT INTO $partFilterTableName VALUES " +
"('a', 'x', 'd1'), ('a', 'y', 'd2'), ('b', 'x', 'd3'), ('b', 'y', 'd4')")

val df = sql(s"SELECT * FROM $partFilterTableName WHERE " +
"(p1 = 'a' AND p2 = 'x' AND data = 'd1') OR (p1 = 'b' AND p2 = 'y' AND data = 'd4')")
checkAnswer(df, Seq(Row("a", "x", "d1"), Row("b", "y", "d4")))
assertPushedPartitionPredicates(df, 1)
assertScanReturnsPartitionKeys(df, Set("a/x", "b/y"))
assertReferencedPartitionFieldOrdinals(df, Array(0, 1), Array("p1", "p2"))
}
}

test("two partition predicates pushed: UDF on p1 and " +
"extracted filter on p2 from mixed data and partition references") {
withTable(partFilterTableName) {
sql(s"CREATE TABLE $partFilterTableName (p1 string, p2 string, data string) " +
s"USING $v2Source PARTITIONED BY (p1, p2)")
sql(s"INSERT INTO $partFilterTableName VALUES " +
"('a', 'x', 'd1'), " +
"('a', 'y', 'd4'), " +
"('b', 'x', 'd3'), " +
"('b', 'y', 'd4'), " +
"('c', 'z', 'd5')")

spark.udf.register("my_upper_multi", (s: String) =>
if (s == null) null else s.toUpperCase(Locale.ROOT))

// my_upper_multi(p1) = 'A' is untranslatable and partition-only, so it is a partition filter.
// The OR mixes p2 and data; we infer (p2 = 'x' OR p2 = 'y') as a partition filter.
// Both are pushed as separate PartitionPredicates.
val df = sql(s"SELECT * FROM $partFilterTableName WHERE " +
"my_upper_multi(p1) = 'A' AND " +
"((p2 = 'x' AND data = 'd1') OR (p2 = 'y' AND data = 'd4'))")
checkAnswer(df, Seq(Row("a", "x", "d1"), Row("a", "y", "d4")))
assertPushedPartitionPredicates(df, 2)
assertScanReturnsPartitionKeys(df, Set("a/x", "a/y"))
}
}

test("nested partition: extract partition filter from " +
"OR with mixed data and partition references") {
withTable(partFilterTableName) {
sql(s"CREATE TABLE $partFilterTableName " +
s"(s struct<tz: string, x: int>, data string) USING $v2Source " +
"PARTITIONED BY (s.tz)")
sql(s"INSERT INTO $partFilterTableName VALUES " +
"(named_struct('tz', 'LA', 'x', 1), 'a'), " +
"(named_struct('tz', 'NY', 'x', 2), 'b'), " +
"(named_struct('tz', 'SF', 'x', 3), 'c')")

val df = sql(s"SELECT * FROM $partFilterTableName WHERE " +
"(s.tz = 'LA' AND data = 'a') OR (s.tz = 'NY' AND data = 'b')")
checkAnswer(df, Seq(Row(Row("LA", 1), "a"), Row(Row("NY", 2), "b")))
assertPushedPartitionPredicates(df, 1)
assertScanReturnsPartitionKeys(df, Set("LA", "NY"))
assertReferencedPartitionFieldOrdinals(df, Array(0), Array("s.tz"))
}
}

test("nested partition: two partition predicates from " +
"UDF and extracted mixed data and partition references") {
withTable(partFilterTableName) {
sql(s"CREATE TABLE $partFilterTableName " +
s"(s struct<tz: string, x: int>, data string) USING $v2Source " +
"PARTITIONED BY (s.tz)")
sql(s"INSERT INTO $partFilterTableName VALUES " +
"(named_struct('tz', 'LA', 'x', 1), 'a'), " +
"(named_struct('tz', 'la', 'x', 2), 'b'), " +
"(named_struct('tz', 'NY', 'x', 3), 'c'), " +
"(named_struct('tz', 'SF', 'x', 4), 'd')")

spark.udf.register("my_upper_nested2", (s: String) =>
if (s == null) null else s.toUpperCase(Locale.ROOT))

// my_upper_nested2(s.tz) = 'LA' is untranslatable and partition-only,
// it is a partition filter.
// The OR mixes s.tz and data; we infer (s.tz = 'LA' OR s.tz = 'la') as an partition filter.
// Both are pushed as separate PartitionPredicates.
val df = sql(s"SELECT * FROM $partFilterTableName WHERE " +
"my_upper_nested2(s.tz) = 'LA' AND " +
"((s.tz = 'LA' AND data = 'a') OR (s.tz = 'la' AND data = 'b'))")
checkAnswer(df, Seq(Row(Row("LA", 1), "a"), Row(Row("la", 2), "b")))
assertPushedPartitionPredicates(df, 2)
assertScanReturnsPartitionKeys(df, Set("LA", "la"))
assertReferencedPartitionFieldOrdinals(df, Array(0), Array("s.tz"))
}
}

private def assertTranslatableBeforeUntranslatableInPostScan(df: DataFrame): Unit = {
val postScanFilterExec = df.queryExecution.executedPlan.collect {
case f @ FilterExec(_, _) if f.exists(_.isInstanceOf[BatchScanExec]) => f
Expand Down