Skip to content

Commit 9893653

Browse files
committed
fix: propagate pre-endpoint errors in sse_client instead of deadlocking
When sse_reader encounters an error before receiving the endpoint event, the except handler tried to send the exception to read_stream_writer. With a zero-buffer stream and no reader (the caller is still blocked in tg.start() waiting for task_status.started()), send() blocks forever. Track whether started() has fired. Before it, re-raise so the exception propagates through tg.start(). After it, send to the stream as before. This also adds a guard for the case where a server sends a message event before the endpoint event, which would deadlock on the same send() call. The dedicated SSEError handler from #975 is removed since the started flag now handles all pre-endpoint exceptions uniformly. Github-Issue: #447
1 parent 7ba4fb8 commit 9893653

File tree

2 files changed

+113
-9
lines changed

2 files changed

+113
-9
lines changed

src/mcp/client/sse.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from anyio.abc import TaskStatus
1010
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1111
from httpx_sse import aconnect_sse
12-
from httpx_sse._exceptions import SSEError
1312

1413
from mcp import types
1514
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
@@ -69,6 +68,12 @@ async def sse_client(
6968
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
7069

7170
async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
71+
# Before task_status.started() fires, the caller is blocked inside
72+
# tg.start() and nobody reads from read_stream. Sending to the
73+
# zero-buffer stream in that phase would deadlock, so errors must
74+
# be raised instead. After started(), the caller has the streams
75+
# and errors are delivered through read_stream.
76+
started = False
7277
try:
7378
async for sse in event_source.aiter_sse(): # pragma: no branch
7479
logger.debug(f"Received SSE event: {sse.event}")
@@ -79,27 +84,28 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
7984

8085
url_parsed = urlparse(url)
8186
endpoint_parsed = urlparse(endpoint_url)
82-
if ( # pragma: no cover
87+
if (
8388
url_parsed.netloc != endpoint_parsed.netloc
8489
or url_parsed.scheme != endpoint_parsed.scheme
8590
):
86-
error_msg = ( # pragma: no cover
91+
raise ValueError(
8792
f"Endpoint origin does not match connection origin: {endpoint_url}"
8893
)
89-
logger.error(error_msg) # pragma: no cover
90-
raise ValueError(error_msg) # pragma: no cover
9194

9295
if on_session_created:
9396
session_id = _extract_session_id_from_endpoint(endpoint_url)
9497
if session_id:
9598
on_session_created(session_id)
9699

97100
task_status.started(endpoint_url)
101+
started = True
98102

99103
case "message":
100104
# Skip empty data (keep-alive pings)
101105
if not sse.data:
102106
continue
107+
if not started:
108+
raise RuntimeError("Received message event before endpoint event")
103109
try:
104110
message = types.jsonrpc_message_adapter.validate_json(sse.data, by_name=False)
105111
logger.debug(f"Received server message: {message}")
@@ -112,11 +118,10 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
112118
await read_stream_writer.send(session_message)
113119
case _: # pragma: no cover
114120
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
115-
except SSEError as sse_exc: # pragma: lax no cover
116-
logger.exception("Encountered SSE exception")
117-
raise sse_exc
118-
except Exception as exc: # pragma: lax no cover
121+
except Exception as exc:
119122
logger.exception("Error in sse_reader")
123+
if not started:
124+
raise
120125
await read_stream_writer.send(exc)
121126
finally:
122127
await read_stream_writer.aclose()

tests/shared/test_sse.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,105 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]:
604604
assert msg.message.id == 1
605605

606606

607+
def _mock_sse_connection(aiter_sse: AsyncGenerator[ServerSentEvent, None]) -> Any:
608+
"""Patch sse_client's HTTP layer to yield the given SSE event stream."""
609+
mock_event_source = MagicMock()
610+
mock_event_source.aiter_sse.return_value = aiter_sse
611+
mock_event_source.response.raise_for_status = MagicMock()
612+
613+
mock_aconnect_sse = MagicMock()
614+
mock_aconnect_sse.__aenter__ = AsyncMock(return_value=mock_event_source)
615+
mock_aconnect_sse.__aexit__ = AsyncMock(return_value=None)
616+
617+
mock_client = MagicMock()
618+
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
619+
mock_client.__aexit__ = AsyncMock(return_value=None)
620+
mock_client.post = AsyncMock(return_value=MagicMock(status_code=200, raise_for_status=MagicMock()))
621+
622+
return patch.multiple(
623+
"mcp.client.sse",
624+
create_mcp_http_client=Mock(return_value=mock_client),
625+
aconnect_sse=Mock(return_value=mock_aconnect_sse),
626+
)
627+
628+
629+
@pytest.mark.anyio
630+
async def test_sse_client_raises_on_endpoint_origin_mismatch() -> None:
631+
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447
632+
633+
When the server sends an endpoint URL with a different origin than the
634+
connection URL, sse_client must raise promptly instead of deadlocking.
635+
Before the fix, the ValueError was caught and sent to a zero-buffer stream
636+
with no reader, hanging forever.
637+
"""
638+
639+
async def events() -> AsyncGenerator[ServerSentEvent, None]:
640+
yield ServerSentEvent(event="endpoint", data="http://wrong-host:9999/messages?sessionId=abc")
641+
await anyio.sleep_forever() # pragma: no cover
642+
643+
with _mock_sse_connection(events()), anyio.fail_after(5):
644+
with pytest.raises(BaseExceptionGroup) as exc_info: # noqa: F821 # builtin via anyio on 3.10
645+
async with sse_client("http://test/sse"): # pragma: no branch
646+
pytest.fail("sse_client should not yield on origin mismatch") # pragma: no cover
647+
assert exc_info.group_contains(ValueError, match="Endpoint origin does not match")
648+
649+
650+
@pytest.mark.anyio
651+
async def test_sse_client_raises_on_error_before_endpoint() -> None:
652+
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447
653+
654+
Any exception raised while waiting for the endpoint event must propagate
655+
instead of deadlocking on the zero-buffer read stream.
656+
"""
657+
658+
async def events() -> AsyncGenerator[ServerSentEvent, None]:
659+
raise ConnectionError("connection reset by peer")
660+
yield # pragma: no cover
661+
662+
with _mock_sse_connection(events()), anyio.fail_after(5):
663+
with pytest.raises(BaseExceptionGroup) as exc_info: # noqa: F821 # builtin via anyio on 3.10
664+
async with sse_client("http://test/sse"): # pragma: no branch
665+
pytest.fail("sse_client should not yield on pre-endpoint error") # pragma: no cover
666+
assert exc_info.group_contains(ConnectionError, match="connection reset")
667+
668+
669+
@pytest.mark.anyio
670+
async def test_sse_client_raises_on_message_before_endpoint() -> None:
671+
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447
672+
673+
If the server sends a message event before the endpoint event (protocol
674+
violation), sse_client must raise rather than deadlock trying to send the
675+
message to a stream nobody is reading yet.
676+
"""
677+
678+
async def events() -> AsyncGenerator[ServerSentEvent, None]:
679+
yield ServerSentEvent(event="message", data='{"jsonrpc":"2.0","id":1,"result":{}}')
680+
await anyio.sleep_forever() # pragma: no cover
681+
682+
with _mock_sse_connection(events()), anyio.fail_after(5):
683+
with pytest.raises(BaseExceptionGroup) as exc_info: # noqa: F821 # builtin via anyio on 3.10
684+
async with sse_client("http://test/sse"): # pragma: no branch
685+
pytest.fail("sse_client should not yield on protocol violation") # pragma: no cover
686+
assert exc_info.group_contains(RuntimeError, match="before endpoint event")
687+
688+
689+
@pytest.mark.anyio
690+
async def test_sse_client_delivers_post_endpoint_errors_via_stream() -> None:
691+
"""After the endpoint is received, errors in sse_reader are delivered on the
692+
read stream so the session can handle them, rather than crashing the task group.
693+
"""
694+
695+
async def events() -> AsyncGenerator[ServerSentEvent, None]:
696+
yield ServerSentEvent(event="endpoint", data="/messages/?session_id=abc")
697+
raise ConnectionError("mid-stream failure")
698+
699+
with _mock_sse_connection(events()), anyio.fail_after(5):
700+
async with sse_client("http://test/sse") as (read_stream, _):
701+
received = await read_stream.receive()
702+
assert isinstance(received, ConnectionError)
703+
assert "mid-stream failure" in str(received)
704+
705+
607706
@pytest.mark.anyio
608707
async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) -> None:
609708
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1227

0 commit comments

Comments
 (0)