diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 32b63c1ae..c79c1e378 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -15,6 +15,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from http import HTTPStatus +from types import TracebackType import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -23,6 +24,7 @@ from starlette.requests import Request from starlette.responses import Response from starlette.types import Receive, Scope, Send +from typing_extensions import Self from mcp.server.transport_security import ( TransportSecurityMiddleware, @@ -140,6 +142,7 @@ def __init__( is_json_response_enabled: bool = False, event_store: EventStore | None = None, security_settings: TransportSecuritySettings | None = None, + timeout: float | None = None, ) -> None: """ Initialize a new StreamableHTTP server transport. @@ -153,6 +156,9 @@ def __init__( resumability will be enabled, allowing clients to reconnect and resume messages. security_settings: Optional security settings for DNS rebinding protection. + timeout: Optional idle timeout for transport. If provided, the transport will + terminate if it remains idle for longer than the defined timeout + duration in seconds. Raises: ValueError: If the session ID contains invalid characters. @@ -172,6 +178,12 @@ def __init__( ], ] = {} self._terminated = False + self._timeout = timeout + + # for idle detection + self._processing_request_count = 0 + self._idle_condition = anyio.Condition() + self._has_request = False @property def is_terminated(self) -> bool: @@ -626,6 +638,9 @@ async def terminate(self) -> None: Once terminated, all requests with this session ID will receive 404 Not Found. """ + if self._terminated: + return + self._terminated = True logger.info(f"Terminating session: {self.mcp_session_id}") @@ -796,6 +811,42 @@ async def send_event(event_message: EventMessage) -> None: ) await response(request.scope, request.receive, send) + async def __aenter__(self) -> Self: + async with self._idle_condition: + self._processing_request_count += 1 + self._has_request = True + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + async with self._idle_condition: + self._processing_request_count -= 1 + if self._processing_request_count == 0: + self._idle_condition.notify_all() + + async def _idle_timeout_terminate(self, timeout: float) -> None: + """ + Terminate the transport if it remains idle for longer than the defined timeout duration. + """ + while not self._terminated: + # wait for transport to be idle + async with self._idle_condition: + if self._processing_request_count > 0: + await self._idle_condition.wait() + self._has_request = False + + # wait for idle timeout + await anyio.sleep(timeout) + + # If there are no requests during the wait period, terminate the transport + if not self._has_request: + logger.debug(f"Terminating transport due to idle timeout: {self.mcp_session_id}") + await self.terminate() + @asynccontextmanager async def connect( self, @@ -812,6 +863,10 @@ async def connect( Tuple of (read_stream, write_stream) for bidirectional communication """ + # Terminated transports should not be connected again + if self._terminated: + raise RuntimeError("Transport is terminated") + # Create the memory streams for this connection read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) @@ -884,20 +939,13 @@ async def message_router(): # Start the message router tg.start_soon(message_router) + # Start idle timeout task if timeout is set + if self._timeout is not None: + tg.start_soon(self._idle_timeout_terminate, self._timeout) + try: # Yield the streams for the caller to use yield read_stream, write_stream finally: - for stream_id in list(self._request_streams.keys()): - await self._clean_up_memory_streams(stream_id) - self._request_streams.clear() - - # Clean up the read and write streams - try: - await read_stream_writer.aclose() - await read_stream.aclose() - await write_stream_reader.aclose() - await write_stream.aclose() - except Exception as e: - # During cleanup, we catch all exceptions since streams might be in various states - logger.debug(f"Error closing streams: {e}") + # Terminate the transport when the context manager exits + await self.terminate() diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index f38e6afec..ced6a879e 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -52,6 +52,9 @@ class StreamableHTTPSessionManager: json_response: Whether to use JSON responses instead of SSE streams stateless: If True, creates a completely fresh transport for each request with no session tracking or state persistence between requests. + timeout: Optional idle timeout for the stateful transport. If specified, + the stateful transport will terminate if it remains idle for longer + than the defined timeout duration in seconds. """ def __init__( @@ -60,12 +63,14 @@ def __init__( event_store: EventStore | None = None, json_response: bool = False, stateless: bool = False, + timeout: float | None = None, security_settings: TransportSecuritySettings | None = None, ): self.app = app self.event_store = event_store self.json_response = json_response self.stateless = stateless + self.timeout = timeout self.security_settings = security_settings # Session tracking (only used if not stateless) @@ -187,11 +192,12 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA # Start the server task await self._task_group.start(run_stateless_server) - # Handle the HTTP request and return the response - await http_transport.handle_request(scope, receive, send) - - # Terminate the transport after the request is handled - await http_transport.terminate() + try: + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) + finally: + # Terminate the transport after the request is handled + await http_transport.terminate() async def _handle_stateful_request( self, @@ -214,7 +220,8 @@ async def _handle_stateful_request( if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: transport = self._server_instances[request_mcp_session_id] logger.debug("Session already exists, handling request directly") - await transport.handle_request(scope, receive, send) + async with transport: + await transport.handle_request(scope, receive, send) return if request_mcp_session_id is None: @@ -227,6 +234,7 @@ async def _handle_stateful_request( is_json_response_enabled=self.json_response, event_store=self.event_store, # May be None (no resumability) security_settings=self.security_settings, + timeout=self.timeout, ) assert http_transport.mcp_session_id is not None @@ -251,11 +259,11 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE exc_info=True, ) finally: - # Only remove from instances if not terminated + # remove from instances, we do not need to terminate the transport + # as it will be terminated when the context manager exits if ( http_transport.mcp_session_id and http_transport.mcp_session_id in self._server_instances - and not http_transport.is_terminated ): logger.info( "Cleaning up crashed session " @@ -270,11 +278,13 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE await self._task_group.start(run_server) # Handle the HTTP request and return the response - await http_transport.handle_request(scope, receive, send) + async with http_transport: + await http_transport.handle_request(scope, receive, send) else: - # Invalid session ID + # Client may send a outdated session ID + # We should return 404 to notify the client to start a new session response = Response( - "Bad Request: No valid session ID provided", - status_code=HTTPStatus.BAD_REQUEST, + "Not Found: Session has been terminated", + status_code=HTTPStatus.NOT_FOUND, ) await response(scope, receive, send) diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 9a4c695b8..70a547ca3 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -260,3 +260,70 @@ async def mock_receive(): # Verify internal state is cleaned up assert len(transport._request_streams) == 0, "Transport should have no active request streams" + + +@pytest.mark.anyio +async def test_stateful_session_cleanup_on_idle_timeout(): + """Test that stateful sessions are cleaned up when idle timeout with real transports and sessions.""" + app = Server("test-stateful-idle-timeout") + manager = StreamableHTTPSessionManager(app=app, timeout=0.01) + + created_transports: list[streamable_http_manager.StreamableHTTPServerTransport] = [] + + original_transport_constructor = streamable_http_manager.StreamableHTTPServerTransport + + def track_transport(*args, **kwargs): + transport = original_transport_constructor(*args, **kwargs) + created_transports.append(transport) + return transport + + with patch.object(streamable_http_manager, "StreamableHTTPServerTransport", side_effect=track_transport): + async with manager.run(): + sent_messages = [] + + async def mock_send(message): + sent_messages.append(message) + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp", + "headers": [(b"content-type", b"application/json")], + } + + async def mock_receive(): + return {"type": "http.request", "body": b"", "more_body": False} + + # Trigger session creation + await manager.handle_request(scope, mock_receive, mock_send) + + session_id = None + for msg in sent_messages: + if msg["type"] == "http.response.start": + for header_name, header_value in msg.get("headers", []): + if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower(): + session_id = header_value.decode() + break + if session_id: # Break outer loop if session_id is found + break + + assert session_id is not None, "Session ID not found in response headers" + + assert len(created_transports) == 1, "Should have created one transport" + + transport = created_transports[0] + + # the transport should not be terminated before idle timeout + assert not transport.is_terminated, "Transport should not be terminated before idle timeout" + assert session_id in manager._server_instances, ( + "Session ID should be tracked in _server_instances before idle timeout" + ) + + # wait for idle timeout + await anyio.sleep(0.1) + + assert transport.is_terminated, "Transport should be terminated after idle timeout" + assert session_id not in manager._server_instances, ( + "Session ID should be removed from _server_instances after idle timeout" + ) + assert not manager._server_instances, "No sessions should be tracked after the only session idle timeout"