|
7 | 7 | from typing import Any |
8 | 8 | from unittest.mock import MagicMock |
9 | 9 |
|
| 10 | +import nexusrpc.handler |
10 | 11 | import pytest |
11 | 12 | from langsmith import traceable, tracing_context |
12 | 13 |
|
13 | | -from temporalio import activity, common, workflow |
| 14 | +from temporalio import activity, common, nexus, workflow |
14 | 15 | from temporalio.client import Client, WorkflowFailureError |
15 | 16 | from temporalio.contrib.langsmith import LangSmithPlugin |
16 | 17 | from temporalio.exceptions import ApplicationError |
17 | 18 | from temporalio.testing import WorkflowEnvironment |
18 | 19 | from tests.contrib.langsmith.conftest import InMemoryRunCollector, dump_runs |
19 | 20 | from tests.helpers import new_worker |
| 21 | +from tests.helpers.nexus import make_nexus_endpoint_name |
20 | 22 |
|
21 | 23 | # --------------------------------------------------------------------------- |
22 | 24 | # Shared @traceable functions and activities |
@@ -64,6 +66,29 @@ async def run(self) -> str: |
64 | 66 | ) |
65 | 67 |
|
66 | 68 |
|
| 69 | +@workflow.defn |
| 70 | +class SimpleNexusWorkflow: |
| 71 | + @workflow.run |
| 72 | + async def run(self, input: str) -> str: |
| 73 | + return await workflow.execute_activity( |
| 74 | + traceable_activity, |
| 75 | + start_to_close_timeout=timedelta(seconds=10), |
| 76 | + ) |
| 77 | + |
| 78 | + |
| 79 | +@nexusrpc.handler.service_handler |
| 80 | +class NexusService: |
| 81 | + @nexus.workflow_run_operation |
| 82 | + async def run_operation( |
| 83 | + self, ctx: nexus.WorkflowRunOperationContext, input: str |
| 84 | + ) -> nexus.WorkflowHandle[str]: |
| 85 | + return await ctx.start_workflow( |
| 86 | + SimpleNexusWorkflow.run, |
| 87 | + input, |
| 88 | + id=f"nexus-wf-{ctx.request_id}", |
| 89 | + ) |
| 90 | + |
| 91 | + |
67 | 92 | # --------------------------------------------------------------------------- |
68 | 93 | # Simple/basic workflows and activities |
69 | 94 | # --------------------------------------------------------------------------- |
@@ -113,7 +138,17 @@ async def run(self) -> str: |
113 | 138 | TraceableActivityWorkflow.run, |
114 | 139 | id=f"child-{workflow.info().workflow_id}", |
115 | 140 | ) |
116 | | - # 4. Wait for signal |
| 141 | + # 4. Nexus operation |
| 142 | + nexus_client = workflow.create_nexus_client( |
| 143 | + endpoint=make_nexus_endpoint_name(workflow.info().task_queue), |
| 144 | + service=NexusService, |
| 145 | + ) |
| 146 | + nexus_handle = await nexus_client.start_operation( |
| 147 | + operation=NexusService.run_operation, |
| 148 | + input="test-input", |
| 149 | + ) |
| 150 | + await nexus_handle |
| 151 | + # 5. Wait for signal |
117 | 152 | await workflow.wait_condition(lambda: self._signal_received) |
118 | 153 | # 5. Wait for update to complete |
119 | 154 | await workflow.wait_condition(lambda: self._complete) |
@@ -449,8 +484,14 @@ async def user_pipeline() -> str: |
449 | 484 | temporal_client, |
450 | 485 | ComprehensiveWorkflow, |
451 | 486 | TraceableActivityWorkflow, |
| 487 | + SimpleNexusWorkflow, |
452 | 488 | activities=[nested_traceable_activity, traceable_activity], |
| 489 | + nexus_service_handlers=[NexusService()], |
453 | 490 | ) as worker: |
| 491 | + await env.create_nexus_endpoint( |
| 492 | + make_nexus_endpoint_name(worker.task_queue), |
| 493 | + worker.task_queue, |
| 494 | + ) |
454 | 495 | handle = await temporal_client.start_workflow( |
455 | 496 | ComprehensiveWorkflow.run, |
456 | 497 | id=f"comprehensive-{uuid.uuid4()}", |
@@ -489,6 +530,13 @@ async def user_pipeline() -> str: |
489 | 530 | " StartActivity:traceable_activity", |
490 | 531 | " RunActivity:traceable_activity", |
491 | 532 | " inner_llm_call", |
| 533 | + " StartNexusOperation:NexusService/run_operation", |
| 534 | + " RunStartNexusOperationHandler:NexusService/run_operation", |
| 535 | + " StartWorkflow:SimpleNexusWorkflow", |
| 536 | + " RunWorkflow:SimpleNexusWorkflow", |
| 537 | + " StartActivity:traceable_activity", |
| 538 | + " RunActivity:traceable_activity", |
| 539 | + " inner_llm_call", |
492 | 540 | " QueryWorkflow:my_query", |
493 | 541 | " HandleQuery:my_query", |
494 | 542 | " SignalWorkflow:my_signal", |
@@ -517,8 +565,14 @@ async def user_pipeline() -> str: |
517 | 565 | temporal_client, |
518 | 566 | ComprehensiveWorkflow, |
519 | 567 | TraceableActivityWorkflow, |
| 568 | + SimpleNexusWorkflow, |
520 | 569 | activities=[nested_traceable_activity, traceable_activity], |
| 570 | + nexus_service_handlers=[NexusService()], |
521 | 571 | ) as worker: |
| 572 | + await env.create_nexus_endpoint( |
| 573 | + make_nexus_endpoint_name(worker.task_queue), |
| 574 | + worker.task_queue, |
| 575 | + ) |
522 | 576 | handle = await temporal_client.start_workflow( |
523 | 577 | ComprehensiveWorkflow.run, |
524 | 578 | id=f"comprehensive-no-runs-{uuid.uuid4()}", |
@@ -547,6 +601,7 @@ async def user_pipeline() -> str: |
547 | 601 | " outer_chain", |
548 | 602 | " inner_llm_call", |
549 | 603 | " inner_llm_call", |
| 604 | + " inner_llm_call", |
550 | 605 | ] |
551 | 606 | assert hierarchy == expected, ( |
552 | 607 | f"Hierarchy mismatch.\nExpected:\n{expected}\nActual:\n{hierarchy}" |
|
0 commit comments