|
4 | 4 |
|
5 | 5 | from typing import Any |
6 | 6 |
|
7 | | -import anyio |
8 | 7 | import pytest |
9 | 8 |
|
| 9 | +from mcp.client._memory import InMemoryTransport |
10 | 10 | from mcp.client.session import ClientSession |
11 | 11 | from mcp.server.mcpserver import Context, MCPServer |
12 | 12 | from mcp.shared._context import RequestContext |
|
16 | 16 | OnNotificationFn, |
17 | 17 | OnRequestFn, |
18 | 18 | ) |
19 | | -from mcp.shared.memory import create_client_server_memory_streams |
20 | 19 | from mcp.shared.message import MessageMetadata |
21 | 20 | from mcp.types import ( |
22 | 21 | CreateMessageRequestParams, |
@@ -96,25 +95,18 @@ async def sampling_callback( |
96 | 95 | stop_reason="endTurn", |
97 | 96 | ) |
98 | 97 |
|
99 | | - async with create_client_server_memory_streams() as (client_streams, server_streams): |
100 | | - client_read, client_write = client_streams |
101 | | - server_read, server_write = server_streams |
102 | | - |
103 | | - # The spy wraps a real JSON-RPC dispatcher so the server side works unchanged. |
104 | | - # What matters is that ClientSession has no idea it isn't the default. |
| 98 | + # InMemoryTransport runs the server for us and yields client-side streams — |
| 99 | + # we intercept those streams and feed them through a custom dispatcher. |
| 100 | + async with InMemoryTransport(app) as (client_read, client_write): |
105 | 101 | inner = JSONRPCDispatcher(client_read, client_write, response_routers=[]) |
106 | 102 | spy = SpyDispatcher(inner) |
107 | 103 |
|
108 | | - async with anyio.create_task_group() as tg: |
109 | | - server = app._lowlevel_server # type: ignore[reportPrivateUsage] |
110 | | - tg.start_soon(lambda: server.run(server_read, server_write, server.create_initialization_options())) |
111 | | - |
112 | | - async with ClientSession(dispatcher=spy, sampling_callback=sampling_callback) as session: |
113 | | - await session.initialize() |
114 | | - result = await session.call_tool("ask", {"question": "meaning of life?"}) |
115 | | - assert result.content[0].text == "42" # type: ignore[union-attr] |
116 | | - |
117 | | - tg.cancel_scope.cancel() |
| 104 | + async with ClientSession(dispatcher=spy, sampling_callback=sampling_callback) as session: |
| 105 | + await session.initialize() |
| 106 | + result = await session.call_tool("ask", {"question": "meaning of life?"}) |
| 107 | + content = result.content[0] |
| 108 | + assert isinstance(content, TextContent) |
| 109 | + assert content.text == "42" |
118 | 110 |
|
119 | 111 | # initialize, tools/call (triggers sampling on the server), tools/list (schema refresh) |
120 | 112 | assert [r["method"] for r in spy.sent_requests] == ["initialize", "tools/call", "tools/list"] |
|
0 commit comments