diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f11dcbd1e7c1e..6c2ce58b884d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -787,6 +787,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { outputAttr, stateInfo = None, batchTimestampMs = None, + prevBatchTimestampMs = None, eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, planLater(child), @@ -815,6 +816,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { func, t.leftAttributes, outputAttrs, outputMode, timeMode, stateInfo = None, batchTimestampMs = None, + prevBatchTimestampMs = None, eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, userFacingDataType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala index 1ceaf6c4bf81f..45f2af5c1dfe8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala @@ -74,6 +74,7 @@ case class TransformWithStateInPySparkExec( timeMode: TimeMode, stateInfo: Option[StatefulOperatorStateInfo], batchTimestampMs: Option[Long], + prevBatchTimestampMs: Option[Long] = None, eventTimeWatermarkForLateEvents: Option[Long], eventTimeWatermarkForEviction: Option[Long], userFacingDataType: TransformWithStateInPySpark.UserFacingDataType.Value, @@ -314,7 +315,8 @@ case class TransformWithStateInPySparkExec( val data = groupAndProject(filteredIter, groupingAttributes, child.output, dedupAttributes) val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId, - groupingKeyExprEncoder, timeMode, isStreaming, batchTimestampMs, metrics) + groupingKeyExprEncoder, timeMode, isStreaming, batchTimestampMs, + prevBatchTimestampMs, metrics) val evalType = { if (userFacingDataType == TransformWithStateInPySpark.UserFacingDataType.PANDAS) { @@ -442,6 +444,7 @@ object TransformWithStateInPySparkExec { Some(System.currentTimeMillis), None, None, + None, userFacingDataType, child, isStreaming = false, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala index cc1c3263ad743..b11f6d93a642b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala @@ -67,6 +67,7 @@ case class TransformWithStateExec( outputObjAttr: Attribute, stateInfo: Option[StatefulOperatorStateInfo], batchTimestampMs: Option[Long], + prevBatchTimestampMs: Option[Long] = None, eventTimeWatermarkForLateEvents: Option[Long], eventTimeWatermarkForEviction: Option[Long], child: SparkPlan, @@ -251,7 +252,7 @@ case class TransformWithStateExec( case ProcessingTime => assert(batchTimestampMs.isDefined) val batchTimestamp = batchTimestampMs.get - processorHandle.getExpiredTimers(batchTimestamp) + processorHandle.getExpiredTimers(batchTimestamp, prevBatchTimestampMs) .flatMap { case (keyObj, expiryTimestampMs) => numExpiredTimers += 1 handleTimerRows(keyObj, expiryTimestampMs, processorHandle) @@ -260,7 +261,13 @@ case class TransformWithStateExec( case EventTime => assert(eventTimeWatermarkForEviction.isDefined) val watermark = eventTimeWatermarkForEviction.get - processorHandle.getExpiredTimers(watermark) + // Only use the late-events watermark as the scan lower bound when a previous batch + // actually existed (prevBatchTimestampMs is set). In the very first batch the + // watermark propagation yields Some(0) for late events even though no timers have + // been processed yet, which would incorrectly skip timers registered at timestamp 0. + val prevWatermark = + if (prevBatchTimestampMs.isDefined) eventTimeWatermarkForLateEvents else None + processorHandle.getExpiredTimers(watermark, prevWatermark) .flatMap { case (keyObj, expiryTimestampMs) => numExpiredTimers += 1 handleTimerRows(keyObj, expiryTimestampMs, processorHandle) @@ -493,7 +500,7 @@ case class TransformWithStateExec( CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl( store, getStateInfo.queryRunId, keyEncoder, timeMode, - isStreaming, batchTimestampMs, metrics) + isStreaming, batchTimestampMs, prevBatchTimestampMs, metrics) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) withStatefulProcessorErrorHandling("init") { @@ -509,7 +516,7 @@ case class TransformWithStateExec( initStateIterator: Iterator[InternalRow]): CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId, - keyEncoder, timeMode, isStreaming, batchTimestampMs, metrics) + keyEncoder, timeMode, isStreaming, batchTimestampMs, prevBatchTimestampMs, metrics) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) withStatefulProcessorErrorHandling("init") { @@ -581,6 +588,7 @@ object TransformWithStateExec { Some(System.currentTimeMillis), None, None, + None, child, isStreaming = false, hasInitialState, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statefulprocessor/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statefulprocessor/StatefulProcessorHandleImpl.scala index dfba0e1f12146..291cc02ea989b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statefulprocessor/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statefulprocessor/StatefulProcessorHandleImpl.scala @@ -114,6 +114,7 @@ class StatefulProcessorHandleImpl( timeMode: TimeMode, isStreaming: Boolean = true, batchTimestampMs: Option[Long] = None, + prevBatchTimestampMs: Option[Long] = None, metrics: Map[String, SQLMetric] = Map.empty) extends StatefulProcessorHandleImplBase(timeMode, keyEncoder) with Logging { import StatefulProcessorHandleState._ @@ -171,13 +172,19 @@ class StatefulProcessorHandleImpl( /** * Function to retrieve all expired registered timers for all grouping keys - * @param expiryTimestampMs Threshold for expired timestamp in milliseconds, this function - * will return all timers that have timestamp less than passed threshold + * @param expiryTimestampMs Threshold for expired timestamp in milliseconds (inclusive), + * this function will return all timers that have timestamp + * less than or equal to the passed threshold. + * @param prevExpiryTimestampMs If provided, the lower bound (exclusive) of the scan range. + * Timers at or below this timestamp are assumed to have been + * already processed in the previous batch and will be skipped. * @return - iterator of registered timers for all grouping keys */ - def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = { + def getExpiredTimers( + expiryTimestampMs: Long, + prevExpiryTimestampMs: Option[Long] = None): Iterator[(Any, Long)] = { verifyTimerOperations("get_expired_timers") - timerState.getExpiredTimers(expiryTimestampMs) + timerState.getExpiredTimers(expiryTimestampMs, prevExpiryTimestampMs) } /** @@ -237,7 +244,8 @@ class StatefulProcessorHandleImpl( validateTTLConfig(ttlConfig, stateName) assert(batchTimestampMs.isDefined) val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName, - keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics) + keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, + prevBatchTimestampMs, metrics) ttlStates.add(valueStateWithTTL) TWSMetricsUtils.incrementMetric(metrics, "numValueStateWithTTLVars") valueStateWithTTL @@ -286,7 +294,8 @@ class StatefulProcessorHandleImpl( validateTTLConfig(ttlConfig, stateName) assert(batchTimestampMs.isDefined) val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName, - keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics) + keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, + prevBatchTimestampMs, metrics) TWSMetricsUtils.incrementMetric(metrics, "numListStateWithTTLVars") ttlStates.add(listStateWithTTL) listStateWithTTL @@ -324,7 +333,8 @@ class StatefulProcessorHandleImpl( validateTTLConfig(ttlConfig, stateName) assert(batchTimestampMs.isDefined) val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc, - valEncoder, ttlConfig, batchTimestampMs.get, metrics) + valEncoder, ttlConfig, batchTimestampMs.get, + prevBatchTimestampMs, metrics) TWSMetricsUtils.incrementMetric(metrics, "numMapStateWithTTLVars") ttlStates.add(mapStateWithTTL) mapStateWithTTL diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala index 101265fd8d83b..f4a1a06974aa8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala @@ -112,6 +112,27 @@ class TimerStateImpl( schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, Seq(0)), useMultipleValuesPerKey = false, isInternal = true) + private val secIndexProjection = UnsafeProjection.create(keySchemaForSecIndex) + + /** + * Encodes a timestamp into an UnsafeRow key for the secondary index. + * The timestamp is incremented by 1 so that the encoded key serves as an exclusive + * lower / upper bound in range scans. Returns None if tsMs is Long.MaxValue + * (overflow guard). + * + * The returned UnsafeRow is always a fresh copy, safe to hold alongside other + * rows produced by the same projection. + */ + private def encodeTimestampAsKey(tsMs: Long): Option[UnsafeRow] = { + if (tsMs < Long.MaxValue) { + val row = new GenericInternalRow(keySchemaForSecIndex.length) + row.setLong(0, tsMs + 1) + Some(secIndexProjection.apply(row).copy()) + } else { + None + } + } + private def getGroupingKey(cfName: String): Any = { val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption if (keyOption.isEmpty) { @@ -189,15 +210,22 @@ class TimerStateImpl( /** * Function to get all the expired registered timers for all grouping keys. - * Perform a range scan on timestamp and will stop iterating once the key row timestamp equals or + * Perform a range scan on timestamp and will stop iterating once the key row timestamp * exceeds the limit (as timestamp key is increasingly sorted). - * @param expiryTimestampMs Threshold for expired timestamp in milliseconds, this function - * will return all timers that have timestamp less than passed threshold. + * @param expiryTimestampMs Threshold for expired timestamp in milliseconds (inclusive), + * this function will return all timers that have timestamp + * less than or equal to the passed threshold. + * @param prevExpiryTimestampMs If provided, the lower bound (exclusive) of the scan range. + * Timers at or below this timestamp are assumed to have been + * already processed in the previous batch and will be skipped. * @return - iterator of all the registered timers for all grouping keys */ - def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = { - // this iter is increasingly sorted on timestamp - val iter = store.iterator(tsToKeyCFName) + def getExpiredTimers( + expiryTimestampMs: Long, + prevExpiryTimestampMs: Option[Long] = None): Iterator[(Any, Long)] = { + val startKey = prevExpiryTimestampMs.flatMap(encodeTimestampAsKey) + val endKey = encodeTimestampAsKey(expiryTimestampMs) + val iter = store.rangeScan(startKey, endKey, tsToKeyCFName) new NextIterator[(Any, Long)] { override protected def getNext(): (Any, Long) = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/ListStateImplWithTTL.scala index 08f97e38bd086..10ec3a58500af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/ListStateImplWithTTL.scala @@ -35,6 +35,10 @@ import org.apache.spark.util.NextIterator * @param valEncoder - Spark SQL encoder for value * @param ttlConfig - TTL configuration for values stored in this state * @param batchTimestampMs - current batch processing timestamp. + * @param prevBatchTimestampMs - batch timestamp from the previous micro-batch (exclusive). + * Entries with expiration at or below this timestamp are assumed + * to have been already cleaned up and will be skipped during + * TTL eviction scans. * @param metrics - metrics to be updated as part of stateful processing * @tparam S - data type of object that will be stored */ @@ -45,9 +49,11 @@ class ListStateImplWithTTL[S]( valEncoder: ExpressionEncoder[Any], ttlConfig: TTLConfig, batchTimestampMs: Long, + prevBatchTimestampMs: Option[Long] = None, metrics: Map[String, SQLMetric]) extends OneToManyTTLState( - stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs, metrics) with ListState[S] { + stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs, + prevBatchTimestampMs, metrics) with ListState[S] { private lazy val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder, stateName, hasTtl = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala index f063354bc8c8c..03aa8aaa6ace2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala @@ -36,6 +36,10 @@ import org.apache.spark.util.NextIterator * @param valEncoder - SQL encoder for state variable * @param ttlConfig - the ttl configuration (time to live duration etc.) * @param batchTimestampMs - current batch processing timestamp. + * @param prevBatchTimestampMs - batch timestamp from the previous micro-batch (exclusive). + * Entries with expiration at or below this timestamp are assumed + * to have been already cleaned up and will be skipped during + * TTL eviction scans. * @param metrics - metrics to be updated as part of stateful processing * @tparam K - type of key for map state variable * @tparam V - type of value for map state variable @@ -49,10 +53,11 @@ class MapStateImplWithTTL[K, V]( valEncoder: ExpressionEncoder[Any], ttlConfig: TTLConfig, batchTimestampMs: Long, -metrics: Map[String, SQLMetric]) + prevBatchTimestampMs: Option[Long] = None, + metrics: Map[String, SQLMetric]) extends OneToOneTTLState( stateName, store, getCompositeKeySchema(keyExprEnc.schema, userKeyEnc.schema), ttlConfig, - batchTimestampMs, metrics) with MapState[K, V] with Logging { + batchTimestampMs, prevBatchTimestampMs, metrics) with MapState[K, V] with Logging { private val stateTypesEncoder = new CompositeKeyStateEncoder( keyExprEnc, userKeyEnc, valEncoder, stateName, hasTtl = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/TTLState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/TTLState.scala index 548a47ea75e13..6219313b7e027 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/TTLState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/TTLState.scala @@ -88,6 +88,11 @@ trait TTLState { // an expiration at or before this timestamp must be cleaned up. private[sql] def batchTimestampMs: Long + // The batch timestamp from the previous micro-batch, used to derive the startKey + // for scan-based TTL eviction. Entries at or below prevBatchTimestampMs were already + // cleaned up in the previous batch. + private[sql] def prevBatchTimestampMs: Option[Long] + // The configuration for this run of the streaming query. It may change between runs // (e.g. user sets ttlConfig1, stops their query, updates to ttlConfig2, and then // resumes their query). @@ -105,6 +110,8 @@ trait TTLState { private final val TTL_ENCODER = new TTLEncoder(elementKeySchema) + private final val ELEMENT_KEY_PROJECTION = UnsafeProjection.create(elementKeySchema) + // Empty row used for values private final val TTL_EMPTY_VALUE_ROW = UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) @@ -155,16 +162,34 @@ trait TTLState { store.iterator(TTL_INDEX).map(kv => toTTLRow(kv.key)) } - // Returns an Iterator over all the keys in the TTL index that have expired. This method - // does not delete the keys from the TTL index; it is the responsibility of the caller - // to do so. + // Returns an Iterator over the keys in the TTL index that have expired. Uses a bounded + // range scan over [prevBatchTimestampMs+1, batchTimestampMs+1) to skip entries that + // were already evicted in previous batches. + // + // This method does not delete the keys from the TTL index; it is the responsibility of + // the caller to do so. // // The schema of the UnsafeRow returned by this iterator is (expirationMs, elementKey). private[sql] def ttlEvictionIterator(): Iterator[UnsafeRow] = { - val ttlIterator = store.iterator(TTL_INDEX) + val dummyElementKey = ELEMENT_KEY_PROJECTION + .apply(new GenericInternalRow(elementKeySchema.length)) + val startKey = prevBatchTimestampMs.flatMap { prevTs => + if (prevTs < Long.MaxValue) { + Some(TTL_ENCODER.encodeTTLRow(prevTs + 1, dummyElementKey).copy()) + } else { + None + } + } + val endKey = if (batchTimestampMs < Long.MaxValue) { + Some(TTL_ENCODER.encodeTTLRow(batchTimestampMs + 1, dummyElementKey).copy()) + } else { + None + } + val ttlIterator = store.rangeScan(startKey, endKey, TTL_INDEX) // Recall that the format is (expirationMs, elementKey) -> TTL_EMPTY_VALUE_ROW, so // kv.value doesn't ever need to be used. + // Safety filter: keep only truly expired entries ttlIterator.takeWhile { kv => val expirationMs = kv.key.getLong(0) StateTTL.isExpired(expirationMs, batchTimestampMs) @@ -223,12 +248,14 @@ abstract class OneToOneTTLState( elementKeySchemaArg: StructType, ttlConfigArg: TTLConfig, batchTimestampMsArg: Long, + prevBatchTimestampMsArg: Option[Long], metricsArg: Map[String, SQLMetric]) extends TTLState { override private[sql] def stateName: String = stateNameArg override private[sql] def store: StateStore = storeArg override private[sql] def elementKeySchema: StructType = elementKeySchemaArg override private[sql] def ttlConfig: TTLConfig = ttlConfigArg override private[sql] def batchTimestampMs: Long = batchTimestampMsArg + override private[sql] def prevBatchTimestampMs: Option[Long] = prevBatchTimestampMsArg override private[sql] def metrics: Map[String, SQLMetric] = metricsArg /** @@ -340,12 +367,14 @@ abstract class OneToManyTTLState( elementKeySchemaArg: StructType, ttlConfigArg: TTLConfig, batchTimestampMsArg: Long, + prevBatchTimestampMsArg: Option[Long], metricsArg: Map[String, SQLMetric]) extends TTLState { override private[sql] def stateName: String = stateNameArg override private[sql] def store: StateStore = storeArg override private[sql] def elementKeySchema: StructType = elementKeySchemaArg override private[sql] def ttlConfig: TTLConfig = ttlConfigArg override private[sql] def batchTimestampMs: Long = batchTimestampMsArg + override private[sql] def prevBatchTimestampMs: Option[Long] = prevBatchTimestampMsArg override private[sql] def metrics: Map[String, SQLMetric] = metricsArg // Schema of the min-expiry index: elementKey -> minExpirationMs diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/ValueStateImplWithTTL.scala index 587da75993610..1559acf7222cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/ValueStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/ValueStateImplWithTTL.scala @@ -33,6 +33,10 @@ import org.apache.spark.sql.streaming.{TTLConfig, ValueState} * @param valEncoder - Spark SQL encoder for value * @param ttlConfig - TTL configuration for values stored in this state * @param batchTimestampMs - current batch processing timestamp. + * @param prevBatchTimestampMs - batch timestamp from the previous micro-batch (exclusive). + * Entries with expiration at or below this timestamp are assumed + * to have been already cleaned up and will be skipped during + * TTL eviction scans. * @param metrics - metrics to be updated as part of stateful processing * @tparam S - data type of object that will be stored */ @@ -43,9 +47,11 @@ class ValueStateImplWithTTL[S]( valEncoder: ExpressionEncoder[Any], ttlConfig: TTLConfig, batchTimestampMs: Long, + prevBatchTimestampMs: Option[Long] = None, metrics: Map[String, SQLMetric] = Map.empty) extends OneToOneTTLState( - stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs, metrics) with ValueState[S] { + stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs, + prevBatchTimestampMs, metrics) with ValueState[S] { private val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder, stateName, hasTtl = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala index 169ab6f606dae..0d5d89db9334f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala @@ -384,6 +384,7 @@ class IncrementalExecution( t.copy( stateInfo = Some(nextStatefulOperationStateInfo()), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), + prevBatchTimestampMs = prevOffsetSeqMetadata.map(_.batchTimestampMs), eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, hasInitialState = hasInitialState @@ -394,6 +395,7 @@ class IncrementalExecution( t.copy( stateInfo = Some(nextStatefulOperationStateInfo()), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), + prevBatchTimestampMs = prevOffsetSeqMetadata.map(_.batchTimestampMs), eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, hasInitialState = hasInitialState diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index a24a76269828f..4ad3a662b4d51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -1662,6 +1662,93 @@ class RocksDB( } } + /** + * Scan key-value pairs in the range [startKey, endKey). + * + * @param startKey None to seek to the beginning of the column family, + * or Some(key) to seek to the given start position (inclusive). + * @param endKey None to scan to the end of the column family, + * or Some(key) as the exclusive upper bound for the scan (encoded key bytes). + * @param cfName The column family name. + * @return An iterator of ByteArrayPairs in the given range. + */ + def scan( + startKey: Option[Array[Byte]], + endKey: Option[Array[Byte]], + cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): NextIterator[ByteArrayPair] = { + updateMemoryUsageIfNeeded() + + val upperBoundBytes: Option[Array[Byte]] = endKey match { + case Some(key) => + Some(if (useColumnFamilies) encodeStateRowWithPrefix(key, cfName) else key) + case None => + if (useColumnFamilies) { + val cfPrefix = encodeStateRowWithPrefix(Array.emptyByteArray, cfName) + RocksDB.prefixUpperBound(cfPrefix) + } else { + None + } + } + + val seekTarget = startKey match { + case Some(key) => + if (useColumnFamilies) encodeStateRowWithPrefix(key, cfName) else key + case None => + if (useColumnFamilies) encodeStateRowWithPrefix(Array.emptyByteArray, cfName) + else null + } + + val upperBoundSlice = upperBoundBytes.map(new Slice(_)) + val scanReadOptions = new ReadOptions() + upperBoundSlice.foreach(scanReadOptions.setIterateUpperBound) + + val iter = db.newIterator(scanReadOptions) + if (seekTarget != null) { + iter.seek(seekTarget) + } else { + iter.seekToFirst() + } + + def closeResources(): Unit = { + iter.close() + scanReadOptions.close() + upperBoundSlice.foreach(_.close()) + } + + Option(TaskContext.get()).foreach { tc => + tc.addTaskCompletionListener[Unit] { _ => closeResources() } + } + + new NextIterator[ByteArrayPair] { + override protected def getNext(): ByteArrayPair = { + if (iter.isValid) { + val key = if (useColumnFamilies) { + decodeStateRowWithPrefix(iter.key)._1 + } else { + iter.key + } + + val value = if (conf.rowChecksumEnabled) { + KeyValueChecksumEncoder.decodeAndVerifyValueRowWithChecksum( + readVerifier, iter.key, iter.value, delimiterSize) + } else { + iter.value + } + + byteArrayPair.set(key, value) + iter.next() + byteArrayPair + } else { + finished = true + closeResources() + null + } + } + + override protected def close(): Unit = closeResources() + } + } + def release(): Unit = {} /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 5acf7cdc9b975..4d9c77348b493 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -46,6 +46,7 @@ import org.apache.spark.unsafe.Platform sealed trait RocksDBKeyStateEncoder { def supportPrefixKeyScan: Boolean def supportsDeleteRange: Boolean + def supportsRangeScan: Boolean def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] def encodeKey(row: UnsafeRow): Array[Byte] def decodeKey(keyBytes: Array[Byte]): UnsafeRow @@ -1500,6 +1501,8 @@ class PrefixKeyScanStateEncoder( override def supportPrefixKeyScan: Boolean = true override def supportsDeleteRange: Boolean = false + + override def supportsRangeScan: Boolean = false } /** @@ -1699,6 +1702,8 @@ class RangeKeyScanStateEncoder( override def supportPrefixKeyScan: Boolean = true override def supportsDeleteRange: Boolean = true + + override def supportsRangeScan: Boolean = true } /** @@ -1731,6 +1736,8 @@ class NoPrefixKeyStateEncoder( override def supportsDeleteRange: Boolean = false + override def supportsRangeScan: Boolean = false + override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { throw new IllegalStateException("This encoder doesn't support prefix key!") } @@ -1884,6 +1891,8 @@ class TimestampAsPrefixKeyStateEncoder( // TODO: [SPARK-55491] Revisit this to support delete range if needed. override def supportsDeleteRange: Boolean = false + + override def supportsRangeScan: Boolean = true } /** @@ -1932,6 +1941,8 @@ class TimestampAsPostfixKeyStateEncoder( } override def supportsDeleteRange: Boolean = false + + override def supportsRangeScan: Boolean = true } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 8b023a0e9f9fc..e1490c71bc69b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -549,6 +549,68 @@ private[sql] class RocksDBStateStoreProvider new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) } + override def rangeScan( + startKey: Option[UnsafeRow], + endKey: Option[UnsafeRow], + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + validateAndTransitionState(UPDATE) + verifyColFamilyOperations("rangeScan", colFamilyName) + + val kvEncoder = keyValueEncoderMap.get(colFamilyName) + require(kvEncoder._1.supportsRangeScan, + "Range scan requires an encoder that supports range scanning!") + + val encodedStartKey = startKey.map(kvEncoder._1.encodeKey) + val encodedEndKey = endKey.map(kvEncoder._1.encodeKey) + + val rowPair = new UnsafeRowPair() + val rocksDbIter = rocksDB.scan(encodedStartKey, encodedEndKey, colFamilyName) + val iter = rocksDbIter.map { kv => + rowPair.withRows(kvEncoder._1.decodeKey(kv.key), + kvEncoder._2.decodeValue(kv.value)) + rowPair + } + + new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) + } + + override def rangeScanWithMultiValues( + startKey: Option[UnsafeRow], + endKey: Option[UnsafeRow], + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + validateAndTransitionState(UPDATE) + verifyColFamilyOperations("rangeScanWithMultiValues", colFamilyName) + + val kvEncoder = keyValueEncoderMap.get(colFamilyName) + require(kvEncoder._1.supportsRangeScan, + "Range scan requires an encoder that supports range scanning!") + verify( + kvEncoder._2.supportsMultipleValuesPerKey, + "Multi-value iterator operation requires an encoder" + + " which supports multiple values for a single key") + + val encodedStartKey = startKey.map(kvEncoder._1.encodeKey) + val encodedEndKey = endKey.map(kvEncoder._1.encodeKey) + val rocksDbIter = rocksDB.scan(encodedStartKey, encodedEndKey, colFamilyName) + + val rowPair = new UnsafeRowPair() + val iter = rocksDbIter.flatMap { kv => + val keyRow = kvEncoder._1.decodeKey(kv.key) + val valueRows = kvEncoder._2.decodeValues(kv.value) + valueRows.iterator.map { valueRow => + rowPair.withRows(keyRow, valueRow) + if (!isValidated && rowPair.value != null && !useColumnFamilies) { + StateStoreProvider.validateStateRowFormat( + rowPair.key, keySchema, rowPair.value, valueSchema, stateStoreId, storeConf) + isValidated = true + } + rowPair + } + } + + new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) + } + var checkpointInfo: Option[StateStoreCheckpointInfo] = None private var storedMetrics: Option[RocksDBMetrics] = None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 6e08c10476ce7..e3601f1ef2246 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -183,6 +183,51 @@ trait ReadStateStore { prefixKey: UnsafeRow, colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair] + /** + * Scan key-value pairs in the range [startKey, endKey). + * + * @param startKey None to scan from the beginning of the column family, + * or Some(key) to seek to the given start position (inclusive). + * @param endKey None to scan to the end of the column family, + * or Some(key) as the exclusive upper bound for the scan. + * @param colFamilyName The column family name. + * + * Callers must ensure the column family's key encoder produces lexicographically ordered + * bytes for the scan range to be meaningful (e.g., timestamp-based encoders or + * RangeKeyScanStateEncoder). + */ + def rangeScan( + startKey: Option[UnsafeRow], + endKey: Option[UnsafeRow], + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) + : StateStoreIterator[UnsafeRowPair] = { + throw StateStoreErrors.unsupportedOperationException("rangeScan", "") + } + + /** + * Scan key-value pairs in the range [startKey, endKey), expanding multi-valued entries. + * + * @param startKey None to scan from the beginning of the column family, + * or Some(key) to seek to the given start position (inclusive). + * @param endKey None to scan to the end of the column family, + * or Some(key) as the exclusive upper bound for the scan. + * @param colFamilyName The column family name. + * + * Callers must ensure the column family's key encoder produces lexicographically ordered + * bytes for the scan range to be meaningful (e.g., timestamp-based encoders or + * RangeKeyScanStateEncoder). + * + * It is expected to throw exception if Spark calls this method without setting + * multipleValuesPerKey as true for the column family. + */ + def rangeScanWithMultiValues( + startKey: Option[UnsafeRow], + endKey: Option[UnsafeRow], + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) + : StateStoreIterator[UnsafeRowPair] = { + throw StateStoreErrors.unsupportedOperationException("rangeScanWithMultiValues", "") + } + /** Return an iterator containing all the key-value pairs in the StateStore. */ def iterator( colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair] @@ -411,6 +456,20 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore { store.prefixScanWithMultiValues(prefixKey, colFamilyName) } + override def rangeScan( + startKey: Option[UnsafeRow], + endKey: Option[UnsafeRow], + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + store.rangeScan(startKey, endKey, colFamilyName) + } + + override def rangeScanWithMultiValues( + startKey: Option[UnsafeRow], + endKey: Option[UnsafeRow], + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + store.rangeScanWithMultiValues(startKey, endKey, colFamilyName) + } + override def iteratorWithMultiValues( colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { store.iteratorWithMultiValues(colFamilyName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala index 14aa43d3234f7..fe09506023ccd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala @@ -172,6 +172,20 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta innerStore.prefixScanWithMultiValues(prefixKey, colFamilyName) } + override def rangeScan( + startKey: Option[UnsafeRow], + endKey: Option[UnsafeRow], + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + innerStore.rangeScan(startKey, endKey, colFamilyName) + } + + override def rangeScanWithMultiValues( + startKey: Option[UnsafeRow], + endKey: Option[UnsafeRow], + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + innerStore.rangeScanWithMultiValues(startKey, endKey, colFamilyName) + } + override def iteratorWithMultiValues( colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { innerStore.iteratorWithMultiValues(colFamilyName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 3e4b4b7320f53..0c300192dd898 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -1629,6 +1629,217 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } + private val diverseTimestamps = Seq(931L, 8000L, 452300L, 4200L, -1L, 90L, 1L, 2L, 8L, + -230L, -14569L, -92L, -7434253L, 35L, 6L, 9L, -323L, 5L, + -32L, -64L, -256L, 64L, 32L, 1024L, 4096L, 0L) + + testWithColumnFamiliesAndEncodingTypes("rocksdb range scan - rangeScan", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + try { + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) + } + + diverseTimestamps.foreach { ts => + store.put(dataToKeyRowWithRangeScan(ts, "a"), dataToValueRow(ts.toInt), cfName) + } + + // Bounded positive range [0, 100) + val boundedIter = store.rangeScan( + Some(dataToKeyRowWithRangeScan(0L, "a")), + Some(dataToKeyRowWithRangeScan(100L, "a")), cfName) + val boundedResults = boundedIter.map { pair => + (pair.key.getLong(0), pair.value.getInt(0)) + }.toList + boundedIter.close() + val expectedBoundedTs = diverseTimestamps.filter(ts => ts >= 0 && ts < 100).sorted + assert(boundedResults.map(_._1) === expectedBoundedTs) + assert(boundedResults.map(_._2) === expectedBoundedTs.map(_.toInt)) + + // Exact bound: startKey is inclusive, endKey is exclusive. + // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. + // Scan [9, 90) should include 9 but exclude 90. + val exactIter = store.rangeScan( + Some(dataToKeyRowWithRangeScan(9L, "a")), + Some(dataToKeyRowWithRangeScan(90L, "a")), cfName) + val exactResults = exactIter.map(_.key.getLong(0)).toList + exactIter.close() + assert(exactResults === diverseTimestamps.filter(ts => ts >= 9 && ts < 90).sorted) + assert(exactResults.contains(9L)) + assert(!exactResults.contains(90L)) + + // None startKey scans from beginning to 0 + val noneStartIter = store.rangeScan( + None, Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) + val noneStartResults = noneStartIter.map(_.key.getLong(0)).toList + noneStartIter.close() + assert(noneStartResults === diverseTimestamps.filter(_ < 0).sorted) + + // None endKey scans from 1000 to end + val noneEndIter = store.rangeScan( + Some(dataToKeyRowWithRangeScan(1000L, "a")), None, cfName) + val noneEndResults = noneEndIter.map(_.key.getLong(0)).toList + noneEndIter.close() + assert(noneEndResults === diverseTimestamps.filter(_ >= 1000).sorted) + + // Empty range [10, 31) - no entries between 9 and 32 + val emptyIter = store.rangeScan( + Some(dataToKeyRowWithRangeScan(10L, "a")), + Some(dataToKeyRowWithRangeScan(31L, "a")), cfName) + assert(!emptyIter.hasNext) + emptyIter.close() + + // Bounded negative range [-300, 0) + val negIter = store.rangeScan( + Some(dataToKeyRowWithRangeScan(-300L, "a")), + Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) + val negResults = negIter.map(_.key.getLong(0)).toList + negIter.close() + assert(negResults === diverseTimestamps.filter(ts => ts >= -300 && ts < 0).sorted) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - scan with multiple key2 values within same key1 range", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + tryWithProviderResource(newStoreProvider(keySchemaWithRangeScan, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + try { + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0))) + } + + Seq("a", "b", "c").foreach { key2 => + Seq(100L, 200L, 300L).foreach { ts => + store.put(dataToKeyRowWithRangeScan(ts, key2), dataToValueRow(ts.toInt), cfName) + } + } + + val startKey = dataToKeyRowWithRangeScan(100L, "a") + val endKey = dataToKeyRowWithRangeScan(201L, "a") + val iter = store.rangeScan(Some(startKey), Some(endKey), cfName) + val results = iter.map { pair => + (pair.key.getLong(0), pair.key.getUTF8String(1).toString) + }.toList + iter.close() + + val expectedResults = Seq( + (100L, "a"), (100L, "b"), (100L, "c"), + (200L, "a"), (200L, "b"), (200L, "c")) + assert(results === expectedResults) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + + testWithColumnFamiliesAndEncodingTypes( + "rocksdb range scan - rangeScanWithMultiValues", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + if (colFamiliesEnabled) { + tryWithProviderResource(newStoreProvider( + StateStoreId(newDir(), Random.nextInt(), 0), + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + keySchema = keySchemaWithRangeScan, + useColumnFamilies = colFamiliesEnabled, + useMultipleValuesPerKey = true)) { provider => + val store = provider.getStore(0) + try { + val cfName = "testColFamily" + store.createColFamilyIfAbsent(cfName, + keySchemaWithRangeScan, valueSchema, + RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)), + useMultipleValuesPerKey = true) + + diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => + store.putList(dataToKeyRowWithRangeScan(ts, "a"), + Array(dataToValueRow(idx * 10), dataToValueRow(idx * 10 + 1)), cfName) + } + + // Bounded range [0, 1001) + val boundedIter = store.rangeScanWithMultiValues( + Some(dataToKeyRowWithRangeScan(0L, "a")), + Some(dataToKeyRowWithRangeScan(1001L, "a")), cfName) + val boundedResults = boundedIter.map { pair => + (pair.key.getLong(0), pair.value.getInt(0)) + }.toList + boundedIter.close() + + val expectedTimestamps = diverseTimestamps.filter(ts => ts >= 0 && ts <= 1000).sorted + assert(boundedResults.map(_._1).distinct === expectedTimestamps) + val expectedValues = diverseTimestamps.zipWithIndex + .filter { case (ts, _) => ts >= 0 && ts <= 1000 } + .sortBy(_._1) + .flatMap { case (_, idx) => Seq(idx * 10, idx * 10 + 1) } + assert(boundedResults.map(_._2) === expectedValues) + + // Exact bound: startKey is inclusive, endKey is exclusive. + // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. + val exactIter = store.rangeScanWithMultiValues( + Some(dataToKeyRowWithRangeScan(9L, "a")), + Some(dataToKeyRowWithRangeScan(90L, "a")), cfName) + val exactResults = exactIter.map(_.key.getLong(0)).toList + exactIter.close() + val exactResultsDistinct = exactResults.distinct + assert(exactResultsDistinct === diverseTimestamps + .filter(ts => ts >= 9 && ts < 90).sorted) + assert(exactResultsDistinct.contains(9L)) + assert(!exactResultsDistinct.contains(90L)) + + // None startKey scans from beginning to 0 + val noneStartIter = store.rangeScanWithMultiValues( + None, Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) + val noneStartResults = noneStartIter.map(_.key.getLong(0)).toList + noneStartIter.close() + assert(noneStartResults.distinct === diverseTimestamps.filter(_ < 0).sorted) + + // None endKey scans from 1000 to end + val noneEndIter = store.rangeScanWithMultiValues( + Some(dataToKeyRowWithRangeScan(1000L, "a")), None, cfName) + val noneEndResults = noneEndIter.map(_.key.getLong(0)).toList + noneEndIter.close() + assert(noneEndResults.distinct === diverseTimestamps.filter(_ >= 1000).sorted) + + // Empty range [10, 31) - no entries between 9 and 32 + val emptyIter = store.rangeScanWithMultiValues( + Some(dataToKeyRowWithRangeScan(10L, "a")), + Some(dataToKeyRowWithRangeScan(31L, "a")), cfName) + assert(!emptyIter.hasNext) + emptyIter.close() + + // Bounded negative range [-300, 0) + val negIter = store.rangeScanWithMultiValues( + Some(dataToKeyRowWithRangeScan(-300L, "a")), + Some(dataToKeyRowWithRangeScan(0L, "a")), cfName) + val negResults = negIter.map(_.key.getLong(0)).toList + negIter.close() + assert(negResults.distinct === diverseTimestamps + .filter(ts => ts >= -300 && ts < 0).sorted) + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + } + testWithColumnFamiliesAndEncodingTypes( "rocksdb key and value schema encoders for column families", TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala index a9540a4ad623e..5fcdfb12ba354 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala @@ -566,6 +566,141 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession } } + Seq("unsaferow", "avro").foreach { encoding => + test(s"rangeScan with postfix encoder (encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "postfix", + useMultipleValuesPerKey = true, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + diverseTimestamps.zipWithIndex.foreach { case (ts, idx) => + val keyRow = keyAndTimestampToRow("key1", 1, ts) + store.putList(keyRow, Array(valueToRow(idx * 10), valueToRow(idx * 10 + 1))) + } + + // key2 entry to verify prefix isolation + store.putList(keyAndTimestampToRow("key2", 1, 500L), + Array(valueToRow(999))) + + // Bounded range [0, 1001) + val boundedIter = store.rangeScanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, 0L)), + Some(keyAndTimestampToRow("key1", 1, 1001L))) + val boundedResults = boundedIter.map { pair => + (pair.key.getString(0), pair.key.getLong(2), pair.value.getInt(0)) + }.toList + boundedIter.close() + + val expectedTimestamps = diverseTimestamps.filter(ts => ts >= 0 && ts <= 1000).sorted + assert(boundedResults.map(_._2).distinct === expectedTimestamps) + val expectedValues = diverseTimestamps.zipWithIndex + .filter { case (ts, _) => ts >= 0 && ts <= 1000 } + .sortBy(_._1) + .flatMap { case (_, idx) => Seq(idx * 10, idx * 10 + 1) } + assert(boundedResults.map(_._3) === expectedValues) + assert(boundedResults.forall(_._1 == "key1")) + + // Exact bound: startKey is inclusive, endKey is exclusive. + // 9 exists in diverseTimestamps, 90 exists in diverseTimestamps. + val exactIter = store.rangeScanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, 9L)), + Some(keyAndTimestampToRow("key1", 1, 90L))) + val exactResults = exactIter.map(_.key.getLong(2)).toList + exactIter.close() + val exactResultsDistinct = exactResults.distinct + assert(exactResultsDistinct === diverseTimestamps + .filter(ts => ts >= 9 && ts < 90).sorted) + assert(exactResultsDistinct.contains(9L)) + assert(!exactResultsDistinct.contains(90L)) + + // Postfix timestamp encoder places the timestamp after the key prefix. + // With different key prefixes, None in startKey or endKey would scan across + // key boundaries, which is not meaningful for postfix encoding. Hence we only + // test bounded ranges with explicit keys here. + + // Full range [MinValue, MaxValue) + val fullIter = store.rangeScanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, Long.MinValue)), + Some(keyAndTimestampToRow("key1", 1, Long.MaxValue))) + val fullResults = fullIter.map(_.key.getLong(2)).toList + fullIter.close() + + assert(fullResults.distinct === diverseTimestamps.sorted) + + // Bounded negative range [-300, 0) + val negIter = store.rangeScanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, -300L)), + Some(keyAndTimestampToRow("key1", 1, 0L))) + val negResults = negIter.map(_.key.getLong(2)).toList + negIter.close() + assert(negResults.distinct === diverseTimestamps + .filter(ts => ts >= -300 && ts < 0).sorted) + + // Empty range [10, 31) - no diverseTimestamps entries between 9 and 32 + val emptyIter = store.rangeScanWithMultiValues( + Some(keyAndTimestampToRow("key1", 1, 10L)), + Some(keyAndTimestampToRow("key1", 1, 31L))) + assert(!emptyIter.hasNext) + emptyIter.close() + } finally { + store.abort() + } + } + } + + // Sanity test for prefix encoder scan. Full scan coverage is in RocksDBStateStoreSuite's + // "rocksdb range scan - rangeScan" and "rocksdb range scan - rangeScanWithMultiValues" tests. + // This test verifies the timestamp prefix encoder integration works correctly. + test(s"rangeScan with prefix encoder (encoding = $encoding)") { + tryWithProviderResource( + newStoreProviderWithTimestampEncoder( + encoderType = "prefix", + useMultipleValuesPerKey = true, + dataEncoding = encoding) + ) { provider => + val store = provider.getStore(0) + + try { + val timestamps = Seq(100L, 200L, 300L, 400L, 500L) + timestamps.foreach { ts => + store.merge(keyAndTimestampToRow("key1", 1, ts), valueToRow(ts.toInt)) + } + store.merge(keyAndTimestampToRow("key2", 2, 150L), valueToRow(150)) + + // None startKey scans from beginning up to 301 (exclusive) + val iter1 = store.rangeScanWithMultiValues(None, + Some(keyAndTimestampToRow("key1", 1, 301L))) + val results1 = iter1.map { pair => + (pair.key.getString(0), pair.key.getLong(2)) + }.toList + iter1.close() + + assert(results1 === Seq( + ("key1", 100L), ("key2", 150L), ("key1", 200L), ("key1", 300L))) + + // Boundary safety: endKey at 201, includes everything up to 200 + // regardless of join key + val iter2 = store.rangeScanWithMultiValues(None, + Some(keyAndTimestampToRow("key1", 1, 201L))) + val results2 = iter2.map { pair => + (pair.key.getString(0), pair.key.getLong(2)) + }.toList + iter2.close() + + assert(results2 === Seq( + ("key1", 100L), ("key2", 150L), ("key1", 200L))) + } finally { + store.abort() + } + } + } + + } + // Helper methods to create test data private val keyProjection = UnsafeProjection.create(keySchema) private val keyAndTimestampProjection = UnsafeProjection.create(