11import asyncio
2+ from datetime import timedelta , datetime
23from unittest .mock import Mock , AsyncMock , PropertyMock
34
45import pytest
6+ from google .protobuf .timestamp_pb2 import Timestamp
7+ from google .protobuf .duration import from_timedelta
58
6- from cadence import Client
9+ from cadence import activity , Client
710from cadence ._internal .activity import ActivityExecutor
8- from cadence .api .v1 .common_pb2 import WorkflowExecution , ActivityType , Payload , Failure
11+ from cadence .activity import ActivityInfo
12+ from cadence .api .v1 .common_pb2 import WorkflowExecution , ActivityType , Payload , Failure , WorkflowType
913from cadence .api .v1 .service_worker_pb2 import RespondActivityTaskCompletedResponse , PollForActivityTaskResponse , \
1014 RespondActivityTaskCompletedRequest , RespondActivityTaskFailedResponse , RespondActivityTaskFailedRequest
1115from cadence .data_converter import DefaultDataConverter
@@ -19,7 +23,6 @@ def client() -> Client:
1923 return client
2024
2125
22- @pytest .mark .asyncio
2326async def test_activity_async_success (client ):
2427 worker_stub = client .worker_stub
2528 worker_stub .RespondActivityTaskCompleted = AsyncMock (return_value = RespondActivityTaskCompletedResponse ())
@@ -37,7 +40,6 @@ async def activity_fn():
3740 identity = 'identity' ,
3841 ))
3942
40- @pytest .mark .asyncio
4143async def test_activity_async_failure (client ):
4244 worker_stub = client .worker_stub
4345 worker_stub .RespondActivityTaskFailed = AsyncMock (return_value = RespondActivityTaskFailedResponse ())
@@ -64,7 +66,6 @@ async def activity_fn():
6466 identity = 'identity' ,
6567 )
6668
67- @pytest .mark .asyncio
6869async def test_activity_args (client ):
6970 worker_stub = client .worker_stub
7071 worker_stub .RespondActivityTaskCompleted = AsyncMock (return_value = RespondActivityTaskCompletedResponse ())
@@ -82,8 +83,6 @@ async def activity_fn(first: str, second: str):
8283 identity = 'identity' ,
8384 ))
8485
85-
86- @pytest .mark .asyncio
8786async def test_activity_sync_success (client ):
8887 worker_stub = client .worker_stub
8988 worker_stub .RespondActivityTaskCompleted = AsyncMock (return_value = RespondActivityTaskCompletedResponse ())
@@ -105,7 +104,6 @@ def activity_fn():
105104 identity = 'identity' ,
106105 ))
107106
108- @pytest .mark .asyncio
109107async def test_activity_sync_failure (client ):
110108 worker_stub = client .worker_stub
111109 worker_stub .RespondActivityTaskFailed = AsyncMock (return_value = RespondActivityTaskFailedResponse ())
@@ -132,7 +130,6 @@ def activity_fn():
132130 identity = 'identity' ,
133131 )
134132
135- @pytest .mark .asyncio
136133async def test_activity_unknown (client ):
137134 worker_stub = client .worker_stub
138135 worker_stub .RespondActivityTaskFailed = AsyncMock (return_value = RespondActivityTaskFailedResponse ())
@@ -148,7 +145,7 @@ def registry(name: str):
148145
149146 call = worker_stub .RespondActivityTaskFailed .call_args [0 ][0 ]
150147
151- assert 'unknown activity : any' in call .failure .details .decode ()
148+ assert 'Activity type not found : any' in call .failure .details .decode ()
152149 call .failure .details = bytes ()
153150 assert call == RespondActivityTaskFailedRequest (
154151 task_token = b'task_token' ,
@@ -158,15 +155,85 @@ def registry(name: str):
158155 identity = 'identity' ,
159156 )
160157
158+ async def test_activity_context (client ):
159+ worker_stub = client .worker_stub
160+ worker_stub .RespondActivityTaskCompleted = AsyncMock (return_value = RespondActivityTaskCompletedResponse ())
161+
162+ async def activity_fn ():
163+ assert fake_info ("activity_type" ) == activity .info ()
164+ assert activity .in_activity ()
165+ assert activity .client () is not None
166+ return "success"
167+
168+ executor = ActivityExecutor (client , 'task_list' , 'identity' , 1 , lambda name : activity_fn )
169+
170+ await executor .execute (fake_task ("activity_type" , "" ))
171+
172+ worker_stub .RespondActivityTaskCompleted .assert_called_once_with (RespondActivityTaskCompletedRequest (
173+ task_token = b'task_token' ,
174+ result = Payload (data = '"success"' .encode ()),
175+ identity = 'identity' ,
176+ ))
177+
178+ async def test_activity_context_sync (client ):
179+ worker_stub = client .worker_stub
180+ worker_stub .RespondActivityTaskCompleted = AsyncMock (return_value = RespondActivityTaskCompletedResponse ())
181+
182+ def activity_fn ():
183+ assert fake_info ("activity_type" ) == activity .info ()
184+ assert activity .in_activity ()
185+ with pytest .raises (RuntimeError ):
186+ activity .client ()
187+ return "success"
188+
189+ executor = ActivityExecutor (client , 'task_list' , 'identity' , 1 , lambda name : activity_fn )
190+
191+ await executor .execute (fake_task ("activity_type" , "" ))
192+
193+ worker_stub .RespondActivityTaskCompleted .assert_called_once_with (RespondActivityTaskCompletedRequest (
194+ task_token = b'task_token' ,
195+ result = Payload (data = '"success"' .encode ()),
196+ identity = 'identity' ,
197+ ))
198+
199+
200+ def fake_info (activity_type : str ) -> ActivityInfo :
201+ return ActivityInfo (
202+ task_token = b'task_token' ,
203+ workflow_domain = "workflow_domain" ,
204+ workflow_id = "workflow_id" ,
205+ workflow_run_id = "run_id" ,
206+ activity_id = "activity_id" ,
207+ activity_type = activity_type ,
208+ attempt = 1 ,
209+ workflow_type = "workflow_type" ,
210+ task_list = "task_list" ,
211+ heartbeat_timeout = timedelta (seconds = 1 ),
212+ scheduled_timestamp = datetime (2020 , 1 , 2 ,3 ),
213+ started_timestamp = datetime (2020 , 1 , 2 ,4 ),
214+ start_to_close_timeout = timedelta (seconds = 2 ),
215+ )
216+
161217def fake_task (activity_type : str , input_json : str ) -> PollForActivityTaskResponse :
162218 return PollForActivityTaskResponse (
163219 task_token = b'task_token' ,
220+ workflow_domain = "workflow_domain" ,
221+ workflow_type = WorkflowType (name = "workflow_type" ),
164222 workflow_execution = WorkflowExecution (
165223 workflow_id = "workflow_id" ,
166224 run_id = "run_id" ,
167225 ),
168226 activity_id = "activity_id" ,
169227 activity_type = ActivityType (name = activity_type ),
170228 input = Payload (data = input_json .encode ()),
171- attempt = 0 ,
172- )
229+ attempt = 1 ,
230+ heartbeat_timeout = from_timedelta (timedelta (seconds = 1 )),
231+ scheduled_time = from_datetime (datetime (2020 , 1 , 2 , 3 )),
232+ started_time = from_datetime (datetime (2020 , 1 , 2 , 4 )),
233+ start_to_close_timeout = from_timedelta (timedelta (seconds = 2 )),
234+ )
235+
236+ def from_datetime (time : datetime ) -> Timestamp :
237+ t = Timestamp ()
238+ t .FromDatetime (time )
239+ return t
0 commit comments