Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,14 @@ def row(self):
return RowCompositeOutputProcessor()


class LargeValueStatefulProcessorFactory(StatefulProcessorFactory):
def pandas(self):
return PandasLargeValueStatefulProcessor()

def row(self):
return RowLargeValueStatefulProcessor()


# StatefulProcessor implementations


Expand Down Expand Up @@ -2039,3 +2047,87 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:

def close(self) -> None:
pass


class PandasLargeValueStatefulProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle):
# Test all three state types with large values
value_state_schema = StructType([StructField("value", StringType(), True)])
self.value_state = handle.getValueState("valueState", value_state_schema)

list_state_schema = StructType([StructField("value", StringType(), True)])
self.list_state = handle.getListState("listState", list_state_schema)

self.map_state = handle.getMapState("mapState", "key string", "value string")

def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
# Create a large string (512 KB)
target_size_bytes = 512 * 1024
large_string = "a" * target_size_bytes

# Test ValueState with large string
self.value_state.update((large_string,))
value_retrieved = self.value_state.get()[0]

# Test ListState with large strings
self.list_state.put([(large_string,), (large_string + "b",), (large_string + "c",)])
list_retrieved = list(self.list_state.get())
list_elements = ",".join([elem[0] for elem in list_retrieved])

# Test MapState with large strings
map_key = ("large_string_key",)
self.map_state.updateValue(map_key, (large_string,))
map_retrieved = f"{map_key[0]}:{self.map_state.getValue(map_key)[0]}"

yield pd.DataFrame(
{
"id": key,
"valueStateResult": [value_retrieved],
"listStateResult": [list_elements],
"mapStateResult": [map_retrieved],
}
)

def close(self) -> None:
pass


class RowLargeValueStatefulProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle):
# Test all three state types with large values
value_state_schema = StructType([StructField("value", StringType(), True)])
self.value_state = handle.getValueState("valueState", value_state_schema)

list_state_schema = StructType([StructField("value", StringType(), True)])
self.list_state = handle.getListState("listState", list_state_schema)

self.map_state = handle.getMapState("mapState", "key string", "value string")

def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
# Create a large string (512 KB)
target_size_bytes = 512 * 1024
large_string = "a" * target_size_bytes

# Test ValueState with large string
self.value_state.update((large_string,))
value_retrieved = self.value_state.get()[0]

# Test ListState with large strings
self.list_state.put([(large_string,), (large_string + "b",), (large_string + "c",)])
list_retrieved = list(self.list_state.get())
list_elements = ",".join([elem[0] for elem in list_retrieved])

# Test MapState with large strings
map_key = ("large_string_key",)
self.map_state.updateValue(map_key, (large_string,))
map_retrieved = f"{map_key[0]}:{self.map_state.getValue(map_key)[0]}"

yield Row(
id=key[0],
valueStateResult=value_retrieved,
listStateResult=list_elements,
mapStateResult=map_retrieved,
)

def close(self) -> None:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
ListStateLargeTTLProcessorFactory,
MapStateProcessorFactory,
MapStateLargeTTLProcessorFactory,
LargeValueStatefulProcessorFactory,
BasicProcessorFactory,
BasicProcessorNotNullableFactory,
AddFieldsProcessorFactory,
Expand Down Expand Up @@ -2264,6 +2265,50 @@ def check_results(batch_df, batch_id):
),
)

# test all state types (value, list, map) with large values (512 KB)
def test_transform_with_state_large_values(self):
def check_results(batch_df, batch_id):
batch_df.collect()
# Create expected large string (512 KB)
target_size_bytes = 512 * 1024
large_string = "a" * target_size_bytes
expected_list_elements = ",".join(
[large_string, large_string + "b", large_string + "c"]
)
expected_map_result = f"large_string_key:{large_string}"

assert set(batch_df.sort("id").collect()) == {
Row(
id="0",
valueStateResult=large_string,
listStateResult=expected_list_elements,
mapStateResult=expected_map_result,
),
Row(
id="1",
valueStateResult=large_string,
listStateResult=expected_list_elements,
mapStateResult=expected_map_result,
),
}

output_schema = StructType(
[
StructField("id", StringType(), True),
StructField("valueStateResult", StringType(), True),
StructField("listStateResult", StringType(), True),
StructField("mapStateResult", StringType(), True),
]
)

self._test_transform_with_state_basic(
LargeValueStatefulProcessorFactory(),
check_results,
True,
"None",
output_schema=output_schema,
)


@unittest.skipIf(
not have_pyarrow or os.environ.get("PYTHON_GIL", "?") == "0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class TransformWithStateInPySparkStateServer(
private def parseProtoMessage(): StateRequest = {
val messageLen = inputStream.readInt()
val messageBytes = new Array[Byte](messageLen)
inputStream.read(messageBytes)
inputStream.readFully(messageBytes)
StateRequest.parseFrom(ByteString.copyFrom(messageBytes))
}

Expand Down