Skip to content

Add idle timeout termination for StreamableHTTPServerTransport #1159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 61 additions & 13 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from contextlib import asynccontextmanager
from dataclasses import dataclass
from http import HTTPStatus
from types import TracebackType

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
Expand All @@ -23,6 +24,7 @@
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import Receive, Scope, Send
from typing_extensions import Self

from mcp.server.transport_security import (
TransportSecurityMiddleware,
Expand Down Expand Up @@ -140,6 +142,7 @@ def __init__(
is_json_response_enabled: bool = False,
event_store: EventStore | None = None,
security_settings: TransportSecuritySettings | None = None,
timeout: float | None = None,
) -> None:
"""
Initialize a new StreamableHTTP server transport.
Expand All @@ -153,6 +156,9 @@ def __init__(
resumability will be enabled, allowing clients to
reconnect and resume messages.
security_settings: Optional security settings for DNS rebinding protection.
timeout: Optional idle timeout for transport. If provided, the transport will
terminate if it remains idle for longer than the defined timeout
duration in seconds.

Raises:
ValueError: If the session ID contains invalid characters.
Expand All @@ -172,6 +178,12 @@ def __init__(
],
] = {}
self._terminated = False
self._timeout = timeout

# for idle detection
self._processing_request_count = 0
self._idle_condition = anyio.Condition()
self._has_request = False

@property
def is_terminated(self) -> bool:
Expand Down Expand Up @@ -626,6 +638,9 @@ async def terminate(self) -> None:
Once terminated, all requests with this session ID will receive 404 Not Found.
"""

if self._terminated:
return

self._terminated = True
logger.info(f"Terminating session: {self.mcp_session_id}")

Expand Down Expand Up @@ -796,6 +811,42 @@ async def send_event(event_message: EventMessage) -> None:
)
await response(request.scope, request.receive, send)

async def __aenter__(self) -> Self:
async with self._idle_condition:
self._processing_request_count += 1
self._has_request = True
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
async with self._idle_condition:
self._processing_request_count -= 1
if self._processing_request_count == 0:
self._idle_condition.notify_all()

async def _idle_timeout_terminate(self, timeout: float) -> None:
"""
Terminate the transport if it remains idle for longer than the defined timeout duration.
"""
while not self._terminated:
# wait for transport to be idle
async with self._idle_condition:
if self._processing_request_count > 0:
await self._idle_condition.wait()
self._has_request = False

# wait for idle timeout
await anyio.sleep(timeout)

# If there are no requests during the wait period, terminate the transport
if not self._has_request:
logger.debug(f"Terminating transport due to idle timeout: {self.mcp_session_id}")
await self.terminate()

