Skip to content

Commit 6841fdc

Browse files
committed
fix: gracefully close streamable HTTP sessions on shutdown
1 parent 2472563 commit 6841fdc

3 files changed

Lines changed: 91 additions & 6 deletions

File tree

src/mcp/server/streamable_http.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,10 +767,19 @@ async def terminate(self) -> None:
767767
768768
Once terminated, all requests with this session ID will receive 404 Not Found.
769769
"""
770+
if self._terminated:
771+
return
770772

771773
self._terminated = True
772774
logger.info(f"Terminating session: {self.mcp_session_id}")
773775

776+
# Close active SSE responses so ASGI response tasks can finish before
777+
# the session manager cancels the owning task group.
778+
sse_stream_writers = list(self._sse_stream_writers.values())
779+
self._sse_stream_writers.clear()
780+
for writer in sse_stream_writers:
781+
writer.close()
782+
774783
# We need a copy of the keys to avoid modification during iteration
775784
request_stream_keys = list(self._request_streams.keys())
776785

src/mcp/server/streamable_http_manager.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,22 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
130130
yield # Let the application run
131131
finally:
132132
logger.info("StreamableHTTP session manager shutting down")
133-
# Cancel task group to stop all spawned tasks
134-
tg.cancel_scope.cancel()
135-
self._task_group = None
136-
# Clear any remaining server instances
137-
self._server_instances.clear()
133+
try:
134+
await self._terminate_active_sessions()
135+
finally:
136+
# Cancel task group to stop all spawned tasks
137+
tg.cancel_scope.cancel()
138+
self._task_group = None
139+
# Clear any remaining server instances
140+
self._server_instances.clear()
141+
142+
async def _terminate_active_sessions(self) -> None:
143+
"""Terminate tracked transports before cancelling their task group."""
144+
for transport in list(self._server_instances.values()):
145+
try:
146+
await transport.terminate()
147+
except Exception:
148+
logger.exception("Error terminating StreamableHTTP session during shutdown")
138149

139150
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
140151
"""Process ASGI request with proper session handling and transport setup.

tests/server/test_streamable_http_manager.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import json
44
import logging
5-
from typing import Any
5+
from types import SimpleNamespace
6+
from typing import Any, cast
67
from unittest.mock import AsyncMock, patch
78

89
import anyio
@@ -62,6 +63,50 @@ async def try_run():
6263
assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(errors[0])
6364

6465

66+
@pytest.mark.anyio
67+
async def test_run_terminates_active_streaming_session_before_shutdown():
68+
"""run() should close active SSE transports before task cancellation."""
69+
app = Server("test-shutdown-cleanup")
70+
manager = StreamableHTTPSessionManager(app=app)
71+
transport = StreamableHTTPServerTransport(mcp_session_id="session-id")
72+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](1)
73+
74+
try:
75+
transport._sse_stream_writers["request-id"] = sse_stream_writer
76+
77+
async with manager.run():
78+
manager._server_instances["session-id"] = transport
79+
80+
assert transport.is_terminated
81+
assert transport._sse_stream_writers == {}
82+
assert manager._server_instances == {}
83+
with pytest.raises(anyio.ClosedResourceError):
84+
await sse_stream_writer.send({"data": "still-open"})
85+
finally:
86+
await sse_stream_reader.aclose()
87+
88+
89+
@pytest.mark.anyio
90+
async def test_run_terminates_remaining_sessions_if_one_shutdown_fails(caplog: pytest.LogCaptureFixture):
91+
"""One failed transport shutdown should not skip later active sessions."""
92+
app = Server("test-shutdown-cleanup-error")
93+
manager = StreamableHTTPSessionManager(app=app)
94+
failing_terminate = AsyncMock(side_effect=RuntimeError("terminate failed"))
95+
healthy_terminate = AsyncMock()
96+
failing_transport = cast(StreamableHTTPServerTransport, SimpleNamespace(terminate=failing_terminate))
97+
healthy_transport = cast(StreamableHTTPServerTransport, SimpleNamespace(terminate=healthy_terminate))
98+
99+
with caplog.at_level(logging.ERROR):
100+
async with manager.run():
101+
manager._server_instances["bad-session"] = failing_transport
102+
manager._server_instances["healthy-session"] = healthy_transport
103+
104+
failing_terminate.assert_awaited_once_with()
105+
healthy_terminate.assert_awaited_once_with()
106+
assert "Error terminating StreamableHTTP session during shutdown" in caplog.text
107+
assert manager._server_instances == {}
108+
109+
65110
@pytest.mark.anyio
66111
async def test_handle_request_without_run_raises_error():
67112
"""Test that handle_request raises error if run() hasn't been called."""
@@ -269,6 +314,26 @@ async def mock_receive():
269314
assert len(transport._request_streams) == 0, "Transport should have no active request streams"
270315

271316

317+
@pytest.mark.anyio
318+
async def test_transport_terminate_closes_sse_stream_writers():
319+
"""terminate() should close active SSE writers so streaming responses can finish."""
320+
transport = StreamableHTTPServerTransport(mcp_session_id="test-session")
321+
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](1)
322+
323+
try:
324+
transport._sse_stream_writers["request-id"] = sse_stream_writer
325+
326+
await transport.terminate()
327+
328+
assert transport._sse_stream_writers == {}
329+
with pytest.raises(anyio.ClosedResourceError):
330+
await sse_stream_writer.send({"data": "still-open"})
331+
332+
await transport.terminate()
333+
finally:
334+
await sse_stream_reader.aclose()
335+
336+
272337
@pytest.mark.anyio
273338
async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture):
274339
"""Test that requests with unknown session IDs return HTTP 404 per MCP spec."""

0 commit comments

Comments
 (0)