Skip to content

Commit 75a80b6

Browse files
maxisbeyKludex
andauthored
refactor: connect-first stream lifecycle for sse and streamable_http (#2292)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
1 parent abfb482 commit 75a80b6

File tree

3 files changed

+161
-153
lines changed

3 files changed

+161
-153
lines changed

src/mcp/client/sse.py

Lines changed: 97 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -57,108 +57,101 @@ async def sse_client(
5757
write_stream: MemoryObjectSendStream[SessionMessage]
5858
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
5959

60-
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
61-
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
62-
63-
async with anyio.create_task_group() as tg:
64-
try:
65-
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
66-
async with httpx_client_factory(
67-
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
68-
) as client:
69-
async with aconnect_sse(
70-
client,
71-
"GET",
72-
url,
73-
) as event_source:
74-
event_source.response.raise_for_status()
75-
logger.debug("SSE connection established")
76-
77-
async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
78-
try:
79-
async for sse in event_source.aiter_sse(): # pragma: no branch
80-
logger.debug(f"Received SSE event: {sse.event}")
81-
match sse.event:
82-
case "endpoint":
83-
endpoint_url = urljoin(url, sse.data)
84-
logger.debug(f"Received endpoint URL: {endpoint_url}")
85-
86-
url_parsed = urlparse(url)
87-
endpoint_parsed = urlparse(endpoint_url)
88-
if ( # pragma: no cover
89-
url_parsed.netloc != endpoint_parsed.netloc
90-
or url_parsed.scheme != endpoint_parsed.scheme
91-
):
92-
error_msg = ( # pragma: no cover
93-
f"Endpoint origin does not match connection origin: {endpoint_url}"
94-
)
95-
logger.error(error_msg) # pragma: no cover
96-
raise ValueError(error_msg) # pragma: no cover
97-
98-
if on_session_created:
99-
session_id = _extract_session_id_from_endpoint(endpoint_url)
100-
if session_id:
101-
on_session_created(session_id)
102-
103-
task_status.started(endpoint_url)
104-
105-
case "message":
106-
# Skip empty data (keep-alive pings)
107-
if not sse.data:
108-
continue
109-
try:
110-
message = types.jsonrpc_message_adapter.validate_json(
111-
sse.data, by_name=False
112-
)
113-
logger.debug(f"Received server message: {message}")
114-
except Exception as exc: # pragma: no cover
115-
logger.exception("Error parsing server message") # pragma: no cover
116-
await read_stream_writer.send(exc) # pragma: no cover
117-
continue # pragma: no cover
118-
119-
session_message = SessionMessage(message)
120-
await read_stream_writer.send(session_message)
121-
case _: # pragma: no cover
122-
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
123-
except SSEError as sse_exc: # pragma: lax no cover
124-
logger.exception("Encountered SSE exception")
125-
raise sse_exc
126-
except Exception as exc: # pragma: lax no cover
127-
logger.exception("Error in sse_reader")
128-
await read_stream_writer.send(exc)
129-
finally:
130-
await read_stream_writer.aclose()
131-
132-
async def post_writer(endpoint_url: str):
133-
try:
134-
async with write_stream_reader:
135-
async for session_message in write_stream_reader:
136-
logger.debug(f"Sending client message: {session_message}")
137-
response = await client.post(
138-
endpoint_url,
139-
json=session_message.message.model_dump(
140-
by_alias=True,
141-
mode="json",
142-
exclude_unset=True,
143-
),
60+
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
61+
async with httpx_client_factory(
62+
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
63+
) as client:
64+
async with aconnect_sse(client, "GET", url) as event_source:
65+
event_source.response.raise_for_status()
66+
logger.debug("SSE connection established")
67+
68+
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
69+
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
70+
71+
async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
72+
try:
73+
async for sse in event_source.aiter_sse(): # pragma: no branch
74+
logger.debug(f"Received SSE event: {sse.event}")
75+
match sse.event:
76+
case "endpoint":
77+
endpoint_url = urljoin(url, sse.data)
78+
logger.debug(f"Received endpoint URL: {endpoint_url}")
79+
80+
url_parsed = urlparse(url)
81+
endpoint_parsed = urlparse(endpoint_url)
82+
if ( # pragma: no cover
83+
url_parsed.netloc != endpoint_parsed.netloc
84+
or url_parsed.scheme != endpoint_parsed.scheme
85+
):
86+
error_msg = ( # pragma: no cover
87+
f"Endpoint origin does not match connection origin: {endpoint_url}"
14488
)
145-
response.raise_for_status()
146-
logger.debug(f"Client message sent successfully: {response.status_code}")
147-
except Exception: # pragma: lax no cover
148-
logger.exception("Error in post_writer")
149-
finally:
150-
await write_stream.aclose()
151-
152-
endpoint_url = await tg.start(sse_reader)
153-
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
154-
tg.start_soon(post_writer, endpoint_url)
155-
156-
try:
157-
yield read_stream, write_stream
158-
finally:
159-
tg.cancel_scope.cancel()
160-
finally:
161-
await read_stream_writer.aclose()
162-
await write_stream.aclose()
163-
await read_stream.aclose()
164-
await write_stream_reader.aclose()
89+
logger.error(error_msg) # pragma: no cover
90+
raise ValueError(error_msg) # pragma: no cover
91+
92+
if on_session_created:
93+
session_id = _extract_session_id_from_endpoint(endpoint_url)
94+
if session_id:
95+
on_session_created(session_id)
96+
97+
task_status.started(endpoint_url)
98+
99+
case "message":
100+
# Skip empty data (keep-alive pings)
101+
if not sse.data:
102+
continue
103+
try:
104+
message = types.jsonrpc_message_adapter.validate_json(sse.data, by_name=False)
105+
logger.debug(f"Received server message: {message}")
106+
except Exception as exc: # pragma: no cover
107+
logger.exception("Error parsing server message") # pragma: no cover
108+
await read_stream_writer.send(exc) # pragma: no cover
109+
continue # pragma: no cover
110+
111+
session_message = SessionMessage(message)
112+
await read_stream_writer.send(session_message)
113+
case _: # pragma: no cover
114+
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
115+
except SSEError as sse_exc: # pragma: lax no cover
116+
logger.exception("Encountered SSE exception")
117+
raise sse_exc
118+
except Exception as exc: # pragma: lax no cover
119+
logger.exception("Error in sse_reader")
120+
await read_stream_writer.send(exc)
121+
finally:
122+
await read_stream_writer.aclose()
123+
124+
async def post_writer(endpoint_url: str):
125+
try:
126+
async with write_stream_reader, write_stream:
127+
async for session_message in write_stream_reader:
128+
logger.debug(f"Sending client message: {session_message}")
129+
response = await client.post(
130+
endpoint_url,
131+
json=session_message.message.model_dump(
132+
by_alias=True,
133+
mode="json",
134+
exclude_unset=True,
135+
),
136+
)
137+
response.raise_for_status()
138+
logger.debug(f"Client message sent successfully: {response.status_code}")
139+
except Exception: # pragma: lax no cover
140+
logger.exception("Error in post_writer")
141+
142+
# On Python 3.14, coverage.py reports a phantom branch arc on this
143+
# line (->yield) when nested two async-with levels deep. The branch
144+
# is the unreachable "did __aexit__ suppress?" arm for memory streams.
145+
async with ( # pragma: no branch
146+
read_stream_writer,
147+
read_stream,
148+
write_stream,
149+
write_stream_reader,
150+
anyio.create_task_group() as tg,
151+
):
152+
endpoint_url = await tg.start(sse_reader)
153+
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
154+
tg.start_soon(post_writer, endpoint_url)
155+
156+
yield read_stream, write_stream
157+
tg.cancel_scope.cancel()

src/mcp/client/streamable_http.py

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ async def post_writer(
440440
) -> None:
441441
"""Handle writing requests to the server."""
442442
try:
443-
async with write_stream_reader:
443+
async with write_stream_reader, read_stream_writer, write_stream:
444444
async for session_message in write_stream_reader:
445445
message = session_message.message
446446
metadata = (
@@ -480,9 +480,6 @@ async def handle_request_async():
480480

481481
except Exception: # pragma: lax no cover
482482
logger.exception("Error in post_writer")
483-
finally:
484-
await read_stream_writer.aclose()
485-
await write_stream.aclose()
486483

487484
async def terminate_session(self, client: httpx.AsyncClient) -> None:
488485
"""Terminate the session by sending a DELETE request."""
@@ -533,9 +530,6 @@ async def streamable_http_client(
533530
Example:
534531
See examples/snippets/clients/ for usage patterns.
535532
"""
536-
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
537-
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
538-
539533
# Determine if we need to create and manage the client
540534
client_provided = http_client is not None
541535
client = http_client
@@ -546,36 +540,40 @@ async def streamable_http_client(
546540

547541
transport = StreamableHTTPTransport(url)
548542

549-
async with anyio.create_task_group() as tg:
550-
try:
551-
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
552-
553-
async with contextlib.AsyncExitStack() as stack:
554-
# Only manage client lifecycle if we created it
555-
if not client_provided:
556-
await stack.enter_async_context(client)
557-
558-
def start_get_stream() -> None:
559-
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
560-
561-
tg.start_soon(
562-
transport.post_writer,
563-
client,
564-
write_stream_reader,
565-
read_stream_writer,
566-
write_stream,
567-
start_get_stream,
568-
tg,
569-
)
543+
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
544+
545+
async with contextlib.AsyncExitStack() as stack:
546+
# Only manage client lifecycle if we created it
547+
if not client_provided:
548+
await stack.enter_async_context(client)
549+
550+
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
551+
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
552+
553+
async with (
554+
read_stream_writer,
555+
read_stream,
556+
write_stream,
557+
write_stream_reader,
558+
anyio.create_task_group() as tg,
559+
):
560+
561+
def start_get_stream() -> None:
562+
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
563+
564+
tg.start_soon(
565+
transport.post_writer,
566+
client,
567+
write_stream_reader,
568+
read_stream_writer,
569+
write_stream,
570+
start_get_stream,
571+
tg,
572+
)
570573

571-
try:
572-
yield read_stream, write_stream
573-
finally:
574-
if transport.session_id and terminate_on_close:
575-
await transport.terminate_session(client)
576-
tg.cancel_scope.cancel()
577-
finally:
578-
await read_stream_writer.aclose()
579-
await write_stream.aclose()
580-
await read_stream.aclose()
581-
await write_stream_reader.aclose()
574+
try:
575+
yield read_stream, write_stream
576+
finally:
577+
if transport.session_id and terminate_on_close:
578+
await transport.terminate_session(client)
579+
tg.cancel_scope.cancel()

tests/client/test_transport_stream_cleanup.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,22 +58,39 @@ def hook(args: "sys.UnraisableHookArgs") -> None: # pragma: no cover
5858

5959
@pytest.mark.anyio
6060
async def test_sse_client_closes_all_streams_on_connection_error(free_tcp_port: int) -> None:
61-
"""sse_client must close all 4 stream ends when the connection fails.
61+
"""sse_client creates streams only after the SSE connection succeeds, so a
62+
ConnectError propagates directly with nothing to leak.
6263
63-
Before the fix, only read_stream_writer and write_stream were closed in
64-
the finally block. read_stream and write_stream_reader were leaked.
64+
Before the fix, streams were created before connecting and only 2 of 4 were
65+
closed in the finally block.
6566
"""
6667
with _assert_no_memory_stream_leak():
67-
# sse_client enters a task group BEFORE connecting, so anyio wraps the
68-
# ConnectError from aconnect_sse in an ExceptionGroup.
69-
with pytest.raises(Exception) as exc_info: # noqa: B017
68+
with pytest.raises(httpx.ConnectError):
7069
async with sse_client(f"http://127.0.0.1:{free_tcp_port}/sse"):
7170
pytest.fail("should not reach here") # pragma: no cover
7271

73-
assert exc_info.group_contains(httpx.ConnectError)
74-
# exc_info holds the traceback → holds frame locals → keeps leaked
75-
# streams alive. Must drop it before gc.collect() can detect a leak.
76-
del exc_info
72+
73+
@pytest.mark.anyio
74+
async def test_sse_client_closes_all_streams_on_http_error() -> None:
75+
"""sse_client creates streams only after raise_for_status() passes, so an
76+
HTTPStatusError from a 4xx/5xx response propagates bare (not wrapped in an
77+
ExceptionGroup) with nothing to leak — the task group is never entered.
78+
"""
79+
80+
def return_403(request: httpx.Request) -> httpx.Response:
81+
return httpx.Response(403)
82+
83+
def mock_factory(
84+
headers: dict[str, str] | None = None,
85+
timeout: httpx.Timeout | None = None,
86+
auth: httpx.Auth | None = None,
87+
) -> httpx.AsyncClient:
88+
return httpx.AsyncClient(transport=httpx.MockTransport(return_403))
89+
90+
with _assert_no_memory_stream_leak():
91+
with pytest.raises(httpx.HTTPStatusError):
92+
async with sse_client("http://test/sse", httpx_client_factory=mock_factory):
93+
pytest.fail("should not reach here") # pragma: no cover
7794

7895

7996
@pytest.mark.anyio

0 commit comments

Comments
 (0)