Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions core/src/main/scala/org/apache/spark/util/RpcUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ private[spark] object RpcUtils {
rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name)
}

def makeDriverRef(
name: String,
driverHost: String,
driverPort: Int,
rpcEnv: RpcEnv): RpcEndpointRef = {
Utils.checkHost(driverHost)
rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name)
}

/** Returns the default Spark timeout to use for RPC ask operations. */
def askRpcTimeout(conf: SparkConf): RpcTimeout = {
RpcTimeout(conf, Seq(RPC_ASK_TIMEOUT.key, NETWORK_TIMEOUT.key), "120s")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.sources

import java.util

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{SparkEnv, SparkUnsupportedOperationException}
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.connector.catalog.{SupportsWrite, TableCapability}
import org.apache.spark.sql.connector.write.{
LogicalWriteInfo,
PhysicalWriteInfo,
Write,
WriteBuilder,
WriterCommitMessage
}
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend
import org.apache.spark.sql.types.StructType

/**
* A sink that stores the results in memory. This [[org.apache.spark.sql.execution.streaming.Sink]]
* is primarily intended for use in unit tests and does not provide durability.
* This is mostly copied from MemorySink, except that the data needs to be available not in
* commit() but after each write.
*/
class ContinuousMemorySink
extends MemorySink
with SupportsWrite {

private val batches = new ArrayBuffer[Row]()
override def name(): String = "ContinuousMemorySink"

override def schema(): StructType = StructType(Nil)

override def capabilities(): util.Set[TableCapability] = {
util.EnumSet.of(TableCapability.STREAMING_WRITE)
}

override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
new WriteBuilder with SupportsStreamingUpdateAsAppend {
private val inputSchema: StructType = info.schema()

override def build(): Write = {
new ContinuousMemoryWrite(batches, inputSchema)
}
}
}

/** Returns all rows that are stored in this [[Sink]]. */
override def allData: Seq[Row] = {
val batches = getBatches()
batches.synchronized {
batches.toSeq
}
}

override def latestBatchId: Option[Long] = {
None
}

override def latestBatchData: Seq[Row] = {
throw new SparkUnsupportedOperationException(
errorClass = "UNSUPPORTED_OPERATION_FOR_CONTINUOUS_MEMORY_SINK",
messageParameters = Map("operation" -> "latestBatchData")
)
}

override def dataSinceBatch(sinceBatchId: Long): Seq[Row] = {
throw new SparkUnsupportedOperationException(
errorClass = "UNSUPPORTED_OPERATION_FOR_CONTINUOUS_MEMORY_SINK",
messageParameters = Map("operation" -> "dataSinceBatch")
)
}

override def toDebugString: String = {
s"${allData}"
}

override def write(batchId: Long, needTruncate: Boolean, newRows: Array[Row]): Unit = {
throw new SparkUnsupportedOperationException(
errorClass = "UNSUPPORTED_OPERATION_FOR_CONTINUOUS_MEMORY_SINK",
messageParameters = Map("operation" -> "write")
)
}

override def clear(): Unit = synchronized {
batches.clear()
}

private def getBatches(): ArrayBuffer[Row] = {
batches
}

override def toString(): String = "ContinuousMemorySink"
}

class ContinuousMemoryWrite(batches: ArrayBuffer[Row], schema: StructType) extends Write {
override def toStreaming: StreamingWrite = {
new ContinuousMemoryStreamingWrite(batches, schema)
}
}

/**
* An RPC endpoint that receives rows and stores them to the ArrayBuffer in real-time.
*/
class MemoryRealTimeRpcEndpoint(
override val rpcEnv: RpcEnv,
schema: StructType,
batches: ArrayBuffer[Row]
) extends ThreadSafeRpcEndpoint {
private val encoder = ExpressionEncoder(schema).resolveAndBind().createDeserializer()

override def receive: PartialFunction[Any, Unit] = {
case rows: Array[InternalRow] =>
// synchronized block is optional here since ThreadSafeRpcEndpoint already, just to be safe
batches.synchronized {
rows.foreach { row =>
batches += encoder(row)
}
}
}
}

class ContinuousMemoryStreamingWrite(val batches: ArrayBuffer[Row], schema: StructType)
extends StreamingWrite {

private val memoryEndpoint =
new MemoryRealTimeRpcEndpoint(
SparkEnv.get.rpcEnv,
schema,
batches
)
@volatile private var endpointRef: RpcEndpointRef = _

override def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = {
val endpointName = s"MemoryRealTimeRpcEndpoint-${java.util.UUID.randomUUID()}"
endpointRef = memoryEndpoint.rpcEnv.setupEndpoint(endpointName, memoryEndpoint)
RealTimeRowWriterFactory(endpointName, endpointRef.address)
}

override def useCommitCoordinator(): Boolean = false

override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
// We don't need to commit anything in this case, as the rows have already been printed
if (endpointRef != null) {
memoryEndpoint.rpcEnv.stop(endpointRef)
}
Comment on lines +166 to +168
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this commit is called for each batch, does it mean that the endpoint stops to work after that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay, I see. For each batch, there will be a new MemoryRealTimeRpcEndpoint created.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup

}

override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
if (endpointRef != null) {
memoryEndpoint.rpcEnv.stop(endpointRef)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.sources

import org.apache.spark.SparkEnv
import org.apache.spark.rpc.RpcAddress
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory
import org.apache.spark.util.RpcUtils

/**
* A [[StreamingDataWriterFactory]] that creates [[RealTimeRowWriter]], which sends rows to
* the driver in real-time through RPC.
*
* Note that, because it sends all rows to the driver, this factory will generally be unsuitable
* for production-quality sinks. It's intended for use in tests.
*
*/
case class RealTimeRowWriterFactory(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why this doesn't follow the same naming pattern like ContinuousMemoryRowWriterFactory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good question. The reason is because this code is going to be shared with the future RTM version of ConsoleStreamingWrite

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I can rename it if you fee strongly about it.

driverEndpointName: String,
driverEndpointAddr: RpcAddress
) extends StreamingDataWriterFactory {
override def createWriter(
partitionId: Int,
taskId: Long,
epochId: Long): DataWriter[InternalRow] = {
new RealTimeRowWriter(
driverEndpointName,
driverEndpointAddr
)
}
}

/**
* A [[DataWriter]] that sends arrays of rows to the driver in real-time through RPC.
*/
class RealTimeRowWriter(
driverEndpointName: String,
driverEndpointAddr: RpcAddress
) extends DataWriter[InternalRow] {

private val endpointRef = RpcUtils.makeDriverRef(
driverEndpointName,
driverEndpointAddr.host,
driverEndpointAddr.port,
SparkEnv.get.rpcEnv
)

// Spark reuses the same `InternalRow` instance, here we copy it before buffer it.
override def write(row: InternalRow): Unit = {
endpointRef.send(Array(row.copy()))
}

override def commit(): WriterCommitMessage = { null }

override def abort(): Unit = {}

override def close(): Unit = {}
}