Skip to content

Commit 1ab51df

Browse files
committed
Fix Streamable HTTP Accept negotiation
1 parent fb2276b commit 1ab51df

File tree

3 files changed

+68
-13
lines changed

3 files changed

+68
-13
lines changed

src/mcp/server/streamable_http.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -426,16 +426,24 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se
426426
)
427427
await response(scope, request.receive, send)
428428
return False
429-
# For SSE responses, require both content types
430-
elif not (has_json and has_sse):
429+
# For SSE-capable responses, accept either JSON or SSE and negotiate later.
430+
elif not (has_json or has_sse):
431431
response = self._create_error_response(
432-
"Not Acceptable: Client must accept both application/json and text/event-stream",
432+
"Not Acceptable: Client must accept application/json or text/event-stream",
433433
HTTPStatus.NOT_ACCEPTABLE,
434434
)
435435
await response(scope, request.receive, send)
436436
return False
437437
return True
438438

439+
def _should_use_json_response(self, request: Request) -> bool:
440+
"""Choose JSON when required or when the client does not accept SSE."""
441+
if self.is_json_response_enabled:
442+
return True
443+
444+
has_json, has_sse = self._check_accept_headers(request)
445+
return has_json and not has_sse
446+
439447
async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None:
440448
"""Handle POST requests containing JSON-RPC messages."""
441449
writer = self._read_stream_writer
@@ -476,6 +484,8 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
476484
await response(scope, receive, send)
477485
return
478486

487+
use_json_response = self._should_use_json_response(request)
488+
479489
# Check if this is an initialization request
480490
is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize"
481491

@@ -527,7 +537,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
527537
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0)
528538
request_stream_reader = self._request_streams[request_id][1]
529539

530-
if self.is_json_response_enabled:
540+
if use_json_response:
531541
# Process the message
532542
metadata = ServerMessageMetadata(request_context=request)
533543
session_message = SessionMessage(message, metadata=metadata)

tests/issues/test_1363_race_condition_streamable_http.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ async def test_race_condition_invalid_accept_headers(caplog: pytest.LogCaptureFi
123123
"""Test the race condition with invalid Accept headers.
124124
125125
This test reproduces the exact scenario described in issue #1363:
126-
- Send POST request with incorrect Accept headers (missing either application/json or text/event-stream)
126+
- Send POST request with incorrect Accept headers that match neither JSON nor SSE
127127
- Request fails validation early and returns quickly
128128
- This should trigger the race condition where message_router encounters ClosedResourceError
129129
"""
@@ -137,34 +137,34 @@ async def test_race_condition_invalid_accept_headers(caplog: pytest.LogCaptureFi
137137

138138
# Suppress WARNING logs (expected validation errors) and capture ERROR logs
139139
with caplog.at_level(logging.ERROR):
140-
# Test with missing text/event-stream in Accept header
140+
# Test with an incompatible text media type
141141
async with httpx.AsyncClient(
142142
transport=httpx.ASGITransport(app=app), base_url="http://testserver", timeout=5.0
143143
) as client:
144144
response = await client.post(
145145
"/",
146146
json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}},
147147
headers={
148-
"Accept": "application/json", # Missing text/event-stream
148+
"Accept": "text/plain",
149149
"Content-Type": "application/json",
150150
},
151151
)
152-
# Should get 406 Not Acceptable due to missing text/event-stream
152+
# Should get 406 Not Acceptable for an unsupported media type
153153
assert response.status_code == 406
154154

155-
# Test with missing application/json in Accept header
155+
# Test with an incompatible application media type
156156
async with httpx.AsyncClient(
157157
transport=httpx.ASGITransport(app=app), base_url="http://testserver", timeout=5.0
158158
) as client:
159159
response = await client.post(
160160
"/",
161161
json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}},
162162
headers={
163-
"Accept": "text/event-stream", # Missing application/json
163+
"Accept": "application/xml",
164164
"Content-Type": "application/json",
165165
},
166166
)
167-
# Should get 406 Not Acceptable due to missing application/json
167+
# Should get 406 Not Acceptable for an unsupported media type
168168
assert response.status_code == 406
169169

170170
# Test with completely invalid Accept header

tests/shared/test_streamable_http.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,12 +608,57 @@ def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accep
608608
assert response.status_code == 200
609609

610610

611+
@pytest.mark.parametrize(
612+
("accept_header", "expected_content_type"),
613+
[
614+
("application/json", "application/json"),
615+
("text/event-stream", "text/event-stream"),
616+
],
617+
)
618+
def test_accept_header_single_media_type_negotiates_response(
619+
basic_server: None, basic_server_url: str, accept_header: str, expected_content_type: str
620+
):
621+
"""Test that SSE-capable servers negotiate JSON or SSE from a single accepted media type."""
622+
mcp_url = f"{basic_server_url}/mcp"
623+
init_response = requests.post(
624+
mcp_url,
625+
headers={
626+
"Accept": accept_header,
627+
"Content-Type": "application/json",
628+
},
629+
json=INIT_REQUEST,
630+
)
631+
assert init_response.status_code == 200
632+
assert init_response.headers.get("Content-Type") == expected_content_type
633+
634+
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
635+
assert session_id is not None
636+
637+
if expected_content_type == "application/json":
638+
negotiated_version = init_response.json()["result"]["protocolVersion"]
639+
else:
640+
negotiated_version = extract_protocol_version_from_sse(init_response)
641+
642+
tools_response = requests.post(
643+
mcp_url,
644+
headers={
645+
"Accept": accept_header,
646+
"Content-Type": "application/json",
647+
MCP_SESSION_ID_HEADER: session_id,
648+
MCP_PROTOCOL_VERSION_HEADER: negotiated_version,
649+
},
650+
json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-accept-single-media"},
651+
)
652+
assert tools_response.status_code == 200
653+
assert tools_response.headers.get("Content-Type") == expected_content_type
654+
655+
611656
@pytest.mark.parametrize(
612657
"accept_header",
613658
[
614659
"text/html",
615-
"application/*",
616-
"text/*",
660+
"text/plain",
661+
"application/xml",
617662
],
618663
)
619664
def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str):

0 commit comments

Comments
 (0)