Skip to content

Commit 91f7da8

Browse files
committed
Initial implementation
1 parent c658a52 commit 91f7da8

File tree

9 files changed

+838
-116
lines changed

9 files changed

+838
-116
lines changed

CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## v1.4.0
9+
10+
ADDED
11+
12+
- Added `AsyncTaskHubGrpcClient` for asyncio-based applications using `grpc.aio`
13+
- Added `DefaultAsyncClientInterceptorImpl` for async gRPC metadata interceptors
14+
- Added `get_async_grpc_channel` helper for creating async gRPC channels
15+
16+
CHANGED
17+
18+
- Refactored `TaskHubGrpcClient` to share request-building and validation logic
19+
with `AsyncTaskHubGrpcClient` via module-level helper functions
20+
821
## v1.3.0
922

1023
ADDED

durabletask/client.py

Lines changed: 243 additions & 100 deletions
Large diffs are not rendered by default.
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
from __future__ import annotations
5+
6+
import logging
7+
import uuid
8+
from datetime import datetime, timezone
9+
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union
10+
11+
import durabletask.internal.helpers as helpers
12+
import durabletask.internal.orchestrator_service_pb2 as pb
13+
import durabletask.internal.shared as shared
14+
from durabletask import task
15+
from durabletask.internal.grpc_interceptor import (
16+
DefaultAsyncClientInterceptorImpl,
17+
DefaultClientInterceptorImpl,
18+
)
19+
20+
if TYPE_CHECKING:
21+
from durabletask.client import (
22+
EntityQuery,
23+
OrchestrationQuery,
24+
OrchestrationState,
25+
OrchestrationStatus,
26+
)
27+
from durabletask.entities import EntityInstanceId
28+
29+
TInput = TypeVar('TInput')
30+
TOutput = TypeVar('TOutput')
31+
32+
33+
def prepare_sync_interceptors(
34+
metadata: Optional[list[tuple[str, str]]],
35+
interceptors: Optional[Sequence[shared.ClientInterceptor]]
36+
) -> Optional[list[shared.ClientInterceptor]]:
37+
"""Prepare the list of sync gRPC interceptors, adding a metadata interceptor if needed."""
38+
result: Optional[list[shared.ClientInterceptor]] = None
39+
if interceptors is not None:
40+
result = list(interceptors)
41+
if metadata is not None:
42+
result.append(DefaultClientInterceptorImpl(metadata))
43+
elif metadata is not None:
44+
result = [DefaultClientInterceptorImpl(metadata)]
45+
return result
46+
47+
48+
def prepare_async_interceptors(
49+
metadata: Optional[list[tuple[str, str]]],
50+
interceptors: Optional[Sequence[shared.AsyncClientInterceptor]]
51+
) -> Optional[list[shared.AsyncClientInterceptor]]:
52+
"""Prepare the list of async gRPC interceptors, adding a metadata interceptor if needed."""
53+
result: Optional[list[shared.AsyncClientInterceptor]] = None
54+
if interceptors is not None:
55+
result = list(interceptors)
56+
if metadata is not None:
57+
result.append(DefaultAsyncClientInterceptorImpl(metadata))
58+
elif metadata is not None:
59+
result = [DefaultAsyncClientInterceptorImpl(metadata)]
60+
return result
61+
62+
63+
def build_schedule_new_orchestration_req(
64+
orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
65+
input: Optional[TInput],
66+
instance_id: Optional[str],
67+
start_at: Optional[datetime],
68+
reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy],
69+
tags: Optional[dict[str, str]],
70+
version: Optional[str]) -> pb.CreateInstanceRequest:
71+
"""Build a CreateInstanceRequest for scheduling a new orchestration."""
72+
name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)
73+
return pb.CreateInstanceRequest(
74+
name=name,
75+
instanceId=instance_id if instance_id else uuid.uuid4().hex,
76+
input=helpers.get_string_value(shared.to_json(input) if input is not None else None),
77+
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
78+
version=helpers.get_string_value(version),
79+
orchestrationIdReusePolicy=reuse_id_policy,
80+
tags=tags
81+
)
82+
83+
84+
def build_query_instances_req(
85+
orchestration_query: OrchestrationQuery,
86+
continuation_token) -> pb.QueryInstancesRequest:
87+
"""Build a QueryInstancesRequest from an OrchestrationQuery."""
88+
return pb.QueryInstancesRequest(
89+
query=pb.InstanceQuery(
90+
runtimeStatus=[status.value for status in orchestration_query.runtime_status] if orchestration_query.runtime_status else None,
91+
createdTimeFrom=helpers.new_timestamp(orchestration_query.created_time_from) if orchestration_query.created_time_from else None,
92+
createdTimeTo=helpers.new_timestamp(orchestration_query.created_time_to) if orchestration_query.created_time_to else None,
93+
maxInstanceCount=orchestration_query.max_instance_count,
94+
fetchInputsAndOutputs=orchestration_query.fetch_inputs_and_outputs,
95+
continuationToken=continuation_token
96+
)
97+
)
98+
99+
100+
def build_purge_by_filter_req(
101+
created_time_from: Optional[datetime],
102+
created_time_to: Optional[datetime],
103+
runtime_status: Optional[List[OrchestrationStatus]],
104+
recursive: bool) -> pb.PurgeInstancesRequest:
105+
"""Build a PurgeInstancesRequest for purging orchestrations by filter."""
106+
return pb.PurgeInstancesRequest(
107+
purgeInstanceFilter=pb.PurgeInstanceFilter(
108+
createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None,
109+
createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None,
110+
runtimeStatus=[status.value for status in runtime_status] if runtime_status else None
111+
),
112+
recursive=recursive
113+
)
114+
115+
116+
def build_query_entities_req(
117+
entity_query: EntityQuery,
118+
continuation_token) -> pb.QueryEntitiesRequest:
119+
"""Build a QueryEntitiesRequest from an EntityQuery."""
120+
return pb.QueryEntitiesRequest(
121+
query=pb.EntityQuery(
122+
instanceIdStartsWith=helpers.get_string_value(entity_query.instance_id_starts_with),
123+
lastModifiedFrom=helpers.new_timestamp(entity_query.last_modified_from) if entity_query.last_modified_from else None,
124+
lastModifiedTo=helpers.new_timestamp(entity_query.last_modified_to) if entity_query.last_modified_to else None,
125+
includeState=entity_query.include_state,
126+
includeTransient=entity_query.include_transient,
127+
pageSize=helpers.get_int_value(entity_query.page_size),
128+
continuationToken=continuation_token
129+
)
130+
)
131+
132+
133+
def check_continuation_token(resp_token, prev_token, logger: logging.Logger) -> bool:
134+
"""Check if a continuation token indicates more pages. Returns True to continue, False to stop."""
135+
if resp_token and resp_token.value and resp_token.value != "0":
136+
logger.info(f"Received continuation token with value {resp_token.value}, fetching next page...")
137+
if prev_token and prev_token.value and prev_token.value == resp_token.value:
138+
logger.warning(f"Received the same continuation token value {resp_token.value} again, stopping to avoid infinite loop.")
139+
return False
140+
return True
141+
return False
142+
143+
144+
def log_completion_state(
145+
logger: logging.Logger,
146+
instance_id: str,
147+
state: Optional[OrchestrationState]):
148+
"""Log the final state of a completed orchestration."""
149+
if not state:
150+
return
151+
# Compare against proto constants to avoid circular imports with client.py
152+
status_val = state.runtime_status.value
153+
if status_val == pb.ORCHESTRATION_STATUS_FAILED and state.failure_details is not None:
154+
details = state.failure_details
155+
logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}")
156+
elif status_val == pb.ORCHESTRATION_STATUS_TERMINATED:
157+
logger.info(f"Instance '{instance_id}' was terminated.")
158+
elif status_val == pb.ORCHESTRATION_STATUS_COMPLETED:
159+
logger.info(f"Instance '{instance_id}' completed.")
160+
161+
162+
def build_raise_event_req(
163+
instance_id: str,
164+
event_name: str,
165+
data: Optional[Any] = None) -> pb.RaiseEventRequest:
166+
"""Build a RaiseEventRequest for raising an orchestration event."""
167+
return pb.RaiseEventRequest(
168+
instanceId=instance_id,
169+
name=event_name,
170+
input=helpers.get_string_value(shared.to_json(data) if data is not None else None)
171+
)
172+
173+
174+
def build_terminate_req(
175+
instance_id: str,
176+
output: Optional[Any] = None,
177+
recursive: bool = True) -> pb.TerminateRequest:
178+
"""Build a TerminateRequest for terminating an orchestration."""
179+
return pb.TerminateRequest(
180+
instanceId=instance_id,
181+
output=helpers.get_string_value(shared.to_json(output) if output is not None else None),
182+
recursive=recursive
183+
)
184+
185+
186+
def build_signal_entity_req(
187+
entity_instance_id: EntityInstanceId,
188+
operation_name: str,
189+
input: Optional[Any] = None) -> pb.SignalEntityRequest:
190+
"""Build a SignalEntityRequest for signaling an entity."""
191+
return pb.SignalEntityRequest(
192+
instanceId=str(entity_instance_id),
193+
name=operation_name,
194+
input=helpers.get_string_value(shared.to_json(input) if input is not None else None),
195+
requestId=str(uuid.uuid4()),
196+
scheduledTime=None,
197+
parentTraceContext=None,
198+
requestTime=helpers.new_timestamp(datetime.now(timezone.utc))
199+
)

