Skip to content

Commit 08d234d

Browse files
committed
fix: gracefully close streamable HTTP sessions on shutdown
1 parent 616476f commit 08d234d

3 files changed

Lines changed: 109 additions & 7 deletions

File tree

src/mcp/server/streamable_http.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,10 +767,19 @@ async def terminate(self) -> None:
767767
768768
Once terminated, all requests with this session ID will receive 404 Not Found.
769769
"""
770+
if self._terminated:
771+
return
770772

771773
self._terminated = True
772774
logger.info(f"Terminating session: {self.mcp_session_id}")
773775

776+
# Close active SSE responses so ASGI response tasks can finish before
777+
# the session manager cancels the owning task group.
778+
sse_stream_writers = list(self._sse_stream_writers.values())
779+
self._sse_stream_writers.clear()
780+
for writer in sse_stream_writers:
781+
writer.close()
782+
774783
# We need a copy of the keys to avoid modification during iteration
775784
request_stream_keys = list(self._request_streams.keys())
776785

src/mcp/server/streamable_http_manager.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,23 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
133133
yield # Let the application run
134134
finally:
135135
logger.info("StreamableHTTP session manager shutting down")
136-
# Cancel task group to stop all spawned tasks
137-
tg.cancel_scope.cancel()
138-
self._task_group = None
139-
# Clear any remaining server instances
140-
self._server_instances.clear()
141-
self._session_owners.clear()
136+
try:
137+
await self._terminate_active_sessions()
138+
finally:
139+
# Cancel task group to stop all spawned tasks
140+
tg.cancel_scope.cancel()
141+
self._task_group = None
142+
# Clear any remaining server instances
143+
self._server_instances.clear()
144+
self._session_owners.clear()
145+
146+
async def _terminate_active_sessions(self) -> None:
147+
"""Terminate tracked transports before cancelling their task group."""
148+
for transport in list(self._server_instances.values()):
149+
try:
150+
await transport.terminate()
151+
except Exception:
152+
logger.exception("Error terminating StreamableHTTP session during shutdown")
142153

143154
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
144155
"""Process ASGI request with proper session handling and transport setup.

tests/server/test_streamable_http_manager.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import json
44
import logging
5-
from typing import Any
5+
from types import SimpleNamespace
6+
from typing import Any, cast
67
from unittest.mock import AsyncMock, patch
78

89
import anyio
@@ -64,6 +65,50 @@ async def try_run():
6465
assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(errors[0])
6566

6667

68+
@pytest.mark.anyio
69+
async def test_run_terminates_active_streaming_session_before_shutdown():
70+
"""run() should close active SSE transports before task cancellation."""
71+
app = Server("test-shutdown-cleanup")
72+
manager = StreamableHTTPSessionManager(app=app)
73+
transport = StreamableHTTPServerTransport(mcp_session_id="session-id")
74+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](1)
75+
76+
try:
77+
transport._sse_stream_writers["request-id"] = sse_stream_writer
78+
79+
async with manager.run():
80+
manager._server_instances["session-id"] = transport
81+
82+
assert transport.is_terminated
83+
assert transport._sse_stream_writers == {}
84+
assert manager._server_instances == {}
85+
with pytest.raises(anyio.ClosedResourceError):
86+
await sse_stream_writer.send({"data": "still-open"})
87+
finally:
88+
await sse_stream_reader.aclose()
89+
90+
91+
@pytest.mark.anyio
92+
async def test_run_terminates_remaining_sessions_if_one_shutdown_fails(caplog: pytest.LogCaptureFixture):
93+
"""One failed transport shutdown should not skip later active sessions."""
94+
app = Server("test-shutdown-cleanup-error")
95+
manager = StreamableHTTPSessionManager(app=app)
96+
failing_terminate = AsyncMock(side_effect=RuntimeError("terminate failed"))
97+
healthy_terminate = AsyncMock()
98+
failing_transport = cast(StreamableHTTPServerTransport, SimpleNamespace(terminate=failing_terminate))
99+
healthy_transport = cast(StreamableHTTPServerTransport, SimpleNamespace(terminate=healthy_terminate))
100+
101+
with caplog.at_level(logging.ERROR):
102+
async with manager.run():
103+
manager._server_instances["bad-session"] = failing_transport
104+
manager._server_instances["healthy-session"] = healthy_transport
105+
106+
failing_terminate.assert_awaited_once_with()
107+
healthy_terminate.assert_awaited_once_with()
108+
assert "Error terminating StreamableHTTP session during shutdown" in caplog.text
109+
assert manager._server_instances == {}
110+
111+
67112
@pytest.mark.anyio
68113
async def test_handle_request_without_run_raises_error():
69114
"""Test that handle_request raises error if run() hasn't been called."""
@@ -271,6 +316,43 @@ async def mock_receive():
271316
assert len(transport._request_streams) == 0, "Transport should have no active request streams"
272317

273318

319+
@pytest.mark.anyio
320+
async def test_transport_terminate_closes_sse_stream_writers():
321+
"""terminate() should close active SSE writers so streaming responses can finish."""
322+
transport = StreamableHTTPServerTransport(mcp_session_id="test-session")
323+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](1)
324+
325+
try:
326+
transport._sse_stream_writers["request-id"] = sse_stream_writer
327+
328+
await transport.terminate()
329+
330+
assert transport._sse_stream_writers == {}
331+
with pytest.raises(anyio.ClosedResourceError):
332+
await sse_stream_writer.send({"data": "still-open"})
333+
334+
await transport.terminate()
335+
finally:
336+
await sse_stream_reader.aclose()
337+
338+
339+
@pytest.mark.anyio
340+
async def test_transport_connect_cleans_request_streams_on_exit():
341+
"""connect() should close registered request streams when the transport exits."""
342+
transport = StreamableHTTPServerTransport(mcp_session_id="test-session")
343+
request_stream_writer, request_stream_reader = anyio.create_memory_object_stream[Any](1)
344+
345+
transport._request_streams["request-id"] = (request_stream_writer, request_stream_reader)
346+
347+
async with transport.connect():
348+
assert "request-id" in transport._request_streams
349+
transport._terminated = True
350+
351+
assert transport._request_streams == {}
352+
with pytest.raises(anyio.ClosedResourceError):
353+
await request_stream_writer.send(cast(Any, object()))
354+
355+
274356
@pytest.mark.anyio
275357
async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture):
276358
"""Test that requests with unknown session IDs return HTTP 404 per MCP spec."""

0 commit comments

Comments
 (0)