diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 07ab434..2037eb4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,7 +6,7 @@ on: branches: [ main ] jobs: - lint-test-docs: + lint-test: runs-on: ${{ matrix.os }} strategy: matrix: @@ -17,6 +17,12 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 + - name: Install protoc + uses: arduino/setup-protoc@v3 + with: + version: '23.x' + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Install uv uses: astral-sh/setup-uv@v6 with: diff --git a/nexusmcp/__init__.py b/nexusmcp/__init__.py index de60797..8b70b41 100644 --- a/nexusmcp/__init__.py +++ b/nexusmcp/__init__.py @@ -1,9 +1,10 @@ -from temporalio import workflow +import temporalio.workflow + +with temporalio.workflow.unsafe.imports_passed_through(): + from nexusmcp import workflow -with workflow.unsafe.imports_passed_through(): from .inbound_gateway import InboundGateway from .service import MCPService from .service_handler import MCPServiceHandler, exclude - from .workflow_transport import WorkflowTransport -__all__ = ["MCPService", "MCPServiceHandler", "InboundGateway", "exclude", "WorkflowTransport"] +__all__ = ["MCPService", "MCPServiceHandler", "InboundGateway", "exclude", "workflow"] diff --git a/nexusmcp/workflow/__init__.py b/nexusmcp/workflow/__init__.py new file mode 100644 index 0000000..31561f6 --- /dev/null +++ b/nexusmcp/workflow/__init__.py @@ -0,0 +1,3 @@ +from .mcp_client import MCPClient + +__all__ = ["MCPClient"] diff --git a/nexusmcp/workflow/mcp_client.py b/nexusmcp/workflow/mcp_client.py new file mode 100644 index 0000000..dd7a3f8 --- /dev/null +++ b/nexusmcp/workflow/mcp_client.py @@ -0,0 +1,66 @@ +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator + +import mcp.types as types +from mcp.client.session import ClientSession +from mcp.server.lowlevel import Server +from mcp.shared.memory import create_connected_server_and_client_session +from temporalio import workflow + +from nexusmcp.service import MCPService + + +class MCPClient: + """ + An MCP client for use in Temporal workflows. + + This class provides a client that proxies MCP traffic from a Temporal Workflow to a Temporal + Nexus service. It works by running an MCP server in the workflow whose handlers delegate to + nexus operations, and connecting to it via an in-memory transport. + + Example: + ```python + client = MCPClient("my-endpoint") + async with client.connect() as session: + await session.list_tools() + await session.call_tool("my-service_my-operation", {"arg": "value"}) + ``` + """ + + def __init__( + self, + endpoint: str, + ): + self.endpoint = endpoint + # Run an in-workflow MCP server whose handlers make nexus calls + self.mcp_server = Server("workflow-gateway-mcp-server") + self.mcp_server.list_tools()(self._handle_list_tools) # type: ignore[no-untyped-call] + self.mcp_server.call_tool()(self._handle_call_tool) + + @asynccontextmanager + async def connect(self) -> AsyncGenerator[ClientSession, None]: + """ + Create a connected MCP ClientSession. + + The session is automatically initialized before being yielded. + """ + async with create_connected_server_and_client_session( + self.mcp_server, + raise_exceptions=True, + ) as session: + yield session + + async def _handle_list_tools(self) -> list[types.Tool]: + nexus_client = workflow.create_nexus_client( + endpoint=self.endpoint, + service=MCPService, + ) + return await nexus_client.execute_operation(MCPService.list_tools, None) + + async def _handle_call_tool(self, name: str, arguments: dict[str, Any]) -> Any: + service, _, operation = name.partition("_") + nexus_client = workflow.create_nexus_client( + endpoint=self.endpoint, + service=service, + ) + return await nexus_client.execute_operation(operation, arguments) diff --git a/nexusmcp/workflow_transport.py b/nexusmcp/workflow_transport.py deleted file mode 100644 index 682480d..0000000 --- a/nexusmcp/workflow_transport.py +++ /dev/null @@ -1,152 +0,0 @@ -import asyncio -from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator - -import anyio -import mcp.types as types -import pydantic -from mcp.shared.message import SessionMessage -from temporalio import workflow - -from .service import MCPService - - -class WorkflowTransport: - """ - An MCP Transport for use in Temporal workflows. - - This class provides a transport that proxies MCP requests from a Temporal Workflow to a Temporal - Nexus service. It can be used to make MCP calls via `mcp.ClientSession` from Temporal workflow - code. - - Example: - ```python async with WorkflowNexusTransport("my-endpoint") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() await session.list_tools() await - session.call_tool("my-service/my-operation", {"arg": "value"}) - ``` - """ - - def __init__( - self, - endpoint: str, - ): - self.endpoint = endpoint - - @asynccontextmanager - async def connect( - self, - ) -> AsyncGenerator[ - tuple[ - anyio.streams.memory.MemoryObjectReceiveStream[SessionMessage], # pyright: ignore[reportAttributeAccessIssue] - anyio.streams.memory.MemoryObjectSendStream[SessionMessage], # pyright: ignore[reportAttributeAccessIssue] - ], - None, - ]: - client_write, transport_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated] - transport_write, client_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated] - - async def message_router() -> None: - try: - async for session_message in transport_read: - request = session_message.message.root - if not isinstance(request, types.JSONRPCRequest): - # Ignore e.g. types.JSONRPCNotification - continue - result: types.Result | types.ErrorData - try: - match request: - case types.JSONRPCRequest(method="initialize"): - result = self._handle_initialize( - types.InitializeRequestParams.model_validate(request.params) - ) - case types.JSONRPCRequest(method="tools/list"): - result = await self._handle_list_tools() - case types.JSONRPCRequest(method="tools/call"): - result = await self._handle_call_tool( - types.CallToolRequestParams.model_validate(request.params) - ) - case _: - result = types.ErrorData( - code=types.METHOD_NOT_FOUND, message=f"Unknown method: {request.method}" - ) - except pydantic.ValidationError as e: - result = types.ErrorData(code=types.INVALID_PARAMS, message=f"Invalid request: {e}") - - match result: - case types.Result(): - response = self._json_rpc_result_response(request, result) - case types.ErrorData(): - response = self._json_rpc_error_response(request, result) - - await transport_write.send(SessionMessage(types.JSONRPCMessage(root=response))) - - except anyio.ClosedResourceError: - pass - finally: - await transport_write.aclose() - - router_task = asyncio.create_task(message_router()) - - try: - yield client_read, client_write - finally: - await client_write.aclose() - router_task.cancel() - try: - await router_task - except asyncio.CancelledError: - pass - await transport_read.aclose() - - def _handle_initialize(self, params: types.InitializeRequestParams) -> types.InitializeResult: - # TODO: MCPService should implement this - return types.InitializeResult( - protocolVersion="2024-11-05", - capabilities=types.ServerCapabilities(tools=types.ToolsCapability()), - serverInfo=types.Implementation( - name="nexus-mcp-transport", - version="0.1.0", - ), - ) - - async def _handle_list_tools(self) -> types.ListToolsResult: - nexus_client = workflow.create_nexus_client( - endpoint=self.endpoint, - service=MCPService, - ) - tools = await nexus_client.execute_operation(MCPService.list_tools, None) - return types.ListToolsResult(tools=tools) - - async def _handle_call_tool(self, params: types.CallToolRequestParams) -> types.CallToolResult: - service, _, operation = params.name.partition("_") - nexus_client = workflow.create_nexus_client( - endpoint=self.endpoint, - service=service, - ) - result: Any = await nexus_client.execute_operation( - operation, - params.arguments or {}, - ) - if isinstance(result, dict): - return types.CallToolResult(content=[], structuredContent=result) - else: - return types.CallToolResult(content=[types.TextContent(type="text", text=str(result))]) - - def _json_rpc_error_response(self, request: types.JSONRPCRequest, error: types.ErrorData) -> types.JSONRPCResponse: - return types.JSONRPCResponse.model_validate( - { - "jsonrpc": "2.0", - "id": request.id, - "error": error.model_dump(), - } - ) - - def _json_rpc_result_response(self, request: types.JSONRPCRequest, result: types.Result) -> types.JSONRPCResponse: - return types.JSONRPCResponse.model_validate( - { - "jsonrpc": "2.0", - "id": request.id, - "result": result.model_dump(), - } - ) diff --git a/pyproject.toml b/pyproject.toml index bcda848..3ea0497 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ dependencies = [ "mcp>=1.13.0", "nexus-rpc>=1.1.0", "pydantic>=2.11.7", - "temporalio>=1.15.0", + "temporalio>=1.16.0", ] [dependency-groups] @@ -54,6 +54,3 @@ format = [ {cmd = "uv run ruff check --select I --fix"}, {cmd = "uv run ruff format"}, ] - -[tool.uv.sources] -temporalio = { git = "https://github.com/temporalio/sdk-python" } diff --git a/tests/test_workflow_caller.py b/tests/test_workflow_caller.py index b3458e8..ecd4d89 100644 --- a/tests/test_workflow_caller.py +++ b/tests/test_workflow_caller.py @@ -2,7 +2,6 @@ from dataclasses import dataclass import pytest -from mcp import ClientSession from temporalio import workflow from temporalio.api.nexus.v1 import EndpointSpec, EndpointTarget from temporalio.api.operatorservice.v1 import CreateNexusEndpointRequest @@ -10,7 +9,7 @@ from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from nexusmcp import WorkflowTransport +import nexusmcp.workflow from .service import TestServiceHandler, mcp_service @@ -26,18 +25,15 @@ class MCPCallerWorkflowInput: class MCPCallerWorkflow: @workflow.run async def run(self, input: MCPCallerWorkflowInput) -> None: - transport = WorkflowTransport(input.endpoint) - async with transport.connect() as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() + async with nexusmcp.workflow.MCPClient(input.endpoint).connect() as session: + # Session is already initialized + list_tools_result = await session.list_tools() + assert len(list_tools_result.tools) == 2 + assert list_tools_result.tools[0].name == "modified-service-name_modified-op-name" + assert list_tools_result.tools[1].name == "modified-service-name_op2" - list_tools_result = await session.list_tools() - assert len(list_tools_result.tools) == 2 - assert list_tools_result.tools[0].name == "modified-service-name_modified-op-name" - assert list_tools_result.tools[1].name == "modified-service-name_op2" - - call_result = await session.call_tool("modified-service-name_modified-op-name", {"name": "World"}) - assert call_result.structuredContent == {"message": "Hello, World"} + call_result = await session.call_tool("modified-service-name_modified-op-name", {"name": "World"}) + assert call_result.structuredContent == {"message": "Hello, World"} @pytest.mark.asyncio diff --git a/uv.lock b/uv.lock index 6d32849..69d4881 100644 --- a/uv.lock +++ b/uv.lock @@ -239,7 +239,7 @@ requires-dist = [ { name = "mcp", specifier = ">=1.13.0" }, { name = "nexus-rpc", specifier = ">=1.1.0" }, { name = "pydantic", specifier = ">=2.11.7" }, - { name = "temporalio", git = "https://github.com/temporalio/sdk-python" }, + { name = "temporalio", specifier = ">=1.16.0" }, ] [package.metadata.requires-dev] @@ -631,14 +631,21 @@ wheels = [ [[package]] name = "temporalio" -version = "1.15.0" -source = { git = "https://github.com/temporalio/sdk-python#2b5de918ce73fbab7a07ab6752a468e9e8f47be6" } +version = "1.16.0" +source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nexus-rpc" }, { name = "protobuf" }, { name = "types-protobuf" }, { name = "typing-extensions" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/f3/32/375ab75d0ebb468cf9c8abbc450a03d3a8c66401fc320b338bd8c00d36b4/temporalio-1.16.0.tar.gz", hash = "sha256:dd926f3e30626fd4edf5e0ce596b75ecb5bbe0e4a0281e545ac91b5577967c91", size = 1733873, upload-time = "2025-08-21T22:12:50.879Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/36/12bb7234c83ddca4b8b032c8f1a9e07a03067c6ed6d2ddb39c770a4c87c6/temporalio-1.16.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:547c0853310350d3e5b5b9c806246cbf2feb523f685b05bf14ec1b0ece8a7bb6", size = 12540769, upload-time = "2025-08-21T22:11:24.551Z" }, + { url = "https://files.pythonhosted.org/packages/3c/16/a7d402435b8f994979abfeffd3f5ffcaaeada467ac16438e61c51c9f7abe/temporalio-1.16.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b05bb0d06025645aed6f936615311a6774eb8dc66280f32a810aac2283e1258", size = 12968631, upload-time = "2025-08-21T22:11:48.375Z" }, + { url = "https://files.pythonhosted.org/packages/11/6f/16663eef877b61faa5fd917b3a63497416ec4319195af75f6169a1594479/temporalio-1.16.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a08aed4e0f6c2b6bfc779b714e91dfe8c8491a0ddb4c4370627bb07f9bddcfd", size = 13164612, upload-time = "2025-08-21T22:12:16.366Z" }, + { url = "https://files.pythonhosted.org/packages/af/0e/8c6704ca7033aa09dc084f285d70481d758972cc341adc3c84d5f82f7b01/temporalio-1.16.0-cp39-abi3-win_amd64.whl", hash = "sha256:7c190362b0d7254f1f93fb71456063e7b299ac85a89f6227758af82c6a5aa65b", size = 13177058, upload-time = "2025-08-21T22:12:44.239Z" }, +] [[package]] name = "types-protobuf"