Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ case class StreamingSymmetricHashJoinExec(
private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes)

private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match {
case Some(JoinStateKeyWatermarkPredicate(expr, _)) =>
case Some(JoinStateKeyWatermarkPredicate(expr, _, _)) =>
// inputSchema can be empty as expr should only have BoundReferences and does not require
// the schema to generated predicate. See [[StreamingSymmetricHashJoinHelper]].
Predicate.create(expr, Seq.empty).eval _
Expand All @@ -672,7 +672,7 @@ case class StreamingSymmetricHashJoinExec(
}

private[this] val stateValueWatermarkPredicateFunc = stateWatermarkPredicate match {
case Some(JoinStateValueWatermarkPredicate(expr, _)) =>
case Some(JoinStateValueWatermarkPredicate(expr, _, _)) =>
Predicate.create(expr, inputAttributes).eval _
case _ =>
Predicate.create(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate
Expand Down Expand Up @@ -893,21 +893,25 @@ case class StreamingSymmetricHashJoinExec(
*/
def removeOldState(): Long = {
stateWatermarkPredicate match {
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) =>
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
joinStateManager match {
case s: SupportsEvictByCondition =>
s.evictByKeyCondition(stateKeyWatermarkPredicateFunc)

case s: SupportsEvictByTimestamp =>
s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
s.evictByTimestamp(
watermarkMsToStateTimestamp(stateWatermark),
prevStateWatermark.map(watermarkMsToStateTimestamp))
}
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) =>
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
joinStateManager match {
case s: SupportsEvictByCondition =>
s.evictByValueCondition(stateValueWatermarkPredicateFunc)

case s: SupportsEvictByTimestamp =>
s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
s.evictByTimestamp(
watermarkMsToStateTimestamp(stateWatermark),
prevStateWatermark.map(watermarkMsToStateTimestamp))
}
case _ => 0L
}
Expand All @@ -925,21 +929,25 @@ case class StreamingSymmetricHashJoinExec(
*/
def removeAndReturnOldState(): Iterator[KeyToValuePair] = {
stateWatermarkPredicate match {
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) =>
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
joinStateManager match {
case s: SupportsEvictByCondition =>
s.evictAndReturnByKeyCondition(stateKeyWatermarkPredicateFunc)

case s: SupportsEvictByTimestamp =>
s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
s.evictAndReturnByTimestamp(
watermarkMsToStateTimestamp(stateWatermark),
prevStateWatermark.map(watermarkMsToStateTimestamp))
}
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) =>
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
joinStateManager match {
case s: SupportsEvictByCondition =>
s.evictAndReturnByValueCondition(stateValueWatermarkPredicateFunc)

case s: SupportsEvictByTimestamp =>
s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
s.evictAndReturnByTimestamp(
watermarkMsToStateTimestamp(stateWatermark),
prevStateWatermark.map(watermarkMsToStateTimestamp))
}
case _ => Iterator.empty
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,18 @@ object StreamingSymmetricHashJoinHelper extends Logging {
override def toString: String = s"$desc: $expr"
}
/** Predicate for watermark on state keys */
case class JoinStateKeyWatermarkPredicate(expr: Expression, stateWatermark: Long)
case class JoinStateKeyWatermarkPredicate(
expr: Expression,
stateWatermark: Long,
prevStateWatermark: Option[Long] = None)
extends JoinStateWatermarkPredicate {
def desc: String = "key predicate"
}
/** Predicate for watermark on state values */
case class JoinStateValueWatermarkPredicate(expr: Expression, stateWatermark: Long)
case class JoinStateValueWatermarkPredicate(
expr: Expression,
stateWatermark: Long,
prevStateWatermark: Option[Long] = None)
extends JoinStateWatermarkPredicate {
def desc: String = "value predicate"
}
Expand Down Expand Up @@ -185,6 +191,7 @@ object StreamingSymmetricHashJoinHelper extends Logging {
rightKeys: Seq[Expression],
condition: Option[Expression],
eventTimeWatermarkForEviction: Option[Long],
eventTimeWatermarkForLateEvents: Option[Long],
useFirstEventTimeColumn: Boolean): JoinStateWatermarkPredicates = {

// Perform assertions against multiple event time columns in the same DataFrame. This method
Expand Down Expand Up @@ -215,20 +222,30 @@ object StreamingSymmetricHashJoinHelper extends Logging {
expr.map { e =>
// watermarkExpression only provides the expression when eventTimeWatermarkForEviction
// is defined
JoinStateKeyWatermarkPredicate(e, eventTimeWatermarkForEviction.get)
JoinStateKeyWatermarkPredicate(
e,
eventTimeWatermarkForEviction.get,
eventTimeWatermarkForLateEvents)
}
} else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs
val stateValueWatermark = StreamingJoinHelper.getStateValueWatermark(
attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes),
attributesWithEventWatermark = AttributeSet(otherSideInputAttributes),
condition,
eventTimeWatermarkForEviction)
val prevStateValueWatermark = eventTimeWatermarkForLateEvents.flatMap { _ =>
StreamingJoinHelper.getStateValueWatermark(
attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes),
attributesWithEventWatermark = AttributeSet(otherSideInputAttributes),
condition,
eventTimeWatermarkForLateEvents)
}
val inputAttributeWithWatermark = oneSideInputAttributes.find(_.metadata.contains(delayKey))
val expr = watermarkExpression(inputAttributeWithWatermark, stateValueWatermark)
expr.map { e =>
// watermarkExpression only provides the expression when eventTimeWatermarkForEviction
// is defined
JoinStateValueWatermarkPredicate(e, stateValueWatermark.get)
JoinStateValueWatermarkPredicate(e, stateValueWatermark.get, prevStateValueWatermark)
}
} else {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOper
import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
import org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor, KeyStateEncoderSpec, NoopStatePartitionKeyExtractor, NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast, StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema, StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics, StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay, TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec, TimestampKeyStateEncoder}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType, StructField, StructType}
import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DateType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.NextIterator

/**
Expand Down Expand Up @@ -184,15 +185,28 @@ trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
import SymmetricHashJoinStateManager._

/** Evict the state by timestamp. Returns the number of values evicted. */
def evictByTimestamp(endTimestamp: Long): Long
/**
* Evict the state by timestamp. Returns the number of values evicted.
*
* @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp.
* @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are
* assumed to have been evicted already (e.g. from the previous batch). When provided,
* the scan starts from startTimestamp + 1.
*/
def evictByTimestamp(endTimestamp: Long, startTimestamp: Option[Long] = None): Long

/**
* Evict the state by timestamp and return the evicted key-value pairs.
*
* It is caller's responsibility to consume the whole iterator.
*
* @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp.
* @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are
* assumed to have been evicted already (e.g. from the previous batch). When provided,
* the scan starts from startTimestamp + 1.
*/
def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair]
def evictAndReturnByTimestamp(
endTimestamp: Long, startTimestamp: Option[Long] = None): Iterator[KeyToValuePair]
}

