|
1 | | -import multiprocessing |
2 | | -import socket |
3 | | -from collections.abc import AsyncGenerator, Generator |
4 | | -from urllib.parse import urlparse |
| 1 | +"""Smoke test for the WebSocket transport. |
| 2 | +
|
| 3 | +Runs the full WS stack end-to-end over a real TCP connection, covering both |
| 4 | +``src/mcp/client/websocket.py`` and ``src/mcp/server/websocket.py``. MCP |
| 5 | +semantics (error propagation, timeouts, etc.) are transport-agnostic and are |
| 6 | +covered in ``tests/client/test_client.py`` and ``tests/issues/test_88_random_error.py``. |
| 7 | +""" |
| 8 | + |
| 9 | +from collections.abc import Generator |
5 | 10 |
|
6 | | -import anyio |
7 | 11 | import pytest |
8 | | -import uvicorn |
9 | 12 | from starlette.applications import Starlette |
10 | 13 | from starlette.routing import WebSocketRoute |
11 | 14 | from starlette.websockets import WebSocket |
12 | 15 |
|
13 | | -from mcp import MCPError |
14 | 16 | from mcp.client.session import ClientSession |
15 | 17 | from mcp.client.websocket import websocket_client |
16 | | -from mcp.server import Server, ServerRequestContext |
| 18 | +from mcp.server import Server |
17 | 19 | from mcp.server.websocket import websocket_server |
18 | | -from mcp.types import ( |
19 | | - CallToolRequestParams, |
20 | | - CallToolResult, |
21 | | - EmptyResult, |
22 | | - InitializeResult, |
23 | | - ListToolsResult, |
24 | | - PaginatedRequestParams, |
25 | | - ReadResourceRequestParams, |
26 | | - ReadResourceResult, |
27 | | - TextContent, |
28 | | - TextResourceContents, |
29 | | - Tool, |
30 | | -) |
31 | | -from tests.test_helpers import wait_for_server |
| 20 | +from mcp.types import EmptyResult, InitializeResult |
| 21 | +from tests.test_helpers import run_uvicorn_in_thread |
32 | 22 |
|
33 | 23 | SERVER_NAME = "test_server_for_WS" |
34 | 24 |
|
35 | 25 |
|
36 | | -@pytest.fixture |
37 | | -def server_port() -> int: |
38 | | - with socket.socket() as s: |
39 | | - s.bind(("127.0.0.1", 0)) |
40 | | - return s.getsockname()[1] |
41 | | - |
42 | | - |
43 | | -@pytest.fixture |
44 | | -def server_url(server_port: int) -> str: |
45 | | - return f"ws://127.0.0.1:{server_port}" |
46 | | - |
47 | | - |
48 | | -async def handle_read_resource( # pragma: no cover |
49 | | - ctx: ServerRequestContext, params: ReadResourceRequestParams |
50 | | -) -> ReadResourceResult: |
51 | | - parsed = urlparse(str(params.uri)) |
52 | | - if parsed.scheme == "foobar": |
53 | | - return ReadResourceResult( |
54 | | - contents=[TextResourceContents(uri=str(params.uri), text=f"Read {parsed.netloc}", mime_type="text/plain")] |
55 | | - ) |
56 | | - elif parsed.scheme == "slow": |
57 | | - await anyio.sleep(2.0) |
58 | | - return ReadResourceResult( |
59 | | - contents=[ |
60 | | - TextResourceContents( |
61 | | - uri=str(params.uri), text=f"Slow response from {parsed.netloc}", mime_type="text/plain" |
62 | | - ) |
63 | | - ] |
64 | | - ) |
65 | | - raise MCPError(code=404, message="OOPS! no resource with that URI was found") |
66 | | - |
67 | | - |
68 | | -async def handle_list_tools( # pragma: no cover |
69 | | - ctx: ServerRequestContext, params: PaginatedRequestParams | None |
70 | | -) -> ListToolsResult: |
71 | | - return ListToolsResult( |
72 | | - tools=[ |
73 | | - Tool( |
74 | | - name="test_tool", |
75 | | - description="A test tool", |
76 | | - input_schema={"type": "object", "properties": {}}, |
77 | | - ) |
78 | | - ] |
79 | | - ) |
80 | | - |
| 26 | +def make_server_app() -> Starlette: |
| 27 | + srv = Server(SERVER_NAME) |
81 | 28 |
|
82 | | -async def handle_call_tool( # pragma: no cover |
83 | | - ctx: ServerRequestContext, params: CallToolRequestParams |
84 | | -) -> CallToolResult: |
85 | | - return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) |
86 | | - |
87 | | - |
88 | | -def _create_server() -> Server: # pragma: no cover |
89 | | - return Server( |
90 | | - SERVER_NAME, |
91 | | - on_read_resource=handle_read_resource, |
92 | | - on_list_tools=handle_list_tools, |
93 | | - on_call_tool=handle_call_tool, |
94 | | - ) |
95 | | - |
96 | | - |
97 | | -# Test fixtures |
98 | | -def make_server_app() -> Starlette: # pragma: no cover |
99 | | - """Create test Starlette app with WebSocket transport""" |
100 | | - server = _create_server() |
101 | | - |
102 | | - async def handle_ws(websocket: WebSocket): |
| 29 | + async def handle_ws(websocket: WebSocket) -> None: |
103 | 30 | async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams: |
104 | | - await server.run(streams[0], streams[1], server.create_initialization_options()) |
105 | | - |
106 | | - app = Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)]) |
107 | | - return app |
108 | | - |
109 | | - |
110 | | -def run_server(server_port: int) -> None: # pragma: no cover |
111 | | - app = make_server_app() |
112 | | - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) |
113 | | - print(f"starting server on {server_port}") |
114 | | - server.run() |
115 | | - |
116 | | - |
117 | | -@pytest.fixture() |
118 | | -def server(server_port: int) -> Generator[None, None, None]: |
119 | | - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) |
120 | | - print("starting process") |
121 | | - proc.start() |
| 31 | + await srv.run(streams[0], streams[1], srv.create_initialization_options()) |
122 | 32 |
|
123 | | - # Wait for server to be running |
124 | | - print("waiting for server to start") |
125 | | - wait_for_server(server_port) |
| 33 | + return Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)]) |
126 | 34 |
|
127 | | - yield |
128 | 35 |
|
129 | | - print("killing server") |
130 | | - # Signal the server to stop |
131 | | - proc.kill() |
132 | | - proc.join(timeout=2) |
133 | | - if proc.is_alive(): # pragma: no cover |
134 | | - print("server process failed to terminate") |
135 | | - |
136 | | - |
137 | | -@pytest.fixture() |
138 | | -async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: |
139 | | - """Create and initialize a WebSocket client session""" |
140 | | - async with websocket_client(server_url + "/ws") as streams: |
141 | | - async with ClientSession(*streams) as session: |
142 | | - # Test initialization |
143 | | - result = await session.initialize() |
144 | | - assert isinstance(result, InitializeResult) |
145 | | - assert result.server_info.name == SERVER_NAME |
146 | | - |
147 | | - # Test ping |
148 | | - ping_result = await session.send_ping() |
149 | | - assert isinstance(ping_result, EmptyResult) |
150 | | - |
151 | | - yield session |
| 36 | +@pytest.fixture |
| 37 | +def ws_server_url() -> Generator[str, None, None]: |
| 38 | + with run_uvicorn_in_thread(make_server_app()) as base_url: |
| 39 | + yield base_url.replace("http://", "ws://") + "/ws" |
152 | 40 |
|
153 | 41 |
|
154 | | -# Tests |
155 | 42 | @pytest.mark.anyio |
156 | | -async def test_ws_client_basic_connection(server: None, server_url: str) -> None: |
157 | | - """Test the WebSocket connection establishment""" |
158 | | - async with websocket_client(server_url + "/ws") as streams: |
| 43 | +async def test_ws_client_basic_connection(ws_server_url: str) -> None: |
| 44 | + async with websocket_client(ws_server_url) as streams: |
159 | 45 | async with ClientSession(*streams) as session: |
160 | | - # Test initialization |
161 | 46 | result = await session.initialize() |
162 | 47 | assert isinstance(result, InitializeResult) |
163 | 48 | assert result.server_info.name == SERVER_NAME |
164 | 49 |
|
165 | | - # Test ping |
166 | 50 | ping_result = await session.send_ping() |
167 | 51 | assert isinstance(ping_result, EmptyResult) |
168 | | - |
169 | | - |
170 | | -@pytest.mark.anyio |
171 | | -async def test_ws_client_happy_request_and_response( |
172 | | - initialized_ws_client_session: ClientSession, |
173 | | -) -> None: |
174 | | - """Test a successful request and response via WebSocket""" |
175 | | - result = await initialized_ws_client_session.read_resource("foobar://example") |
176 | | - assert isinstance(result, ReadResourceResult) |
177 | | - assert isinstance(result.contents, list) |
178 | | - assert len(result.contents) > 0 |
179 | | - assert isinstance(result.contents[0], TextResourceContents) |
180 | | - assert result.contents[0].text == "Read example" |
181 | | - |
182 | | - |
183 | | -@pytest.mark.anyio |
184 | | -async def test_ws_client_exception_handling( |
185 | | - initialized_ws_client_session: ClientSession, |
186 | | -) -> None: |
187 | | - """Test exception handling in WebSocket communication""" |
188 | | - with pytest.raises(MCPError) as exc_info: |
189 | | - await initialized_ws_client_session.read_resource("unknown://example") |
190 | | - assert exc_info.value.error.code == 404 |
191 | | - |
192 | | - |
193 | | -@pytest.mark.anyio |
194 | | -async def test_ws_client_timeout( |
195 | | - initialized_ws_client_session: ClientSession, |
196 | | -) -> None: |
197 | | - """Test timeout handling in WebSocket communication""" |
198 | | - # Set a very short timeout to trigger a timeout exception |
199 | | - with pytest.raises(TimeoutError): |
200 | | - with anyio.fail_after(0.1): # 100ms timeout |
201 | | - await initialized_ws_client_session.read_resource("slow://example") |
202 | | - |
203 | | - # Now test that we can still use the session after a timeout |
204 | | - with anyio.fail_after(5): # Longer timeout to allow completion |
205 | | - result = await initialized_ws_client_session.read_resource("foobar://example") |
206 | | - assert isinstance(result, ReadResourceResult) |
207 | | - assert isinstance(result.contents, list) |
208 | | - assert len(result.contents) > 0 |
209 | | - assert isinstance(result.contents[0], TextResourceContents) |
210 | | - assert result.contents[0].text == "Read example" |
0 commit comments