@asynccontextmanager
async def connect(
self,
Expand All @@ -812,6 +863,10 @@ async def connect(
Tuple of (read_stream, write_stream) for bidirectional communication
"""

# Terminated transports should not be connected again
if self._terminated:
raise RuntimeError("Transport is terminated")

# Create the memory streams for this connection

read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
Expand Down Expand Up @@ -884,20 +939,13 @@ async def message_router():
# Start the message router
tg.start_soon(message_router)

# Start idle timeout task if timeout is set
if self._timeout is not None:
tg.start_soon(self._idle_timeout_terminate, self._timeout)

try:
# Yield the streams for the caller to use
yield read_stream, write_stream
finally:
for stream_id in list(self._request_streams.keys()):
await self._clean_up_memory_streams(stream_id)
self._request_streams.clear()

# Clean up the read and write streams
try:
await read_stream_writer.aclose()
await read_stream.aclose()
await write_stream_reader.aclose()
await write_stream.aclose()
except Exception as e:
# During cleanup, we catch all exceptions since streams might be in various states
logger.debug(f"Error closing streams: {e}")
# Terminate the transport when the context manager exits
await self.terminate()
34 changes: 22 additions & 12 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class StreamableHTTPSessionManager:
json_response: Whether to use JSON responses instead of SSE streams
stateless: If True, creates a completely fresh transport for each request
with no session tracking or state persistence between requests.
timeout: Optional idle timeout for the stateful transport. If specified,
the stateful transport will terminate if it remains idle for longer
than the defined timeout duration in seconds.
"""

def __init__(
Expand All @@ -60,12 +63,14 @@ def __init__(
event_store: EventStore | None = None,
json_response: bool = False,
stateless: bool = False,
timeout: float | None = None,
security_settings: TransportSecuritySettings | None = None,
):
self.app = app
self.event_store = event_store
self.json_response = json_response
self.stateless = stateless
self.timeout = timeout
self.security_settings = security_settings

# Session tracking (only used if not stateless)
Expand Down Expand Up @@ -187,11 +192,12 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA
# Start the server task
await self._task_group.start(run_stateless_server)

# Handle the HTTP request and return the response
await http_transport.handle_request(scope, receive, send)

# Terminate the transport after the request is handled
await http_transport.terminate()
try:
# Handle the HTTP request and return the response
await http_transport.handle_request(scope, receive, send)
finally:
# Terminate the transport after the request is handled
await http_transport.terminate()

async def _handle_stateful_request(
self,
Expand All @@ -214,7 +220,8 @@ async def _handle_stateful_request(
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances:
transport = self._server_instances[request_mcp_session_id]
logger.debug("Session already exists, handling request directly")
await transport.handle_request(scope, receive, send)
async with transport:
await transport.handle_request(scope, receive, send)
return

if request_mcp_session_id is None:
Expand All @@ -227,6 +234,7 @@ async def _handle_stateful_request(
is_json_response_enabled=self.json_response,
event_store=self.event_store, # May be None (no resumability)
security_settings=self.security_settings,
timeout=self.timeout,
)

assert http_transport.mcp_session_id is not None
Expand All @@ -251,11 +259,11 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
exc_info=True,
)
finally:
# Only remove from instances if not terminated
# remove from instances, we do not need to terminate the transport
# as it will be terminated when the context manager exits
if (
http_transport.mcp_session_id
and http_transport.mcp_session_id in self._server_instances
and not http_transport.is_terminated
):
logger.info(
"Cleaning up crashed session "
Expand All @@ -270,11 +278,13 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
await self._task_group.start(run_server)

# Handle the HTTP request and return the response
await http_transport.handle_request(scope, receive, send)
async with http_transport:
await http_transport.handle_request(scope, receive, send)
else:
# Invalid session ID
# Client may send a outdated session ID
# We should return 404 to notify the client to start a new session
response = Response(
"Bad Request: No valid session ID provided",
status_code=HTTPStatus.BAD_REQUEST,
"Not Found: Session has been terminated",
status_code=HTTPStatus.NOT_FOUND,
)
await response(scope, receive, send)
67 changes: 67 additions & 0 deletions tests/server/test_streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,70 @@ async def mock_receive():

# Verify internal state is cleaned up
assert len(transport._request_streams) == 0, "Transport should have no active request streams"


@pytest.mark.anyio
async def test_stateful_session_cleanup_on_idle_timeout():
"""Test that stateful sessions are cleaned up when idle timeout with real transports and sessions."""
app = Server("test-stateful-idle-timeout")
manager = StreamableHTTPSessionManager(app=app, timeout=0.01)

created_transports: list[streamable_http_manager.StreamableHTTPServerTransport] = []

original_transport_constructor = streamable_http_manager.StreamableHTTPServerTransport

def track_transport(*args, **kwargs):
transport = original_transport_constructor(*args, **kwargs)
created_transports.append(transport)
return transport

with patch.object(streamable_http_manager, "StreamableHTTPServerTransport", side_effect=track_transport):
async with manager.run():
sent_messages = []

async def mock_send(message):
sent_messages.append(message)

scope = {
"type": "http",
"method": "POST",
"path": "/mcp",
"headers": [(b"content-type", b"application/json")],
}

async def mock_receive():
return {"type": "http.request", "body": b"", "more_body": False}

# Trigger session creation
await manager.handle_request(scope, mock_receive, mock_send)

session_id = None
for msg in sent_messages:
if msg["type"] == "http.response.start":
for header_name, header_value in msg.get("headers", []):
if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower():
session_id = header_value.decode()
break
if session_id: # Break outer loop if session_id is found
break

assert session_id is not None, "Session ID not found in response headers"

assert len(created_transports) == 1, "Should have created one transport"

transport = created_transports[0]

# the transport should not be terminated before idle timeout
assert not transport.is_terminated, "Transport should not be terminated before idle timeout"
assert session_id in manager._server_instances, (
"Session ID should be tracked in _server_instances before idle timeout"
)

# wait for idle timeout
await anyio.sleep(0.1)

assert transport.is_terminated, "Transport should be terminated after idle timeout"
assert session_id not in manager._server_instances, (
"Session ID should be removed from _server_instances after idle timeout"
)
assert not manager._server_instances, "No sessions should be tracked after the only session idle timeout"
Loading