diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala index 1c50e6802c323..2b15c43bd3db6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala @@ -663,7 +663,7 @@ case class StreamingSymmetricHashJoinExec( private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes) private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match { - case Some(JoinStateKeyWatermarkPredicate(expr, _)) => + case Some(JoinStateKeyWatermarkPredicate(expr, _, _)) => // inputSchema can be empty as expr should only have BoundReferences and does not require // the schema to generated predicate. See [[StreamingSymmetricHashJoinHelper]]. Predicate.create(expr, Seq.empty).eval _ @@ -672,7 +672,7 @@ case class StreamingSymmetricHashJoinExec( } private[this] val stateValueWatermarkPredicateFunc = stateWatermarkPredicate match { - case Some(JoinStateValueWatermarkPredicate(expr, _)) => + case Some(JoinStateValueWatermarkPredicate(expr, _, _)) => Predicate.create(expr, inputAttributes).eval _ case _ => Predicate.create(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate @@ -893,21 +893,25 @@ case class StreamingSymmetricHashJoinExec( */ def removeOldState(): Long = { stateWatermarkPredicate match { - case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) => + case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark, prevStateWatermark)) => joinStateManager match { case s: SupportsEvictByCondition => s.evictByKeyCondition(stateKeyWatermarkPredicateFunc) case s: SupportsEvictByTimestamp => - s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark)) + s.evictByTimestamp( + watermarkMsToStateTimestamp(stateWatermark), + prevStateWatermark.map(watermarkMsToStateTimestamp)) } - case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) => + case Some(JoinStateValueWatermarkPredicate(_, stateWatermark, prevStateWatermark)) => joinStateManager match { case s: SupportsEvictByCondition => s.evictByValueCondition(stateValueWatermarkPredicateFunc) case s: SupportsEvictByTimestamp => - s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark)) + s.evictByTimestamp( + watermarkMsToStateTimestamp(stateWatermark), + prevStateWatermark.map(watermarkMsToStateTimestamp)) } case _ => 0L } @@ -925,21 +929,25 @@ case class StreamingSymmetricHashJoinExec( */ def removeAndReturnOldState(): Iterator[KeyToValuePair] = { stateWatermarkPredicate match { - case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) => + case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark, prevStateWatermark)) => joinStateManager match { case s: SupportsEvictByCondition => s.evictAndReturnByKeyCondition(stateKeyWatermarkPredicateFunc) case s: SupportsEvictByTimestamp => - s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark)) + s.evictAndReturnByTimestamp( + watermarkMsToStateTimestamp(stateWatermark), + prevStateWatermark.map(watermarkMsToStateTimestamp)) } - case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) => + case Some(JoinStateValueWatermarkPredicate(_, stateWatermark, prevStateWatermark)) => joinStateManager match { case s: SupportsEvictByCondition => s.evictAndReturnByValueCondition(stateValueWatermarkPredicateFunc) case s: SupportsEvictByTimestamp => - s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark)) + s.evictAndReturnByTimestamp( + watermarkMsToStateTimestamp(stateWatermark), + prevStateWatermark.map(watermarkMsToStateTimestamp)) } case _ => Iterator.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala index cea6398f4e501..80a299b4e6bb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala @@ -46,12 +46,18 @@ object StreamingSymmetricHashJoinHelper extends Logging { override def toString: String = s"$desc: $expr" } /** Predicate for watermark on state keys */ - case class JoinStateKeyWatermarkPredicate(expr: Expression, stateWatermark: Long) + case class JoinStateKeyWatermarkPredicate( + expr: Expression, + stateWatermark: Long, + prevStateWatermark: Option[Long] = None) extends JoinStateWatermarkPredicate { def desc: String = "key predicate" } /** Predicate for watermark on state values */ - case class JoinStateValueWatermarkPredicate(expr: Expression, stateWatermark: Long) + case class JoinStateValueWatermarkPredicate( + expr: Expression, + stateWatermark: Long, + prevStateWatermark: Option[Long] = None) extends JoinStateWatermarkPredicate { def desc: String = "value predicate" } @@ -185,6 +191,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { rightKeys: Seq[Expression], condition: Option[Expression], eventTimeWatermarkForEviction: Option[Long], + eventTimeWatermarkForLateEvents: Option[Long], useFirstEventTimeColumn: Boolean): JoinStateWatermarkPredicates = { // Perform assertions against multiple event time columns in the same DataFrame. This method @@ -215,7 +222,10 @@ object StreamingSymmetricHashJoinHelper extends Logging { expr.map { e => // watermarkExpression only provides the expression when eventTimeWatermarkForEviction // is defined - JoinStateKeyWatermarkPredicate(e, eventTimeWatermarkForEviction.get) + JoinStateKeyWatermarkPredicate( + e, + eventTimeWatermarkForEviction.get, + eventTimeWatermarkForLateEvents) } } else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs val stateValueWatermark = StreamingJoinHelper.getStateValueWatermark( @@ -223,12 +233,19 @@ object StreamingSymmetricHashJoinHelper extends Logging { attributesWithEventWatermark = AttributeSet(otherSideInputAttributes), condition, eventTimeWatermarkForEviction) + val prevStateValueWatermark = eventTimeWatermarkForLateEvents.flatMap { _ => + StreamingJoinHelper.getStateValueWatermark( + attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes), + attributesWithEventWatermark = AttributeSet(otherSideInputAttributes), + condition, + eventTimeWatermarkForLateEvents) + } val inputAttributeWithWatermark = oneSideInputAttributes.find(_.metadata.contains(delayKey)) val expr = watermarkExpression(inputAttributeWithWatermark, stateValueWatermark) expr.map { e => // watermarkExpression only provides the expression when eventTimeWatermarkForEviction // is defined - JoinStateValueWatermarkPredicate(e, stateValueWatermark.get) + JoinStateValueWatermarkPredicate(e, stateValueWatermark.get, prevStateValueWatermark) } } else { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index fc2a69312fe79..611f548f44b88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -34,7 +34,8 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOper import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor, KeyStateEncoderSpec, NoopStatePartitionKeyExtractor, NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast, StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema, StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics, StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay, TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec, TimestampKeyStateEncoder} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DateType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.NextIterator /** @@ -184,15 +185,28 @@ trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager => trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager => import SymmetricHashJoinStateManager._ - /** Evict the state by timestamp. Returns the number of values evicted. */ - def evictByTimestamp(endTimestamp: Long): Long + /** + * Evict the state by timestamp. Returns the number of values evicted. + * + * @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp. + * @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are + * assumed to have been evicted already (e.g. from the previous batch). When provided, + * the scan starts from startTimestamp + 1. + */ + def evictByTimestamp(endTimestamp: Long, startTimestamp: Option[Long] = None): Long /** * Evict the state by timestamp and return the evicted key-value pairs. * * It is caller's responsibility to consume the whole iterator. + * + * @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp. + * @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are + * assumed to have been evicted already (e.g. from the previous batch). When provided, + * the scan starts from startTimestamp + 1. */ - def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair] + def evictAndReturnByTimestamp( + endTimestamp: Long, startTimestamp: Option[Long] = None): Iterator[KeyToValuePair] } /** @@ -507,9 +521,9 @@ class SymmetricHashJoinStateManagerV4( } } - override def evictByTimestamp(endTimestamp: Long): Long = { + override def evictByTimestamp(endTimestamp: Long, startTimestamp: Option[Long] = None): Long = { var removed = 0L - tsWithKey.scanEvictedKeys(endTimestamp).foreach { evicted => + tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp).foreach { evicted => val key = evicted.key val timestamp = evicted.timestamp val numValues = evicted.numValues @@ -523,10 +537,11 @@ class SymmetricHashJoinStateManagerV4( removed } - override def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair] = { + override def evictAndReturnByTimestamp( + endTimestamp: Long, startTimestamp: Option[Long] = None): Iterator[KeyToValuePair] = { val reusableKeyToValuePair = KeyToValuePair() - tsWithKey.scanEvictedKeys(endTimestamp).flatMap { evicted => + tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp).flatMap { evicted => val key = evicted.key val timestamp = evicted.timestamp val values = keyWithTsToValues.get(key, timestamp) @@ -647,17 +662,30 @@ class SymmetricHashJoinStateManagerV4( /** * Returns entries where minTs <= timestamp <= maxTs (both inclusive), grouped by timestamp. - * Skips entries before minTs and stops iterating past maxTs (timestamps are sorted). + * When maxTs is bounded (< Long.MaxValue), uses rangeScanWithMultiValues for efficient + * range access; falls back to prefixScan otherwise to stay within the key's scope. + * + * When prefixScan is used (maxTs == Long.MaxValue), entries outside [minTs, maxTs] are + * filtered out so both code paths produce identical results. */ def getValuesInRange( key: UnsafeRow, minTs: Long, maxTs: Long): Iterator[GetValuesResult] = { val reusableGetValuesResult = new GetValuesResult() + // Only use rangeScan when maxTs < Long.MaxValue, since rangeScan requires + // an exclusive end key (maxTs + 1) which would overflow at Long.MaxValue. + val useRangeScan = maxTs < Long.MaxValue new NextIterator[GetValuesResult] { - private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName) + private val iter = if (useRangeScan) { + val startKey = createKeyRow(key, minTs).copy() + // rangeScanWithMultiValues endKey is exclusive, so use maxTs + 1 + val endKey = Some(createKeyRow(key, maxTs + 1)) + stateStore.rangeScanWithMultiValues(Some(startKey), endKey, colFamilyName) + } else { + stateStore.prefixScanWithMultiValues(key, colFamilyName) + } private var currentTs = -1L - private var pastUpperBound = false private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]() private def flushAccumulated(): GetValuesResult = { @@ -675,16 +703,16 @@ class SymmetricHashJoinStateManagerV4( @tailrec override protected def getNext(): GetValuesResult = { - if (pastUpperBound || !iter.hasNext) { + if (!iter.hasNext) { flushAccumulated() } else { val unsafeRowPair = iter.next() val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key) - if (ts > maxTs) { - pastUpperBound = true - getNext() - } else if (ts < minTs) { + // Filter out entries outside [minTs, maxTs]. This is essential when using + // prefixScan (which returns all timestamps for the key) and serves as a + // safety guard for rangeScan as well. + if (ts < minTs || ts > maxTs) { getNext() } else if (currentTs == -1L || currentTs == ts) { currentTs = ts @@ -757,6 +785,8 @@ class SymmetricHashJoinStateManagerV4( isInternal = true ) + // Returns an UnsafeRow backed by a reused projection buffer. Callers that need to + // hold the row beyond the immediate state store call must invoke copy() on the result. private def createKeyRow(key: UnsafeRow, timestamp: Long): UnsafeRow = { TimestampKeyStateEncoder.attachTimestamp( attachTimestampProjection, keySchemaWithTimestamp, key, timestamp) @@ -772,9 +802,66 @@ class SymmetricHashJoinStateManagerV4( case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int) - // NOTE: This assumes we consume the whole iterator to trigger completion. - def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = { - val evictIterator = stateStore.iteratorWithMultiValues(colFamilyName) + private def defaultInternalRow(schema: StructType): InternalRow = { + InternalRow.fromSeq(schema.map(f => defaultValueForType(f.dataType))) + } + + private def defaultValueForType(dt: DataType): Any = dt match { + case BooleanType => false + case ByteType => 0.toByte + case ShortType => 0.toShort + case IntegerType | DateType => 0 + case LongType | TimestampType | TimestampNTZType => 0L + case FloatType => 0.0f + case DoubleType => 0.0 + case StringType => UTF8String.EMPTY_UTF8 + case BinaryType => Array.emptyByteArray + case st: StructType => defaultInternalRow(st) + case _ => null + } + + /** + * Build a scan boundary row for rangeScan. The TsWithKeyTypeStore uses + * TimestampAsPrefixKeyStateEncoder, which encodes the row as [timestamp][key_fields]. + * We need a full-schema row (not just the timestamp) because the encoder expects all + * key columns to be present. Default values are used for the key fields since only the + * timestamp matters for ordering in the prefix encoder. + */ + private def createScanBoundaryRow(timestamp: Long): UnsafeRow = { + val defaultKey = UnsafeProjection.create(keySchema) + .apply(defaultInternalRow(keySchema)) + createKeyRow(defaultKey, timestamp).copy() + } + + /** + * Scan keys eligible for eviction within the timestamp range. + * + * This assumes we consume the whole iterator to trigger completion. + * + * @param endTimestamp Inclusive upper bound: entries with timestamp <= endTimestamp are + * eligible for eviction. + * @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp + * are assumed to have been evicted already. The scan starts from startTimestamp + 1. + */ + def scanEvictedKeys( + endTimestamp: Long, + startTimestamp: Option[Long] = None): Iterator[EvictedKeysResult] = { + // rangeScanWithMultiValues: startKey is inclusive, endKey is exclusive. + // startTimestamp is exclusive (already evicted), so we seek from st + 1. + val startKeyRow = startTimestamp.flatMap { st => + if (st < Long.MaxValue) Some(createScanBoundaryRow(st + 1)) + else None + } + // endTimestamp is inclusive, so we use endTimestamp + 1 as the exclusive upper bound. + // When endTimestamp == Long.MaxValue we cannot add 1, so endKeyRow is None. This is + // safe because rangeScanWithMultiValues with no end key uses the column-family prefix + // as the upper bound, naturally scoping the scan within this column family. + val endKeyRow = if (endTimestamp < Long.MaxValue) { + Some(createScanBoundaryRow(endTimestamp + 1)) + } else { + None + } + val evictIterator = stateStore.rangeScanWithMultiValues(startKeyRow, endKeyRow, colFamilyName) new NextIterator[EvictedKeysResult]() { var currentKeyRow: UnsafeRow = null var currentEventTime: Long = -1L diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala index 169ab6f606dae..7e08c24e452f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala @@ -528,13 +528,19 @@ class IncrementalExecution( case j: StreamingSymmetricHashJoinExec => val iwLateEvents = inputWatermarkForLateEvents(j.stateInfo.get) val iwEviction = inputWatermarkForEviction(j.stateInfo.get) + // Only use the late-events watermark as the scan lower bound when a previous + // batch actually existed. In the very first batch the watermark propagation + // yields Some(0) even though no state has been evicted yet, which would + // incorrectly skip entries at timestamp 0. + val prevBatchLateEventsWm = + if (prevOffsetSeqMetadata.isDefined) iwLateEvents else None j.copy( eventTimeWatermarkForLateEvents = iwLateEvents, eventTimeWatermarkForEviction = iwEviction, stateWatermarkPredicates = StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, - iwEviction, !allowMultipleStatefulOperators) + iwEviction, prevBatchLateEventsWm, !allowMultipleStatefulOperators) ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index a24a76269828f..4ad3a662b4d51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -1662,6 +1662,93 @@ class RocksDB( } } + /** + * Scan key-value pairs in the range [startKey, endKey). + * + * @param startKey None to seek to the beginning of the column family, + * or Some(key) to seek to the given start position (inclusive). + * @param endKey None to scan to the end of the column family, + * or Some(key) as the exclusive upper bound for the scan (encoded key bytes). + * @param cfName The column family name. + * @return An iterator of ByteArrayPairs in the given range. + */ + def scan( + startKey: Option[Array[Byte]], + endKey: Option[Array[Byte]], + cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): NextIterator[ByteArrayPair] = { + updateMemoryUsageIfNeeded() + + val upperBoundBytes: Option[Array[Byte]] = endKey match { + case Some(key) => + Some(if (useColumnFamilies) encodeStateRowWithPrefix(key, cfName) else key) + case None => + if (useColumnFamilies) { + val cfPrefix = encodeStateRowWithPrefix(Array.emptyByteArray, cfName) + RocksDB.prefixUpperBound(cfPrefix) + } else { + None + } + } + + val seekTarget = startKey match { + case Some(key) => + if (useColumnFamilies) encodeStateRowWithPrefix(key, cfName) else key + case None => + if (useColumnFamilies) encodeStateRowWithPrefix(Array.emptyByteArray, cfName) + else null + } + + val upperBoundSlice = upperBoundBytes.map(new Slice(_)) + val scanReadOptions = new ReadOptions() + upperBoundSlice.foreach(scanReadOptions.setIterateUpperBound) + + val iter = db.newIterator(scanReadOptions) + if (seekTarget != null) { + iter.seek(seekTarget) + } else { + iter.seekToFirst() + } + + def closeResources(): Unit = { + iter.close() + scanReadOptions.close() + upperBoundSlice.foreach(_.close()) + } + + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => closeResources() } + } + + new NextIterator[ByteArrayPair] { + override protected def getNext(): ByteArrayPair = { + if (iter.isValid) { + val key = if (useColumnFamilies) { + decodeStateRowWithPrefix(iter.key)._1 + } else { + iter.key + } + + val value = if (conf.rowChecksumEnabled) { + KeyValueChecksumEncoder.decodeAndVerifyValueRowWithChecksum( + readVerifier, iter.key, iter.value, delimiterSize) + } else { + iter.value + } + + byteArrayPair.set(key, value) + iter.next() + byteArrayPair + } else { + finished = true + closeResources() + null + } + } + + override protected def close(): Unit = closeResources() + } + } + def release(): Unit = {} /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 5acf7cdc9b975..4d9c77348b493 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -46,6 +46,7 @@ import org.apache.spark.unsafe.Platform sealed trait RocksDBKeyStateEncoder { def supportPrefixKeyScan: Boolean def supportsDeleteRange: Boolean + def supportsRangeScan: Boolean def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] def encodeKey(row: UnsafeRow): Array[Byte] def decodeKey(keyBytes: Array[Byte]): UnsafeRow @@ -1500,6 +1501,8 @@ class PrefixKeyScanStateEncoder( override def supportPrefixKeyScan: Boolean = true override def supportsDeleteRange: Boolean = false + + override def supportsRangeScan: Boolean = false } /** @@ -1699,6 +1702,8 @@ class RangeKeyScanStateEncoder( override def supportPrefixKeyScan: Boolean = true override def supportsDeleteRange: Boolean = true + + override def supportsRangeScan: Boolean = true } /** @@ -1731,6 +1736,8 @@ class NoPrefixKeyStateEncoder( override def supportsDeleteRange: Boolean = false + override def supportsRangeScan: Boolean = false + override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { throw new IllegalStateException("This encoder doesn't support prefix key!") } @@ -1884,6 +1891,8 @@ class TimestampAsPrefixKeyStateEncoder( // TODO: [SPARK-55491] Revisit this to support delete range if needed. override def supportsDeleteRange: Boolean = false + + override def supportsRangeScan: Boolean = true } /** @@ -1932,6 +1941,8 @@ class TimestampAsPostfixKeyStateEncoder( } override def supportsDeleteRange: Boolean = false + + override def supportsRangeScan: Boolean = true } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 8b023a0e9f9fc..e1490c71bc69b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -549,6 +549,68 @@ private[sql] class RocksDBStateStoreProvider new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) } + override def rangeScan( + startKey: Option[UnsafeRow], + endKey: Option[UnsafeRow], + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + validateAndTransitionState(UPDATE) + verifyColFamilyOperations("rangeScan", colFamilyName) + + val kvEncoder = keyValueEncoderMap.get(colFamilyName) + require(kvEncoder._1.supportsRangeScan, + "Range scan requires an encoder that supports range scanning!") + + val encodedStartKey = startKey.map(kvEncoder._1.encodeKey) + val encodedEndKey = endKey.map(kvEncoder._1.encodeKey) + + val rowPair = new UnsafeRowPair() + val rocksDbIter = rocksDB.scan(encodedStartKey, encodedEndKey, colFamilyName) + val iter = rocksDbIter.map { kv => + rowPair.withRows(kvEncoder._1.decodeKey(kv.key), + kvEncoder._2.decodeValue(kv.value)) + rowPair + } + + new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) + } + + override def rangeScanWithMultiValues( + startKey: Option[UnsafeRow], + endKey: Option[UnsafeRow], + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + validateAndTransitionState(UPDATE) + verifyColFamilyOperations("rangeScanWithMultiValues", colFamilyName) + + val kvEncoder = keyValueEncoderMap.get(colFamilyName) + require(kvEncoder._1.supportsRangeScan, + "Range scan requires an encoder that supports range scanning!") + verify( + kvEncoder._2.supportsMultipleValuesPerKey, + "Multi-value iterator operation requires an encoder" + + " which supports multiple values for a single key") + + val encodedStartKey = startKey.map(kvEncoder._1.encodeKey) + val encodedEndKey = endKey.map(kvEncoder._1.encodeKey) + val rocksDbIter = rocksDB.scan(encodedStartKey, encodedEndKey, colFamilyName) + + val rowPair = new UnsafeRowPair() + val iter = rocksDbIter.flatMap { kv => + val keyRow = kvEncoder._1.decodeKey(kv.key) + val valueRows = kvEncoder._2.decodeValues(kv.value) + valueRows.iterator.map { valueRow => + rowPair.withRows(keyRow, valueRow) + if (!isValidated && rowPair.value != null && !useColumnFamilies) { + StateStoreProvider.validateStateRowFormat( + rowPair.key, keySchema, rowPair.value, valueSchema, stateStoreId, storeConf) + isValidated = true + } + rowPair + } + } + + new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) + } + var checkpointInfo: Option[StateStoreCheckpointInfo] = None private var storedMetrics: Option[RocksDBMetrics] = None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 6e08c10476ce7..e3601f1ef2246 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -183,6 +183,51 @@ trait ReadStateStore { prefixKey: UnsafeRow, colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair] + /** + * Scan key-value pairs in the range [startKey, endKey). + * + * @param startKey None to scan from the beginning of the column family, + * or Some(key) to seek to the given start position (inclusive). + * @param endKey None to scan to the end of the column family, + * or Some(key) as the exclusive upper bound for the scan. + * @param colFamilyName The column family name. + * + * Callers must ensure the column family's key encoder produces lexicographically ordered + * bytes for the scan range to be meaningful (e.g., timestamp-based encoders or + * RangeKeyScanStateEncoder). + */ + def rangeScan( + startKey: Option[UnsafeRow], + endKey: Option[UnsafeRow], + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) + : StateStoreIterator[UnsafeRowPair] = { + throw StateStoreErrors.unsupportedOperationException("rangeScan", "") + } + + /** + * Scan key-value pairs in the range [startKey, endKey), expanding multi-valued entries. + * + * @param startKey None to scan from the beginning of the column family, + * or Some(key) to seek to the given start position (inclusive). + * @param endKey None to scan to the end of the column family, + * or Some(key) as the exclusive upper bound for the scan. + * @param colFamilyName The column family name. + * + * Callers must ensure the column family's key encoder produces lexicographically ordered + * bytes for the scan range to be meaningful (e.g., timestamp-based encoders or + * RangeKeyScanStateEncoder). + * + * It is expected to throw exception if Spark calls this method without setting + * multipleValuesPerKey as true for the column family. + */ + def rangeScanWithMultiValues( + startKey: Option[UnsafeRow], + endKey: Option[UnsafeRow], + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) + : StateStoreIterator[UnsafeRowPair] = { + throw StateStoreErrors.unsupportedOperationException("rangeScanWithMultiValues", "") + } + /** Return an iterator containing all the key-value pairs in the StateStore. */ def iterator( colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair] @@ -411,6 +456,20 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore { store.prefixScanWithMultiValues(prefixKey, colFamilyName) } + override def rangeScan( + startKey: Option[UnsafeRow], + endKey: Option[UnsafeRow], + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + store.rangeScan(startKey, endKey, colFamilyName) + } + + override def rangeScanWithMultiValues( + startKey: Option[UnsafeRow], + endKey: Option[UnsafeRow], + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + store.rangeScanWithMultiValues(startKey, endKey, colFamilyName) + } + override def iteratorWithMultiValues( colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { store.iteratorWithMultiValues(colFamilyName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala index 14aa43d3234f7..fe09506023ccd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala @@ -172,6 +172,20 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta innerStore.prefixScanWithMultiValues(prefixKey, colFamilyName) } + override def rangeScan( + startKey: Option[UnsafeRow], + endKey: Option[UnsafeRow], + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + innerStore.rangeScan(startKey, endKey, colFamilyName) + } + + override def rangeScanWithMultiValues( + startKey: Option[UnsafeRow], + endKey: Option[UnsafeRow], + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + innerStore.rangeScanWithMultiValues(startKey, endKey, colFamilyName) + } + override def iteratorWithMultiValues( colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { innerStore.iteratorWithMultiValues(colFamilyName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 3e4b4b7320f53..0c300192dd898 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -1629,6 +1629,217 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } + private val diverseTimestamps = Seq(931L, 8000L, 452300L, 4200L, -1L, 90L, 1L, 2L, 8L, + -230L, -14569L, -92L, -7434253L, 35L, 6L, 9L, -323L, 5L, + -32L, -64L, -256L, 64L, 32L, 1024L, 4096L, 0L) + + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - rangeScan", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + try { + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) + } + + diverseTimestamps.foreach { ts => + store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) + } + + // Bounded positive range [0, 100) + val boundedIter = store.rangeScan( + Some(dataToKeyRowWithRangeScan(0L, "a")), + Some(dataToKeyRowWithRangeScan(100L, "a")), cfName) + val boundedResults = boundedIter.map { pair => + (pair.key.getLong(0), pair.value.getInt(0)) + }.toList + boundedIter.close() + val expectedBoundedTs = diverseTimestamps.filter(ts => ts >= 0 && ts < 100).sorted + assert(boundedResults.map(_._1) === expectedBoundedTs) + assert(boundedResults.map(_._2) === expectedBoundedTs.map(_.toInt)) + + // Exact bound: startKey is inclusive, endKey is exclusive. + // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. + // Scan [9, 90) should include 9 but exclude 90. + val exactIter = store.rangeScan( + Some(dataToKeyRowWithRangeScan(9L, "a")), + Some(dataToKeyRowWithRangeScan(90L, "a")), cfName) + val exactResults = exactIter.map(_.key.getLong(0)).toList + exactIter.close() + assert(exactResults === diverseTimestamps.filter(ts => ts >= 9 && ts < 90).sorted) + assert(exactResults.contains(9L)) + assert(!exactResults.contains(90L)) + + // None startKey scans from beginning to 0 + val noneStartIter = store.rangeScan( + None, Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) + val noneStartResults = noneStartIter.map(_.key.getLong(0)).toList + noneStartIter.close() + assert(noneStartResults === diverseTimestamps.filter(_ < 0).sorted) + + // None endKey scans from 1000 to end + val noneEndIter = store.rangeScan( + Some(dataToKeyRowWithRangeScan(1000L, "a")), None, cfName) + val noneEndResults = noneEndIter.map(_.key.getLong(0)).toList + noneEndIter.close() + assert(noneEndResults === diverseTimestamps.filter(_ >= 1000).sorted) + + // Empty range [10, 31) - no entries between 9 and 32 + val emptyIter = store.rangeScan( + Some(dataToKeyRowWithRangeScan(10L, "a")), + Some(dataToKeyRowWithRangeScan(31L, "a")), cfName) + assert(!emptyIter.hasNext) + emptyIter.close() + + // Bounded negative range [-300, 0) + val negIter = store.rangeScan( + Some(dataToKeyRowWithRangeScan(-300L, "a")), + Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) + val negResults = negIter.map(_.key.getLong(0)).toList + negIter.close() + assert(negResults === diverseTimestamps.filter(ts => ts >= -300 && ts < 0).sorted) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - scan with multiple key2 values within same key1 range", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + try { + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) + } + + Seq("a", "b", "c").foreach { key2 => + Seq(100L, 200L, 300L).foreach { ts => + store.put(dataToKeyRowWithRangeScan(ts, key2), dataToValueRow(ts.toInt), cfName) + } + } + + val startKey = dataToKeyRowWithRangeScan(100L, "a") + val endKey = dataToKeyRowWithRangeScan(201L, "a") + val iter = store.rangeScan(Some(startKey), Some(endKey), cfName) + val results = iter.map { pair => + (pair.key.getLong(0), pair.key.getUTF8String(1).toString) + }.toList + iter.close() + + val expectedResults = Seq( + (100L, "a"), (100L, "b"), (100L, "c"), + (200L, "a"), (200L, "b"), (200L, "c")) + assert(results === expectedResults) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - rangeScanWithMultiValues", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + if (colFamiliesEnabled) { + tryWithProviderResource(newStoreProvider( + StateStoreId(newDir(), Random.nextInt(), 0), + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + keySchema = keySchemaWithRangeScan, + useColumnFamilies = colFamiliesEnabled, + useMultipleValuesPerKey = true)) { provider => + val store = provider.getStore(0) + try { + val cfName = "testColFamily" + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + useMultipleValuesPerKey = true) + + diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => + store.putList(dataToKeyRowWithRangeScan(ts, "a"), + Array(dataToValueRow(idx * 10), dataToValueRow(idx * 10 + 1)), cfName) + } + + // Bounded range [0, 1001) + val boundedIter = store.rangeScanWithMultiValues( + Some(dataToKeyRowWithRangeScan(0L, "a")), + Some(dataToKeyRowWithRangeScan(1001L, "a")), cfName) + val boundedResults = boundedIter.map { pair => + (pair.key.getLong(0), pair.value.getInt(0)) + }.toList + boundedIter.close() + + val expectedTimestamps = diverseTimestamps.filter(ts => ts >= 0 && ts <= 1000).sorted + assert(boundedResults.map(_._1).distinct === expectedTimestamps) + val expectedValues = diverseTimestamps.zipWithIndex + .filter { case (ts, _) => ts >= 0 && ts <= 1000 } + .sortBy(_._1) + .flatMap { case (_, idx) => Seq(idx * 10, idx * 10 + 1) } + assert(boundedResults.map(_._2) === expectedValues) + + // Exact bound: startKey is inclusive, endKey is exclusive. + // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. + val exactIter = store.rangeScanWithMultiValues( + Some(dataToKeyRowWithRangeScan(9L, "a")), + Some(dataToKeyRowWithRangeScan(90L, "a")), cfName) + val exactResults = exactIter.map(_.key.getLong(0)).toList + exactIter.close() + val exactResultsDistinct = exactResults.distinct + assert(exactResultsDistinct === diverseTimestamps + .filter(ts => ts >= 9 && ts < 90).sorted) + assert(exactResultsDistinct.contains(9L)) + assert(!exactResultsDistinct.contains(90L)) + + // None startKey scans from beginning to 0 + val noneStartIter = store.rangeScanWithMultiValues( + None, Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) + val noneStartResults = noneStartIter.map(_.key.getLong(0)).toList + noneStartIter.close() + assert(noneStartResults.distinct === diverseTimestamps.filter(_ < 0).sorted) + + // None endKey scans from 1000 to end + val noneEndIter = store.rangeScanWithMultiValues( + Some(dataToKeyRowWithRangeScan(1000L, "a")), None, cfName) + val noneEndResults = noneEndIter.map(_.key.getLong(0)).toList + noneEndIter.close() + assert(noneEndResults.distinct === diverseTimestamps.filter(_ >= 1000).sorted) + + // Empty range [10, 31) - no entries between 9 and 32 + val emptyIter = store.rangeScanWithMultiValues( + Some(dataToKeyRowWithRangeScan(10L, "a")), + Some(dataToKeyRowWithRangeScan(31L, "a")), cfName) + assert(!emptyIter.hasNext) + emptyIter.close() + + // Bounded negative range [-300, 0) + val negIter = store.rangeScanWithMultiValues( + Some(dataToKeyRowWithRangeScan(-300L, "a")), + Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) + val negResults = negIter.map(_.key.getLong(0)).toList + negIter.close() + assert(negResults.distinct === diverseTimestamps + .filter(ts => ts >= -300 && ts < 0).sorted) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + } + testWithColumnFamiliesAndEncodingTypes( "rocksdb key and value schema encoders for column families", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala index a9540a4ad623e..5fcdfb12ba354 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala @@ -566,6 +566,141 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession } } + Seq("unsaferow", "avro").foreach { encoding => + test(s"rangeScan with postfix encoder (encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "postfix", + useMultipleValuesPerKey = true, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => + val keyRow = keyAndTimestampToRow("key1", 1, ts) + store.putList(keyRow, Array(valueToRow(idx * 10), valueToRow(idx * 10 + 1))) + } + + // key2 entry to verify prefix isolation + store.putList(keyAndTimestampToRow("key2", 1, 500L), + Array(valueToRow(999))) + + // Bounded range [0, 1001) + val boundedIter = store.rangeScanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, 0L)), + Some(keyAndTimestampToRow("key1", 1, 1001L))) + val boundedResults = boundedIter.map { pair => + (pair.key.getString(0), pair.key.getLong(2), pair.value.getInt(0)) + }.toList + boundedIter.close() + + val expectedTimestamps = diverseTimestamps.filter(ts => ts >= 0 && ts <= 1000).sorted + assert(boundedResults.map(_._2).distinct === expectedTimestamps) + val expectedValues = diverseTimestamps.zipWithIndex + .filter { case (ts, _) => ts >= 0 && ts <= 1000 } + .sortBy(_._1) + .flatMap { case (_, idx) => Seq(idx * 10, idx * 10 + 1) } + assert(boundedResults.map(_._3) === expectedValues) + assert(boundedResults.forall(_._1 == "key1")) + + // Exact bound: startKey is inclusive, endKey is exclusive. + // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. + val exactIter = store.rangeScanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, 9L)), + Some(keyAndTimestampToRow("key1", 1, 90L))) + val exactResults = exactIter.map(_.key.getLong(2)).toList + exactIter.close() + val exactResultsDistinct = exactResults.distinct + assert(exactResultsDistinct === diverseTimestamps + .filter(ts => ts >= 9 && ts < 90).sorted) + assert(exactResultsDistinct.contains(9L)) + assert(!exactResultsDistinct.contains(90L)) + + // Postfix timestamp encoder places the timestamp after the key prefix. + // With different key prefixes, None in startKey or endKey would scan across + // key boundaries, which is not meaningful for postfix encoding. Hence we only + // test bounded ranges with explicit keys here. + + // Full range [MinValue, MaxValue) + val fullIter = store.rangeScanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, Long.MinValue)), + Some(keyAndTimestampToRow("key1", 1, Long.MaxValue))) + val fullResults = fullIter.map(_.key.getLong(2)).toList + fullIter.close() + + assert(fullResults.distinct === diverseTimestamps.sorted) + + // Bounded negative range [-300, 0) + val negIter = store.rangeScanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, -300L)), + Some(keyAndTimestampToRow("key1", 1, 0L))) + val negResults = negIter.map(_.key.getLong(2)).toList + negIter.close() + assert(negResults.distinct === diverseTimestamps + .filter(ts => ts >= -300 && ts < 0).sorted) + + // Empty range [10, 31) - no diverseTimestamps entries between 9 and 32 + val emptyIter = store.rangeScanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, 10L)), + Some(keyAndTimestampToRow("key1", 1, 31L))) + assert(!emptyIter.hasNext) + emptyIter.close() + } finally { + store.abort() + } + } + } + + // Sanity test for prefix encoder scan. Full scan coverage is in RocksDBStateStoreSuite's + // "rocksdb range scan - rangeScan" and "rocksdb range scan - rangeScanWithMultiValues" tests. + // This test verifies the timestamp prefix encoder integration works correctly. + test(s"rangeScan with prefix encoder (encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "prefix", + useMultipleValuesPerKey = true, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + val timestamps = Seq(100L, 200L, 300L, 400L, 500L) + timestamps.foreach { ts => + store.merge(keyAndTimestampToRow("key1", 1, ts), valueToRow(ts.toInt)) + } + store.merge(keyAndTimestampToRow("key2", 2, 150L), valueToRow(150)) + + // None startKey scans from beginning up to 301 (exclusive) + val iter1 = store.rangeScanWithMultiValues(None, + Some(keyAndTimestampToRow("key1", 1, 301L))) + val results1 = iter1.map { pair => + (pair.key.getString(0), pair.key.getLong(2)) + }.toList + iter1.close() + + assert(results1 === Seq( + ("key1", 100L), ("key2", 150L), ("key1", 200L), ("key1", 300L))) + + // Boundary safety: endKey at 201, includes everything up to 200 + // regardless of join key + val iter2 = store.rangeScanWithMultiValues(None, + Some(keyAndTimestampToRow("key1", 1, 201L))) + val results2 = iter2.map { pair => + (pair.key.getString(0), pair.key.getLong(2)) + }.toList + iter2.close() + + assert(results2 === Seq( + ("key1", 100L), ("key2", 150L), ("key1", 200L))) + } finally { + store.abort() + } + } + } + + } + // Helper methods to create test data private val keyProjection = UnsafeProjection.create(keySchema) private val keyAndTimestampProjection = UnsafeProjection.create( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index 1042f01463b05..ae7dce78151a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -1105,6 +1105,67 @@ class SymmetricHashJoinStateManagerEventTimeInValueSuite } } + test("StreamingJoinStateManager V4 - getValuesInRange boundary edge cases") { + withJoinStateManager( + inputValueAttributes, joinKeyExpressions, stateFormatVersion = 4) { manager => + implicit val mgr = manager + + Seq(10, 20, 30, 40, 50).foreach(append(40, _)) + + // Exact boundary matches (both inclusive) + assert(getJoinedRowTimestamps(40, Some((10L, 10L))) === Seq(10)) + assert(getJoinedRowTimestamps(40, Some((50L, 50L))) === Seq(50)) + + // Range with Long.MinValue / Long.MaxValue + assert(getJoinedRowTimestamps(40, Some((Long.MinValue, 30L))) === Seq(10, 20, 30)) + assert(getJoinedRowTimestamps(40, Some((30L, Long.MaxValue))) === Seq(30, 40, 50)) + assert(getJoinedRowTimestamps(40, Some((Long.MinValue, Long.MaxValue))) === + Seq(10, 20, 30, 40, 50)) + + // Empty range (minTs > maxTs) + assert(getJoinedRowTimestamps(40, Some((50L, 10L))) === Seq.empty) + + // Range entirely outside stored timestamps + assert(getJoinedRowTimestamps(40, Some((100L, 200L))) === Seq.empty) + assert(getJoinedRowTimestamps(40, Some((1L, 5L))) === Seq.empty) + + // Full range via None (all entries) + assert(getJoinedRowTimestamps(40, None) === Seq(10, 20, 30, 40, 50)) + } + } + + test("StreamingJoinStateManager V4 - evictByTimestamp boundary edge cases") { + withJoinStateManager( + inputValueAttributes, joinKeyExpressions, stateFormatVersion = 4) { manager => + implicit val mgr = manager + val evictByTs = manager.asInstanceOf[SupportsEvictByTimestamp] + + // --- Range eviction with startTimestamp (exclusive) and endTimestamp (inclusive) --- + Seq(10, 20, 30, 40, 50).foreach(append(40, _)) + // startTimestamp=20 is exclusive, endTimestamp=40 is inclusive: evicts timestamps 30, 40 + assert(evictByTs.evictByTimestamp(40, Some(20)) === 2) + assert(get(40) === Seq(10, 20, 50)) + + // --- evictAndReturnByTimestamp returns evicted values --- + Seq(30, 40).foreach(append(40, _)) // restore evicted entries + val evictedValues = evictByTs.evictAndReturnByTimestamp(30, Some(10)) + .map(p => toValueInt(p.value)).toSeq.sorted + // startTimestamp=10 is exclusive, endTimestamp=30 is inclusive: timestamps 20 and 30 + assert(evictedValues === Seq(20, 30)) + assert(get(40) === Seq(10, 40, 50)) + + // --- start equals end: empty range (exclusive start = inclusive end) --- + // startTimestamp=40 (exclusive) and endTimestamp=40 (inclusive): range is empty + assert(evictByTs.evictByTimestamp(40, Some(40)) === 0) + assert(get(40) === Seq(10, 40, 50)) + + // --- start just below entry: evicts exactly that entry --- + // startTimestamp=39 (exclusive) means entries >= 40 are scanned; endTimestamp=40 inclusive + assert(evictByTs.evictByTimestamp(40, Some(39)) === 1) + assert(get(40) === Seq(10, 50)) + } + } + // V1 excluded: V1 converter does not persist matched flags (SPARK-26154) versionsInTest.filter(_ >= 2).foreach { ver => test(s"StreamingJoinStateManager V$ver - skipUpdatingMatchedFlag skips matched flag update") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala index ef4615c1254f3..66406cc2afa7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala @@ -22,7 +22,8 @@ import org.scalatest.Tag import org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinExec -import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.{JoinStateKeyWatermarkPredicate, JoinStateValueWatermarkPredicate} +import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamExecution} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -184,6 +185,72 @@ class StreamingInnerJoinV4Suite ) } } + + test("prevStateWatermark must be None in the first batch") { + // Regression test for the IncrementalExecution guard: in the first batch + // prevOffsetSeqMetadata is None, so eventTimeWatermarkForLateEvents must NOT + // be passed to getStateWatermarkPredicates. Without the guard the watermark + // propagation framework yields Some(0) even in batch 0, which would cause + // scanEvictedKeys to skip state entries at timestamp 0. + val input1 = MemoryStream[(Int, Int)] + val input2 = MemoryStream[(Int, Int)] + + val df1 = input1.toDF().toDF("key", "time") + .select($"key", timestamp_seconds($"time") as "leftTime", + ($"key" * 2) as "leftValue") + .withWatermark("leftTime", "10 seconds") + val df2 = input2.toDF().toDF("key", "time") + .select($"key", timestamp_seconds($"time") as "rightTime", + ($"key" * 3) as "rightValue") + .withWatermark("rightTime", "10 seconds") + + val joined = df1.join(df2, + df1("key") === df2("key") && + expr("leftTime >= rightTime - interval 5 seconds " + + "AND leftTime <= rightTime + interval 5 seconds"), + "inner") + .select(df1("key"), $"leftTime".cast("long"), $"leftValue", $"rightValue") + + def extractPrevWatermarks(q: StreamExecution): (Option[Long], Option[Long]) = { + val joinExec = q.lastExecution.executedPlan.collect { + case j: StreamingSymmetricHashJoinExec => j + }.head + val leftPrev = joinExec.stateWatermarkPredicates.left.flatMap { + case p: JoinStateKeyWatermarkPredicate => p.prevStateWatermark + case p: JoinStateValueWatermarkPredicate => p.prevStateWatermark + } + val rightPrev = joinExec.stateWatermarkPredicates.right.flatMap { + case p: JoinStateKeyWatermarkPredicate => p.prevStateWatermark + case p: JoinStateValueWatermarkPredicate => p.prevStateWatermark + } + (leftPrev, rightPrev) + } + + testStream(joined)( + // First batch: prevStateWatermark must be None on both sides. + MultiAddData(input1, (1, 5))(input2, (1, 5)), + CheckNewAnswer((1, 5, 2, 3)), + Execute { q => + val (leftPrev, rightPrev) = extractPrevWatermarks(q) + assert(leftPrev.isEmpty, + s"Left prevStateWatermark should be None in the first batch, got $leftPrev") + assert(rightPrev.isEmpty, + s"Right prevStateWatermark should be None in the first batch, got $rightPrev") + }, + + // Second batch: after watermark advances, prevStateWatermark should be set. + MultiAddData(input1, (2, 30))(input2, (2, 30)), + CheckNewAnswer((2, 30, 4, 6)), + Execute { q => + val (leftPrev, rightPrev) = extractPrevWatermarks(q) + assert(leftPrev.isDefined, + "Left prevStateWatermark should be defined after the first batch") + assert(rightPrev.isDefined, + "Right prevStateWatermark should be defined after the first batch") + }, + StopStream + ) + } } @SlowSQLTest