Skip to content

Commit 88071b0

Browse files
committed
fix: propagate pre-endpoint errors in sse_client instead of deadlocking
When sse_reader encounters an error before receiving the endpoint event, the except handler tried to send the exception to read_stream_writer. With a zero-buffer stream and no reader (the caller is still blocked in tg.start() waiting for task_status.started()), send() blocks forever. Track whether started() has fired. Before it, re-raise so the exception propagates through tg.start(). After it, send to the stream as before. This also adds a guard for the case where a server sends a message event before the endpoint event, which would deadlock on the same send() call. The dedicated SSEError handler from #975 is removed since the started flag now handles all pre-endpoint exceptions uniformly. Github-Issue: #447
1 parent 7ba4fb8 commit 88071b0

File tree

2 files changed

+110
-9
lines changed

2 files changed

+110
-9
lines changed

src/mcp/client/sse.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from anyio.abc import TaskStatus
1010
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1111
from httpx_sse import aconnect_sse
12-
from httpx_sse._exceptions import SSEError
1312

1413
from mcp import types
1514
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
@@ -69,6 +68,12 @@ async def sse_client(
6968
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
7069

7170
async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
71+
# Before task_status.started() fires, the caller is blocked inside
72+
# tg.start() and nobody reads from read_stream. Sending to the
73+
# zero-buffer stream in that phase would deadlock, so errors must
74+
# be raised instead. After started(), the caller has the streams
75+
# and errors are delivered through read_stream.
76+
started = False
7277
try:
7378
async for sse in event_source.aiter_sse(): # pragma: no branch
7479
logger.debug(f"Received SSE event: {sse.event}")
@@ -79,27 +84,28 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
7984

8085
url_parsed = urlparse(url)
8186
endpoint_parsed = urlparse(endpoint_url)
82-
if ( # pragma: no cover
87+
if (
8388
url_parsed.netloc != endpoint_parsed.netloc
8489
or url_parsed.scheme != endpoint_parsed.scheme
8590
):
86-
error_msg = ( # pragma: no cover
91+
raise ValueError(
8792
f"Endpoint origin does not match connection origin: {endpoint_url}"
8893
)
89-
logger.error(error_msg) # pragma: no cover
90-
raise ValueError(error_msg) # pragma: no cover
9194

9295
if on_session_created:
9396
session_id = _extract_session_id_from_endpoint(endpoint_url)
9497
if session_id:
9598
on_session_created(session_id)
9699

97100
task_status.started(endpoint_url)
101+
started = True
98102

99103
case "message":
100104
# Skip empty data (keep-alive pings)
101105
if not sse.data:
102106
continue
107+
if not started:
108+
raise RuntimeError("Received message event before endpoint event")
103109
try:
104110
message = types.jsonrpc_message_adapter.validate_json(sse.data, by_name=False)
105111
logger.debug(f"Received server message: {message}")
@@ -112,11 +118,10 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
112118
await read_stream_writer.send(session_message)
113119
case _: # pragma: no cover
114120
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
115-
except SSEError as sse_exc: # pragma: lax no cover
116-
logger.exception("Encountered SSE exception")
117-
raise sse_exc
118-
except Exception as exc: # pragma: lax no cover
121+
except Exception as exc:
119122
logger.exception("Error in sse_reader")
123+
if not started:
124+
raise
120125
await read_stream_writer.send(exc)
121126
finally:
122127
await read_stream_writer.aclose()

tests/shared/test_sse.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,102 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]:
604604
assert msg.message.id == 1
605605

606606

607+
def _mock_sse_connection(aiter_sse: AsyncGenerator[ServerSentEvent, None]) -> Any:
608+
"""Patch sse_client's HTTP layer to yield the given SSE event stream."""
609+
mock_event_source = MagicMock()
610+
mock_event_source.aiter_sse.return_value = aiter_sse
611+
mock_event_source.response.raise_for_status = MagicMock()
612+
613+
mock_aconnect_sse = MagicMock()
614+
mock_aconnect_sse.__aenter__ = AsyncMock(return_value=mock_event_source)
615+
mock_aconnect_sse.__aexit__ = AsyncMock(return_value=None)
616+
617+
mock_client = MagicMock()
618+
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
619+
mock_client.__aexit__ = AsyncMock(return_value=None)
620+
mock_client.post = AsyncMock(return_value=MagicMock(status_code=200, raise_for_status=MagicMock()))
621+
622+
return patch.multiple(
623+
"mcp.client.sse",
624+
create_mcp_http_client=Mock(return_value=mock_client),
625+
aconnect_sse=Mock(return_value=mock_aconnect_sse),
626+
)
627+
628+
629+
@pytest.mark.anyio
630+
async def test_sse_client_raises_on_endpoint_origin_mismatch() -> None:
631+
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447
632+
633+
When the server sends an endpoint URL with a different origin than the
634+
connection URL, sse_client must raise promptly instead of deadlocking.
635+
Before the fix, the ValueError was caught and sent to a zero-buffer stream
636+
with no reader, hanging forever.
637+
"""
638+
639+
async def events() -> AsyncGenerator[ServerSentEvent, None]:
640+
yield ServerSentEvent(event="endpoint", data="http://wrong-host:9999/messages?sessionId=abc")
641+
await anyio.sleep_forever()
642+
643+
with _mock_sse_connection(events()), anyio.fail_after(5):
644+
with pytest.RaisesGroup(pytest.RaisesExc(ValueError, match="Endpoint origin does not match")):
645+
async with sse_client("http://test/sse"):
646+
pytest.fail("sse_client should not yield on origin mismatch")
647+
648+
649+
@pytest.mark.anyio
650+
async def test_sse_client_raises_on_error_before_endpoint() -> None:
651+
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447
652+
653+
Any exception raised while waiting for the endpoint event must propagate
654+
instead of deadlocking on the zero-buffer read stream.
655+
"""
656+
657+
async def events() -> AsyncGenerator[ServerSentEvent, None]:
658+
raise ConnectionError("connection reset by peer")
659+
yield # pragma: no cover
660+
661+
with _mock_sse_connection(events()), anyio.fail_after(5):
662+
with pytest.RaisesGroup(pytest.RaisesExc(ConnectionError, match="connection reset")):
663+
async with sse_client("http://test/sse"):
664+
pytest.fail("sse_client should not yield on pre-endpoint error")
665+
666+
667+
@pytest.mark.anyio
668+
async def test_sse_client_raises_on_message_before_endpoint() -> None:
669+
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447
670+
671+
If the server sends a message event before the endpoint event (protocol
672+
violation), sse_client must raise rather than deadlock trying to send the
673+
message to a stream nobody is reading yet.
674+
"""
675+
676+
async def events() -> AsyncGenerator[ServerSentEvent, None]:
677+
yield ServerSentEvent(event="message", data='{"jsonrpc":"2.0","id":1,"result":{}}')
678+
await anyio.sleep_forever()
679+
680+
with _mock_sse_connection(events()), anyio.fail_after(5):
681+
with pytest.RaisesGroup(pytest.RaisesExc(RuntimeError, match="before endpoint event")):
682+
async with sse_client("http://test/sse"):
683+
pytest.fail("sse_client should not yield on protocol violation")
684+
685+
686+
@pytest.mark.anyio
687+
async def test_sse_client_delivers_post_endpoint_errors_via_stream() -> None:
688+
"""After the endpoint is received, errors in sse_reader are delivered on the
689+
read stream so the session can handle them, rather than crashing the task group.
690+
"""
691+
692+
async def events() -> AsyncGenerator[ServerSentEvent, None]:
693+
yield ServerSentEvent(event="endpoint", data="/messages/?session_id=abc")
694+
raise ConnectionError("mid-stream failure")
695+
696+
with _mock_sse_connection(events()), anyio.fail_after(5):
697+
async with sse_client("http://test/sse") as (read_stream, _):
698+
received = await read_stream.receive()
699+
assert isinstance(received, ConnectionError)
700+
assert "mid-stream failure" in str(received)
701+
702+
607703
@pytest.mark.anyio
608704
async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) -> None:
609705
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1227

0 commit comments

Comments
 (0)