Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package uk.co.gresearch.spark.source

import org.apache.spark.sql.connector.read.{Batch, SupportsReportPartitioning, partitioning}
import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Distribution}

trait Reporting extends SupportsReportPartitioning {
this: Batch =>

def partitioned: Boolean
def ordered: Boolean

def outputPartitioning: partitioning.Partitioning =
Partitioning(this.planInputPartitions().length, partitioned)
}

case class Partitioning(partitions: Int, partitioned: Boolean) extends partitioning.Partitioning {
override def numPartitions(): Int = partitions
override def satisfy(distribution: Distribution): Boolean = distribution match {
case c: ClusteredDistribution => partitioned && c.clusteredColumns.contains("id")
case _ => false
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package uk.co.gresearch.spark.source

import org.apache.spark.sql.connector.expressions.{Expression, NamedReference, Transform}
import org.apache.spark.sql.connector.read.{Batch, SupportsReportPartitioning, partitioning}
import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, UnknownPartitioning}
import uk.co.gresearch.spark.source.Reporting.namedReference

trait Reporting extends SupportsReportPartitioning {
this: Batch =>

def partitioned: Boolean
def ordered: Boolean

val partitionKeys: Array[Expression] = Array(namedReference("id"))

def outputPartitioning: partitioning.Partitioning = if (partitioned) {
new KeyGroupedPartitioning(partitionKeys, this.planInputPartitions().length)
} else {
new UnknownPartitioning(this.planInputPartitions().length)
}
}

object Reporting {
def namedReference(columnName: String): Expression =
new Transform {
override def name(): String = "identity"
override def references(): Array[NamedReference] = Array.empty
override def arguments(): Array[Expression] = Array(new NamedReference {
override def fieldNames(): Array[String] = Array(columnName)
})
}
}
86 changes: 86 additions & 0 deletions src/main/scala/uk/co/gresearch/spark/source/DefaultSource.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package uk.co.gresearch.spark.source

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.read
import org.apache.spark.sql.connector.read.{Batch, InputPartition, Scan}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap

import java.sql.Timestamp
import java.util
import scala.collection.JavaConverters._

class DefaultSource() extends TableProvider with DataSourceRegister {
override def shortName(): String = "example"
override def inferPartitioning(options: CaseInsensitiveStringMap): Array[Transform] = Array.empty
override def inferSchema(options: CaseInsensitiveStringMap): StructType = DefaultSource.schema
override def getTable(schema: StructType, partitioning: Array[Transform], properties: util.Map[String, String]): Table =
BatchTable(
properties.getOrDefault("partitioned", "false").toBoolean,
properties.getOrDefault("ordered", "false").toBoolean
)
}

object DefaultSource {
val supportsReportingOrder: Boolean = false
val schema: StructType = StructType(Seq(
StructField("id", IntegerType),
StructField("time", TimestampType),
StructField("value", DoubleType),
))
val ts: Long = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2020-01-01 12:00:00"))
val data: Map[Int, Array[InternalRow]] = Map(
1 -> Array(
InternalRow(1, ts + 1000000, 1.1),
InternalRow(1, ts + 2000000, 1.2),
InternalRow(1, ts + 3000000, 1.3),
InternalRow(3, ts + 1000000, 3.1),
InternalRow(3, ts + 2000000, 3.2)
),
2 -> Array(
InternalRow(2, ts + 1000000, 2.1),
InternalRow(2, ts + 2000000, 2.2),
InternalRow(4, ts + 1000000, 4.1),
InternalRow(4, ts + 2000000, 4.2),
InternalRow(4, ts + 3000000, 4.3)
)
)
val partitions: Int = data.size
}

case class BatchTable(partitioned: Boolean, ordered: Boolean) extends Table with SupportsRead {
override def name(): String = "table"
override def schema(): StructType = DefaultSource.schema
override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava
override def newScanBuilder(caseInsensitiveStringMap: CaseInsensitiveStringMap): read.ScanBuilder =
new ScanBuilder(partitioned, ordered)
}

class ScanBuilder(partitioned: Boolean, ordered: Boolean) extends read.ScanBuilder {
override def build(): Scan = BatchScan(partitioned, ordered)
}

case class BatchScan(partitioned: Boolean, ordered: Boolean) extends read.Scan with read.Batch with Reporting {
override def readSchema(): StructType = DefaultSource.schema
override def toBatch: Batch = this
override def planInputPartitions(): Array[InputPartition] = DefaultSource.data.keys.map(Partition).toArray
override def createReaderFactory(): read.PartitionReaderFactory = PartitionReaderFactory()
}

case class Partition(id: Int) extends InputPartition

case class PartitionReaderFactory() extends read.PartitionReaderFactory {
override def createReader(partition: InputPartition): read.PartitionReader[InternalRow] = PartitionReader(partition)
}


case class PartitionReader(partition: InputPartition) extends read.PartitionReader[InternalRow] {
val rows: Iterator[InternalRow] = DefaultSource.data.getOrElse(partition.asInstanceOf[Partition].id, Array.empty[InternalRow]).iterator
def next: Boolean = rows.hasNext
def get: InternalRow = rows.next()
def close(): Unit = { }
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ trait SparkTestSession extends SQLHelper {
.master("local[1]")
.appName("spark test example")
.config("spark.sql.shuffle.partitions", 2)
.config("spark.sql.adaptive.coalescePartitions.enabled", value = false)
.config("spark.local.dir", ".")
.getOrCreate()
}
Expand Down
66 changes: 66 additions & 0 deletions src/test/scala/uk/co/gresearch/spark/source/SourceSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package uk.co.gresearch.spark.source

import org.apache.spark.sql.{DataFrame, DataFrameReader}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.execution.{SortExec, SparkPlan}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.sum
import org.scalatest.funsuite.AnyFunSuite
import uk.co.gresearch.spark.SparkTestSession

class SourceSuite extends AnyFunSuite with SparkTestSession with AdaptiveSparkPlanHelper {
import spark.implicits._

private val source = new DefaultSource().getClass.getPackage.getName
private def df: DataFrameReader = spark.read.format(source)
private val dfpartitioned = df.option("partitioned", value = true)
private val dfpartitionedAndSorted = df.option("partitioned", value = true).option("ordered", value = true)
private val window = Window.partitionBy($"id").orderBy($"time")

test("show") {
df.load().show()
}

test("groupBy without partition information") {
assertPlan(
df.load().groupBy($"id").count(),
{ case e: Exchange => e },
expected = true
)
}

test("groupBy with partition information") {
assertPlan(
dfpartitioned.load().groupBy($"id").count(),
{ case e: Exchange => e },
expected = false
)
}

test("window function without partition information") {
val df = this.df.load().select($"id", $"time", sum($"value").over(window))
assertPlan(df, { case e: Exchange => e }, expected = true)
assertPlan(df, { case s: SortExec => s }, expected = true)
}

test("window function with partition information") {
val df = this.dfpartitioned.load().select($"id", $"time", sum($"value").over(window))
assertPlan(df, { case e: Exchange => e }, expected = false)
assertPlan(df, { case s: SortExec => s }, expected = true)
}

test("window function with partition and order information") {
assertPlan(
dfpartitionedAndSorted.load().select($"id", $"time", sum($"value").over(window)),
{ case e: Exchange => e; case s: SortExec => s },
expected = !DefaultSource.supportsReportingOrder
)
}

def assertPlan[T](df: DataFrame, func: PartialFunction[SparkPlan, T], expected: Boolean): Unit = {
df.explain()
assert(df.rdd.getNumPartitions === DefaultSource.partitions)
assert(collectFirst(df.queryExecution.executedPlan)(func).isDefined === expected)
}
}