diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 6a35ad3e9ecab..d3c9cf3b402e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -157,13 +157,22 @@ case class DataSourceV2Relation( * @param keyGroupedPartitioning if set, the partitioning expressions that are used to split the * rows in the scan across different partitions * @param ordering if set, the ordering provided by the scan + * @param pushedFilters Catalyst expressions for filters that were fully pushed to the data + * source and do not appear as post-scan filters */ case class DataSourceV2ScanRelation( relation: DataSourceV2Relation, scan: Scan, output: Seq[AttributeReference], keyGroupedPartitioning: Option[Seq[Expression]] = None, - ordering: Option[Seq[SortOrder]] = None) extends LeafNode with NamedRelation { + ordering: Option[Seq[SortOrder]] = None, + pushedFilters: Seq[Expression] = Seq.empty) extends LeafNode with NamedRelation { + + // TODO: Override validConstraints to return ExpressionSet(pushedFilters) so that pushed + // filters participate in constraint propagation (InferFiltersFromConstraints, PruneFilters). + // This changes which filters InferFiltersFromConstraints adds or removes (e.g., it may + // skip adding IsNotNull when the scan already implies it, or infer new filters across + // joins), so plan stability testing is needed first. override def name: String = relation.name @@ -197,7 +206,8 @@ case class DataSourceV2ScanRelation( ), ordering = ordering.map( _.map(o => o.copy(child = QueryPlan.normalizeExpressions(o.child, output))) - ) + ), + pushedFilters = pushedFilters.map(QueryPlan.normalizeExpressions(_, output)) ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 4a4ccab47cad0..5fbb40934e1c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import org.apache.spark.SparkException import org.apache.spark.internal.LogKeys.{AGGREGATE_FUNCTIONS, COLUMN_NAMES, GROUP_BY_EXPRS, JOIN_CONDITION, JOIN_TYPE, POST_SCAN_FILTERS, PUSHED_FILTERS, RELATION_NAME, RELATION_OUTPUT} -import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, ExprId, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, ExpressionSet, ExprId, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.planning.{PhysicalOperation, ScanOperation} @@ -95,6 +95,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery + // Compute the pushed filter expressions: the normalized filters that were fully pushed + // down (i.e., not in postScanFilters). These are stored on the scan relation for + // potential future use in constraint propagation. + val postScanFilterSet = ExpressionSet(postScanFiltersWithoutSubquery) + sHolder.pushedFilterExpressions = normalizedFiltersWithoutSubquery + .filterNot(postScanFilterSet.contains) + .filter(_.deterministic) + logInfo( log""" |Pushing operators to ${MDC(RELATION_NAME, sHolder.relation.name)} @@ -698,6 +706,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { assert(realOutput.length == holder.output.length, "The data source returns unexpected number of columns") val wrappedScan = getWrappedScan(scan, holder) + // Note: holder.pushedFilterExpressions is not propagated here because the output schema + // changes to aggregate columns. When validConstraints is wired up, this needs revisiting. val scanRelation = DataSourceV2ScanRelation(holder.relation, wrappedScan, realOutput) val projectList = realOutput.zip(holder.output).map { case (a1, a2) => // The data source may return columns with arbitrary data types and it's safer to cast them @@ -715,6 +725,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { assert(realOutput.length == holder.output.length, "The data source returns unexpected number of columns") val wrappedScan = getWrappedScan(scan, holder) + // Note: holder.pushedFilterExpressions is not propagated here because the output schema + // changes with pushed join. When validConstraints is wired up, this needs revisiting. val scanRelation = DataSourceV2ScanRelation(holder.relation, wrappedScan, realOutput) // When join is pushed down, the real output is going to be, for example, @@ -737,6 +749,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { val scan = holder.builder.build() val realOutput = toAttributes(scan.readSchema()) val wrappedScan = getWrappedScan(scan, holder) + // Note: holder.pushedFilterExpressions is not propagated here because the output schema + // changes with variant extraction. When validConstraints is wired up, this needs revisiting. val scanRelation = DataSourceV2ScanRelation(holder.relation, wrappedScan, realOutput) // Create projection to map real output to expected output (with transformed types) @@ -787,14 +801,19 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { val wrappedScan = getWrappedScan(scan, sHolder) - val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) - val projectionOverSchema = ProjectionOverSchema(output.toStructType, AttributeSet(output)) val projectionFunc = (expr: Expression) => expr transformDown { case projectionOverSchema(newExpr) => newExpr } + // Remap pushed filter attributes to the pruned output schema and drop filters + // whose references are no longer in the pruned output. + val remappedPushedFilters = sHolder.pushedFilterExpressions.map(projectionFunc) + .filter(_.references.subsetOf(AttributeSet(output))) + val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output, + pushedFilters = remappedPushedFilters) + val finalFilters = normalizedFilters.map(projectionFunc) // bottom-most filters are put in the left of the list. val withFilter = finalFilters.foldLeft[LogicalPlan](scanRelation)((plan, cond) => { @@ -1018,6 +1037,8 @@ case class ScanBuilderHolder( var pushedVariantAttributeMap: Map[ExprId, AttributeReference] = Map.empty var pushedVariants: Option[VariantInRelation] = None + + var pushedFilterExpressions: Seq[Expression] = Seq.empty } // A wrapper for v1 scan to carry the translated filters and the handled ones, along with diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index a09b7e0827c49..6ea1ea3faa0ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -28,7 +28,7 @@ import test.org.apache.spark.sql.connector._ import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.ScalarSubquery +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterThan => CatalystGreaterThan, Literal => CatalystLiteral, ScalarSubquery} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} import org.apache.spark.sql.connector.catalog.{PartitionInternalRow, SupportsRead, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ @@ -48,7 +48,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.ArrayImplicits._ @@ -1158,6 +1158,196 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS checkAnswer(query, Row(4, 1)) } } + + private def getScanRelation(query: DataFrame): DataSourceV2ScanRelation = { + query.queryExecution.optimizedPlan.collect { + case s: DataSourceV2ScanRelation => s + }.head + } + + test("pushedFilters are set for fully pushed filters") { + val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() + // AdvancedDataSourceV2 only supports pushing GreaterThan on column "i". + // i > 3 matches, so it is fully pushed. + val q = df.filter($"i" > 3) + checkAnswer(q, (4 until 10).map(i => Row(i, -i))) + + val scanRelation = getScanRelation(q) + assert(scanRelation.pushedFilters.nonEmpty, + "pushedFilters should be non-empty when filters are fully pushed") + // The pushed filter should reference column i + assert(scanRelation.pushedFilters.flatMap(_.references.map(_.name)).contains("i")) + } + + test("pushedFilters are empty when no filters are pushed") { + val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() + // AdvancedDataSourceV2 only supports pushing GreaterThan on column "i". + // j < -10 does not match, so it is not pushed. + val q = df.filter($"j" < -10) + checkAnswer(q, Nil) + + val scanRelation = getScanRelation(q) + assert(scanRelation.pushedFilters.isEmpty, + "pushedFilters should be empty when no filters are pushed") + } + + test("pushedFilters are empty when no filter is present") { + val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() + val q = df.select($"i", $"j") + checkAnswer(q, (0 until 10).map(i => Row(i, -i))) + + val scanRelation = getScanRelation(q) + assert(scanRelation.pushedFilters.isEmpty, + "pushedFilters should be empty when there is no filter") + } + + test("pushedFilters contains only pushed filters in mixed case") { + val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() + // AdvancedDataSourceV2 only supports pushing GreaterThan on column "i". + // i > 3 matches so it is pushed; j < 0 does not match so it is not pushed. + val q = df.filter($"i" > 3 && $"j" < 0) + checkAnswer(q, (4 until 10).map(i => Row(i, -i))) + + val scanRelation = getScanRelation(q) + // Only i > 3 should be in pushedFilters, not j < 0 + assert(scanRelation.pushedFilters.nonEmpty, + "pushedFilters should be non-empty for the pushed portion") + val referencedCols = scanRelation.pushedFilters.flatMap(_.references.map(_.name)).toSet + assert(referencedCols.contains("i"), + "pushedFilters should contain the pushed filter on column i") + assert(!referencedCols.contains("j"), + "pushedFilters should not contain the unsupported filter on column j") + } + + test("pushedFilters does not include filters that remain as post-scan") { + val df = spark.read.format(classOf[OverlappingFilterDataSourceV2].getName).load() + // OverlappingFilterDataSourceV2 only evaluates GreaterThan on column "i". + // i > 3 is fully pushed; j < 0 overlaps (reported by pushedFilters() but also post-scan). + val q = df.filter($"i" > 3 && $"j" < 0) + checkAnswer(q, (4 until 10).map(i => Row(i, -i))) + + val scanRelation = getScanRelation(q) + // Only i > 3 should be in pushedFilters (fully pushed, not in post-scan). + // j < 0 is in post-scan so it should be excluded from pushedFilters despite being + // reported by the scan's pushedFilters() API. + assert(scanRelation.pushedFilters.nonEmpty, + "pushedFilters should contain the fully-pushed filter") + val referencedCols = scanRelation.pushedFilters.flatMap(_.references.map(_.name)).toSet + assert(referencedCols.contains("i"), + "pushedFilters should contain the fully-pushed filter on column i") + assert(!referencedCols.contains("j"), + "pushedFilters should not contain the overlapping filter on column j") + } + + test("pushedFilters with V2 filter API") { + val df = spark.read.format(classOf[AdvancedDataSourceV2WithV2Filter].getName).load() + // i > 3 is fully pushed via V2 filter API + val q = df.filter($"i" > 3) + checkAnswer(q, (4 until 10).map(i => Row(i, -i))) + + val scanRelation = getScanRelation(q) + assert(scanRelation.pushedFilters.nonEmpty, + "pushedFilters should be non-empty with V2 filter API") + assert(scanRelation.pushedFilters.flatMap(_.references.map(_.name)).contains("i")) + } + + test("pushedFilters with V2 filter API contains only pushed filters in mixed case") { + val df = spark.read.format(classOf[AdvancedDataSourceV2WithV2Filter].getName).load() + // AdvancedScanBuilderWithV2Filter only supports pushing ">" predicates. + // i > 3 has predicate name ">" so it is pushed; j < 0 has predicate name "<" so it is not. + val q = df.filter($"i" > 3 && $"j" < 0) + checkAnswer(q, (4 until 10).map(i => Row(i, -i))) + + val scanRelation = getScanRelation(q) + assert(scanRelation.pushedFilters.nonEmpty, + "pushedFilters should be non-empty for the pushed portion") + val referencedCols = scanRelation.pushedFilters.flatMap(_.references.map(_.name)).toSet + assert(referencedCols.contains("i"), + "pushedFilters should contain the pushed filter on column i") + assert(!referencedCols.contains("j"), + "pushedFilters should not contain the unsupported filter on column j") + } + + test("pushedFilters are remapped by ProjectionOverSchema after nested schema pruning") { + val df = spark.read.format(classOf[NestedSchemaDataSourceV2].getName).load() + // NestedSchemaScanBuilder pushes GreaterThan on "s.a". + // Selecting only s.a triggers nested schema pruning: s goes from struct to struct. + val q = df.select($"s.a").filter($"s.a" > 3) + checkAnswer(q, (4 until 10).map(i => Row(i))) + + val scanRelation = getScanRelation(q) + assert(scanRelation.pushedFilters.nonEmpty, + "pushedFilters should be non-empty") + // Find the struct attribute referenced by the pushed filter. + // Before remapping it would have type struct; after remapping, struct. + val structAttrs = scanRelation.pushedFilters + .flatMap(_.collect { case a: AttributeReference if a.name == "s" => a }) + assert(structAttrs.nonEmpty, "pushed filter should reference struct column s") + val prunedStructType = structAttrs.head.dataType.asInstanceOf[StructType] + assert(prunedStructType.fieldNames.toSeq == Seq("a"), + s"struct column in pushed filter should be pruned to struct but was $prunedStructType") + } + + test("scan canonicalization with pushedFilters") { + // Use SimpleDataSourceV2 whose scan implements equals, so canonicalization comparison works + val table = new SimpleDataSourceV2().getTable(CaseInsensitiveStringMap.empty()) + + val relation1 = DataSourceV2Relation.create( + table, None, None, CaseInsensitiveStringMap.empty()) + val relation2 = DataSourceV2Relation.create( + table, None, None, CaseInsensitiveStringMap.empty()) + val scan1 = relation1.table.asReadable.newScanBuilder(relation1.options).build() + val scan2 = relation2.table.asReadable.newScanBuilder(relation2.options).build() + + val filter1 = CatalystGreaterThan(relation1.output.head, CatalystLiteral(3)) + val filter2 = CatalystGreaterThan(relation2.output.head, CatalystLiteral(3)) + + val scanRelation1 = DataSourceV2ScanRelation(relation1, scan1, relation1.output, + pushedFilters = Seq(filter1)) + val scanRelation2 = DataSourceV2ScanRelation(relation2, scan2, relation2.output, + pushedFilters = Seq(filter2)) + + assert(scanRelation1 != scanRelation2, + "Two instances should not be equal before canonicalization") + assert(scanRelation1.canonicalized == scanRelation2.canonicalized, + "Canonicalized instances with equivalent pushedFilters should be equal") + } + + test("pushedFilters excludes non-deterministic filters") { + val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() + // i > 3 is pushable and deterministic; rand() > 0.5 is non-deterministic and not pushable. + // Before the fix, ExpressionSet.contains would miss the non-deterministic filter, + // causing it to incorrectly appear in pushedFilterExpressions. + val q = df.filter($"i" > 3 && rand() > 0.5) + + val scanRelation = getScanRelation(q) + // pushedFilters should only contain the deterministic pushed filter (i > 3). + assert(scanRelation.pushedFilters.nonEmpty, + "pushedFilters should contain the deterministic pushed filter") + assert(scanRelation.pushedFilters.forall(_.deterministic), + "pushedFilters should not contain non-deterministic filters") + val referencedCols = scanRelation.pushedFilters.flatMap(_.references.map(_.name)).toSet + assert(referencedCols.contains("i"), + "pushedFilters should contain the pushed filter on column i") + } + + test("pushedFilters drops filters referencing pruned columns") { + // Disable constraint propagation so IsNotNull(i) is not added (it would keep + // column i in the scan output). This simulates a connector that pushes IsNotNull. + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { + val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() + // i > 3 is fully pushed; selecting only j causes column pruning to drop i. + val q = df.filter($"i" > 3).select($"j") + checkAnswer(q, (4 until 10).map(i => Row(-i))) + + val scanRelation = getScanRelation(q) + assert(!scanRelation.output.exists(_.name == "i"), + "column i should be pruned from scan output") + assert(scanRelation.pushedFilters.isEmpty, + "pushedFilters should drop filters referencing pruned columns") + } + } + } case class RangeInputPartition(start: Int, end: Int) extends InputPartition @@ -1443,6 +1633,171 @@ class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderF } +// Data source where pushed filters overlap with post-scan filters. +// pushFilters returns unsupported filters, but pushedFilters() returns ALL filters +// (including the unsupported ones). This mimics the Parquet row-group filter pattern +// where a filter is pushed for best-effort evaluation but must also be re-evaluated. +class OverlappingFilterDataSourceV2 extends TestingV2Source { + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new OverlappingScanBuilder() + } + } +} + +class OverlappingScanBuilder extends ScanBuilder + with Scan with SupportsPushDownFilters with SupportsPushDownRequiredColumns { + + var requiredSchema = TestingV2Source.schema + private var allFilters = Array.empty[Filter] + private var evaluableFilters = Array.empty[Filter] + + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema + } + + override def readSchema(): StructType = requiredSchema + + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val (supported, unsupported) = filters.partition { + case GreaterThan("i", _: Int) => true + case _ => false + } + this.evaluableFilters = supported + this.allFilters = filters + // Return unsupported filters as post-scan, but report ALL filters as pushed + unsupported + } + + // Reports all filters as pushed (including unsupported ones that overlap with post-scan) + override def pushedFilters(): Array[Filter] = allFilters + + override def build(): Scan = this + + override def toBatch: Batch = new AdvancedBatch(evaluableFilters, requiredSchema) +} + +// Data source with a nested (struct) column to test ProjectionOverSchema remapping. +// Schema: s: struct, i: int +// Pushes GreaterThan on nested field "s.a". +class NestedSchemaDataSourceV2 extends TableProvider { + + override def inferSchema(options: CaseInsensitiveStringMap): StructType = + NestedSchemaDataSourceV2.schema + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + new Table with SupportsRead { + override def name(): String = "nested-schema-test" + override def schema(): StructType = NestedSchemaDataSourceV2.schema + override def capabilities(): util.Set[TableCapability] = + util.EnumSet.of(TableCapability.BATCH_READ) + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = + new NestedSchemaScanBuilder() + } + } +} + +object NestedSchemaDataSourceV2 { + val schema: StructType = StructType(Seq( + StructField("s", StructType(Seq( + StructField("a", IntegerType), + StructField("b", IntegerType) + ))), + StructField("i", IntegerType) + )) +} + +class NestedSchemaScanBuilder extends ScanBuilder + with Scan with SupportsPushDownFilters with SupportsPushDownRequiredColumns { + + var requiredSchema: StructType = NestedSchemaDataSourceV2.schema + var filters = Array.empty[Filter] + + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema + } + + override def readSchema(): StructType = requiredSchema + + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + // Push GreaterThan on nested field "s.a" + val (supported, unsupported) = filters.partition { + case GreaterThan("s.a", _: Int) => true + case _ => false + } + this.filters = supported + unsupported + } + + override def pushedFilters(): Array[Filter] = filters + + override def build(): Scan = this + + override def toBatch: Batch = new NestedSchemaBatch(filters, requiredSchema) +} + +class NestedSchemaBatch( + val filters: Array[Filter], + val requiredSchema: StructType) extends Batch { + + override def planInputPartitions(): Array[InputPartition] = { + val lowerBound = filters.collectFirst { + case GreaterThan("s.a", v: Int) => v + } + val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] + if (lowerBound.isEmpty) { + res.append(RangeInputPartition(0, 5)) + res.append(RangeInputPartition(5, 10)) + } else if (lowerBound.get < 4) { + res.append(RangeInputPartition(lowerBound.get + 1, 5)) + res.append(RangeInputPartition(5, 10)) + } else if (lowerBound.get < 9) { + res.append(RangeInputPartition(lowerBound.get + 1, 10)) + } + res.toArray + } + + override def createReaderFactory(): PartitionReaderFactory = + new NestedSchemaReaderFactory(requiredSchema) +} + +class NestedSchemaReaderFactory( + requiredSchema: StructType) extends PartitionReaderFactory { + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[InternalRow] { + private var current = start - 1 + + override def next(): Boolean = { + current += 1 + current < end + } + + override def get(): InternalRow = { + val values = requiredSchema.map { field => + field.name match { + case "s" => + val structType = field.dataType.asInstanceOf[StructType] + val structValues = structType.map(_.name).map { + case "a" => current + case "b" => -current + } + InternalRow.fromSeq(structValues) + case "i" => current + } + } + InternalRow.fromSeq(values) + } + + override def close(): Unit = {} + } + } +} + class SchemaRequiredDataSource extends TableProvider { class MyScanBuilder(schema: StructType) extends SimpleScanBuilder {