Skip to content

Commit 67201a9

Browse files
authored
test: fix WS test port race; narrow to single smoke test covering both transport ends (#2267)
1 parent 7826ade commit 67201a9

File tree

4 files changed

+97
-187
lines changed

4 files changed

+97
-187
lines changed

src/mcp/server/websocket.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from mcp.shared.message import SessionMessage
1111

1212

13-
@asynccontextmanager # pragma: no cover
13+
@asynccontextmanager
1414
async def websocket_server(scope: Scope, receive: Receive, send: Send):
1515
"""WebSocket server transport for MCP. This is an ASGI application, suitable for use
1616
with a framework like Starlette and a server like Hypercorn.
@@ -34,13 +34,13 @@ async def ws_reader():
3434
async for msg in websocket.iter_text():
3535
try:
3636
client_message = types.jsonrpc_message_adapter.validate_json(msg, by_name=False)
37-
except ValidationError as exc:
37+
except ValidationError as exc: # pragma: no cover
3838
await read_stream_writer.send(exc)
3939
continue
4040

4141
session_message = SessionMessage(client_message)
4242
await read_stream_writer.send(session_message)
43-
except anyio.ClosedResourceError:
43+
except anyio.ClosedResourceError: # pragma: no cover
4444
await websocket.close()
4545

4646
async def ws_writer():
@@ -49,7 +49,7 @@ async def ws_writer():
4949
async for session_message in write_stream_reader:
5050
obj = session_message.message.model_dump_json(by_alias=True, exclude_unset=True)
5151
await websocket.send_text(obj)
52-
except anyio.ClosedResourceError:
52+
except anyio.ClosedResourceError: # pragma: no cover
5353
await websocket.close()
5454

5555
async with anyio.create_task_group() as tg:

tests/client/test_client.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99
from inline_snapshot import snapshot
1010

11-
from mcp import types
11+
from mcp import MCPError, types
1212
from mcp.client._memory import InMemoryTransport
1313
from mcp.client.client import Client
1414
from mcp.server import Server, ServerRequestContext
@@ -175,6 +175,21 @@ async def test_read_resource(app: MCPServer):
175175
)
176176

177177

178+
async def test_read_resource_error_propagates():
179+
"""MCPError raised by a server handler propagates to the client with its code intact."""
180+
181+
async def handle_read_resource(
182+
ctx: ServerRequestContext, params: types.ReadResourceRequestParams
183+
) -> ReadResourceResult:
184+
raise MCPError(code=404, message="no resource with that URI was found")
185+
186+
server = Server("test", on_read_resource=handle_read_resource)
187+
async with Client(server) as client:
188+
with pytest.raises(MCPError) as exc_info:
189+
await client.read_resource("unknown://example")
190+
assert exc_info.value.error.code == 404
191+
192+
178193
async def test_get_prompt(app: MCPServer):
179194
"""Test getting a prompt."""
180195
async with Client(app) as client:

tests/shared/test_ws.py

Lines changed: 23 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -1,210 +1,51 @@
1-
import multiprocessing
2-
import socket
3-
from collections.abc import AsyncGenerator, Generator
4-
from urllib.parse import urlparse
1+
"""Smoke test for the WebSocket transport.
2+
3+
Runs the full WS stack end-to-end over a real TCP connection, covering both
4+
``src/mcp/client/websocket.py`` and ``src/mcp/server/websocket.py``. MCP
5+
semantics (error propagation, timeouts, etc.) are transport-agnostic and are
6+
covered in ``tests/client/test_client.py`` and ``tests/issues/test_88_random_error.py``.
7+
"""
8+
9+
from collections.abc import Generator
510

6-
import anyio
711
import pytest
8-
import uvicorn
912
from starlette.applications import Starlette
1013
from starlette.routing import WebSocketRoute
1114
from starlette.websockets import WebSocket
1215

13-
from mcp import MCPError
1416
from mcp.client.session import ClientSession
1517
from mcp.client.websocket import websocket_client
16-
from mcp.server import Server, ServerRequestContext
18+
from mcp.server import Server
1719
from mcp.server.websocket import websocket_server
18-
from mcp.types import (
19-
CallToolRequestParams,
20-
CallToolResult,
21-
EmptyResult,
22-
InitializeResult,
23-
ListToolsResult,
24-
PaginatedRequestParams,
25-
ReadResourceRequestParams,
26-
ReadResourceResult,
27-
TextContent,
28-
TextResourceContents,
29-
Tool,
30-
)
31-
from tests.test_helpers import wait_for_server
20+
from mcp.types import EmptyResult, InitializeResult
21+
from tests.test_helpers import run_uvicorn_in_thread
3222

3323
SERVER_NAME = "test_server_for_WS"
3424

3525

36-
@pytest.fixture
37-
def server_port() -> int:
38-
with socket.socket() as s:
39-
s.bind(("127.0.0.1", 0))
40-
return s.getsockname()[1]
41-
42-
43-
@pytest.fixture
44-
def server_url(server_port: int) -> str:
45-
return f"ws://127.0.0.1:{server_port}"
46-
47-
48-
async def handle_read_resource( # pragma: no cover
49-
ctx: ServerRequestContext, params: ReadResourceRequestParams
50-
) -> ReadResourceResult:
51-
parsed = urlparse(str(params.uri))
52-
if parsed.scheme == "foobar":
53-
return ReadResourceResult(
54-
contents=[TextResourceContents(uri=str(params.uri), text=f"Read {parsed.netloc}", mime_type="text/plain")]
55-
)
56-
elif parsed.scheme == "slow":
57-
await anyio.sleep(2.0)
58-
return ReadResourceResult(
59-
contents=[
60-
TextResourceContents(
61-
uri=str(params.uri), text=f"Slow response from {parsed.netloc}", mime_type="text/plain"
62-
)
63-
]
64-
)
65-
raise MCPError(code=404, message="OOPS! no resource with that URI was found")
66-
67-
68-
async def handle_list_tools( # pragma: no cover
69-
ctx: ServerRequestContext, params: PaginatedRequestParams | None
70-
) -> ListToolsResult:
71-
return ListToolsResult(
72-
tools=[
73-
Tool(
74-
name="test_tool",
75-
description="A test tool",
76-
input_schema={"type": "object", "properties": {}},
77-
)
78-
]
79-
)
80-
26+
def make_server_app() -> Starlette:
27+
srv = Server(SERVER_NAME)
8128

82-
async def handle_call_tool( # pragma: no cover
83-
ctx: ServerRequestContext, params: CallToolRequestParams
84-
) -> CallToolResult:
85-
return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")])
86-
87-
88-
def _create_server() -> Server: # pragma: no cover
89-
return Server(
90-
SERVER_NAME,
91-
on_read_resource=handle_read_resource,
92-
on_list_tools=handle_list_tools,
93-
on_call_tool=handle_call_tool,
94-
)
95-
96-
97-
# Test fixtures
98-
def make_server_app() -> Starlette: # pragma: no cover
99-
"""Create test Starlette app with WebSocket transport"""
100-
server = _create_server()
101-
102-
async def handle_ws(websocket: WebSocket):
29+
async def handle_ws(websocket: WebSocket) -> None:
10330
async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams:
104-
await server.run(streams[0], streams[1], server.create_initialization_options())
105-
106-
app = Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)])
107-
return app
108-
109-
110-
def run_server(server_port: int) -> None: # pragma: no cover
111-
app = make_server_app()
112-
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
113-
print(f"starting server on {server_port}")
114-
server.run()
115-
116-
117-
@pytest.fixture()
118-
def server(server_port: int) -> Generator[None, None, None]:
119-
proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True)
120-
print("starting process")
121-
proc.start()
31+
await srv.run(streams[0], streams[1], srv.create_initialization_options())
12232

123-
# Wait for server to be running
124-
print("waiting for server to start")
125-
wait_for_server(server_port)
33+
return Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)])
12634

127-
yield
12835

129-
print("killing server")
130-
# Signal the server to stop
131-
proc.kill()
132-
proc.join(timeout=2)
133-
if proc.is_alive(): # pragma: no cover
134-
print("server process failed to terminate")
135-
136-
137-
@pytest.fixture()
138-
async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]:
139-
"""Create and initialize a WebSocket client session"""
140-
async with websocket_client(server_url + "/ws") as streams:
141-
async with ClientSession(*streams) as session:
142-
# Test initialization
143-
result = await session.initialize()
144-
assert isinstance(result, InitializeResult)
145-
assert result.server_info.name == SERVER_NAME
146-
147-
# Test ping
148-
ping_result = await session.send_ping()
149-
assert isinstance(ping_result, EmptyResult)
150-
151-
yield session
36+
@pytest.fixture
37+
def ws_server_url() -> Generator[str, None, None]:
38+
with run_uvicorn_in_thread(make_server_app()) as base_url:
39+
yield base_url.replace("http://", "ws://") + "/ws"
15240

15341

154-
# Tests
15542
@pytest.mark.anyio
156-
async def test_ws_client_basic_connection(server: None, server_url: str) -> None:
157-
"""Test the WebSocket connection establishment"""
158-
async with websocket_client(server_url + "/ws") as streams:
43+
async def test_ws_client_basic_connection(ws_server_url: str) -> None:
44+
async with websocket_client(ws_server_url) as streams:
15945
async with ClientSession(*streams) as session:
160-
# Test initialization
16146
result = await session.initialize()
16247
assert isinstance(result, InitializeResult)
16348
assert result.server_info.name == SERVER_NAME
16449

165-
# Test ping
16650
ping_result = await session.send_ping()
16751
assert isinstance(ping_result, EmptyResult)
168-
169-
170-
@pytest.mark.anyio
171-
async def test_ws_client_happy_request_and_response(
172-
initialized_ws_client_session: ClientSession,
173-
) -> None:
174-
"""Test a successful request and response via WebSocket"""
175-
result = await initialized_ws_client_session.read_resource("foobar://example")
176-
assert isinstance(result, ReadResourceResult)
177-
assert isinstance(result.contents, list)
178-
assert len(result.contents) > 0
179-
assert isinstance(result.contents[0], TextResourceContents)
180-
assert result.contents[0].text == "Read example"
181-
182-
183-
@pytest.mark.anyio
184-
async def test_ws_client_exception_handling(
185-
initialized_ws_client_session: ClientSession,
186-
) -> None:
187-
"""Test exception handling in WebSocket communication"""
188-
with pytest.raises(MCPError) as exc_info:
189-
await initialized_ws_client_session.read_resource("unknown://example")
190-
assert exc_info.value.error.code == 404
191-
192-
193-
@pytest.mark.anyio
194-
async def test_ws_client_timeout(
195-
initialized_ws_client_session: ClientSession,
196-
) -> None:
197-
"""Test timeout handling in WebSocket communication"""
198-
# Set a very short timeout to trigger a timeout exception
199-
with pytest.raises(TimeoutError):
200-
with anyio.fail_after(0.1): # 100ms timeout
201-
await initialized_ws_client_session.read_resource("slow://example")
202-
203-
# Now test that we can still use the session after a timeout
204-
with anyio.fail_after(5): # Longer timeout to allow completion
205-
result = await initialized_ws_client_session.read_resource("foobar://example")
206-
assert isinstance(result, ReadResourceResult)
207-
assert isinstance(result.contents, list)
208-
assert len(result.contents) > 0
209-
assert isinstance(result.contents[0], TextResourceContents)
210-
assert result.contents[0].text == "Read example"

tests/test_helpers.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,61 @@
11
"""Common test utilities for MCP server tests."""
22

33
import socket
4+
import threading
45
import time
6+
from collections.abc import Generator
7+
from contextlib import contextmanager
8+
from typing import Any
9+
10+
import uvicorn
11+
12+
_SERVER_SHUTDOWN_TIMEOUT_S = 5.0
13+
14+
15+
@contextmanager
16+
def run_uvicorn_in_thread(app: Any, **config_kwargs: Any) -> Generator[str, None, None]:
17+
"""Run a uvicorn server in a background thread on an ephemeral port.
18+
19+
The socket is bound and put into listening state *before* the thread
20+
starts, so the port is known immediately with no wait. The kernel's
21+
listen queue buffers any connections that arrive before uvicorn's event
22+
loop reaches ``accept()``, so callers can connect as soon as this
23+
function yields — no polling, no sleeps, no startup race.
24+
25+
This also avoids the TOCTOU race of the old pick-a-port-then-rebind
26+
pattern: the socket passed here is the one uvicorn serves on, with no
27+
gap where another pytest-xdist worker could claim it.
28+
29+
Args:
30+
app: ASGI application to serve.
31+
**config_kwargs: Additional keyword arguments for :class:`uvicorn.Config`
32+
(e.g. ``log_level``). ``host``/``port`` are ignored since the
33+
socket is pre-bound.
34+
35+
Yields:
36+
The base URL of the running server, e.g. ``http://127.0.0.1:54321``.
37+
"""
38+
host = "127.0.0.1"
39+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
40+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
41+
sock.bind((host, 0))
42+
sock.listen()
43+
port = sock.getsockname()[1]
44+
45+
config_kwargs.setdefault("log_level", "error")
46+
# Uvicorn's interface autodetection calls asyncio.iscoroutinefunction,
47+
# which Python 3.14 deprecates. Under filterwarnings=error this crashes
48+
# the server thread silently. Starlette is asgi3; skip the autodetect.
49+
config_kwargs.setdefault("interface", "asgi3")
50+
server = uvicorn.Server(config=uvicorn.Config(app=app, **config_kwargs))
51+
52+
thread = threading.Thread(target=server.run, kwargs={"sockets": [sock]}, daemon=True)
53+
thread.start()
54+
try:
55+
yield f"http://{host}:{port}"
56+
finally:
57+
server.should_exit = True
58+
thread.join(timeout=_SERVER_SHUTDOWN_TIMEOUT_S)
559

660

761
def wait_for_server(port: int, timeout: float = 20.0) -> None:

0 commit comments

Comments
 (0)