/**
Expand Down Expand Up @@ -507,9 +521,9 @@ class SymmetricHashJoinStateManagerV4(
}
}

override def evictByTimestamp(endTimestamp: Long): Long = {
override def evictByTimestamp(endTimestamp: Long, startTimestamp: Option[Long] = None): Long = {
var removed = 0L
tsWithKey.scanEvictedKeys(endTimestamp).foreach { evicted =>
tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp).foreach { evicted =>
val key = evicted.key
val timestamp = evicted.timestamp
val numValues = evicted.numValues
Expand All @@ -523,10 +537,11 @@ class SymmetricHashJoinStateManagerV4(
removed
}

override def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair] = {
override def evictAndReturnByTimestamp(
endTimestamp: Long, startTimestamp: Option[Long] = None): Iterator[KeyToValuePair] = {
val reusableKeyToValuePair = KeyToValuePair()

tsWithKey.scanEvictedKeys(endTimestamp).flatMap { evicted =>
tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp).flatMap { evicted =>
val key = evicted.key
val timestamp = evicted.timestamp
val values = keyWithTsToValues.get(key, timestamp)
Expand Down Expand Up @@ -647,17 +662,30 @@ class SymmetricHashJoinStateManagerV4(

/**
* Returns entries where minTs <= timestamp <= maxTs (both inclusive), grouped by timestamp.
* Skips entries before minTs and stops iterating past maxTs (timestamps are sorted).
* When maxTs is bounded (< Long.MaxValue), uses rangeScanWithMultiValues for efficient
* range access; falls back to prefixScan otherwise to stay within the key's scope.
*
* When prefixScan is used (maxTs == Long.MaxValue), entries outside [minTs, maxTs] are
* filtered out so both code paths produce identical results.
*/
def getValuesInRange(
key: UnsafeRow, minTs: Long, maxTs: Long): Iterator[GetValuesResult] = {
val reusableGetValuesResult = new GetValuesResult()
// Only use rangeScan when maxTs < Long.MaxValue, since rangeScan requires
// an exclusive end key (maxTs + 1) which would overflow at Long.MaxValue.
val useRangeScan = maxTs < Long.MaxValue

new NextIterator[GetValuesResult] {
private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName)
private val iter = if (useRangeScan) {
val startKey = createKeyRow(key, minTs).copy()
// rangeScanWithMultiValues endKey is exclusive, so use maxTs + 1
val endKey = Some(createKeyRow(key, maxTs + 1))
stateStore.rangeScanWithMultiValues(Some(startKey), endKey, colFamilyName)
} else {
stateStore.prefixScanWithMultiValues(key, colFamilyName)
}

private var currentTs = -1L
private var pastUpperBound = false
private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]()

