diff --git a/src/main/scala-spark-3.0/uk/co/gresearch/spark/source/Reporting.scala b/src/main/scala-spark-3.0/uk/co/gresearch/spark/source/Reporting.scala new file mode 100644 index 00000000..2ffb5f53 --- /dev/null +++ b/src/main/scala-spark-3.0/uk/co/gresearch/spark/source/Reporting.scala @@ -0,0 +1,22 @@ +package uk.co.gresearch.spark.source + +import org.apache.spark.sql.connector.read.{Batch, SupportsReportPartitioning, partitioning} +import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Distribution} + +trait Reporting extends SupportsReportPartitioning { + this: Batch => + + def partitioned: Boolean + def ordered: Boolean + + def outputPartitioning: partitioning.Partitioning = + Partitioning(this.planInputPartitions().length, partitioned) +} + +case class Partitioning(partitions: Int, partitioned: Boolean) extends partitioning.Partitioning { + override def numPartitions(): Int = partitions + override def satisfy(distribution: Distribution): Boolean = distribution match { + case c: ClusteredDistribution => partitioned && c.clusteredColumns.contains("id") + case _ => false + } +} diff --git a/src/main/scala-spark-3.1/uk/co/gresearch/spark/source/Reporting.scala b/src/main/scala-spark-3.1/uk/co/gresearch/spark/source/Reporting.scala new file mode 120000 index 00000000..0abba2a2 --- /dev/null +++ b/src/main/scala-spark-3.1/uk/co/gresearch/spark/source/Reporting.scala @@ -0,0 +1 @@ +../../../../../../scala-spark-3.0/uk/co/gresearch/spark/source/Reporting.scala \ No newline at end of file diff --git a/src/main/scala-spark-3.2/uk/co/gresearch/spark/source/Reporting.scala b/src/main/scala-spark-3.2/uk/co/gresearch/spark/source/Reporting.scala new file mode 120000 index 00000000..0abba2a2 --- /dev/null +++ b/src/main/scala-spark-3.2/uk/co/gresearch/spark/source/Reporting.scala @@ -0,0 +1 @@ +../../../../../../scala-spark-3.0/uk/co/gresearch/spark/source/Reporting.scala \ No newline at end of file diff --git a/src/main/scala-spark-3.3/uk/co/gresearch/spark/source/Reporting.scala b/src/main/scala-spark-3.3/uk/co/gresearch/spark/source/Reporting.scala new file mode 100644 index 00000000..72d3490e --- /dev/null +++ b/src/main/scala-spark-3.3/uk/co/gresearch/spark/source/Reporting.scala @@ -0,0 +1,32 @@ +package uk.co.gresearch.spark.source + +import org.apache.spark.sql.connector.expressions.{Expression, NamedReference, Transform} +import org.apache.spark.sql.connector.read.{Batch, SupportsReportPartitioning, partitioning} +import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, UnknownPartitioning} +import uk.co.gresearch.spark.source.Reporting.namedReference + +trait Reporting extends SupportsReportPartitioning { + this: Batch => + + def partitioned: Boolean + def ordered: Boolean + + val partitionKeys: Array[Expression] = Array(namedReference("id")) + + def outputPartitioning: partitioning.Partitioning = if (partitioned) { + new KeyGroupedPartitioning(partitionKeys, this.planInputPartitions().length) + } else { + new UnknownPartitioning(this.planInputPartitions().length) + } +} + +object Reporting { + def namedReference(columnName: String): Expression = + new Transform { + override def name(): String = "identity" + override def references(): Array[NamedReference] = Array.empty + override def arguments(): Array[Expression] = Array(new NamedReference { + override def fieldNames(): Array[String] = Array(columnName) + }) + } +} diff --git a/src/main/scala-spark-3.4/uk/co/gresearch/spark/source/Reporting.scala b/src/main/scala-spark-3.4/uk/co/gresearch/spark/source/Reporting.scala new file mode 120000 index 00000000..8ec50af0 --- /dev/null +++ b/src/main/scala-spark-3.4/uk/co/gresearch/spark/source/Reporting.scala @@ -0,0 +1 @@ +../../../../../../scala-spark-3.3/uk/co/gresearch/spark/source/Reporting.scala \ No newline at end of file diff --git a/src/main/scala/uk/co/gresearch/spark/source/DefaultSource.scala b/src/main/scala/uk/co/gresearch/spark/source/DefaultSource.scala new file mode 100644 index 00000000..41e97d2e --- /dev/null +++ b/src/main/scala/uk/co/gresearch/spark/source/DefaultSource.scala @@ -0,0 +1,86 @@ +package uk.co.gresearch.spark.source + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.read +import org.apache.spark.sql.connector.read.{Batch, InputPartition, Scan} +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.sql.Timestamp +import java.util +import scala.collection.JavaConverters._ + +class DefaultSource() extends TableProvider with DataSourceRegister { + override def shortName(): String = "example" + override def inferPartitioning(options: CaseInsensitiveStringMap): Array[Transform] = Array.empty + override def inferSchema(options: CaseInsensitiveStringMap): StructType = DefaultSource.schema + override def getTable(schema: StructType, partitioning: Array[Transform], properties: util.Map[String, String]): Table = + BatchTable( + properties.getOrDefault("partitioned", "false").toBoolean, + properties.getOrDefault("ordered", "false").toBoolean + ) +} + +object DefaultSource { + val supportsReportingOrder: Boolean = false + val schema: StructType = StructType(Seq( + StructField("id", IntegerType), + StructField("time", TimestampType), + StructField("value", DoubleType), + )) + val ts: Long = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2020-01-01 12:00:00")) + val data: Map[Int, Array[InternalRow]] = Map( + 1 -> Array( + InternalRow(1, ts + 1000000, 1.1), + InternalRow(1, ts + 2000000, 1.2), + InternalRow(1, ts + 3000000, 1.3), + InternalRow(3, ts + 1000000, 3.1), + InternalRow(3, ts + 2000000, 3.2) + ), + 2 -> Array( + InternalRow(2, ts + 1000000, 2.1), + InternalRow(2, ts + 2000000, 2.2), + InternalRow(4, ts + 1000000, 4.1), + InternalRow(4, ts + 2000000, 4.2), + InternalRow(4, ts + 3000000, 4.3) + ) + ) + val partitions: Int = data.size +} + +case class BatchTable(partitioned: Boolean, ordered: Boolean) extends Table with SupportsRead { + override def name(): String = "table" + override def schema(): StructType = DefaultSource.schema + override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava + override def newScanBuilder(caseInsensitiveStringMap: CaseInsensitiveStringMap): read.ScanBuilder = + new ScanBuilder(partitioned, ordered) +} + +class ScanBuilder(partitioned: Boolean, ordered: Boolean) extends read.ScanBuilder { + override def build(): Scan = BatchScan(partitioned, ordered) +} + +case class BatchScan(partitioned: Boolean, ordered: Boolean) extends read.Scan with read.Batch with Reporting { + override def readSchema(): StructType = DefaultSource.schema + override def toBatch: Batch = this + override def planInputPartitions(): Array[InputPartition] = DefaultSource.data.keys.map(Partition).toArray + override def createReaderFactory(): read.PartitionReaderFactory = PartitionReaderFactory() +} + +case class Partition(id: Int) extends InputPartition + +case class PartitionReaderFactory() extends read.PartitionReaderFactory { + override def createReader(partition: InputPartition): read.PartitionReader[InternalRow] = PartitionReader(partition) +} + + +case class PartitionReader(partition: InputPartition) extends read.PartitionReader[InternalRow] { + val rows: Iterator[InternalRow] = DefaultSource.data.getOrElse(partition.asInstanceOf[Partition].id, Array.empty[InternalRow]).iterator + def next: Boolean = rows.hasNext + def get: InternalRow = rows.next() + def close(): Unit = { } +} diff --git a/src/test/scala/uk/co/gresearch/spark/SparkTestSession.scala b/src/test/scala/uk/co/gresearch/spark/SparkTestSession.scala index 80cccc9c..e9e80320 100644 --- a/src/test/scala/uk/co/gresearch/spark/SparkTestSession.scala +++ b/src/test/scala/uk/co/gresearch/spark/SparkTestSession.scala @@ -28,6 +28,7 @@ trait SparkTestSession extends SQLHelper { .master("local[1]") .appName("spark test example") .config("spark.sql.shuffle.partitions", 2) + .config("spark.sql.adaptive.coalescePartitions.enabled", value = false) .config("spark.local.dir", ".") .getOrCreate() } diff --git a/src/test/scala/uk/co/gresearch/spark/source/SourceSuite.scala b/src/test/scala/uk/co/gresearch/spark/source/SourceSuite.scala new file mode 100644 index 00000000..650c694e --- /dev/null +++ b/src/test/scala/uk/co/gresearch/spark/source/SourceSuite.scala @@ -0,0 +1,66 @@ +package uk.co.gresearch.spark.source + +import org.apache.spark.sql.{DataFrame, DataFrameReader} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.exchange.Exchange +import org.apache.spark.sql.execution.{SortExec, SparkPlan} +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions.sum +import org.scalatest.funsuite.AnyFunSuite +import uk.co.gresearch.spark.SparkTestSession + +class SourceSuite extends AnyFunSuite with SparkTestSession with AdaptiveSparkPlanHelper { + import spark.implicits._ + + private val source = new DefaultSource().getClass.getPackage.getName + private def df: DataFrameReader = spark.read.format(source) + private val dfpartitioned = df.option("partitioned", value = true) + private val dfpartitionedAndSorted = df.option("partitioned", value = true).option("ordered", value = true) + private val window = Window.partitionBy($"id").orderBy($"time") + + test("show") { + df.load().show() + } + + test("groupBy without partition information") { + assertPlan( + df.load().groupBy($"id").count(), + { case e: Exchange => e }, + expected = true + ) + } + + test("groupBy with partition information") { + assertPlan( + dfpartitioned.load().groupBy($"id").count(), + { case e: Exchange => e }, + expected = false + ) + } + + test("window function without partition information") { + val df = this.df.load().select($"id", $"time", sum($"value").over(window)) + assertPlan(df, { case e: Exchange => e }, expected = true) + assertPlan(df, { case s: SortExec => s }, expected = true) + } + + test("window function with partition information") { + val df = this.dfpartitioned.load().select($"id", $"time", sum($"value").over(window)) + assertPlan(df, { case e: Exchange => e }, expected = false) + assertPlan(df, { case s: SortExec => s }, expected = true) + } + + test("window function with partition and order information") { + assertPlan( + dfpartitionedAndSorted.load().select($"id", $"time", sum($"value").over(window)), + { case e: Exchange => e; case s: SortExec => s }, + expected = !DefaultSource.supportsReportingOrder + ) + } + + def assertPlan[T](df: DataFrame, func: PartialFunction[SparkPlan, T], expected: Boolean): Unit = { + df.explain() + assert(df.rdd.getNumPartitions === DefaultSource.partitions) + assert(collectFirst(df.queryExecution.executedPlan)(func).isDefined === expected) + } +}