Skip to content

Commit e8cb12c

Browse files
committed
fix: drain terminal streamable HTTP responses
1 parent 616476f commit e8cb12c

2 files changed

Lines changed: 144 additions & 7 deletions

File tree

src/mcp/client/streamable_http.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -240,16 +240,19 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
240240
event_source.response.raise_for_status()
241241
logger.debug("Resumption GET SSE connection established")
242242

243+
response_complete = False
243244
async for sse in event_source.aiter_sse(): # pragma: no branch
245+
if response_complete:
246+
continue
247+
244248
is_complete = await self._handle_sse_event(
245249
sse,
246250
ctx.read_stream_writer,
247251
original_request_id,
248252
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
249253
)
250254
if is_complete:
251-
await event_source.response.aclose()
252-
break
255+
response_complete = True
253256

254257
async def _handle_post_request(self, ctx: RequestContext) -> None:
255258
"""Handle a POST request with response processing."""
@@ -340,9 +343,13 @@ async def _handle_sse_response(
340343
assert isinstance(ctx.session_message.message, JSONRPCRequest)
341344
original_request_id = ctx.session_message.message.id
342345

346+
response_complete = False
343347
try:
344348
event_source = EventSource(response)
345349
async for sse in event_source.aiter_sse(): # pragma: no branch
350+
if response_complete:
351+
continue
352+
346353
# Track last event ID for potential reconnection
347354
if sse.id:
348355
last_event_id = sse.id
@@ -359,13 +366,15 @@ async def _handle_sse_response(
359366
is_initialization=is_initialization,
360367
)
361368
# If the SSE event indicates completion, like returning response/error
362-
# break the loop
369+
# keep draining the response to EOF so the HTTP connection can be reused.
363370
if is_complete:
364-
await response.aclose()
365-
return # Normal completion, no reconnect needed
371+
response_complete = True
366372
except Exception:
367373
logger.debug("SSE stream ended", exc_info=True) # pragma: no cover
368374

375+
if response_complete:
376+
return # Normal completion, no reconnect needed
377+
369378
# Stream ended without response - reconnect if we received an event with ID
370379
if last_event_id is not None: # pragma: no branch
371380
logger.info("SSE stream disconnected, reconnecting...")
@@ -405,7 +414,11 @@ async def _handle_reconnection(
405414
reconnect_last_event_id: str = last_event_id
406415
reconnect_retry_ms = retry_interval_ms
407416

417+
response_complete = False
408418
async for sse in event_source.aiter_sse():
419+
if response_complete:
420+
continue
421+
409422
if sse.id: # pragma: no branch
410423
reconnect_last_event_id = sse.id
411424
if sse.retry is not None:
@@ -418,8 +431,10 @@ async def _handle_reconnection(
418431
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
419432
)
420433
if is_complete:
421-
await event_source.response.aclose()
422-
return
434+
response_complete = True
435+
436+
if response_complete:
437+
return
423438

424439
# Stream ended again without response - reconnect again (reset attempt counter)
425440
logger.info("SSE stream disconnected, reconnecting...")

tests/shared/test_streamable_http.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from starlette.requests import Request
2828
from starlette.routing import Mount
2929

30+
import mcp.client.streamable_http as streamable_http
3031
from mcp import MCPError, types
3132
from mcp.client.session import ClientSession
3233
from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client
@@ -139,6 +140,39 @@ async def replay_events_after( # pragma: no cover
139140
return target_stream_id
140141

141142

143+
class FakeStreamResponse(httpx.Response):
144+
def __init__(self) -> None:
145+
super().__init__(
146+
200,
147+
request=httpx.Request("POST", "http://localhost:8000/mcp"),
148+
)
149+
self.close_count = 0
150+
151+
async def aclose(self) -> None: # pragma: no cover
152+
self.close_count += 1
153+
154+
155+
class FakeEventSource:
156+
def __init__(self, events: list[ServerSentEvent]) -> None:
157+
self.response = FakeStreamResponse()
158+
self.events = events
159+
self.seen = 0
160+
161+
async def aiter_sse(self) -> AsyncIterator[ServerSentEvent]:
162+
for event in self.events:
163+
self.seen += 1
164+
yield event
165+
166+
167+
def jsonrpc_response_event(request_id: str, event_id: str) -> ServerSentEvent:
168+
return ServerSentEvent(
169+
event="message",
170+
data=json.dumps({"jsonrpc": "2.0", "id": request_id, "result": {}}),
171+
id=event_id,
172+
retry=None,
173+
)
174+
175+
142176
@dataclass
143177
class ServerState:
144178
lock: anyio.Event = field(default_factory=anyio.Event)
@@ -1803,6 +1837,94 @@ async def test_handle_sse_event_skips_empty_data():
18031837
await read_stream.aclose()
18041838

18051839

1840+
@pytest.mark.anyio
1841+
async def test_handle_sse_response_drains_after_terminal_event(monkeypatch: pytest.MonkeyPatch):
1842+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
1843+
response = FakeStreamResponse()
1844+
event_source = FakeEventSource(
1845+
[
1846+
jsonrpc_response_event("request-1", "event-1"),
1847+
ServerSentEvent(event="message", data="", id="event-2", retry=None),
1848+
]
1849+
)
1850+
1851+
def event_source_factory(_response: httpx.Response) -> FakeEventSource:
1852+
return event_source
1853+
1854+
monkeypatch.setattr(streamable_http, "EventSource", event_source_factory)
1855+
1856+
write_stream, read_stream = create_context_streams[SessionMessage | Exception](1)
1857+
try:
1858+
async with httpx.AsyncClient() as client:
1859+
ctx = streamable_http.RequestContext(
1860+
client=client,
1861+
session_id=None,
1862+
session_message=SessionMessage(
1863+
JSONRPCRequest(jsonrpc="2.0", id="request-1", method="tools/call", params={})
1864+
),
1865+
metadata=None,
1866+
read_stream_writer=write_stream,
1867+
)
1868+
1869+
await transport._handle_sse_response(response, ctx)
1870+
1871+
received = await read_stream.receive()
1872+
assert isinstance(received, SessionMessage)
1873+
assert isinstance(received.message, types.JSONRPCResponse)
1874+
assert received.message.id == "request-1"
1875+
assert event_source.seen == 2
1876+
assert response.close_count == 0
1877+
finally:
1878+
await write_stream.aclose()
1879+
await read_stream.aclose()
1880+
1881+
1882+
@pytest.mark.anyio
1883+
async def test_reconnection_drains_after_terminal_event(monkeypatch: pytest.MonkeyPatch):
1884+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
1885+
event_source = FakeEventSource(
1886+
[
1887+
jsonrpc_response_event("request-1", "event-2"),
1888+
ServerSentEvent(event="message", data="", id="event-3", retry=None),
1889+
]
1890+
)
1891+
1892+
async def sleep_noop(_delay: float) -> None:
1893+
pass
1894+
1895+
@asynccontextmanager
1896+
async def connect_sse(*args: Any, **kwargs: Any) -> AsyncIterator[FakeEventSource]:
1897+
yield event_source
1898+
1899+
monkeypatch.setattr(streamable_http.anyio, "sleep", sleep_noop)
1900+
monkeypatch.setattr(streamable_http, "aconnect_sse", connect_sse)
1901+
1902+
write_stream, read_stream = create_context_streams[SessionMessage | Exception](1)
1903+
try:
1904+
async with httpx.AsyncClient() as client:
1905+
ctx = streamable_http.RequestContext(
1906+
client=client,
1907+
session_id=None,
1908+
session_message=SessionMessage(
1909+
JSONRPCRequest(jsonrpc="2.0", id="request-1", method="tools/call", params={})
1910+
),
1911+
metadata=None,
1912+
read_stream_writer=write_stream,
1913+
)
1914+
1915+
await transport._handle_reconnection(ctx, last_event_id="event-1")
1916+
1917+
received = await read_stream.receive()
1918+
assert isinstance(received, SessionMessage)
1919+
assert isinstance(received.message, types.JSONRPCResponse)
1920+
assert received.message.id == "request-1"
1921+
assert event_source.seen == 2
1922+
assert event_source.response.close_count == 0
1923+
finally:
1924+
await write_stream.aclose()
1925+
await read_stream.aclose()
1926+
1927+
18061928
@pytest.mark.anyio
18071929
async def test_priming_event_not_sent_for_old_protocol_version():
18081930
"""Test that _maybe_send_priming_event skips for old protocol versions (backwards compat)."""

0 commit comments

Comments
 (0)