private def flushAccumulated(): GetValuesResult = {
Expand All @@ -675,16 +703,16 @@ class SymmetricHashJoinStateManagerV4(

@tailrec
override protected def getNext(): GetValuesResult = {
if (pastUpperBound || !iter.hasNext) {
if (!iter.hasNext) {
flushAccumulated()
} else {
val unsafeRowPair = iter.next()
val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key)

if (ts > maxTs) {
pastUpperBound = true
getNext()
} else if (ts < minTs) {
// Filter out entries outside [minTs, maxTs]. This is essential when using
// prefixScan (which returns all timestamps for the key) and serves as a
// safety guard for rangeScan as well.
if (ts < minTs || ts > maxTs) {
getNext()
} else if (currentTs == -1L || currentTs == ts) {
currentTs = ts
Expand Down Expand Up @@ -757,6 +785,8 @@ class SymmetricHashJoinStateManagerV4(
isInternal = true
)

// Returns an UnsafeRow backed by a reused projection buffer. Callers that need to
// hold the row beyond the immediate state store call must invoke copy() on the result.
private def createKeyRow(key: UnsafeRow, timestamp: Long): UnsafeRow = {
TimestampKeyStateEncoder.attachTimestamp(
attachTimestampProjection, keySchemaWithTimestamp, key, timestamp)
Expand All @@ -772,9 +802,66 @@ class SymmetricHashJoinStateManagerV4(

case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int)

// NOTE: This assumes we consume the whole iterator to trigger completion.
def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = {
val evictIterator = stateStore.iteratorWithMultiValues(colFamilyName)
private def defaultInternalRow(schema: StructType): InternalRow = {
InternalRow.fromSeq(schema.map(f => defaultValueForType(f.dataType)))
}

private def defaultValueForType(dt: DataType): Any = dt match {
case BooleanType => false
case ByteType => 0.toByte
case ShortType => 0.toShort
case IntegerType | DateType => 0
case LongType | TimestampType | TimestampNTZType => 0L
case FloatType => 0.0f
case DoubleType => 0.0
case StringType => UTF8String.EMPTY_UTF8
case BinaryType => Array.emptyByteArray
case st: StructType => defaultInternalRow(st)
case _ => null
}

/**
* Build a scan boundary row for rangeScan. The TsWithKeyTypeStore uses
* TimestampAsPrefixKeyStateEncoder, which encodes the row as [timestamp][key_fields].
* We need a full-schema row (not just the timestamp) because the encoder expects all
* key columns to be present. Default values are used for the key fields since only the
* timestamp matters for ordering in the prefix encoder.
*/
private def createScanBoundaryRow(timestamp: Long): UnsafeRow = {
val defaultKey = UnsafeProjection.create(keySchema)
.apply(defaultInternalRow(keySchema))
createKeyRow(defaultKey, timestamp).copy()
}

/**
* Scan keys eligible for eviction within the timestamp range.
*
* This assumes we consume the whole iterator to trigger completion.
*
* @param endTimestamp Inclusive upper bound: entries with timestamp <= endTimestamp are
* eligible for eviction.
* @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp
* are assumed to have been evicted already. The scan starts from startTimestamp + 1.
*/
def scanEvictedKeys(
endTimestamp: Long,
startTimestamp: Option[Long] = None): Iterator[EvictedKeysResult] = {
// rangeScanWithMultiValues: startKey is inclusive, endKey is exclusive.
// startTimestamp is exclusive (already evicted), so we seek from st + 1.
val startKeyRow = startTimestamp.flatMap { st =>
if (st < Long.MaxValue) Some(createScanBoundaryRow(st + 1))
else None
}
// endTimestamp is inclusive, so we use endTimestamp + 1 as the exclusive upper bound.
// When endTimestamp == Long.MaxValue we cannot add 1, so endKeyRow is None. This is
// safe because rangeScanWithMultiValues with no end key uses the column-family prefix
// as the upper bound, naturally scoping the scan within this column family.
val endKeyRow = if (endTimestamp < Long.MaxValue) {
Some(createScanBoundaryRow(endTimestamp + 1))
} else {
None
}
val evictIterator = stateStore.rangeScanWithMultiValues(startKeyRow, endKeyRow, colFamilyName)
new NextIterator[EvictedKeysResult]() {
var currentKeyRow: UnsafeRow = null
var currentEventTime: Long = -1L
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,13 +528,19 @@ class IncrementalExecution(
case j: StreamingSymmetricHashJoinExec =>
val iwLateEvents = inputWatermarkForLateEvents(j.stateInfo.get)
val iwEviction = inputWatermarkForEviction(j.stateInfo.get)
// Only use the late-events watermark as the scan lower bound when a previous
// batch actually existed. In the very first batch the watermark propagation
// yields Some(0) even though no state has been evicted yet, which would
// incorrectly skip entries at timestamp 0.
val prevBatchLateEventsWm =
if (prevOffsetSeqMetadata.isDefined) iwLateEvents else None
j.copy(
eventTimeWatermarkForLateEvents = iwLateEvents,
eventTimeWatermarkForEviction = iwEviction,
stateWatermarkPredicates =
StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates(
j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full,
iwEviction, !allowMultipleStatefulOperators)
iwEviction, prevBatchLateEventsWm, !allowMultipleStatefulOperators)
)
}
}
Expand Down
Loading