Skip to content

Commit d0c7a71

Browse files
committed
fix: propagate pre-endpoint errors in sse_client instead of deadlocking
When sse_reader encounters an error before receiving the endpoint event, the except handler tried to send the exception to read_stream_writer. With a zero-buffer stream and no reader (the caller is still blocked in tg.start() waiting for task_status.started()), send() blocks forever. Track whether started() has fired. Before it, re-raise so the exception propagates through tg.start(). After it, send to the stream as before. This also adds a guard for the case where a server sends a message event before the endpoint event, which would deadlock on the same send() call. The dedicated SSEError handler from #975 is removed since the started flag now handles all pre-endpoint exceptions uniformly. Github-Issue: #447
1 parent 7ba4fb8 commit d0c7a71

File tree

2 files changed

+117
-9
lines changed

2 files changed

+117
-9
lines changed

src/mcp/client/sse.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from anyio.abc import TaskStatus
1010
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1111
from httpx_sse import aconnect_sse
12-
from httpx_sse._exceptions import SSEError
1312

1413
from mcp import types
1514
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
@@ -69,6 +68,12 @@ async def sse_client(
6968
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
7069

7170
async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
71+
# Before task_status.started() fires, the caller is blocked inside
72+
# tg.start() and nobody reads from read_stream. Sending to the
73+
# zero-buffer stream in that phase would deadlock, so errors must
74+
# be raised instead. After started(), the caller has the streams
75+
# and errors are delivered through read_stream.
76+
started = False
7277
try:
7378
async for sse in event_source.aiter_sse(): # pragma: no branch
7479
logger.debug(f"Received SSE event: {sse.event}")
@@ -79,27 +84,28 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
7984

8085
url_parsed = urlparse(url)
8186
endpoint_parsed = urlparse(endpoint_url)
82-
if ( # pragma: no cover
87+
if (
8388
url_parsed.netloc != endpoint_parsed.netloc
8489
or url_parsed.scheme != endpoint_parsed.scheme
8590
):
86-
error_msg = ( # pragma: no cover
91+
raise ValueError(
8792
f"Endpoint origin does not match connection origin: {endpoint_url}"
8893
)
89-
logger.error(error_msg) # pragma: no cover
90-
raise ValueError(error_msg) # pragma: no cover
9194

9295
if on_session_created:
9396
session_id = _extract_session_id_from_endpoint(endpoint_url)
9497
if session_id:
9598
on_session_created(session_id)
9699

97100
task_status.started(endpoint_url)
101+
started = True
98102

99103
case "message":
100104
# Skip empty data (keep-alive pings)
101105
if not sse.data:
102106
continue
107+
if not started:
108+
raise RuntimeError("Received message event before endpoint event")
103109
try:
104110
message = types.jsonrpc_message_adapter.validate_json(sse.data, by_name=False)
105111
logger.debug(f"Received server message: {message}")
@@ -112,11 +118,10 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
112118
await read_stream_writer.send(session_message)
113119
case _: # pragma: no cover
114120
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
115-
except SSEError as sse_exc: # pragma: lax no cover
116-
logger.exception("Encountered SSE exception")
117-
raise sse_exc
118-
except Exception as exc: # pragma: lax no cover
121+
except Exception as exc:
119122
logger.exception("Error in sse_reader")
123+
if not started:
124+
raise
120125
await read_stream_writer.send(exc)
121126
finally:
122127
await read_stream_writer.aclose()

tests/shared/test_sse.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import json
22
import multiprocessing
33
import socket
4+
import sys
45
from collections.abc import AsyncGenerator, Generator
56
from typing import Any
67
from unittest.mock import AsyncMock, MagicMock, Mock, patch
78
from urllib.parse import urlparse
89

10+
if sys.version_info < (3, 11): # pragma: lax no cover
11+
from exceptiongroup import BaseExceptionGroup
12+
913
import anyio
1014
import httpx
1115
import pytest
@@ -604,6 +608,105 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]:
604608
assert msg.message.id == 1
605609

606610

611+
def _mock_sse_connection(aiter_sse: AsyncGenerator[ServerSentEvent, None]) -> Any:
612+
"""Patch sse_client's HTTP layer to yield the given SSE event stream."""
613+
mock_event_source = MagicMock()
614+
mock_event_source.aiter_sse.return_value = aiter_sse
615+
mock_event_source.response.raise_for_status = MagicMock()
616+
617+
mock_aconnect_sse = MagicMock()
618+
mock_aconnect_sse.__aenter__ = AsyncMock(return_value=mock_event_source)
619+
mock_aconnect_sse.__aexit__ = AsyncMock(return_value=None)
620+
621+
mock_client = MagicMock()
622+
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
623+
mock_client.__aexit__ = AsyncMock(return_value=None)
624+
mock_client.post = AsyncMock(return_value=MagicMock(status_code=200, raise_for_status=MagicMock()))
625+
626+
return patch.multiple(
627+
"mcp.client.sse",
628+
create_mcp_http_client=Mock(return_value=mock_client),
629+
aconnect_sse=Mock(return_value=mock_aconnect_sse),
630+
)
631+
632+
633+
@pytest.mark.anyio
634+
async def test_sse_client_raises_on_endpoint_origin_mismatch() -> None:
635+
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447
636+
637+
When the server sends an endpoint URL with a different origin than the
638+
connection URL, sse_client must raise promptly instead of deadlocking.
639+
Before the fix, the ValueError was caught and sent to a zero-buffer stream
640+
with no reader, hanging forever.
641+
"""
642+
643+
async def events() -> AsyncGenerator[ServerSentEvent, None]:
644+
yield ServerSentEvent(event="endpoint", data="http://wrong-host:9999/messages?sessionId=abc")
645+
await anyio.sleep_forever() # pragma: no cover
646+
647+
with _mock_sse_connection(events()), anyio.fail_after(5):
648+
with pytest.raises(BaseExceptionGroup) as exc_info:
649+
async with sse_client("http://test/sse"): # pragma: no branch
650+
pytest.fail("sse_client should not yield on origin mismatch") # pragma: no cover
651+
assert exc_info.group_contains(ValueError, match="Endpoint origin does not match")
652+
653+
654+
@pytest.mark.anyio
655+
async def test_sse_client_raises_on_error_before_endpoint() -> None:
656+
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447
657+
658+
Any exception raised while waiting for the endpoint event must propagate
659+
instead of deadlocking on the zero-buffer read stream.
660+
"""
661+
662+
async def events() -> AsyncGenerator[ServerSentEvent, None]:
663+
raise ConnectionError("connection reset by peer")
664+
yield # pragma: no cover
665+
666+
with _mock_sse_connection(events()), anyio.fail_after(5):
667+
with pytest.raises(BaseExceptionGroup) as exc_info:
668+
async with sse_client("http://test/sse"): # pragma: no branch
669+
pytest.fail("sse_client should not yield on pre-endpoint error") # pragma: no cover
670+
assert exc_info.group_contains(ConnectionError, match="connection reset")
671+
672+
673+
@pytest.mark.anyio
674+
async def test_sse_client_raises_on_message_before_endpoint() -> None:
675+
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447
676+
677+
If the server sends a message event before the endpoint event (protocol
678+
violation), sse_client must raise rather than deadlock trying to send the
679+
message to a stream nobody is reading yet.
680+
"""
681+
682+
async def events() -> AsyncGenerator[ServerSentEvent, None]:
683+
yield ServerSentEvent(event="message", data='{"jsonrpc":"2.0","id":1,"result":{}}')
684+
await anyio.sleep_forever() # pragma: no cover
685+
686+
with _mock_sse_connection(events()), anyio.fail_after(5):
687+
with pytest.raises(BaseExceptionGroup) as exc_info:
688+
async with sse_client("http://test/sse"): # pragma: no branch
689+
pytest.fail("sse_client should not yield on protocol violation") # pragma: no cover
690+
assert exc_info.group_contains(RuntimeError, match="before endpoint event")
691+
692+
693+
@pytest.mark.anyio
694+
async def test_sse_client_delivers_post_endpoint_errors_via_stream() -> None:
695+
"""After the endpoint is received, errors in sse_reader are delivered on the
696+
read stream so the session can handle them, rather than crashing the task group.
697+
"""
698+
699+
async def events() -> AsyncGenerator[ServerSentEvent, None]:
700+
yield ServerSentEvent(event="endpoint", data="/messages/?session_id=abc")
701+
raise ConnectionError("mid-stream failure")
702+
703+
with _mock_sse_connection(events()), anyio.fail_after(5):
704+
async with sse_client("http://test/sse") as (read_stream, _):
705+
received = await read_stream.receive()
706+
assert isinstance(received, ConnectionError)
707+
assert "mid-stream failure" in str(received)
708+
709+
607710
@pytest.mark.anyio
608711
async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) -> None:
609712
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1227

0 commit comments

Comments
 (0)