durabletask/internal/grpc_interceptor.py

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections import namedtuple
55

66
import grpc
7+
import grpc.aio
78

89

910
class _ClientCallDetails(
@@ -18,6 +19,32 @@ class _ClientCallDetails(
1819
pass
1920

2021

22+
class _AsyncClientCallDetails(
23+
namedtuple(
24+
'_AsyncClientCallDetails',
25+
['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready']),
26+
grpc.aio.ClientCallDetails):
27+
"""This is an implementation of the aio ClientCallDetails interface needed for async interceptors.
28+
This class takes five named values and inherits the ClientCallDetails from grpc.aio package.
29+
This class encloses the values that describe a RPC to be invoked.
30+
"""
31+
pass
32+
33+
34+
def _apply_metadata(client_call_details, metadata):
35+
"""Shared logic for applying metadata to call details. Returns the updated metadata list."""
36+
if metadata is None:
37+
return client_call_details.metadata
38+
39+
if client_call_details.metadata is not None:
40+
new_metadata = list(client_call_details.metadata)
41+
else:
42+
new_metadata = []
43+
44+
new_metadata.extend(metadata)
45+
return new_metadata
46+
47+
2148
class DefaultClientInterceptorImpl (
2249
grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor,
2350
grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor):
@@ -30,24 +57,17 @@ def __init__(self, metadata: list[tuple[str, str]]):
3057
self._metadata = metadata
3158

3259
def _intercept_call(
33-
self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:
60+
self, client_call_details: grpc.ClientCallDetails) -> grpc.ClientCallDetails:
3461
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
3562
call details."""
36-
if self._metadata is None:
63+
new_metadata = _apply_metadata(client_call_details, self._metadata)
64+
if new_metadata is client_call_details.metadata:
3765
return client_call_details
3866

39-
if client_call_details.metadata is not None:
40-
metadata = list(client_call_details.metadata)
41-
else:
42-
metadata = []
43-
44-
metadata.extend(self._metadata)
45-
client_call_details = _ClientCallDetails(
46-
client_call_details.method, client_call_details.timeout, metadata,
67+
return _ClientCallDetails(
68+
client_call_details.method, client_call_details.timeout, new_metadata,
4769
client_call_details.credentials, client_call_details.wait_for_ready, client_call_details.compression)
4870

49-
return client_call_details
50-
5171
def intercept_unary_unary(self, continuation, client_call_details, request):
5272
new_client_call_details = self._intercept_call(client_call_details)
5373
return continuation(new_client_call_details, request)
@@ -63,3 +83,44 @@ def intercept_stream_unary(self, continuation, client_call_details, request):
6383
def intercept_stream_stream(self, continuation, client_call_details, request):
6484
new_client_call_details = self._intercept_call(client_call_details)
6585
return continuation(new_client_call_details, request)
86+
87+
88+
class DefaultAsyncClientInterceptorImpl(
89+
grpc.aio.UnaryUnaryClientInterceptor, grpc.aio.UnaryStreamClientInterceptor,
90+
grpc.aio.StreamUnaryClientInterceptor, grpc.aio.StreamStreamClientInterceptor):
91+
"""Async gRPC interceptor that adds metadata headers to all calls."""
92+
93+
def __init__(self, metadata: list[tuple[str, str]]):
94+
self._metadata = metadata
95+
96+
def _intercept_call(
97+
self, client_call_details: grpc.aio.ClientCallDetails) -> grpc.aio.ClientCallDetails:
98+
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
99+
call details."""
100+
new_metadata = _apply_metadata(client_call_details, self._metadata)
101+
if new_metadata is client_call_details.metadata:
102+
return client_call_details
103+
104+
return _AsyncClientCallDetails(
105+
client_call_details.method,
106+
client_call_details.timeout,
107+
new_metadata,
108+
client_call_details.credentials,
109+
client_call_details.wait_for_ready,
110+
)
111+
112+
async def intercept_unary_unary(self, continuation, client_call_details, request):
113+
new_client_call_details = self._intercept_call(client_call_details)
114+
return await continuation(new_client_call_details, request)
115+
116+
async def intercept_unary_stream(self, continuation, client_call_details, request):
117+
new_client_call_details = self._intercept_call(client_call_details)
118+
return await continuation(new_client_call_details, request)
119+
120+
async def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
121+
new_client_call_details = self._intercept_call(client_call_details)
122+
return await continuation(new_client_call_details, request_iterator)
123+
124+
async def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
125+
new_client_call_details = self._intercept_call(client_call_details)
126+
return await continuation(new_client_call_details, request_iterator)

durabletask/internal/shared.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, Optional, Sequence, Union
99

1010
import grpc
11+
import grpc.aio
1112

1213
ClientInterceptor = Union[
1314
grpc.UnaryUnaryClientInterceptor,
@@ -16,6 +17,13 @@
1617
grpc.StreamStreamClientInterceptor
1718
]
1819

20+
AsyncClientInterceptor = Union[
21+
grpc.aio.UnaryUnaryClientInterceptor,
22+
grpc.aio.UnaryStreamClientInterceptor,
23+
grpc.aio.StreamUnaryClientInterceptor,
24+
grpc.aio.StreamStreamClientInterceptor
25+
]
26+
1927
# Field name used to indicate that an object was automatically serialized
2028
# and should be deserialized as a SimpleNamespace
2129
AUTO_SERIALIZED = "__durabletask_autoobject__"
@@ -62,6 +70,38 @@ def get_grpc_channel(
6270
return channel
6371

6472

73+
def get_async_grpc_channel(
74+
host_address: Optional[str],
75+
secure_channel: bool = False,
76+
interceptors: Optional[Sequence[AsyncClientInterceptor]] = None) -> grpc.aio.Channel:
77+
78+
if host_address is None:
79+
host_address = get_default_host_address()
80+
81+
for protocol in SECURE_PROTOCOLS:
82+
if host_address.lower().startswith(protocol):
83+
secure_channel = True
84+
host_address = host_address[len(protocol):]
85+
break
86+
87+
for protocol in INSECURE_PROTOCOLS:
88+
if host_address.lower().startswith(protocol):
89+
secure_channel = False
90+
host_address = host_address[len(protocol):]
91+
break
92+
93+
if secure_channel:
94+
channel = grpc.aio.secure_channel(
95+
host_address, grpc.ssl_channel_credentials(),
96+
interceptors=interceptors)
97+
else:
98+
channel = grpc.aio.insecure_channel(
99+
host_address,
100+
interceptors=interceptors)
101+
102+
return channel
103+
104+
65105
def get_logger(
66106
name_suffix: str,
67107
log_handler: Optional[logging.Handler] = None,

0 commit comments

Comments
 (0)