|
27 | 27 | from starlette.requests import Request |
28 | 28 | from starlette.routing import Mount |
29 | 29 |
|
| 30 | +import mcp.client.streamable_http as streamable_http |
30 | 31 | from mcp import MCPError, types |
31 | 32 | from mcp.client.session import ClientSession |
32 | 33 | from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client |
@@ -139,6 +140,39 @@ async def replay_events_after( # pragma: no cover |
139 | 140 | return target_stream_id |
140 | 141 |
|
141 | 142 |
|
| 143 | +class FakeStreamResponse(httpx.Response): |
| 144 | + def __init__(self) -> None: |
| 145 | + super().__init__( |
| 146 | + 200, |
| 147 | + request=httpx.Request("POST", "http://localhost:8000/mcp"), |
| 148 | + ) |
| 149 | + self.close_count = 0 |
| 150 | + |
| 151 | + async def aclose(self) -> None: # pragma: no cover |
| 152 | + self.close_count += 1 |
| 153 | + |
| 154 | + |
| 155 | +class FakeEventSource: |
| 156 | + def __init__(self, events: list[ServerSentEvent]) -> None: |
| 157 | + self.response = FakeStreamResponse() |
| 158 | + self.events = events |
| 159 | + self.seen = 0 |
| 160 | + |
| 161 | + async def aiter_sse(self) -> AsyncIterator[ServerSentEvent]: |
| 162 | + for event in self.events: |
| 163 | + self.seen += 1 |
| 164 | + yield event |
| 165 | + |
| 166 | + |
| 167 | +def jsonrpc_response_event(request_id: str, event_id: str) -> ServerSentEvent: |
| 168 | + return ServerSentEvent( |
| 169 | + event="message", |
| 170 | + data=json.dumps({"jsonrpc": "2.0", "id": request_id, "result": {}}), |
| 171 | + id=event_id, |
| 172 | + retry=None, |
| 173 | + ) |
| 174 | + |
| 175 | + |
142 | 176 | @dataclass |
143 | 177 | class ServerState: |
144 | 178 | lock: anyio.Event = field(default_factory=anyio.Event) |
@@ -1803,6 +1837,94 @@ async def test_handle_sse_event_skips_empty_data(): |
1803 | 1837 | await read_stream.aclose() |
1804 | 1838 |
|
1805 | 1839 |
|
| 1840 | +@pytest.mark.anyio |
| 1841 | +async def test_handle_sse_response_drains_after_terminal_event(monkeypatch: pytest.MonkeyPatch): |
| 1842 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 1843 | + response = FakeStreamResponse() |
| 1844 | + event_source = FakeEventSource( |
| 1845 | + [ |
| 1846 | + jsonrpc_response_event("request-1", "event-1"), |
| 1847 | + ServerSentEvent(event="message", data="", id="event-2", retry=None), |
| 1848 | + ] |
| 1849 | + ) |
| 1850 | + |
| 1851 | + def event_source_factory(_response: httpx.Response) -> FakeEventSource: |
| 1852 | + return event_source |
| 1853 | + |
| 1854 | + monkeypatch.setattr(streamable_http, "EventSource", event_source_factory) |
| 1855 | + |
| 1856 | + write_stream, read_stream = create_context_streams[SessionMessage | Exception](1) |
| 1857 | + try: |
| 1858 | + async with httpx.AsyncClient() as client: |
| 1859 | + ctx = streamable_http.RequestContext( |
| 1860 | + client=client, |
| 1861 | + session_id=None, |
| 1862 | + session_message=SessionMessage( |
| 1863 | + JSONRPCRequest(jsonrpc="2.0", id="request-1", method="tools/call", params={}) |
| 1864 | + ), |
| 1865 | + metadata=None, |
| 1866 | + read_stream_writer=write_stream, |
| 1867 | + ) |
| 1868 | + |
| 1869 | + await transport._handle_sse_response(response, ctx) |
| 1870 | + |
| 1871 | + received = await read_stream.receive() |
| 1872 | + assert isinstance(received, SessionMessage) |
| 1873 | + assert isinstance(received.message, types.JSONRPCResponse) |
| 1874 | + assert received.message.id == "request-1" |
| 1875 | + assert event_source.seen == 2 |
| 1876 | + assert response.close_count == 0 |
| 1877 | + finally: |
| 1878 | + await write_stream.aclose() |
| 1879 | + await read_stream.aclose() |
| 1880 | + |
| 1881 | + |
| 1882 | +@pytest.mark.anyio |
| 1883 | +async def test_reconnection_drains_after_terminal_event(monkeypatch: pytest.MonkeyPatch): |
| 1884 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 1885 | + event_source = FakeEventSource( |
| 1886 | + [ |
| 1887 | + jsonrpc_response_event("request-1", "event-2"), |
| 1888 | + ServerSentEvent(event="message", data="", id="event-3", retry=None), |
| 1889 | + ] |
| 1890 | + ) |
| 1891 | + |
| 1892 | + async def sleep_noop(_delay: float) -> None: |
| 1893 | + pass |
| 1894 | + |
| 1895 | + @asynccontextmanager |
| 1896 | + async def connect_sse(*args: Any, **kwargs: Any) -> AsyncIterator[FakeEventSource]: |
| 1897 | + yield event_source |
| 1898 | + |
| 1899 | + monkeypatch.setattr(streamable_http.anyio, "sleep", sleep_noop) |
| 1900 | + monkeypatch.setattr(streamable_http, "aconnect_sse", connect_sse) |
| 1901 | + |
| 1902 | + write_stream, read_stream = create_context_streams[SessionMessage | Exception](1) |
| 1903 | + try: |
| 1904 | + async with httpx.AsyncClient() as client: |
| 1905 | + ctx = streamable_http.RequestContext( |
| 1906 | + client=client, |
| 1907 | + session_id=None, |
| 1908 | + session_message=SessionMessage( |
| 1909 | + JSONRPCRequest(jsonrpc="2.0", id="request-1", method="tools/call", params={}) |
| 1910 | + ), |
| 1911 | + metadata=None, |
| 1912 | + read_stream_writer=write_stream, |
| 1913 | + ) |
| 1914 | + |
| 1915 | + await transport._handle_reconnection(ctx, last_event_id="event-1") |
| 1916 | + |
| 1917 | + received = await read_stream.receive() |
| 1918 | + assert isinstance(received, SessionMessage) |
| 1919 | + assert isinstance(received.message, types.JSONRPCResponse) |
| 1920 | + assert received.message.id == "request-1" |
| 1921 | + assert event_source.seen == 2 |
| 1922 | + assert event_source.response.close_count == 0 |
| 1923 | + finally: |
| 1924 | + await write_stream.aclose() |
| 1925 | + await read_stream.aclose() |
| 1926 | + |
| 1927 | + |
1806 | 1928 | @pytest.mark.anyio |
1807 | 1929 | async def test_priming_event_not_sent_for_old_protocol_version(): |
1808 | 1930 | """Test that _maybe_send_priming_event skips for old protocol versions (backwards compat).""" |
|
0 commit comments