diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
index 6a35ad3e9ecab..d3c9cf3b402e1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
@@ -157,13 +157,22 @@ case class DataSourceV2Relation(
* @param keyGroupedPartitioning if set, the partitioning expressions that are used to split the
* rows in the scan across different partitions
* @param ordering if set, the ordering provided by the scan
+ * @param pushedFilters Catalyst expressions for filters that were fully pushed to the data
+ * source and do not appear as post-scan filters
*/
case class DataSourceV2ScanRelation(
relation: DataSourceV2Relation,
scan: Scan,
output: Seq[AttributeReference],
keyGroupedPartitioning: Option[Seq[Expression]] = None,
- ordering: Option[Seq[SortOrder]] = None) extends LeafNode with NamedRelation {
+ ordering: Option[Seq[SortOrder]] = None,
+ pushedFilters: Seq[Expression] = Seq.empty) extends LeafNode with NamedRelation {
+
+ // TODO: Override validConstraints to return ExpressionSet(pushedFilters) so that pushed
+ // filters participate in constraint propagation (InferFiltersFromConstraints, PruneFilters).
+ // This changes which filters InferFiltersFromConstraints adds or removes (e.g., it may
+ // skip adding IsNotNull when the scan already implies it, or infer new filters across
+ // joins), so plan stability testing is needed first.
override def name: String = relation.name
@@ -197,7 +206,8 @@ case class DataSourceV2ScanRelation(
),
ordering = ordering.map(
_.map(o => o.copy(child = QueryPlan.normalizeExpressions(o.child, output)))
- )
+ ),
+ pushedFilters = pushedFilters.map(QueryPlan.normalizeExpressions(_, output))
)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 4a4ccab47cad0..5fbb40934e1c2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable
import org.apache.spark.SparkException
import org.apache.spark.internal.LogKeys.{AGGREGATE_FUNCTIONS, COLUMN_NAMES, GROUP_BY_EXPRS, JOIN_CONDITION, JOIN_TYPE, POST_SCAN_FILTERS, PUSHED_FILTERS, RELATION_NAME, RELATION_OUTPUT}
-import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, ExprId, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, ExpressionSet, ExprId, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.planning.{PhysicalOperation, ScanOperation}
@@ -95,6 +95,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery
+ // Compute the pushed filter expressions: the normalized filters that were fully pushed
+ // down (i.e., not in postScanFilters). These are stored on the scan relation for
+ // potential future use in constraint propagation.
+ val postScanFilterSet = ExpressionSet(postScanFiltersWithoutSubquery)
+ sHolder.pushedFilterExpressions = normalizedFiltersWithoutSubquery
+ .filterNot(postScanFilterSet.contains)
+ .filter(_.deterministic)
+
logInfo(
log"""
|Pushing operators to ${MDC(RELATION_NAME, sHolder.relation.name)}
@@ -698,6 +706,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
assert(realOutput.length == holder.output.length,
"The data source returns unexpected number of columns")
val wrappedScan = getWrappedScan(scan, holder)
+ // Note: holder.pushedFilterExpressions is not propagated here because the output schema
+ // changes to aggregate columns. When validConstraints is wired up, this needs revisiting.
val scanRelation = DataSourceV2ScanRelation(holder.relation, wrappedScan, realOutput)
val projectList = realOutput.zip(holder.output).map { case (a1, a2) =>
// The data source may return columns with arbitrary data types and it's safer to cast them
@@ -715,6 +725,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
assert(realOutput.length == holder.output.length,
"The data source returns unexpected number of columns")
val wrappedScan = getWrappedScan(scan, holder)
+ // Note: holder.pushedFilterExpressions is not propagated here because the output schema
+ // changes with pushed join. When validConstraints is wired up, this needs revisiting.
val scanRelation = DataSourceV2ScanRelation(holder.relation, wrappedScan, realOutput)
// When join is pushed down, the real output is going to be, for example,
@@ -737,6 +749,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
val scan = holder.builder.build()
val realOutput = toAttributes(scan.readSchema())
val wrappedScan = getWrappedScan(scan, holder)
+ // Note: holder.pushedFilterExpressions is not propagated here because the output schema
+ // changes with variant extraction. When validConstraints is wired up, this needs revisiting.
val scanRelation = DataSourceV2ScanRelation(holder.relation, wrappedScan, realOutput)
// Create projection to map real output to expected output (with transformed types)
@@ -787,14 +801,19 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
val wrappedScan = getWrappedScan(scan, sHolder)
- val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
-
val projectionOverSchema =
ProjectionOverSchema(output.toStructType, AttributeSet(output))
val projectionFunc = (expr: Expression) => expr transformDown {
case projectionOverSchema(newExpr) => newExpr
}
+ // Remap pushed filter attributes to the pruned output schema and drop filters
+ // whose references are no longer in the pruned output.
+ val remappedPushedFilters = sHolder.pushedFilterExpressions.map(projectionFunc)
+ .filter(_.references.subsetOf(AttributeSet(output)))
+ val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output,
+ pushedFilters = remappedPushedFilters)
+
val finalFilters = normalizedFilters.map(projectionFunc)
// bottom-most filters are put in the left of the list.
val withFilter = finalFilters.foldLeft[LogicalPlan](scanRelation)((plan, cond) => {
@@ -1018,6 +1037,8 @@ case class ScanBuilderHolder(
var pushedVariantAttributeMap: Map[ExprId, AttributeReference] = Map.empty
var pushedVariants: Option[VariantInRelation] = None
+
+ var pushedFilterExpressions: Seq[Expression] = Seq.empty
}
// A wrapper for v1 scan to carry the translated filters and the handled ones, along with
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
index a09b7e0827c49..6ea1ea3faa0ec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
@@ -28,7 +28,7 @@ import test.org.apache.spark.sql.connector._
import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.ScalarSubquery
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterThan => CatalystGreaterThan, Literal => CatalystLiteral, ScalarSubquery}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
import org.apache.spark.sql.connector.catalog.{PartitionInternalRow, SupportsRead, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.catalog.TableCapability._
@@ -48,7 +48,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{Filter, GreaterThan}
import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ArrayImplicits._
@@ -1158,6 +1158,196 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
checkAnswer(query, Row(4, 1))
}
}
+
+ private def getScanRelation(query: DataFrame): DataSourceV2ScanRelation = {
+ query.queryExecution.optimizedPlan.collect {
+ case s: DataSourceV2ScanRelation => s
+ }.head
+ }
+
+ test("pushedFilters are set for fully pushed filters") {
+ val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()
+ // AdvancedDataSourceV2 only supports pushing GreaterThan on column "i".
+ // i > 3 matches, so it is fully pushed.
+ val q = df.filter($"i" > 3)
+ checkAnswer(q, (4 until 10).map(i => Row(i, -i)))
+
+ val scanRelation = getScanRelation(q)
+ assert(scanRelation.pushedFilters.nonEmpty,
+ "pushedFilters should be non-empty when filters are fully pushed")
+ // The pushed filter should reference column i
+ assert(scanRelation.pushedFilters.flatMap(_.references.map(_.name)).contains("i"))
+ }
+
+ test("pushedFilters are empty when no filters are pushed") {
+ val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()
+ // AdvancedDataSourceV2 only supports pushing GreaterThan on column "i".
+ // j < -10 does not match, so it is not pushed.
+ val q = df.filter($"j" < -10)
+ checkAnswer(q, Nil)
+
+ val scanRelation = getScanRelation(q)
+ assert(scanRelation.pushedFilters.isEmpty,
+ "pushedFilters should be empty when no filters are pushed")
+ }
+
+ test("pushedFilters are empty when no filter is present") {
+ val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()
+ val q = df.select($"i", $"j")
+ checkAnswer(q, (0 until 10).map(i => Row(i, -i)))
+
+ val scanRelation = getScanRelation(q)
+ assert(scanRelation.pushedFilters.isEmpty,
+ "pushedFilters should be empty when there is no filter")
+ }
+
+ test("pushedFilters contains only pushed filters in mixed case") {
+ val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()
+ // AdvancedDataSourceV2 only supports pushing GreaterThan on column "i".
+ // i > 3 matches so it is pushed; j < 0 does not match so it is not pushed.
+ val q = df.filter($"i" > 3 && $"j" < 0)
+ checkAnswer(q, (4 until 10).map(i => Row(i, -i)))
+
+ val scanRelation = getScanRelation(q)
+ // Only i > 3 should be in pushedFilters, not j < 0
+ assert(scanRelation.pushedFilters.nonEmpty,
+ "pushedFilters should be non-empty for the pushed portion")
+ val referencedCols = scanRelation.pushedFilters.flatMap(_.references.map(_.name)).toSet
+ assert(referencedCols.contains("i"),
+ "pushedFilters should contain the pushed filter on column i")
+ assert(!referencedCols.contains("j"),
+ "pushedFilters should not contain the unsupported filter on column j")
+ }
+
+ test("pushedFilters does not include filters that remain as post-scan") {
+ val df = spark.read.format(classOf[OverlappingFilterDataSourceV2].getName).load()
+ // OverlappingFilterDataSourceV2 only evaluates GreaterThan on column "i".
+ // i > 3 is fully pushed; j < 0 overlaps (reported by pushedFilters() but also post-scan).
+ val q = df.filter($"i" > 3 && $"j" < 0)
+ checkAnswer(q, (4 until 10).map(i => Row(i, -i)))
+
+ val scanRelation = getScanRelation(q)
+ // Only i > 3 should be in pushedFilters (fully pushed, not in post-scan).
+ // j < 0 is in post-scan so it should be excluded from pushedFilters despite being
+ // reported by the scan's pushedFilters() API.
+ assert(scanRelation.pushedFilters.nonEmpty,
+ "pushedFilters should contain the fully-pushed filter")
+ val referencedCols = scanRelation.pushedFilters.flatMap(_.references.map(_.name)).toSet
+ assert(referencedCols.contains("i"),
+ "pushedFilters should contain the fully-pushed filter on column i")
+ assert(!referencedCols.contains("j"),
+ "pushedFilters should not contain the overlapping filter on column j")
+ }
+
+ test("pushedFilters with V2 filter API") {
+ val df = spark.read.format(classOf[AdvancedDataSourceV2WithV2Filter].getName).load()
+ // i > 3 is fully pushed via V2 filter API
+ val q = df.filter($"i" > 3)
+ checkAnswer(q, (4 until 10).map(i => Row(i, -i)))
+
+ val scanRelation = getScanRelation(q)
+ assert(scanRelation.pushedFilters.nonEmpty,
+ "pushedFilters should be non-empty with V2 filter API")
+ assert(scanRelation.pushedFilters.flatMap(_.references.map(_.name)).contains("i"))
+ }
+
+ test("pushedFilters with V2 filter API contains only pushed filters in mixed case") {
+ val df = spark.read.format(classOf[AdvancedDataSourceV2WithV2Filter].getName).load()
+ // AdvancedScanBuilderWithV2Filter only supports pushing ">" predicates.
+ // i > 3 has predicate name ">" so it is pushed; j < 0 has predicate name "<" so it is not.
+ val q = df.filter($"i" > 3 && $"j" < 0)
+ checkAnswer(q, (4 until 10).map(i => Row(i, -i)))
+
+ val scanRelation = getScanRelation(q)
+ assert(scanRelation.pushedFilters.nonEmpty,
+ "pushedFilters should be non-empty for the pushed portion")
+ val referencedCols = scanRelation.pushedFilters.flatMap(_.references.map(_.name)).toSet
+ assert(referencedCols.contains("i"),
+ "pushedFilters should contain the pushed filter on column i")
+ assert(!referencedCols.contains("j"),
+ "pushedFilters should not contain the unsupported filter on column j")
+ }
+
+ test("pushedFilters are remapped by ProjectionOverSchema after nested schema pruning") {
+ val df = spark.read.format(classOf[NestedSchemaDataSourceV2].getName).load()
+ // NestedSchemaScanBuilder pushes GreaterThan on "s.a".
+ // Selecting only s.a triggers nested schema pruning: s goes from struct to struct.
+ val q = df.select($"s.a").filter($"s.a" > 3)
+ checkAnswer(q, (4 until 10).map(i => Row(i)))
+
+ val scanRelation = getScanRelation(q)
+ assert(scanRelation.pushedFilters.nonEmpty,
+ "pushedFilters should be non-empty")
+ // Find the struct attribute referenced by the pushed filter.
+ // Before remapping it would have type struct; after remapping, struct.
+ val structAttrs = scanRelation.pushedFilters
+ .flatMap(_.collect { case a: AttributeReference if a.name == "s" => a })
+ assert(structAttrs.nonEmpty, "pushed filter should reference struct column s")
+ val prunedStructType = structAttrs.head.dataType.asInstanceOf[StructType]
+ assert(prunedStructType.fieldNames.toSeq == Seq("a"),
+ s"struct column in pushed filter should be pruned to struct but was $prunedStructType")
+ }
+
+ test("scan canonicalization with pushedFilters") {
+ // Use SimpleDataSourceV2 whose scan implements equals, so canonicalization comparison works
+ val table = new SimpleDataSourceV2().getTable(CaseInsensitiveStringMap.empty())
+
+ val relation1 = DataSourceV2Relation.create(
+ table, None, None, CaseInsensitiveStringMap.empty())
+ val relation2 = DataSourceV2Relation.create(
+ table, None, None, CaseInsensitiveStringMap.empty())
+ val scan1 = relation1.table.asReadable.newScanBuilder(relation1.options).build()
+ val scan2 = relation2.table.asReadable.newScanBuilder(relation2.options).build()
+
+ val filter1 = CatalystGreaterThan(relation1.output.head, CatalystLiteral(3))
+ val filter2 = CatalystGreaterThan(relation2.output.head, CatalystLiteral(3))
+
+ val scanRelation1 = DataSourceV2ScanRelation(relation1, scan1, relation1.output,
+ pushedFilters = Seq(filter1))
+ val scanRelation2 = DataSourceV2ScanRelation(relation2, scan2, relation2.output,
+ pushedFilters = Seq(filter2))
+
+ assert(scanRelation1 != scanRelation2,
+ "Two instances should not be equal before canonicalization")
+ assert(scanRelation1.canonicalized == scanRelation2.canonicalized,
+ "Canonicalized instances with equivalent pushedFilters should be equal")
+ }
+
+ test("pushedFilters excludes non-deterministic filters") {
+ val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()
+ // i > 3 is pushable and deterministic; rand() > 0.5 is non-deterministic and not pushable.
+ // Before the fix, ExpressionSet.contains would miss the non-deterministic filter,
+ // causing it to incorrectly appear in pushedFilterExpressions.
+ val q = df.filter($"i" > 3 && rand() > 0.5)
+
+ val scanRelation = getScanRelation(q)
+ // pushedFilters should only contain the deterministic pushed filter (i > 3).
+ assert(scanRelation.pushedFilters.nonEmpty,
+ "pushedFilters should contain the deterministic pushed filter")
+ assert(scanRelation.pushedFilters.forall(_.deterministic),
+ "pushedFilters should not contain non-deterministic filters")
+ val referencedCols = scanRelation.pushedFilters.flatMap(_.references.map(_.name)).toSet
+ assert(referencedCols.contains("i"),
+ "pushedFilters should contain the pushed filter on column i")
+ }
+
+ test("pushedFilters drops filters referencing pruned columns") {
+ // Disable constraint propagation so IsNotNull(i) is not added (it would keep
+ // column i in the scan output). This simulates a connector that pushes IsNotNull.
+ withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
+ val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()
+ // i > 3 is fully pushed; selecting only j causes column pruning to drop i.
+ val q = df.filter($"i" > 3).select($"j")
+ checkAnswer(q, (4 until 10).map(i => Row(-i)))
+
+ val scanRelation = getScanRelation(q)
+ assert(!scanRelation.output.exists(_.name == "i"),
+ "column i should be pruned from scan output")
+ assert(scanRelation.pushedFilters.isEmpty,
+ "pushedFilters should drop filters referencing pruned columns")
+ }
+ }
+
}
case class RangeInputPartition(start: Int, end: Int) extends InputPartition
@@ -1443,6 +1633,171 @@ class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderF
}
+// Data source where pushed filters overlap with post-scan filters.
+// pushFilters returns unsupported filters, but pushedFilters() returns ALL filters
+// (including the unsupported ones). This mimics the Parquet row-group filter pattern
+// where a filter is pushed for best-effort evaluation but must also be re-evaluated.
+class OverlappingFilterDataSourceV2 extends TestingV2Source {
+ override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable {
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
+ new OverlappingScanBuilder()
+ }
+ }
+}
+
+class OverlappingScanBuilder extends ScanBuilder
+ with Scan with SupportsPushDownFilters with SupportsPushDownRequiredColumns {
+
+ var requiredSchema = TestingV2Source.schema
+ private var allFilters = Array.empty[Filter]
+ private var evaluableFilters = Array.empty[Filter]
+
+ override def pruneColumns(requiredSchema: StructType): Unit = {
+ this.requiredSchema = requiredSchema
+ }
+
+ override def readSchema(): StructType = requiredSchema
+
+ override def pushFilters(filters: Array[Filter]): Array[Filter] = {
+ val (supported, unsupported) = filters.partition {
+ case GreaterThan("i", _: Int) => true
+ case _ => false
+ }
+ this.evaluableFilters = supported
+ this.allFilters = filters
+ // Return unsupported filters as post-scan, but report ALL filters as pushed
+ unsupported
+ }
+
+ // Reports all filters as pushed (including unsupported ones that overlap with post-scan)
+ override def pushedFilters(): Array[Filter] = allFilters
+
+ override def build(): Scan = this
+
+ override def toBatch: Batch = new AdvancedBatch(evaluableFilters, requiredSchema)
+}
+
+// Data source with a nested (struct) column to test ProjectionOverSchema remapping.
+// Schema: s: struct, i: int
+// Pushes GreaterThan on nested field "s.a".
+class NestedSchemaDataSourceV2 extends TableProvider {
+
+ override def inferSchema(options: CaseInsensitiveStringMap): StructType =
+ NestedSchemaDataSourceV2.schema
+
+ override def getTable(
+ schema: StructType,
+ partitioning: Array[Transform],
+ properties: util.Map[String, String]): Table = {
+ new Table with SupportsRead {
+ override def name(): String = "nested-schema-test"
+ override def schema(): StructType = NestedSchemaDataSourceV2.schema
+ override def capabilities(): util.Set[TableCapability] =
+ util.EnumSet.of(TableCapability.BATCH_READ)
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
+ new NestedSchemaScanBuilder()
+ }
+ }
+}
+
+object NestedSchemaDataSourceV2 {
+ val schema: StructType = StructType(Seq(
+ StructField("s", StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", IntegerType)
+ ))),
+ StructField("i", IntegerType)
+ ))
+}
+
+class NestedSchemaScanBuilder extends ScanBuilder
+ with Scan with SupportsPushDownFilters with SupportsPushDownRequiredColumns {
+
+ var requiredSchema: StructType = NestedSchemaDataSourceV2.schema
+ var filters = Array.empty[Filter]
+
+ override def pruneColumns(requiredSchema: StructType): Unit = {
+ this.requiredSchema = requiredSchema
+ }
+
+ override def readSchema(): StructType = requiredSchema
+
+ override def pushFilters(filters: Array[Filter]): Array[Filter] = {
+ // Push GreaterThan on nested field "s.a"
+ val (supported, unsupported) = filters.partition {
+ case GreaterThan("s.a", _: Int) => true
+ case _ => false
+ }
+ this.filters = supported
+ unsupported
+ }
+
+ override def pushedFilters(): Array[Filter] = filters
+
+ override def build(): Scan = this
+
+ override def toBatch: Batch = new NestedSchemaBatch(filters, requiredSchema)
+}
+
+class NestedSchemaBatch(
+ val filters: Array[Filter],
+ val requiredSchema: StructType) extends Batch {
+
+ override def planInputPartitions(): Array[InputPartition] = {
+ val lowerBound = filters.collectFirst {
+ case GreaterThan("s.a", v: Int) => v
+ }
+ val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition]
+ if (lowerBound.isEmpty) {
+ res.append(RangeInputPartition(0, 5))
+ res.append(RangeInputPartition(5, 10))
+ } else if (lowerBound.get < 4) {
+ res.append(RangeInputPartition(lowerBound.get + 1, 5))
+ res.append(RangeInputPartition(5, 10))
+ } else if (lowerBound.get < 9) {
+ res.append(RangeInputPartition(lowerBound.get + 1, 10))
+ }
+ res.toArray
+ }
+
+ override def createReaderFactory(): PartitionReaderFactory =
+ new NestedSchemaReaderFactory(requiredSchema)
+}
+
+class NestedSchemaReaderFactory(
+ requiredSchema: StructType) extends PartitionReaderFactory {
+
+ override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
+ val RangeInputPartition(start, end) = partition
+ new PartitionReader[InternalRow] {
+ private var current = start - 1
+
+ override def next(): Boolean = {
+ current += 1
+ current < end
+ }
+
+ override def get(): InternalRow = {
+ val values = requiredSchema.map { field =>
+ field.name match {
+ case "s" =>
+ val structType = field.dataType.asInstanceOf[StructType]
+ val structValues = structType.map(_.name).map {
+ case "a" => current
+ case "b" => -current
+ }
+ InternalRow.fromSeq(structValues)
+ case "i" => current
+ }
+ }
+ InternalRow.fromSeq(values)
+ }
+
+ override def close(): Unit = {}
+ }
+ }
+}
+
class SchemaRequiredDataSource extends TableProvider {
class MyScanBuilder(schema: StructType) extends SimpleScanBuilder {