Skip to content

Commit 4a4ab8c

Browse files
committed
Carry state over CAN
1 parent fdf011e commit 4a4ab8c

File tree

1 file changed

+50
-12
lines changed

1 file changed

+50
-12
lines changed

signals_and_updates/order_handling_of_n_messages.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
import asyncio
22
import logging
33
import random
4+
from typing import Optional
45

56
from temporalio import common, workflow
67
from temporalio.client import Client, WorkflowHandle
78
from temporalio.worker import Worker
89

910
Payload = str
11+
SerializedQueueState = tuple[int, list[tuple[int, Payload]]]
1012

1113

12-
class Queue:
13-
def __init__(self) -> None:
14-
self._head = 0
15-
self._futures: dict[int, asyncio.Future[Payload]] = {}
14+
class OrderedQueue:
15+
def __init__(self):
16+
self._futures = {}
17+
self.head = 0
1618
self.lock = asyncio.Lock()
1719

1820
def add(self, item: Payload, position: int):
@@ -24,33 +26,69 @@ def add(self, item: Payload, position: int):
2426

2527
async def next(self) -> Payload:
2628
async with self.lock:
27-
payload = await self._futures.setdefault(self._head, asyncio.Future())
28-
self._head += 1
29+
payload = await self._futures.setdefault(self.head, asyncio.Future())
30+
# Note: user must delete the payload to avoid unbounded memory usage
31+
del self._futures[self.head]
32+
self.head += 1
2933
return payload
3034

35+
def serialize(self) -> SerializedQueueState:
36+
payloads = [(i, fut.result()) for i, fut in self._futures.items() if fut.done()]
37+
return (self.head, payloads)
38+
39+
# This is bad, but AFAICS it's the best we can do currently until we have a workflow init
40+
# functionality in all SDKs (https://github.com/temporalio/features/issues/400). Currently the
41+
# problem is: if a signal/update handler is sync, then it cannot wait for anything in the main
42+
# wf coroutine. After CAN, a message may arrive in the first WFT, but the sync handler cannot
43+
# wait for wf state to be initialized. So we are forced to update an *existing* queue with the
44+
# carried-over state.
45+
def update_from_serialized_state(self, serialized_state: SerializedQueueState):
46+
head, payloads = serialized_state
47+
self.head = head
48+
for i, p in payloads:
49+
if i in self._futures:
50+
workflow.logger.error(f"duplicate message {i} encountered when deserializing state carried over CAN")
51+
else:
52+
self._futures[i] = resolved_future(p)
53+
54+
55+
def resolved_future[X](x: X) -> asyncio.Future[X]:
56+
fut = asyncio.Future[X]()
57+
fut.set_result(x)
58+
return fut
59+
60+
3161

3262
@workflow.defn
3363
class MessageProcessor:
3464
def __init__(self) -> None:
35-
self.queue = Queue()
65+
self.queue = OrderedQueue()
3666

3767
@workflow.run
38-
async def run(self):
68+
async def run(self, serialized_queue_state: Optional[SerializedQueueState] = None):
69+
# Initialize workflow state after CAN. Note that handler is sync, so it cannot wait for
70+
# workflow initialization.
71+
if serialized_queue_state:
72+
self.queue.update_from_serialized_state(serialized_queue_state)
3973
while True:
74+
workflow.logger.info(f"waiting for msg {self.queue.head + 1}")
4075
payload = await self.queue.next()
4176
workflow.logger.info(payload)
4277
if workflow.info().is_continue_as_new_suggested():
43-
workflow.continue_as_new()
78+
workflow.logger.info("CAN")
79+
workflow.continue_as_new(args=[self.queue.serialize()])
4480

81+
# Note: sync handler
4582
@workflow.update
46-
def process_message(self, sequence_number: int, payload: Payload): # sync handler
83+
def process_message(self, sequence_number: int, payload: Payload):
4784
self.queue.add(payload, sequence_number)
4885

4986

5087
async def app(wf: WorkflowHandle):
51-
sequence_numbers = list(range(10))
88+
sequence_numbers = list(range(100))
5289
random.shuffle(sequence_numbers)
5390
for i in sequence_numbers:
91+
print(f"sending update {i}")
5492
await wf.execute_update(
5593
MessageProcessor.process_message, args=[i, f"payload {i}"]
5694
)
@@ -70,7 +108,7 @@ async def main():
70108
task_queue="tq",
71109
id_reuse_policy=common.WorkflowIDReusePolicy.TERMINATE_IF_RUNNING,
72110
)
73-
await app(wf)
111+
await asyncio.gather(app(wf), wf.result())
74112

75113

76114
if __name__ == "__main__":

0 commit comments

Comments
 (0)