|
2 | 2 |
|
3 | 3 | import json |
4 | 4 | import logging |
5 | | -from typing import Any |
| 5 | +from types import SimpleNamespace |
| 6 | +from typing import Any, cast |
6 | 7 | from unittest.mock import AsyncMock, patch |
7 | 8 |
|
8 | 9 | import anyio |
@@ -62,6 +63,50 @@ async def try_run(): |
62 | 63 | assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(errors[0]) |
63 | 64 |
|
64 | 65 |
|
| 66 | +@pytest.mark.anyio |
| 67 | +async def test_run_terminates_active_streaming_session_before_shutdown(): |
| 68 | + """run() should close active SSE transports before task cancellation.""" |
| 69 | + app = Server("test-shutdown-cleanup") |
| 70 | + manager = StreamableHTTPSessionManager(app=app) |
| 71 | + transport = StreamableHTTPServerTransport(mcp_session_id="session-id") |
| 72 | + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](1) |
| 73 | + |
| 74 | + try: |
| 75 | + transport._sse_stream_writers["request-id"] = sse_stream_writer |
| 76 | + |
| 77 | + async with manager.run(): |
| 78 | + manager._server_instances["session-id"] = transport |
| 79 | + |
| 80 | + assert transport.is_terminated |
| 81 | + assert transport._sse_stream_writers == {} |
| 82 | + assert manager._server_instances == {} |
| 83 | + with pytest.raises(anyio.ClosedResourceError): |
| 84 | + await sse_stream_writer.send({"data": "still-open"}) |
| 85 | + finally: |
| 86 | + await sse_stream_reader.aclose() |
| 87 | + |
| 88 | + |
| 89 | +@pytest.mark.anyio |
| 90 | +async def test_run_terminates_remaining_sessions_if_one_shutdown_fails(caplog: pytest.LogCaptureFixture): |
| 91 | + """One failed transport shutdown should not skip later active sessions.""" |
| 92 | + app = Server("test-shutdown-cleanup-error") |
| 93 | + manager = StreamableHTTPSessionManager(app=app) |
| 94 | + failing_terminate = AsyncMock(side_effect=RuntimeError("terminate failed")) |
| 95 | + healthy_terminate = AsyncMock() |
| 96 | + failing_transport = cast(StreamableHTTPServerTransport, SimpleNamespace(terminate=failing_terminate)) |
| 97 | + healthy_transport = cast(StreamableHTTPServerTransport, SimpleNamespace(terminate=healthy_terminate)) |
| 98 | + |
| 99 | + with caplog.at_level(logging.ERROR): |
| 100 | + async with manager.run(): |
| 101 | + manager._server_instances["bad-session"] = failing_transport |
| 102 | + manager._server_instances["healthy-session"] = healthy_transport |
| 103 | + |
| 104 | + failing_terminate.assert_awaited_once_with() |
| 105 | + healthy_terminate.assert_awaited_once_with() |
| 106 | + assert "Error terminating StreamableHTTP session during shutdown" in caplog.text |
| 107 | + assert manager._server_instances == {} |
| 108 | + |
| 109 | + |
65 | 110 | @pytest.mark.anyio |
66 | 111 | async def test_handle_request_without_run_raises_error(): |
67 | 112 | """Test that handle_request raises error if run() hasn't been called.""" |
@@ -269,6 +314,26 @@ async def mock_receive(): |
269 | 314 | assert len(transport._request_streams) == 0, "Transport should have no active request streams" |
270 | 315 |
|
271 | 316 |
|
| 317 | +@pytest.mark.anyio |
| 318 | +async def test_transport_terminate_closes_sse_stream_writers(): |
| 319 | + """terminate() should close active SSE writers so streaming responses can finish.""" |
| 320 | + transport = StreamableHTTPServerTransport(mcp_session_id="test-session") |
| 321 | + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](1) |
| 322 | + |
| 323 | + try: |
| 324 | + transport._sse_stream_writers["request-id"] = sse_stream_writer |
| 325 | + |
| 326 | + await transport.terminate() |
| 327 | + |
| 328 | + assert transport._sse_stream_writers == {} |
| 329 | + with pytest.raises(anyio.ClosedResourceError): |
| 330 | + await sse_stream_writer.send({"data": "still-open"}) |
| 331 | + |
| 332 | + await transport.terminate() |
| 333 | + finally: |
| 334 | + await sse_stream_reader.aclose() |
| 335 | + |
| 336 | + |
272 | 337 | @pytest.mark.anyio |
273 | 338 | async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture): |
274 | 339 | """Test that requests with unknown session IDs return HTTP 404 per MCP spec.""" |
|
0 commit comments