diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedPartitionFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedPartitionFilterSuite.scala index d49ff779e737..a3c2fbdc83d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedPartitionFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedPartitionFilterSuite.scala @@ -409,6 +409,150 @@ class DataSourceV2EnhancedPartitionFilterSuite } } + test("extract partition filter from translated OR with mixed partition and data references") { + withTable(partFilterTableName) { + sql(s"CREATE TABLE $partFilterTableName (part_col string, data string) USING $v2Source " + + "PARTITIONED BY (part_col)") + sql(s"INSERT INTO $partFilterTableName VALUES " + + "('a', 'x'), ('a', 'other'), ('b', 'y'), ('c', 'z')") + + val df = sql(s"SELECT * FROM $partFilterTableName WHERE " + + "(part_col = 'a' AND data = 'x') OR (part_col = 'b' AND data = 'y')") + checkAnswer(df, Seq(Row("a", "x"), Row("b", "y"))) + assertPushedPartitionPredicates(df, 1) + assertScanReturnsPartitionKeys(df, Set("a", "b")) + assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part_col")) + } + } + + test("extract partition filter from untranslatable OR with mixed partition and data references") { + withTable(partFilterTableName) { + sql(s"CREATE TABLE $partFilterTableName (part_col string, data string) USING $v2Source " + + "PARTITIONED BY (part_col)") + sql(s"INSERT INTO $partFilterTableName VALUES " + + "('a', 'x'), ('b', 'y'), ('c', 'z')") + + spark.udf.register("my_upper_extract", (s: String) => + if (s == null) null else s.toUpperCase(Locale.ROOT)) + + val df = sql(s"SELECT * FROM $partFilterTableName WHERE " + + "(my_upper_extract(part_col) = 'A' AND data = 'x') OR " + + "(my_upper_extract(part_col) = 'B' AND data = 'y')") + checkAnswer(df, Seq(Row("a", "x"), Row("b", "y"))) + assertPushedPartitionPredicates(df, 1) + assertScanReturnsPartitionKeys(df, Set("a", "b")) + assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part_col")) + } + } + + test("extract partition filter from OR with one partition-only and one mixed filter") { + withTable(partFilterTableName) { + sql(s"CREATE TABLE $partFilterTableName (part_col string, data string) USING $v2Source " + + "PARTITIONED BY (part_col)") + sql(s"INSERT INTO $partFilterTableName VALUES " + + "('a', 'x'), ('a', 'other'), ('b', 'y'), ('b', 'other'), ('c', 'z')") + + val df = sql(s"SELECT * FROM $partFilterTableName WHERE " + + "part_col = 'a' OR (part_col = 'b' AND data = 'y')") + checkAnswer(df, Seq(Row("a", "x"), Row("a", "other"), Row("b", "y"))) + assertPushedPartitionPredicates(df, 1) + assertScanReturnsPartitionKeys(df, Set("a", "b")) + assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part_col")) + } + } + + test("extract multi-column partition filter from OR with mixed partition and data references") { + withTable(partFilterTableName) { + sql(s"CREATE TABLE $partFilterTableName (p1 string, p2 string, data string) " + + s"USING $v2Source PARTITIONED BY (p1, p2)") + sql(s"INSERT INTO $partFilterTableName VALUES " + + "('a', 'x', 'd1'), ('a', 'y', 'd2'), ('b', 'x', 'd3'), ('b', 'y', 'd4')") + + val df = sql(s"SELECT * FROM $partFilterTableName WHERE " + + "(p1 = 'a' AND p2 = 'x' AND data = 'd1') OR (p1 = 'b' AND p2 = 'y' AND data = 'd4')") + checkAnswer(df, Seq(Row("a", "x", "d1"), Row("b", "y", "d4"))) + assertPushedPartitionPredicates(df, 1) + assertScanReturnsPartitionKeys(df, Set("a/x", "b/y")) + assertReferencedPartitionFieldOrdinals(df, Array(0, 1), Array("p1", "p2")) + } + } + + test("two partition predicates pushed: UDF on p1 and " + + "extracted filter on p2 from mixed data and partition references") { + withTable(partFilterTableName) { + sql(s"CREATE TABLE $partFilterTableName (p1 string, p2 string, data string) " + + s"USING $v2Source PARTITIONED BY (p1, p2)") + sql(s"INSERT INTO $partFilterTableName VALUES " + + "('a', 'x', 'd1'), " + + "('a', 'y', 'd4'), " + + "('b', 'x', 'd3'), " + + "('b', 'y', 'd4'), " + + "('c', 'z', 'd5')") + + spark.udf.register("my_upper_multi", (s: String) => + if (s == null) null else s.toUpperCase(Locale.ROOT)) + + // my_upper_multi(p1) = 'A' is untranslatable and partition-only, so it is a partition filter. + // The OR mixes p2 and data; we infer (p2 = 'x' OR p2 = 'y') as a partition filter. + // Both are pushed as separate PartitionPredicates. + val df = sql(s"SELECT * FROM $partFilterTableName WHERE " + + "my_upper_multi(p1) = 'A' AND " + + "((p2 = 'x' AND data = 'd1') OR (p2 = 'y' AND data = 'd4'))") + checkAnswer(df, Seq(Row("a", "x", "d1"), Row("a", "y", "d4"))) + assertPushedPartitionPredicates(df, 2) + assertScanReturnsPartitionKeys(df, Set("a/x", "a/y")) + } + } + + test("nested partition: extract partition filter from " + + "OR with mixed data and partition references") { + withTable(partFilterTableName) { + sql(s"CREATE TABLE $partFilterTableName " + + s"(s struct, data string) USING $v2Source " + + "PARTITIONED BY (s.tz)") + sql(s"INSERT INTO $partFilterTableName VALUES " + + "(named_struct('tz', 'LA', 'x', 1), 'a'), " + + "(named_struct('tz', 'NY', 'x', 2), 'b'), " + + "(named_struct('tz', 'SF', 'x', 3), 'c')") + + val df = sql(s"SELECT * FROM $partFilterTableName WHERE " + + "(s.tz = 'LA' AND data = 'a') OR (s.tz = 'NY' AND data = 'b')") + checkAnswer(df, Seq(Row(Row("LA", 1), "a"), Row(Row("NY", 2), "b"))) + assertPushedPartitionPredicates(df, 1) + assertScanReturnsPartitionKeys(df, Set("LA", "NY")) + assertReferencedPartitionFieldOrdinals(df, Array(0), Array("s.tz")) + } + } + + test("nested partition: two partition predicates from " + + "UDF and extracted mixed data and partition references") { + withTable(partFilterTableName) { + sql(s"CREATE TABLE $partFilterTableName " + + s"(s struct, data string) USING $v2Source " + + "PARTITIONED BY (s.tz)") + sql(s"INSERT INTO $partFilterTableName VALUES " + + "(named_struct('tz', 'LA', 'x', 1), 'a'), " + + "(named_struct('tz', 'la', 'x', 2), 'b'), " + + "(named_struct('tz', 'NY', 'x', 3), 'c'), " + + "(named_struct('tz', 'SF', 'x', 4), 'd')") + + spark.udf.register("my_upper_nested2", (s: String) => + if (s == null) null else s.toUpperCase(Locale.ROOT)) + + // my_upper_nested2(s.tz) = 'LA' is untranslatable and partition-only, + // it is a partition filter. + // The OR mixes s.tz and data; we infer (s.tz = 'LA' OR s.tz = 'la') as an partition filter. + // Both are pushed as separate PartitionPredicates. + val df = sql(s"SELECT * FROM $partFilterTableName WHERE " + + "my_upper_nested2(s.tz) = 'LA' AND " + + "((s.tz = 'LA' AND data = 'a') OR (s.tz = 'la' AND data = 'b'))") + checkAnswer(df, Seq(Row(Row("LA", 1), "a"), Row(Row("la", 2), "b"))) + assertPushedPartitionPredicates(df, 2) + assertScanReturnsPartitionKeys(df, Set("LA", "la")) + assertReferencedPartitionFieldOrdinals(df, Array(0), Array("s.tz")) + } + } + private def assertTranslatableBeforeUntranslatableInPostScan(df: DataFrame): Unit = { val postScanFilterExec = df.queryExecution.executedPlan.collect { case f @ FilterExec(_, _) if f.exists(_.isInstanceOf[BatchScanExec]) => f