Skip to content
266 changes: 207 additions & 59 deletions scripts/gen_payload_visitor.py

Large diffs are not rendered by default.

318 changes: 230 additions & 88 deletions temporalio/bridge/_visitor.py

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions temporalio/bridge/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ async def decode_activation(
activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation,
data_converter: temporalio.converter.DataConverter,
decode_headers: bool,
concurrency_limit: int,
) -> temporalio.converter._extstore.StorageOperationMetrics:
"""Decode all payloads in the activation.

Expand All @@ -312,7 +313,9 @@ async def decode_activation(
metrics = temporalio.converter._extstore.StorageOperationMetrics()
with metrics.track():
await CommandAwarePayloadVisitor(
skip_search_attributes=True, skip_headers=not decode_headers
skip_search_attributes=True,
skip_headers=not decode_headers,
concurrency_limit=concurrency_limit,
).visit(_Visitor(data_converter._decode_payload_sequence), activation)
return metrics

Expand All @@ -321,6 +324,7 @@ async def encode_completion(
completion: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion,
data_converter: temporalio.converter.DataConverter,
encode_headers: bool,
concurrency_limit: int,
) -> temporalio.converter._extstore.StorageOperationMetrics:
"""Encode all payloads in the completion.

Expand All @@ -330,6 +334,8 @@ async def encode_completion(
metrics = temporalio.converter._extstore.StorageOperationMetrics()
with metrics.track():
await CommandAwarePayloadVisitor(
skip_search_attributes=True, skip_headers=not encode_headers
skip_search_attributes=True,
skip_headers=not encode_headers,
concurrency_limit=concurrency_limit,
).visit(_Visitor(data_converter._encode_payload_sequence), completion)
return metrics
21 changes: 21 additions & 0 deletions temporalio/worker/_command_aware_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,27 @@ class CommandAwarePayloadVisitor(PayloadVisitor):
activation jobs that have both a 'seq' field and payloads to visit.
"""

def __init__(
self,
*,
skip_search_attributes: bool = False,
skip_headers: bool = False,
concurrency_limit: int = 1,
) -> None:
"""Creates a new command-aware payload visitor.

Args:
skip_search_attributes: If True, search attributes are not visited.
skip_headers: If True, headers are not visited.
concurrency_limit: Maximum number of payload visits that may run
concurrently during a single call to visit(). Defaults to 1.
"""
super().__init__(
skip_search_attributes=skip_search_attributes,
skip_headers=skip_headers,
concurrency_limit=concurrency_limit,
)

# Workflow commands with payloads
async def _visit_coresdk_workflow_commands_ScheduleActivity(
self, fs: VisitorFunctions, o: ScheduleActivity
Expand Down
1 change: 1 addition & 0 deletions temporalio/worker/_replayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def on_eviction_hook(
"header_codec_behavior", HeaderCodecBehavior.NO_CODEC
)
!= HeaderCodecBehavior.NO_CODEC,
max_workflow_task_payload_concurrency=1,
)
external_storage = data_converter.external_storage
storage_driver_types = (
Expand Down
16 changes: 15 additions & 1 deletion temporalio/worker/_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ._nexus import _NexusWorker
from ._plugin import Plugin
from ._tuning import WorkerTuner
from ._workflow import _WorkflowWorker
from ._workflow import _DEFAULT_WORKFLOW_TASK_PAYLOAD_CONCURRENCY, _WorkflowWorker
from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner
from .workflow_sandbox import SandboxedWorkflowRunner

Expand Down Expand Up @@ -142,6 +142,7 @@ def __init__(
maximum=5
),
disable_payload_error_limit: bool = False,
max_workflow_task_payload_concurrency: int = _DEFAULT_WORKFLOW_TASK_PAYLOAD_CONCURRENCY,
) -> None:
"""Create a worker to process workflows and/or activities.

Expand Down Expand Up @@ -316,6 +317,10 @@ def __init__(
and cause a task failure if the size limit is exceeded. The default is False.
See https://docs.temporal.io/troubleshooting/blob-size-limit-error for more
details.
max_workflow_task_payload_concurrency: Maximum number of payload
operations (codec encode/decode, external storage I/O, etc.)
that may run concurrently within a single workflow task
activation. Defaults to 1. WARNING: This setting is experimental.

"""
config = WorkerConfig(
Expand Down Expand Up @@ -361,6 +366,7 @@ def __init__(
activity_task_poller_behavior=activity_task_poller_behavior,
nexus_task_poller_behavior=nexus_task_poller_behavior,
disable_payload_error_limit=disable_payload_error_limit,
max_workflow_task_payload_concurrency=max_workflow_task_payload_concurrency,
)

plugins_from_client = cast(
Expand Down Expand Up @@ -414,6 +420,12 @@ def _init_from_config(self, client: temporalio.client.Client, config: WorkerConf
raise ValueError(
"default_versioning_behavior must be UNSPECIFIED when use_worker_versioning is False"
)
max_workflow_task_payload_concurrency = config.get(
"max_workflow_task_payload_concurrency",
_DEFAULT_WORKFLOW_TASK_PAYLOAD_CONCURRENCY,
)
if max_workflow_task_payload_concurrency < 1:
raise ValueError("max_workflow_task_payload_concurrency must be positive")

# Prepend applicable client interceptors to the given ones
client_config = config["client"].config(active_config=True) # type: ignore[reportTypedDictNotRequiredAccess]
Expand Down Expand Up @@ -518,6 +530,7 @@ def check_activity(activity: str):
assert_local_activity_valid=check_activity,
encode_headers=client_config["header_codec_behavior"]
!= HeaderCodecBehavior.NO_CODEC,
max_workflow_task_payload_concurrency=max_workflow_task_payload_concurrency,
)

tuner = config.get("tuner")
Expand Down Expand Up @@ -964,6 +977,7 @@ class WorkerConfig(TypedDict, total=False):
activity_task_poller_behavior: PollerBehavior
nexus_task_poller_behavior: PollerBehavior
disable_payload_error_limit: bool
max_workflow_task_payload_concurrency: int


def _warn_if_activity_executor_max_workers_is_inconsistent(
Expand Down
8 changes: 8 additions & 0 deletions temporalio/worker/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
# Set to true to log all activations and completions
LOG_PROTOS = False

_DEFAULT_WORKFLOW_TASK_PAYLOAD_CONCURRENCY: int = 1


class _WorkflowWorker: # type:ignore[reportUnusedClass]
def __init__(
Expand Down Expand Up @@ -74,6 +76,7 @@ def __init__(
should_enforce_versioning_behavior: bool,
assert_local_activity_valid: Callable[[str], None],
encode_headers: bool,
max_workflow_task_payload_concurrency: int,
) -> None:
self._bridge_worker = bridge_worker
self._namespace = namespace
Expand Down Expand Up @@ -112,6 +115,9 @@ def __init__(
self._on_eviction_hook = on_eviction_hook
self._disable_safe_eviction = disable_safe_eviction
self._encode_headers = encode_headers
self._max_workflow_task_payload_concurrency = (
max_workflow_task_payload_concurrency
)
self._throw_after_activation: Exception | None = None

# If there's a debug mode or a truthy TEMPORAL_DEBUG env var, disable
Expand Down Expand Up @@ -299,6 +305,7 @@ async def _handle_activation(
act,
data_converter,
decode_headers=self._encode_headers,
concurrency_limit=self._max_workflow_task_payload_concurrency,
)
if not workflow:
assert init_job
Expand Down Expand Up @@ -410,6 +417,7 @@ async def _handle_activation(
completion,
data_converter,
encode_headers=self._encode_headers,
concurrency_limit=self._max_workflow_task_payload_concurrency,
)
except temporalio.converter._payload_limits._PayloadSizeError as err:
logger.warning(err.message)
Expand Down
70 changes: 69 additions & 1 deletion tests/worker/test_visitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import dataclasses
import time
from collections.abc import MutableSequence

from google.protobuf.duration_pb2 import Duration
Expand Down Expand Up @@ -205,6 +207,72 @@ async def test_visit_payloads_on_other_commands():
assert ur.completed.metadata["visited"]


async def test_concurrent_throughput():
"""Demonstrate that concurrent visitation is faster than serialized for I/O-bound codecs."""
N_CMDS = 10
N_ARGS = 5
SLEEP = 0.02

class SlowVisitor(VisitorFunctions):
def __init__(self, *, blocking: bool = False):
self.visit_count = 0
self._active = 0
self.max_concurrent = 0
self._blocking = blocking

async def visit_payload(self, payload: Payload) -> None:
return await self._visit(1)

async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:
return await self._visit(len(payloads))

async def _visit(self, count: int) -> None:
self._active += 1
self.max_concurrent = max(self.max_concurrent, self._active)
try:
if self._blocking:
time.sleep(SLEEP * count)
else:
await asyncio.sleep(SLEEP * count)
self.visit_count += count
finally:
self._active -= 1

completion = WorkflowActivationCompletion(
run_id="1",
successful=Success(
commands=[
WorkflowCommand(
schedule_activity=ScheduleActivity(
seq=i,
activity_id=str(i),
activity_type="",
task_queue="",
arguments=[
Payload(data=f"cmd_{i}_arg_{j}".encode())
for j in range(N_ARGS)
],
priority=Priority(),
)
)
for i in range(N_CMDS)
]
),
)

visitor_default = SlowVisitor()
await PayloadVisitor().visit(visitor_default, completion)

assert visitor_default.visit_count == N_CMDS * N_ARGS
assert visitor_default.max_concurrent == 1

visitor_concurrent = SlowVisitor()
await PayloadVisitor(concurrency_limit=5).visit(visitor_concurrent, completion)

assert visitor_concurrent.visit_count == N_CMDS * N_ARGS
assert visitor_concurrent.max_concurrent == 5


async def test_bridge_encoding():
comp = WorkflowActivationCompletion(
run_id="1",
Expand Down Expand Up @@ -235,7 +303,7 @@ async def test_bridge_encoding():
payload_codec=SimpleCodec(),
)

await temporalio.bridge.worker.encode_completion(comp, data_converter, True)
await temporalio.bridge.worker.encode_completion(comp, data_converter, True, 1)

cmd = comp.successful.commands[0]
sa = cmd.schedule_activity
Expand Down
Loading