Skip to content

Commit 51c53f2

Browse files
shivama205Shivam
andauthored
fix: accept wildcard media types in Accept header per RFC 7231 (#2152)
Co-authored-by: Shivam <shivam@Shivams-MacBook-Air-2.local>
1 parent 7ba41dc commit 51c53f2

File tree

2 files changed

+92
-9
lines changed

2 files changed

+92
-9
lines changed

src/mcp/server/streamable_http.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -391,12 +391,19 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
391391
await self._handle_unsupported_request(request, send)
392392

393393
def _check_accept_headers(self, request: Request) -> tuple[bool, bool]:
394-
"""Check if the request accepts the required media types."""
394+
"""Check if the request accepts the required media types.
395+
396+
Supports wildcard media types per RFC 7231, section 5.3.2:
397+
- */* matches any media type
398+
- application/* matches any application/ subtype
399+
- text/* matches any text/ subtype
400+
"""
395401
accept_header = request.headers.get("accept", "")
396-
accept_types = [media_type.strip() for media_type in accept_header.split(",")]
402+
accept_types = [media_type.strip().split(";")[0].strip().lower() for media_type in accept_header.split(",")]
397403

398-
has_json = any(media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types)
399-
has_sse = any(media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types)
404+
has_wildcard = "*/*" in accept_types
405+
has_json = has_wildcard or any(t in (CONTENT_TYPE_JSON, "application/*") for t in accept_types)
406+
has_sse = has_wildcard or any(t in (CONTENT_TYPE_SSE, "text/*") for t in accept_types)
400407

401408
return has_json, has_sse
402409

tests/shared/test_streamable_http.py

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -572,8 +572,10 @@ def json_server_url(json_server_port: int) -> str:
572572
# Basic request validation tests
573573
def test_accept_header_validation(basic_server: None, basic_server_url: str):
574574
"""Test that Accept header is properly validated."""
575-
# Test without Accept header
576-
response = requests.post(
575+
# Test without Accept header (suppress requests library default Accept: */*)
576+
session = requests.Session()
577+
session.headers.pop("Accept")
578+
response = session.post(
577579
f"{basic_server_url}/mcp",
578580
headers={"Content-Type": "application/json"},
579581
json={"jsonrpc": "2.0", "method": "initialize", "id": 1},
@@ -582,6 +584,52 @@ def test_accept_header_validation(basic_server: None, basic_server_url: str):
582584
assert "Not Acceptable" in response.text
583585

584586

587+
@pytest.mark.parametrize(
588+
"accept_header",
589+
[
590+
"*/*",
591+
"application/*, text/*",
592+
"text/*, application/json",
593+
"application/json, text/*",
594+
"*/*;q=0.8",
595+
"application/*;q=0.9, text/*;q=0.8",
596+
],
597+
)
598+
def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accept_header: str):
599+
"""Test that wildcard Accept headers are accepted per RFC 7231."""
600+
response = requests.post(
601+
f"{basic_server_url}/mcp",
602+
headers={
603+
"Accept": accept_header,
604+
"Content-Type": "application/json",
605+
},
606+
json=INIT_REQUEST,
607+
)
608+
assert response.status_code == 200
609+
610+
611+
@pytest.mark.parametrize(
612+
"accept_header",
613+
[
614+
"text/html",
615+
"application/*",
616+
"text/*",
617+
],
618+
)
619+
def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str):
620+
"""Test that incompatible Accept headers are rejected for SSE mode."""
621+
response = requests.post(
622+
f"{basic_server_url}/mcp",
623+
headers={
624+
"Accept": accept_header,
625+
"Content-Type": "application/json",
626+
},
627+
json=INIT_REQUEST,
628+
)
629+
assert response.status_code == 406
630+
assert "Not Acceptable" in response.text
631+
632+
585633
def test_content_type_validation(basic_server: None, basic_server_url: str):
586634
"""Test that Content-Type header is properly validated."""
587635
# Test with incorrect Content-Type
@@ -826,7 +874,10 @@ def test_json_response_accept_json_only(json_response_server: None, json_server_
826874
def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str):
827875
"""Test that json_response servers reject requests without Accept header."""
828876
mcp_url = f"{json_server_url}/mcp"
829-
response = requests.post(
877+
# Suppress requests library default Accept: */* header
878+
session = requests.Session()
879+
session.headers.pop("Accept")
880+
response = session.post(
830881
mcp_url,
831882
headers={
832883
"Content-Type": "application/json",
@@ -853,6 +904,29 @@ def test_json_response_incorrect_accept_header(json_response_server: None, json_
853904
assert "Not Acceptable" in response.text
854905

855906

907+
@pytest.mark.parametrize(
908+
"accept_header",
909+
[
910+
"*/*",
911+
"application/*",
912+
"application/*;q=0.9",
913+
],
914+
)
915+
def test_json_response_wildcard_accept_header(json_response_server: None, json_server_url: str, accept_header: str):
916+
"""Test that json_response servers accept wildcard Accept headers per RFC 7231."""
917+
mcp_url = f"{json_server_url}/mcp"
918+
response = requests.post(
919+
mcp_url,
920+
headers={
921+
"Accept": accept_header,
922+
"Content-Type": "application/json",
923+
},
924+
json=INIT_REQUEST,
925+
)
926+
assert response.status_code == 200
927+
assert response.headers.get("Content-Type") == "application/json"
928+
929+
856930
def test_get_sse_stream(basic_server: None, basic_server_url: str):
857931
"""Test establishing an SSE stream via GET request."""
858932
# First, we need to initialize a session
@@ -941,8 +1015,10 @@ def test_get_validation(basic_server: None, basic_server_url: str):
9411015
assert init_data is not None
9421016
negotiated_version = init_data["result"]["protocolVersion"]
9431017

944-
# Test without Accept header
945-
response = requests.get(
1018+
# Test without Accept header (suppress requests library default Accept: */*)
1019+
session = requests.Session()
1020+
session.headers.pop("Accept")
1021+
response = session.get(
9461022
mcp_url,
9471023
headers={
9481024
MCP_SESSION_ID_HEADER: session_id,

0 commit comments

Comments
 (0)