Skip to content

Commit da65336

Browse files
committed
test: cover raw invalid UTF-8 stdio regression
1 parent fb2276b commit da65336

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""Regression test for issue #2328 - raw invalid UTF-8 over stdio."""
2+
3+
import io
4+
from io import TextIOWrapper
5+
6+
import anyio
7+
import pytest
8+
from pydantic import AnyHttpUrl, TypeAdapter
9+
10+
from mcp.server.mcpserver import MCPServer
11+
from mcp.server.stdio import stdio_server
12+
from mcp.types import JSONRPCResponse, jsonrpc_message_adapter
13+
14+
15+
@pytest.mark.anyio
16+
async def test_raw_invalid_utf8_stdio_request_does_not_crash_server() -> None:
17+
mcp = MCPServer("test")
18+
19+
@mcp.tool()
20+
async def fetch(url: str) -> str:
21+
# Delay validation so stdin can reach EOF and close the session write
22+
# stream before the tool returns its validation failure.
23+
await anyio.sleep(0.1)
24+
return str(TypeAdapter(AnyHttpUrl).validate_python(url))
25+
26+
initialize = (
27+
b'{"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": '
28+
b'{"protocolVersion": "2024-11-05", "capabilities": {}, "clientInfo": '
29+
b'{"name": "test", "version": "1.0"}}}\n'
30+
)
31+
initialized = b'{"jsonrpc": "2.0", "method": "notifications/initialized"}\n'
32+
malformed_call = (
33+
b'{"jsonrpc": "2.0", "id": 3, "method": "tools/call", "params": '
34+
b'{"name": "fetch", "arguments": {"url": "http://x\xff\xfe"}}}\n'
35+
)
36+
raw_stdin = io.BytesIO(initialize + initialized + malformed_call)
37+
stdout = io.StringIO()
38+
39+
async with stdio_server(
40+
stdin=anyio.AsyncFile(TextIOWrapper(raw_stdin, encoding="utf-8", errors="replace")),
41+
stdout=anyio.AsyncFile(stdout),
42+
) as (read_stream, write_stream):
43+
with anyio.fail_after(5):
44+
await mcp._lowlevel_server.run(
45+
read_stream,
46+
write_stream,
47+
mcp._lowlevel_server.create_initialization_options(),
48+
)
49+
50+
stdout.seek(0)
51+
output_lines = [line.strip() for line in stdout.readlines() if line.strip()]
52+
53+
assert output_lines
54+
initialize_response = jsonrpc_message_adapter.validate_json(output_lines[0])
55+
assert isinstance(initialize_response, JSONRPCResponse)
56+
assert initialize_response.id == 1

0 commit comments

Comments
 (0)