From 9bfd6c9788df1d68bde0e453c9390be5dd36af38 Mon Sep 17 00:00:00 2001 From: Pawan Bhardwaj Date: Sat, 23 May 2026 10:57:19 +0000 Subject: [PATCH 1/4] PR 2452 --- src/mcp/shared/direct_dispatcher.py | 173 +++++++++++++++++++ src/mcp/shared/dispatcher.py | 146 ++++++++++++++++ src/mcp/shared/exceptions.py | 21 ++- src/mcp/shared/transport_context.py | 30 ++++ src/mcp/types/__init__.py | 2 + src/mcp/types/jsonrpc.py | 1 + tests/shared/test_dispatcher.py | 251 ++++++++++++++++++++++++++++ 7 files changed, 623 insertions(+), 1 deletion(-) create mode 100644 src/mcp/shared/direct_dispatcher.py create mode 100644 src/mcp/shared/dispatcher.py create mode 100644 src/mcp/shared/transport_context.py create mode 100644 tests/shared/test_dispatcher.py diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py new file mode 100644 index 0000000000..bb5639a136 --- /dev/null +++ b/src/mcp/shared/direct_dispatcher.py @@ -0,0 +1,173 @@ +"""In-memory `Dispatcher` that wires two peers together with no transport. + +`DirectDispatcher` is the simplest possible `Dispatcher` implementation: a +request on one side directly invokes the other side's `on_request`. There is no +serialization, no JSON-RPC framing, and no streams. It exists to: + +* prove the `Dispatcher` Protocol is implementable without JSON-RPC +* provide a fast substrate for testing the layers above the dispatcher + (`ServerRunner`, `Context`, `Connection`) without wire-level moving parts +* embed a server in-process when the JSON-RPC overhead is unnecessary + +Unlike `JSONRPCDispatcher`, exceptions raised in a handler propagate directly +to the caller — there is no exception-to-`ErrorData` boundary here. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Mapping +from dataclasses import dataclass, field +from typing import Any + +import anyio + +from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT +from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.transport_context import TransportContext +from mcp.types import INTERNAL_ERROR, REQUEST_TIMEOUT + +__all__ = ["DirectDispatcher", "create_direct_dispatcher_pair"] + +DIRECT_TRANSPORT_KIND = "direct" + + +_Request = Callable[[str, Mapping[str, Any] | None, CallOptions | None], Awaitable[dict[str, Any]]] +_Notify = Callable[[str, Mapping[str, Any] | None], Awaitable[None]] + + +@dataclass +class _DirectDispatchContext: + """`DispatchContext` for an inbound request on a `DirectDispatcher`. + + The back-channel callables target the *originating* side, so a handler's + `send_raw_request` reaches the peer that made the inbound request. + """ + + transport: TransportContext + _back_request: _Request + _back_notify: _Notify + _on_progress: ProgressFnT | None = None + cancel_requested: anyio.Event = field(default_factory=anyio.Event) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._back_notify(method, params) + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if not self.transport.can_send_request: + raise NoBackChannelError(method) + return await self._back_request(method, params, opts) + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + if self._on_progress is not None: + await self._on_progress(progress, total, message) + + +class DirectDispatcher: + """A `Dispatcher` that calls a peer's handlers directly, in-process. + + Two instances are wired together with `create_direct_dispatcher_pair`; each + holds a reference to the other. `send_raw_request` on one awaits the peer's + `on_request`. `run` parks until `close` is called. + """ + + def __init__(self, transport_ctx: TransportContext): + self._transport_ctx = transport_ctx + self._peer: DirectDispatcher | None = None + self._on_request: OnRequest | None = None + self._on_notify: OnNotify | None = None + self._ready = anyio.Event() + self._closed = anyio.Event() + + def connect_to(self, peer: DirectDispatcher) -> None: + self._peer = peer + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if self._peer is None: + raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") + return await self._peer._dispatch_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + if self._peer is None: + raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") + await self._peer._dispatch_notify(method, params) + + async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: + self._on_request = on_request + self._on_notify = on_notify + self._ready.set() + await self._closed.wait() + + def close(self) -> None: + self._closed.set() + + def _make_context(self, on_progress: ProgressFnT | None = None) -> _DirectDispatchContext: + assert self._peer is not None + peer = self._peer + return _DirectDispatchContext( + transport=self._transport_ctx, + _back_request=lambda m, p, o: peer._dispatch_request(m, p, o), + _back_notify=lambda m, p: peer._dispatch_notify(m, p), + _on_progress=on_progress, + ) + + async def _dispatch_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None, + ) -> dict[str, Any]: + await self._ready.wait() + assert self._on_request is not None + opts = opts or {} + dctx = self._make_context(on_progress=opts.get("on_progress")) + try: + with anyio.fail_after(opts.get("timeout")): + try: + return await self._on_request(dctx, method, params) + except MCPError: + raise + except Exception as e: + raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e + except TimeoutError: + raise MCPError( + code=REQUEST_TIMEOUT, + message=f"Timed out after {opts.get('timeout')}s waiting for {method!r}", + ) from None + + async def _dispatch_notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._ready.wait() + assert self._on_notify is not None + dctx = self._make_context() + await self._on_notify(dctx, method, params) + + +def create_direct_dispatcher_pair( + *, + can_send_request: bool = True, +) -> tuple[DirectDispatcher, DirectDispatcher]: + """Create two `DirectDispatcher` instances wired to each other. + + Args: + can_send_request: Sets `TransportContext.can_send_request` on both + sides. Pass ``False`` to simulate a transport with no back-channel. + + Returns: + A ``(left, right)`` pair. Conventionally ``left`` is the client side + and ``right`` is the server side, but the wiring is symmetric. + """ + ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request) + left = DirectDispatcher(ctx) + right = DirectDispatcher(ctx) + left.connect_to(right) + right.connect_to(left) + return left, right diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py new file mode 100644 index 0000000000..ee02e23896 --- /dev/null +++ b/src/mcp/shared/dispatcher.py @@ -0,0 +1,146 @@ +"""Dispatcher Protocol — the call/return boundary between transports and handlers. + +A Dispatcher turns a duplex message channel into two things: + +* an outbound API: ``send_raw_request(method, params)`` and ``notify(method, params)`` +* an inbound pump: ``run(on_request, on_notify)`` that drives the receive loop + and invokes the supplied handlers for each incoming request/notification + +It is deliberately *not* MCP-aware. Method names are strings, params and +results are ``dict[str, Any]``. The MCP type layer (request/result models, +capability negotiation, ``Context``) sits above this; the wire encoding +(JSON-RPC, gRPC, in-process direct calls) sits below it. + +See ``JSONRPCDispatcher`` for the production implementation and +``DirectDispatcher`` for an in-memory implementation used in tests and for +embedding a server in-process. +""" + +from collections.abc import Awaitable, Callable, Mapping +from typing import Any, Protocol, TypedDict, TypeVar, runtime_checkable + +import anyio + +from mcp.shared.transport_context import TransportContext + +__all__ = [ + "CallOptions", + "DispatchContext", + "DispatchMiddleware", + "Dispatcher", + "OnNotify", + "OnRequest", + "Outbound", + "ProgressFnT", +] + +TransportT_co = TypeVar("TransportT_co", bound=TransportContext, covariant=True) + + +class ProgressFnT(Protocol): + """Callback invoked when a progress notification arrives for a pending request.""" + + async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ... + + +class CallOptions(TypedDict, total=False): + """Per-call options for `Outbound.send_raw_request`. + + All keys are optional. Dispatchers ignore keys they do not understand. + """ + + timeout: float + """Seconds to wait for a result before raising and sending ``notifications/cancelled``.""" + + on_progress: ProgressFnT + """Receive ``notifications/progress`` updates for this request.""" + + resumption_token: str + """Opaque token to resume a previously interrupted request (transport-dependent).""" + + on_resumption_token: Callable[[str], Awaitable[None]] + """Receive a resumption token when the transport issues one.""" + + +@runtime_checkable +class Outbound(Protocol): + """Anything that can send requests and notifications to the peer. + + Both `Dispatcher` (top-level outbound) and `DispatchContext` (back-channel + during an inbound request) extend this. The MCP type layer (`PeerMixin`, + `Connection`, `Context`) builds typed ``send_request`` / convenience methods + on top of this raw channel. + """ + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a request and await its raw result dict. + + Raises: + MCPError: If the peer responded with an error, or the handler + raised. Implementations normalize all handler exceptions to + `MCPError` so callers see a single exception type. + """ + ... + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a fire-and-forget notification.""" + ... + + +class DispatchContext(Outbound, Protocol[TransportT_co]): + """Per-request context handed to ``on_request`` / ``on_notify``. + + Carries the transport metadata for the inbound message and provides the + back-channel for sending requests/notifications to the peer while handling + it. `send_raw_request` raises `NoBackChannelError` if + ``transport.can_send_request`` is ``False``. + """ + + @property + def transport(self) -> TransportT_co: + """Transport-specific metadata for this inbound message.""" + ... + + @property + def cancel_requested(self) -> anyio.Event: + """Set when the peer sends ``notifications/cancelled`` for this request.""" + ... + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + """Report progress for the inbound request, if the peer supplied a progress token. + + A no-op when no token was supplied. + """ + ... + + +OnRequest = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[dict[str, Any]]] +"""Handler for inbound requests: ``(ctx, method, params) -> result``. Raise ``MCPError`` to send an error response.""" + +OnNotify = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[None]] +"""Handler for inbound notifications: ``(ctx, method, params)``.""" + +DispatchMiddleware = Callable[[OnRequest], OnRequest] +"""Wraps an ``OnRequest`` to produce another ``OnRequest``. Applied outermost-first.""" + + +class Dispatcher(Outbound, Protocol[TransportT_co]): + """A duplex request/notification channel with call-return semantics. + + Implementations own correlation of outbound requests to inbound results, the + receive loop, per-request concurrency, and cancellation/progress wiring. + """ + + async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: + """Drive the receive loop until the underlying channel closes. + + Each inbound request is dispatched to ``on_request`` in its own task; + the returned dict (or raised ``MCPError``) is sent back as the response. + Inbound notifications go to ``on_notify``. + """ + ... diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index f153ea319d..b62629b6c8 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -2,7 +2,7 @@ from typing import Any, cast -from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError +from mcp.types import INVALID_REQUEST, URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError class MCPError(Exception): @@ -41,6 +41,25 @@ def __str__(self) -> str: return self.message +class NoBackChannelError(MCPError): + """Raised when sending a server-initiated request over a transport that cannot deliver it. + + Stateless HTTP and JSON-response-mode HTTP have no channel for the server to + push requests (sampling, elicitation, roots/list) to the client. This is + raised by `DispatchContext.send_raw_request` when `transport.can_send_request` + is ``False``, and serializes to an ``INVALID_REQUEST`` error response. + """ + + def __init__(self, method: str): + super().__init__( + code=INVALID_REQUEST, + message=( + f"Cannot send {method!r}: this transport context has no back-channel for server-initiated requests." + ), + ) + self.method = method + + class StatelessModeNotSupported(RuntimeError): """Raised when attempting to use a method that is not supported in stateless mode. diff --git a/src/mcp/shared/transport_context.py b/src/mcp/shared/transport_context.py new file mode 100644 index 0000000000..832cead515 --- /dev/null +++ b/src/mcp/shared/transport_context.py @@ -0,0 +1,30 @@ +"""Transport-specific metadata attached to each inbound message. + +`TransportContext` is the base; each transport defines its own subclass with +whatever fields make sense (HTTP request id, ASGI scope, stdio process handle, +etc.). The dispatcher passes it through opaquely; only the layers above the +dispatcher (`ServerRunner`, `Context`, user handlers) read its concrete fields. +""" + +from dataclasses import dataclass + +__all__ = ["TransportContext"] + + +@dataclass(kw_only=True, frozen=True) +class TransportContext: + """Base transport metadata for an inbound message. + + Subclass per transport and add fields as needed. Instances are immutable. + """ + + kind: str + """Short identifier for the transport (e.g. ``"stdio"``, ``"streamable-http"``).""" + + can_send_request: bool + """Whether the transport can deliver server-initiated requests to the peer. + + ``False`` for stateless HTTP and HTTP with JSON response mode; ``True`` for + stdio, SSE, and stateful streamable HTTP. When ``False``, + `DispatchContext.send_raw_request` raises `NoBackChannelError`. + """ diff --git a/src/mcp/types/__init__.py b/src/mcp/types/__init__.py index b442303937..ca1c328939 100644 --- a/src/mcp/types/__init__.py +++ b/src/mcp/types/__init__.py @@ -192,6 +192,7 @@ INVALID_REQUEST, METHOD_NOT_FOUND, PARSE_ERROR, + REQUEST_CANCELLED, REQUEST_TIMEOUT, URL_ELICITATION_REQUIRED, ErrorData, @@ -401,6 +402,7 @@ "INVALID_REQUEST", "METHOD_NOT_FOUND", "PARSE_ERROR", + "REQUEST_CANCELLED", "REQUEST_TIMEOUT", "URL_ELICITATION_REQUIRED", "ErrorData", diff --git a/src/mcp/types/jsonrpc.py b/src/mcp/types/jsonrpc.py index 84304a37c1..14743c33b0 100644 --- a/src/mcp/types/jsonrpc.py +++ b/src/mcp/types/jsonrpc.py @@ -43,6 +43,7 @@ class JSONRPCResponse(BaseModel): # SDK error codes CONNECTION_CLOSED = -32000 REQUEST_TIMEOUT = -32001 +REQUEST_CANCELLED = -32002 # Standard JSON-RPC error codes PARSE_ERROR = -32700 diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py new file mode 100644 index 0000000000..784ef6698f --- /dev/null +++ b/tests/shared/test_dispatcher.py @@ -0,0 +1,251 @@ +"""Behavioral tests for the Dispatcher Protocol via DirectDispatcher. + +These exercise the `Dispatcher` / `DispatchContext` contract end-to-end using +the in-memory `DirectDispatcher`. JSON-RPC framing is covered separately in +``test_jsonrpc_dispatcher.py``. +""" + +from collections.abc import AsyncIterator, Mapping +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any + +import anyio +import pytest + +from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair +from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnNotify, OnRequest, Outbound +from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.transport_context import TransportContext +from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT + + +class Recorder: + def __init__(self) -> None: + self.requests: list[tuple[str, Mapping[str, Any] | None]] = [] + self.notifications: list[tuple[str, Mapping[str, Any] | None]] = [] + self.contexts: list[DispatchContext[TransportContext]] = [] + self.notified = anyio.Event() + + +def echo_handlers(recorder: Recorder) -> tuple[OnRequest, OnNotify]: + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + recorder.requests.append((method, params)) + recorder.contexts.append(ctx) + return {"echoed": method, "params": dict(params or {})} + + async def on_notify(ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None) -> None: + recorder.notifications.append((method, params)) + recorder.notified.set() + + return on_request, on_notify + + +@asynccontextmanager +async def running_pair( + *, + server_on_request: OnRequest | None = None, + server_on_notify: OnNotify | None = None, + client_on_request: OnRequest | None = None, + client_on_notify: OnNotify | None = None, + can_send_request: bool = True, +) -> AsyncIterator[tuple[DirectDispatcher, DirectDispatcher, Recorder, Recorder]]: + """Yield ``(client, server, client_recorder, server_recorder)`` with both ``run()`` loops live.""" + client, server = create_direct_dispatcher_pair(can_send_request=can_send_request) + client_rec, server_rec = Recorder(), Recorder() + c_req, c_notify = echo_handlers(client_rec) + s_req, s_notify = echo_handlers(server_rec) + async with anyio.create_task_group() as tg: + tg.start_soon(client.run, client_on_request or c_req, client_on_notify or c_notify) + tg.start_soon(server.run, server_on_request or s_req, server_on_notify or s_notify) + try: + yield client, server, client_rec, server_rec + finally: + client.close() + server.close() + + +@pytest.mark.anyio +async def test_send_raw_request_returns_result_from_peer_on_request(): + async with running_pair() as (client, _server, _crec, srec): + with anyio.fail_after(5): + result = await client.send_raw_request("tools/list", {"cursor": "abc"}) + assert result == {"echoed": "tools/list", "params": {"cursor": "abc"}} + assert srec.requests == [("tools/list", {"cursor": "abc"})] + + +@pytest.mark.anyio +async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(): + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + raise MCPError(code=INVALID_PARAMS, message="bad cursor") + + async with running_pair(server_on_request=on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", {}) + assert exc.value.error.code == INVALID_PARAMS + assert exc.value.error.message == "bad cursor" + + +@pytest.mark.anyio +async def test_send_raw_request_wraps_non_mcperror_exception_as_internal_error(): + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + raise ValueError("oops") + + async with running_pair(server_on_request=on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", {}) + assert exc.value.error.code == INTERNAL_ERROR + assert isinstance(exc.value.__cause__, ValueError) + + +@pytest.mark.anyio +async def test_send_raw_request_with_timeout_raises_mcperror_request_timeout(): + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await anyio.sleep_forever() + raise NotImplementedError + + async with running_pair(server_on_request=on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("slow", None, {"timeout": 0}) + assert exc.value.error.code == REQUEST_TIMEOUT + + +@pytest.mark.anyio +async def test_notify_invokes_peer_on_notify(): + async with running_pair() as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/initialized", {"v": 1}) + await srec.notified.wait() + assert srec.notifications == [("notifications/initialized", {"v": 1})] + + +@pytest.mark.anyio +async def test_ctx_send_raw_request_round_trips_to_calling_side(): + """A handler's ctx.send_raw_request reaches the side that made the inbound request.""" + + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + sample = await ctx.send_raw_request("sampling/createMessage", {"prompt": "hi"}) + return {"sampled": sample} + + async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): + with anyio.fail_after(5): + result = await client.send_raw_request("tools/call", None) + assert crec.requests == [("sampling/createMessage", {"prompt": "hi"})] + assert result == {"sampled": {"echoed": "sampling/createMessage", "params": {"prompt": "hi"}}} + + +@pytest.mark.anyio +async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(): + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + return await ctx.send_raw_request("sampling/createMessage", None) + + async with running_pair(server_on_request=server_on_request, can_send_request=False) as (client, *_): + with anyio.fail_after(5), pytest.raises(NoBackChannelError) as exc: + await client.send_raw_request("tools/call", None) + assert exc.value.method == "sampling/createMessage" + assert exc.value.error.code == INVALID_REQUEST + + +@pytest.mark.anyio +async def test_ctx_notify_invokes_calling_side_on_notify(): + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.notify("notifications/message", {"level": "info"}) + return {} + + async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None) + await crec.notified.wait() + assert crec.notifications == [("notifications/message", {"level": "info"})] + + +@pytest.mark.anyio +async def test_ctx_progress_invokes_caller_on_progress_callback(): + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.progress(0.5, total=1.0, message="halfway") + return {} + + received: list[tuple[float, float | None, str | None]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async with running_pair(server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None, {"on_progress": on_progress}) + assert received == [(0.5, 1.0, "halfway")] + + +@pytest.mark.anyio +async def test_send_raw_request_issued_before_peer_run_blocks_until_peer_ready(): + client, server = create_direct_dispatcher_pair() + s_req, s_notify = echo_handlers(Recorder()) + c_req, c_notify = echo_handlers(Recorder()) + + async def late_start(): + await anyio.sleep(0) + await server.run(s_req, s_notify) + + async with anyio.create_task_group() as tg: + tg.start_soon(client.run, c_req, c_notify) + tg.start_soon(late_start) + with anyio.fail_after(5): + result = await client.send_raw_request("ping", None) + assert result == {"echoed": "ping", "params": {}} + client.close() + server.close() + + +@pytest.mark.anyio +async def test_ctx_progress_is_noop_when_caller_supplied_no_callback(): + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.progress(0.5) + return {"ok": True} + + async with running_pair(server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + result = await client.send_raw_request("tools/call", None) + assert result == {"ok": True} + + +@pytest.mark.anyio +async def test_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_connected(): + d = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) + with pytest.raises(RuntimeError, match="no peer"): + await d.send_raw_request("ping", None) + with pytest.raises(RuntimeError, match="no peer"): + await d.notify("ping", None) + + +@pytest.mark.anyio +async def test_close_makes_run_return(): + client, server = create_direct_dispatcher_pair() + on_request, on_notify = echo_handlers(Recorder()) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(server.run, on_request, on_notify) + tg.start_soon(client.run, on_request, on_notify) + client.close() + server.close() + + +if TYPE_CHECKING: + _d: Dispatcher[TransportContext] = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) + _o: Outbound = _d From 0433ef5c0392553b5c9031c3b1c78f03af7d290f Mon Sep 17 00:00:00 2001 From: Pawan Bhardwaj Date: Sat, 23 May 2026 11:01:08 +0000 Subject: [PATCH 2/4] PR 2458 --- .github/workflows/main.yml | 1 - src/mcp/shared/direct_dispatcher.py | 10 +- src/mcp/shared/dispatcher.py | 13 +- src/mcp/shared/jsonrpc_dispatcher.py | 543 ++++++++++++++++++++++++ tests/shared/conftest.py | 61 +++ tests/shared/test_dispatcher.py | 135 +++--- tests/shared/test_jsonrpc_dispatcher.py | 531 +++++++++++++++++++++++ 7 files changed, 1227 insertions(+), 67 deletions(-) create mode 100644 src/mcp/shared/jsonrpc_dispatcher.py create mode 100644 tests/shared/conftest.py create mode 100644 tests/shared/test_jsonrpc_dispatcher.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d34e438fc9..341df0abb8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -5,7 +5,6 @@ on: branches: ["main", "v1.x"] tags: ["v*.*.*"] pull_request: - branches: ["main", "v1.x"] permissions: contents: read diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index bb5639a136..27443ec874 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -20,6 +20,7 @@ from typing import Any import anyio +import anyio.abc from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT from mcp.shared.exceptions import MCPError, NoBackChannelError @@ -101,10 +102,17 @@ async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") await self._peer._dispatch_notify(method, params) - async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: self._on_request = on_request self._on_notify = on_notify self._ready.set() + task_status.started() await self._closed.wait() def close(self) -> None: diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index ee02e23896..20c090323b 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -20,6 +20,7 @@ from typing import Any, Protocol, TypedDict, TypeVar, runtime_checkable import anyio +import anyio.abc from mcp.shared.transport_context import TransportContext @@ -136,11 +137,21 @@ class Dispatcher(Outbound, Protocol[TransportT_co]): receive loop, per-request concurrency, and cancellation/progress wiring. """ - async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: """Drive the receive loop until the underlying channel closes. Each inbound request is dispatched to ``on_request`` in its own task; the returned dict (or raised ``MCPError``) is sent back as the response. Inbound notifications go to ``on_notify``. + + ``task_status.started()`` is called once the dispatcher is ready to + accept ``send_request``/``notify`` calls, so callers can use + ``await tg.start(dispatcher.run, on_request, on_notify)``. """ ... diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py new file mode 100644 index 0000000000..f1e7b3675e --- /dev/null +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -0,0 +1,543 @@ +"""JSON-RPC `Dispatcher` implementation. + +Consumes the existing `SessionMessage`-based stream contract that all current +transports (stdio, SSE, streamable HTTP) speak. Owns request-id correlation, +the receive loop, per-request task isolation, cancellation/progress wiring, and +the single exception-to-wire boundary. + +The MCP type layer (`ServerRunner`, `Context`, `Client`) sits above this and +sees only `(ctx, method, params) -> dict`. Transports sit below and see only +`SessionMessage` reads/writes. + +The dispatcher is *mostly* MCP-agnostic — methods/params are opaque strings and +dicts — but it intercepts ``notifications/cancelled`` and +``notifications/progress`` because request correlation, cancellation and +progress are exactly the wiring this layer exists to provide. Those few wire +shapes are extracted with structural ``match`` patterns (no casts, no +``mcp.types`` model coupling); a malformed payload simply fails to match and +the correlation is skipped. +""" + +from __future__ import annotations + +import contextvars +import logging +from collections.abc import Awaitable, Callable, Mapping +from dataclasses import dataclass, field +from typing import Any, Generic, Literal, TypeVar, cast, overload + +import anyio +import anyio.abc +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import ValidationError + +from mcp.shared._stream_protocols import ReadStream, WriteStream +from mcp.shared.dispatcher import CallOptions, Dispatcher, OnNotify, OnRequest, ProgressFnT +from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.message import ( + ClientMessageMetadata, + MessageMetadata, + ServerMessageMetadata, + SessionMessage, +) +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + CONNECTION_CLOSED, + INTERNAL_ERROR, + INVALID_PARAMS, + REQUEST_CANCELLED, + REQUEST_TIMEOUT, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ProgressToken, + RequestId, +) + +__all__ = ["JSONRPCDispatcher"] + +logger = logging.getLogger(__name__) + +TransportT = TypeVar("TransportT", bound=TransportContext) + +PeerCancelMode = Literal["interrupt", "signal"] +"""How inbound ``notifications/cancelled`` is applied to a running handler. + +``"interrupt"`` (default) cancels the handler's scope. ``"signal"`` only sets +``ctx.cancel_requested`` and lets the handler observe it cooperatively. +""" + +TransportBuilder = Callable[[RequestId | None, MessageMetadata], TransportContext] +"""Builds the per-message `TransportContext` from the inbound JSON-RPC id and +the `SessionMessage.metadata` the transport attached. Defaults to a plain +`TransportContext(kind="jsonrpc", can_send_request=True)` when not supplied.""" + + +@dataclass(slots=True) +class _Pending: + """An outbound request awaiting its response.""" + + send: MemoryObjectSendStream[dict[str, Any] | ErrorData] + receive: MemoryObjectReceiveStream[dict[str, Any] | ErrorData] + on_progress: ProgressFnT | None = None + + +@dataclass(slots=True) +class _InFlight(Generic[TransportT]): + """An inbound request currently being handled.""" + + scope: anyio.CancelScope + dctx: _JSONRPCDispatchContext[TransportT] + cancelled_by_peer: bool = False + + +@dataclass +class _JSONRPCDispatchContext(Generic[TransportT]): + """Concrete `DispatchContext` produced for each inbound JSON-RPC message.""" + + transport: TransportT + _dispatcher: JSONRPCDispatcher[TransportT] + _request_id: RequestId | None + _progress_token: ProgressToken | None = None + _closed: bool = False + cancel_requested: anyio.Event = field(default_factory=anyio.Event) + + @property + def can_send_request(self) -> bool: + return self.transport.can_send_request and not self._closed + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._dispatcher.notify(method, params, _related_request_id=self._request_id) + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if not self.can_send_request: + raise NoBackChannelError(method) + return await self._dispatcher.send_raw_request(method, params, opts, _related_request_id=self._request_id) + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + if self._progress_token is None: + return + params: dict[str, Any] = {"progressToken": self._progress_token, "progress": progress} + if total is not None: + params["total"] = total + if message is not None: + params["message"] = message + await self.notify("notifications/progress", params) + + def close(self) -> None: + self._closed = True + + +def _default_transport_builder(_request_id: RequestId | None, _meta: MessageMetadata) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=True) + + +def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | None) -> MessageMetadata: + """Choose the `SessionMessage.metadata` for an outgoing request/notification. + + `ServerMessageMetadata` tags a server-to-client message with the inbound + request it belongs to (so streamable-HTTP can route it onto that request's + SSE stream). `ClientMessageMetadata` carries resumption hints to the + client transport. ``None`` is the common case. + """ + if related_request_id is not None: + return ServerMessageMetadata(related_request_id=related_request_id) + if opts: + token = opts.get("resumption_token") + on_token = opts.get("on_resumption_token") + if token is not None or on_token is not None: + return ClientMessageMetadata(resumption_token=token, on_resumption_token_update=on_token) + return None + + +class JSONRPCDispatcher(Dispatcher[TransportT]): + """`Dispatcher` over the existing `SessionMessage` stream contract. + + Inherits the `Dispatcher` Protocol explicitly so pyright checks + conformance at the class definition rather than at first use. + """ + + @overload + def __init__( + self: JSONRPCDispatcher[TransportContext], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + ) -> None: ... + @overload + def __init__( + self, + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + *, + transport_builder: Callable[[RequestId | None, MessageMetadata], TransportT], + peer_cancel_mode: PeerCancelMode = "interrupt", + raise_handler_exceptions: bool = False, + ) -> None: ... + def __init__( + self, + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + *, + transport_builder: Callable[[RequestId | None, MessageMetadata], TransportT] | None = None, + peer_cancel_mode: PeerCancelMode = "interrupt", + raise_handler_exceptions: bool = False, + ) -> None: + self._read_stream = read_stream + self._write_stream = write_stream + # The overloads guarantee that when `transport_builder` is omitted, + # `TransportT` is `TransportContext`, so the default is type-correct; + # pyright can't see across overloads, hence the cast. + self._transport_builder = cast( + "Callable[[RequestId | None, MessageMetadata], TransportT]", + transport_builder or _default_transport_builder, + ) + self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode + self._raise_handler_exceptions = raise_handler_exceptions + + self._next_id = 0 + self._pending: dict[RequestId, _Pending] = {} + self._in_flight: dict[RequestId, _InFlight[TransportT]] = {} + self._tg: anyio.abc.TaskGroup | None = None + self._running = False + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + *, + _related_request_id: RequestId | None = None, + ) -> dict[str, Any]: + """Send a JSON-RPC request and await its response. + + ``_related_request_id`` is set only by `_JSONRPCDispatchContext` when a + handler makes a server-to-client request mid-flight; it routes the + outgoing message onto the correct per-request SSE stream (SHTTP) via + `ServerMessageMetadata`. Top-level callers leave it ``None``. + + Raises: + MCPError: The peer responded with a JSON-RPC error; or + ``REQUEST_TIMEOUT`` if ``opts["timeout"]`` elapsed; or + ``CONNECTION_CLOSED`` if the dispatcher shut down while + awaiting the response. + RuntimeError: Called before ``run()`` has started or after it has + finished. + """ + if not self._running: + raise RuntimeError("JSONRPCDispatcher.send_raw_request called before run() / after close") + opts = opts or {} + request_id = self._allocate_id() + out_params = dict(params) if params is not None else None + on_progress = opts.get("on_progress") + if on_progress is not None: + # The caller wants progress updates. The spec mechanism is: include + # `_meta.progressToken` on the request; the peer echoes that token on + # any `notifications/progress` it sends. We use the request id as the + # token so the receive loop can find this `_Pending.on_progress` by + # `_pending[token]` without a second lookup table. + meta = dict((out_params or {}).get("_meta") or {}) + meta["progressToken"] = request_id + out_params = {**(out_params or {}), "_meta": meta} + + # buffer=1: at most one outcome is ever delivered. A `WouldBlock` from + # `_resolve_pending`/`_fan_out_closed` means the waiter already has an + # outcome and dropping the late/redundant signal is correct. buffer=0 + # is unsafe — there's a window between registering `_pending[id]` and + # parking in `receive()` where a close signal would be lost. + send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + pending = _Pending(send=send, receive=receive, on_progress=on_progress) + self._pending[request_id] = pending + + metadata = _outbound_metadata(_related_request_id, opts) + msg = JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params=out_params) + try: + await self._write(msg, metadata) + with anyio.fail_after(opts.get("timeout")): + outcome = await receive.receive() + except TimeoutError: + # Spec-recommended courtesy: tell the peer we've given up so it can + # stop work and free resources. v1's BaseSession.send_request does + # NOT do this; it's new behaviour. + await self._cancel_outbound(request_id, f"timed out after {opts.get('timeout')}s") + raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None + except anyio.get_cancelled_exc_class(): + # Our caller's scope was cancelled. We're already inside a cancelled + # scope, so any bare `await` here re-raises immediately — shield to + # let the courtesy cancel notification go out before we propagate. + with anyio.CancelScope(shield=True): + await self._cancel_outbound(request_id, "caller cancelled") + raise + finally: + # Always remove the waiter, even on cancel/timeout, so a late + # response from the peer (race) hits a closed stream and is dropped + # in `_dispatch` rather than leaking. + self._pending.pop(request_id, None) + send.close() + receive.close() + + if isinstance(outcome, ErrorData): + raise MCPError(code=outcome.code, message=outcome.message, data=outcome.data) + return outcome + + async def notify( + self, + method: str, + params: Mapping[str, Any] | None, + *, + _related_request_id: RequestId | None = None, + ) -> None: + msg = JSONRPCNotification(jsonrpc="2.0", method=method, params=dict(params) if params is not None else None) + await self._write(msg, _outbound_metadata(_related_request_id, None)) + + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + """Drive the receive loop until the read stream closes. + + Each inbound request is handled in its own task in an internal task + group; ``task_status.started()`` fires once that group is open, so + ``await tg.start(dispatcher.run, ...)`` resumes when ``send_raw_request`` + is usable. + """ + try: + async with anyio.create_task_group() as tg: + self._tg = tg + self._running = True + task_status.started() + async with self._read_stream: + async for item in self._read_stream: + # Duck-typed: `_context_streams.ContextReceiveStream` + # exposes `.last_context` (the sender's contextvars + # snapshot per message). Plain memory streams don't. + sender_ctx: contextvars.Context | None = getattr(self._read_stream, "last_context", None) + self._dispatch(item, on_request, on_notify, sender_ctx) + # Read stream EOF: wake any blocked `send_raw_request` waiters now, + # *before* the task group joins, so handlers parked in + # `dctx.send_raw_request()` can unwind and the join doesn't deadlock. + self._running = False + self._fan_out_closed() + finally: + # Covers the cancel/crash paths where the inline fan-out above is + # never reached. Idempotent. + self._running = False + self._tg = None + self._fan_out_closed() + + def _dispatch( + self, + item: SessionMessage | Exception, + on_request: OnRequest, + on_notify: OnNotify, + sender_ctx: contextvars.Context | None, + ) -> None: + """Route one inbound item. Synchronous: never awaits. + + Everything here is `send_nowait` or `_spawn`. An `await` would let one + slow message head-of-line block the entire read loop. + """ + if isinstance(item, Exception): + logger.debug("transport yielded exception: %r", item) + return + metadata = item.metadata + msg = item.message + match msg: + case JSONRPCRequest(): + self._dispatch_request(msg, metadata, on_request, sender_ctx) + case JSONRPCNotification(): + self._dispatch_notification(msg, metadata, on_notify, sender_ctx) + case JSONRPCResponse(): + self._resolve_pending(msg.id, msg.result) + case JSONRPCError(): # pragma: no branch + # `id` may be None per JSON-RPC (parse error before id known). + # The match is exhaustive over JSONRPCMessage; the no-match arc + # on this final case is unreachable. + self._resolve_pending(msg.id, msg.error) + + def _dispatch_request( + self, + req: JSONRPCRequest, + metadata: MessageMetadata, + on_request: OnRequest, + sender_ctx: contextvars.Context | None, + ) -> None: + progress_token: ProgressToken | None + match req.params: + case {"_meta": {"progressToken": str() | int() as progress_token}}: + pass + case _: + progress_token = None + transport_ctx = self._transport_builder(req.id, metadata) + dctx = _JSONRPCDispatchContext( + transport=transport_ctx, + _dispatcher=self, + _request_id=req.id, + _progress_token=progress_token, + ) + scope = anyio.CancelScope() + self._in_flight[req.id] = _InFlight(scope=scope, dctx=dctx) + self._spawn(self._handle_request, req, dctx, scope, on_request, sender_ctx=sender_ctx) + + def _dispatch_notification( + self, + msg: JSONRPCNotification, + metadata: MessageMetadata, + on_notify: OnNotify, + sender_ctx: contextvars.Context | None, + ) -> None: + if msg.method == "notifications/cancelled": + match msg.params: + case {"requestId": str() | int() as rid} if (in_flight := self._in_flight.get(rid)) is not None: + in_flight.cancelled_by_peer = True + in_flight.dctx.cancel_requested.set() + if self._peer_cancel_mode == "interrupt": + in_flight.scope.cancel() + case _: + pass + return + if msg.method == "notifications/progress": + match msg.params: + case {"progressToken": str() | int() as token, "progress": int() | float() as progress} if ( + pending := self._pending.get(token) + ) is not None and pending.on_progress is not None: + total = msg.params.get("total") + message = msg.params.get("message") + self._spawn( + pending.on_progress, + float(progress), + float(total) if isinstance(total, int | float) else None, + message if isinstance(message, str) else None, + sender_ctx=sender_ctx, + ) + case _: + pass + # fall through: progress is also teed to on_notify + transport_ctx = self._transport_builder(None, metadata) + dctx = _JSONRPCDispatchContext(transport=transport_ctx, _dispatcher=self, _request_id=None) + self._spawn(on_notify, dctx, msg.method, msg.params, sender_ctx=sender_ctx) + + def _resolve_pending(self, request_id: RequestId | None, outcome: dict[str, Any] | ErrorData) -> None: + pending = self._pending.get(request_id) if request_id is not None else None + if pending is None: + logger.debug("dropping response for unknown/late request id %r", request_id) + return + try: + pending.send.send_nowait(outcome) + except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("waiter for request id %r already gone", request_id) + + def _spawn( + self, + fn: Callable[..., Awaitable[Any]], + *args: object, + sender_ctx: contextvars.Context | None, + ) -> None: + """Schedule ``fn(*args)`` in the run() task group, propagating the sender's contextvars. + + ASGI middleware (auth, OTel) sets contextvars on the request task that + wrote into the read stream. ``Context.run(tg.start_soon, ...)`` makes + the spawned handler inherit *that* context instead of the receive + loop's, so ``auth_context_var`` and OTel spans survive. + """ + assert self._tg is not None + if sender_ctx is not None: + sender_ctx.run(self._tg.start_soon, fn, *args) + else: + self._tg.start_soon(fn, *args) + + def _fan_out_closed(self) -> None: + """Wake every pending ``send_raw_request`` waiter with ``CONNECTION_CLOSED``. + + Synchronous (uses ``send_nowait``) because it's called from ``finally`` + which may be inside a cancelled scope. Idempotent. + """ + closed = ErrorData(code=CONNECTION_CLOSED, message="connection closed") + for pending in self._pending.values(): + try: + pending.send.send_nowait(closed) + except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError): + pass + self._pending.clear() + + async def _handle_request( + self, + req: JSONRPCRequest, + dctx: _JSONRPCDispatchContext[TransportT], + scope: anyio.CancelScope, + on_request: OnRequest, + ) -> None: + """Run ``on_request`` for one inbound request and write its response. + + This is the single exception-to-wire boundary: handler exceptions are + caught here and serialized to ``JSONRPCError``. Nothing above this in + the stack constructs wire errors. + """ + try: + with scope: + try: + result = await on_request(dctx, req.method, req.params) + finally: + # Close the back-channel the moment the handler exits + # (success or raise), before the response write — a handler + # spawning detached work that later calls + # `dctx.send_raw_request()` should see `NoBackChannelError`. + dctx.close() + await self._write_result(req.id, result) + # Peer-cancel: `_dispatch_notification` cancelled this scope. anyio + # swallows a scope's *own* cancel at __exit__, so the result write + # (or the handler) is interrupted and execution lands here without + # reaching the `except cancelled` arm below. Spec SHOULD: send no + # response — fall through to `finally`. + except anyio.get_cancelled_exc_class(): + # Outer-cancel: run()'s task group is shutting down. Any bare + # `await` here re-raises immediately, so shield the courtesy write. + with anyio.CancelScope(shield=True): + await self._write_error(req.id, ErrorData(code=REQUEST_CANCELLED, message="Request cancelled")) + raise + except MCPError as e: + await self._write_error(req.id, e.error) + except ValidationError as e: + await self._write_error(req.id, ErrorData(code=INVALID_PARAMS, message=str(e))) + except Exception as e: + logger.exception("handler for %r raised", req.method) + await self._write_error(req.id, ErrorData(code=INTERNAL_ERROR, message=str(e))) + if self._raise_handler_exceptions: + raise + finally: + self._in_flight.pop(req.id, None) + + def _allocate_id(self) -> int: + self._next_id += 1 + return self._next_id + + async def _write(self, message: JSONRPCMessage, metadata: MessageMetadata = None) -> None: + await self._write_stream.send(SessionMessage(message=message, metadata=metadata)) + + async def _write_result(self, request_id: RequestId, result: dict[str, Any]) -> None: + try: + await self._write(JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result)) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped result for %r: write stream closed", request_id) + + async def _write_error(self, request_id: RequestId, error: ErrorData) -> None: + try: + await self._write(JSONRPCError(jsonrpc="2.0", id=request_id, error=error)) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped error for %r: write stream closed", request_id) + + async def _cancel_outbound(self, request_id: RequestId, reason: str) -> None: + try: + await self.notify("notifications/cancelled", {"requestId": request_id, "reason": reason}) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + pass diff --git a/tests/shared/conftest.py b/tests/shared/conftest.py new file mode 100644 index 0000000000..1222c05aba --- /dev/null +++ b/tests/shared/conftest.py @@ -0,0 +1,61 @@ +"""Shared fixtures for `Dispatcher` contract tests. + +The `pair_factory` fixture parametrizes contract tests over every `Dispatcher` +implementation, so the same behavioral assertions run against `DirectDispatcher` +(in-memory) and `JSONRPCDispatcher` (over crossed anyio memory streams). +""" + +from collections.abc import Callable + +import anyio +import pytest + +from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair +from mcp.shared.dispatcher import Dispatcher +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import SessionMessage +from mcp.shared.transport_context import TransportContext + +DispatcherTriple = tuple[Dispatcher[TransportContext], Dispatcher[TransportContext], Callable[[], None]] +PairFactory = Callable[..., DispatcherTriple] + + +def direct_pair(*, can_send_request: bool = True) -> DispatcherTriple: + client, server = create_direct_dispatcher_pair(can_send_request=can_send_request) + + def close() -> None: + client.close() + server.close() + + return client, server, close + + +def jsonrpc_pair(*, can_send_request: bool = True) -> DispatcherTriple: + """Two `JSONRPCDispatcher`s wired over crossed in-memory streams.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + + def builder(_rid: object, _meta: object) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=can_send_request) + + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send, transport_builder=builder) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send, transport_builder=builder) + + def close() -> None: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + return client, server, close + + +@pytest.fixture( + params=[ + pytest.param(direct_pair, id="direct"), + pytest.param(jsonrpc_pair, id="jsonrpc"), + ] +) +def pair_factory(request: pytest.FixtureRequest) -> PairFactory: + return request.param + + +__all__ = ["PairFactory", "direct_pair", "jsonrpc_pair"] diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index 784ef6698f..bdadd4cdae 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -1,8 +1,9 @@ -"""Behavioral tests for the Dispatcher Protocol via DirectDispatcher. +"""Behavioral tests for the Dispatcher Protocol. -These exercise the `Dispatcher` / `DispatchContext` contract end-to-end using -the in-memory `DirectDispatcher`. JSON-RPC framing is covered separately in -``test_jsonrpc_dispatcher.py``. +The contract tests are parametrized over every `Dispatcher` implementation via +the `pair_factory` fixture (see ``conftest.py``); they must pass for both +`DirectDispatcher` and `JSONRPCDispatcher`. Implementation-specific tests pass +a concrete factory directly. """ from collections.abc import AsyncIterator, Mapping @@ -14,10 +15,12 @@ from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnNotify, OnRequest, Outbound -from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.exceptions import MCPError from mcp.shared.transport_context import TransportContext from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT +from .conftest import PairFactory, direct_pair + class Recorder: def __init__(self) -> None: @@ -44,31 +47,34 @@ async def on_notify(ctx: DispatchContext[TransportContext], method: str, params: @asynccontextmanager async def running_pair( + factory: PairFactory, *, server_on_request: OnRequest | None = None, server_on_notify: OnNotify | None = None, client_on_request: OnRequest | None = None, client_on_notify: OnNotify | None = None, can_send_request: bool = True, -) -> AsyncIterator[tuple[DirectDispatcher, DirectDispatcher, Recorder, Recorder]]: +) -> AsyncIterator[tuple[Dispatcher[TransportContext], Dispatcher[TransportContext], Recorder, Recorder]]: """Yield ``(client, server, client_recorder, server_recorder)`` with both ``run()`` loops live.""" - client, server = create_direct_dispatcher_pair(can_send_request=can_send_request) + client, server, close = factory(can_send_request=can_send_request) client_rec, server_rec = Recorder(), Recorder() c_req, c_notify = echo_handlers(client_rec) s_req, s_notify = echo_handlers(server_rec) - async with anyio.create_task_group() as tg: - tg.start_soon(client.run, client_on_request or c_req, client_on_notify or c_notify) - tg.start_soon(server.run, server_on_request or s_req, server_on_notify or s_notify) - try: - yield client, server, client_rec, server_rec - finally: - client.close() - server.close() + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, client_on_request or c_req, client_on_notify or c_notify) + await tg.start(server.run, server_on_request or s_req, server_on_notify or s_notify) + try: + yield client, server, client_rec, server_rec + finally: + tg.cancel_scope.cancel() + finally: + close() @pytest.mark.anyio -async def test_send_raw_request_returns_result_from_peer_on_request(): - async with running_pair() as (client, _server, _crec, srec): +async def test_send_raw_request_returns_result_from_peer_on_request(pair_factory: PairFactory): + async with running_pair(pair_factory) as (client, _server, _crec, srec): with anyio.fail_after(5): result = await client.send_raw_request("tools/list", {"cursor": "abc"}) assert result == {"echoed": "tools/list", "params": {"cursor": "abc"}} @@ -76,13 +82,13 @@ async def test_send_raw_request_returns_result_from_peer_on_request(): @pytest.mark.anyio -async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(): +async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(pair_factory: PairFactory): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: raise MCPError(code=INVALID_PARAMS, message="bad cursor") - async with running_pair(server_on_request=on_request) as (client, *_): + async with running_pair(pair_factory, server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("tools/list", {}) assert exc.value.error.code == INVALID_PARAMS @@ -90,36 +96,22 @@ async def on_request( @pytest.mark.anyio -async def test_send_raw_request_wraps_non_mcperror_exception_as_internal_error(): - async def on_request( - ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None - ) -> dict[str, Any]: - raise ValueError("oops") - - async with running_pair(server_on_request=on_request) as (client, *_): - with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.send_raw_request("tools/list", {}) - assert exc.value.error.code == INTERNAL_ERROR - assert isinstance(exc.value.__cause__, ValueError) - - -@pytest.mark.anyio -async def test_send_raw_request_with_timeout_raises_mcperror_request_timeout(): +async def test_send_raw_request_with_timeout_raises_mcperror_request_timeout(pair_factory: PairFactory): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await anyio.sleep_forever() raise NotImplementedError - async with running_pair(server_on_request=on_request) as (client, *_): + async with running_pair(pair_factory, server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("slow", None, {"timeout": 0}) assert exc.value.error.code == REQUEST_TIMEOUT @pytest.mark.anyio -async def test_notify_invokes_peer_on_notify(): - async with running_pair() as (client, _server, _crec, srec): +async def test_notify_invokes_peer_on_notify(pair_factory: PairFactory): + async with running_pair(pair_factory) as (client, _server, _crec, srec): with anyio.fail_after(5): await client.notify("notifications/initialized", {"v": 1}) await srec.notified.wait() @@ -127,7 +119,7 @@ async def test_notify_invokes_peer_on_notify(): @pytest.mark.anyio -async def test_ctx_send_raw_request_round_trips_to_calling_side(): +async def test_ctx_send_raw_request_round_trips_to_calling_side(pair_factory: PairFactory): """A handler's ctx.send_raw_request reaches the side that made the inbound request.""" async def server_on_request( @@ -136,7 +128,7 @@ async def server_on_request( sample = await ctx.send_raw_request("sampling/createMessage", {"prompt": "hi"}) return {"sampled": sample} - async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): result = await client.send_raw_request("tools/call", None) assert crec.requests == [("sampling/createMessage", {"prompt": "hi"})] @@ -144,28 +136,27 @@ async def server_on_request( @pytest.mark.anyio -async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(): +async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(pair_factory: PairFactory): async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: return await ctx.send_raw_request("sampling/createMessage", None) - async with running_pair(server_on_request=server_on_request, can_send_request=False) as (client, *_): - with anyio.fail_after(5), pytest.raises(NoBackChannelError) as exc: + async with running_pair(pair_factory, server_on_request=server_on_request, can_send_request=False) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("tools/call", None) - assert exc.value.method == "sampling/createMessage" assert exc.value.error.code == INVALID_REQUEST @pytest.mark.anyio -async def test_ctx_notify_invokes_calling_side_on_notify(): +async def test_ctx_notify_invokes_calling_side_on_notify(pair_factory: PairFactory): async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await ctx.notify("notifications/message", {"level": "info"}) return {} - async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): await client.send_raw_request("tools/call", None) await crec.notified.wait() @@ -173,7 +164,7 @@ async def server_on_request( @pytest.mark.anyio -async def test_ctx_progress_invokes_caller_on_progress_callback(): +async def test_ctx_progress_invokes_caller_on_progress_callback(pair_factory: PairFactory): async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -185,14 +176,44 @@ async def server_on_request( async def on_progress(progress: float, total: float | None, message: str | None) -> None: received.append((progress, total, message)) - async with running_pair(server_on_request=server_on_request) as (client, *_): + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): await client.send_raw_request("tools/call", None, {"on_progress": on_progress}) assert received == [(0.5, 1.0, "halfway")] @pytest.mark.anyio -async def test_send_raw_request_issued_before_peer_run_blocks_until_peer_ready(): +async def test_ctx_progress_is_noop_when_caller_supplied_no_callback(pair_factory: PairFactory): + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.progress(0.5) + return {"ok": True} + + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + result = await client.send_raw_request("tools/call", None) + assert result == {"ok": True} + + +@pytest.mark.anyio +async def test_direct_send_raw_request_wraps_non_mcperror_exception_as_internal_error_with_cause(): + """DirectDispatcher-specific: the original exception is chained via __cause__.""" + + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + raise ValueError("oops") + + async with running_pair(direct_pair, server_on_request=on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", {}) + assert exc.value.error.code == INTERNAL_ERROR + assert isinstance(exc.value.__cause__, ValueError) + + +@pytest.mark.anyio +async def test_direct_send_raw_request_issued_before_peer_run_blocks_until_peer_ready(): client, server = create_direct_dispatcher_pair() s_req, s_notify = echo_handlers(Recorder()) c_req, c_notify = echo_handlers(Recorder()) @@ -212,21 +233,7 @@ async def late_start(): @pytest.mark.anyio -async def test_ctx_progress_is_noop_when_caller_supplied_no_callback(): - async def server_on_request( - ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None - ) -> dict[str, Any]: - await ctx.progress(0.5) - return {"ok": True} - - async with running_pair(server_on_request=server_on_request) as (client, *_): - with anyio.fail_after(5): - result = await client.send_raw_request("tools/call", None) - assert result == {"ok": True} - - -@pytest.mark.anyio -async def test_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_connected(): +async def test_direct_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_connected(): d = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) with pytest.raises(RuntimeError, match="no peer"): await d.send_raw_request("ping", None) @@ -235,7 +242,7 @@ async def test_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_conne @pytest.mark.anyio -async def test_close_makes_run_return(): +async def test_direct_close_makes_run_return(): client, server = create_direct_dispatcher_pair() on_request, on_notify = echo_handlers(Recorder()) with anyio.fail_after(5): diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py new file mode 100644 index 0000000000..7f9f11718b --- /dev/null +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -0,0 +1,531 @@ +"""JSON-RPC-specific Dispatcher tests. + +Behaviors with no `DirectDispatcher` analog: request-id correlation, the +exception-to-wire boundary, peer-cancel handling, and shutdown fan-out. +The contract tests shared with `DirectDispatcher` live in +``test_dispatcher.py``. +""" + +import contextvars +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.exceptions import MCPError +from mcp.shared.jsonrpc_dispatcher import ( # pyright: ignore[reportPrivateUsage] + JSONRPCDispatcher, + _outbound_metadata, + _Pending, +) +from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + CONNECTION_CLOSED, + INTERNAL_ERROR, + INVALID_PARAMS, + ErrorData, + JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, + Tool, +) + +from .conftest import jsonrpc_pair +from .test_dispatcher import Recorder, echo_handlers, running_pair + +DCtx = DispatchContext[TransportContext] + + +@pytest.mark.anyio +async def test_concurrent_send_raw_requests_correlate_by_id_when_responses_arrive_out_of_order(): + release_first = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + if method == "first": + await release_first.wait() + return {"m": method} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + results: dict[str, dict[str, Any]] = {} + + async def call(method: str) -> None: + results[method] = await client.send_raw_request(method, None) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + tg.start_soon(call, "first") + await anyio.sleep(0) + tg.start_soon(call, "second") + await anyio.sleep(0) + # second resolves while first is still parked + assert "first" not in results + release_first.set() + assert results == {"first": {"m": "first"}, "second": {"m": "second"}} + + +@pytest.mark.anyio +async def test_handler_raising_exception_sends_internal_error_with_str_message(): + """Per design: INTERNAL_ERROR carries str(e), not a scrubbed message.""" + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + raise RuntimeError("kaboom") + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INTERNAL_ERROR + assert exc.value.error.message == "kaboom" + assert exc.value.__cause__ is None # cause does not survive the wire + + +@pytest.mark.anyio +async def test_peer_cancel_interrupt_mode_sets_cancel_requested_and_sends_no_response(): + handler_started = anyio.Event() + handler_exited = anyio.Event() + seen_ctx: list[DCtx] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + seen_ctx.append(ctx) + handler_started.set() + try: + await anyio.sleep_forever() + finally: + handler_exited.set() + raise NotImplementedError + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call_then_record() -> None: + with pytest.raises(MCPError): # we'll cancel via tg below + await client.send_raw_request("slow", None) + + tg.start_soon(call_then_record) + await handler_started.wait() + # cancel just the handler (peer-cancel), not our caller + await client.notify("notifications/cancelled", {"requestId": 1}) + await handler_exited.wait() + # Handler torn down, no response was written; caller is still parked. + # Cancel the caller's task to end the test. + tg.cancel_scope.cancel() + assert seen_ctx[0].cancel_requested.is_set() + + +@pytest.mark.anyio +async def test_peer_cancel_signal_mode_sets_event_but_handler_runs_to_completion(): + handler_started = anyio.Event() + cancel_seen = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await ctx.cancel_requested.wait() + cancel_seen.set() + return {"finished": True} + + def factory(*, can_send_request: bool = True): + client, server, close = jsonrpc_pair(can_send_request=can_send_request) + # Reach in to set signal mode on the server side. + assert isinstance(server, JSONRPCDispatcher) + server._peer_cancel_mode = "signal" # pyright: ignore[reportPrivateUsage] + return client, server, close + + result_box: list[dict[str, Any]] = [] + async with running_pair(factory, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call() -> None: + result_box.append(await client.send_raw_request("slow", None)) + + tg.start_soon(call) + await handler_started.wait() + await client.notify("notifications/cancelled", {"requestId": 1}) + await cancel_seen.wait() + assert result_box == [{"finished": True}] + + +@pytest.mark.anyio +async def test_send_raw_request_raises_connection_closed_when_read_stream_eofs_mid_await(): + """A blocked send_raw_request is woken with CONNECTION_CLOSED when run() exits.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + + async def caller() -> None: + with pytest.raises(MCPError) as exc: + await client.send_raw_request("ping", None) + assert exc.value.error.code == CONNECTION_CLOSED + + tg.start_soon(caller) + await anyio.sleep(0) + # No server: simulate the peer dropping by closing the read side. + s2c_send.close() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_late_response_after_timeout_is_dropped_without_crashing(): + handler_started = anyio.Event() + proceed = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await proceed.wait() + return {"late": True} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + with pytest.raises(MCPError): # REQUEST_TIMEOUT + await client.send_raw_request("slow", None, {"timeout": 0}) + # The server handler is still running; let it finish and write a + # response for an id the client has already discarded. + await handler_started.wait() + proceed.set() + # One more round-trip proves the dispatcher is still healthy. + assert await client.send_raw_request("ping", None) == {"late": True} + + +@pytest.mark.anyio +async def test_raise_handler_exceptions_true_propagates_out_of_run(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + + def builder(_rid: object, _meta: object) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=True) + + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( + c2s_recv, s2c_send, transport_builder=builder, raise_handler_exceptions=True + ) + + async def boom(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + raise RuntimeError("propagate me") + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + with pytest.raises(BaseException) as exc: + async with anyio.create_task_group() as tg: + await tg.start(server.run, boom, on_notify) + # Inject a request directly onto the server's read stream. + await c2s_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="x", params=None)) + ) + assert exc.group_contains(RuntimeError, match="propagate me") + # The error response was still written before re-raising. + sent = s2c_recv.receive_nowait() + assert isinstance(sent, SessionMessage) + assert isinstance(sent.message, JSONRPCError) + assert sent.message.error.code == INTERNAL_ERROR + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_ctx_send_raw_request_tags_outbound_with_server_message_metadata(): + """Server-to-client requests carry related_request_id for SHTTP routing.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + return await ctx.send_raw_request("sampling/createMessage", {"prompt": "hi"}) + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + # Kick the server with an inbound request id=7. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="t", params=None))) + with anyio.fail_after(5): + outbound = await s2c_recv.receive() + assert isinstance(outbound, SessionMessage) + assert isinstance(outbound.message, JSONRPCRequest) + assert isinstance(outbound.metadata, ServerMessageMetadata) + assert outbound.metadata.related_request_id == 7 + # Reply so the handler completes cleanly. + await c2s_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=outbound.message.id, result={"ok": True})) + ) + with anyio.fail_after(5): + final = await s2c_recv.receive() + assert isinstance(final, SessionMessage) + assert isinstance(final.message, JSONRPCResponse) + assert final.message.id == 7 + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_ctx_progress_with_only_progress_value_omits_total_and_message(): + received: list[tuple[float, float | None, str | None]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + await ctx.progress(0.25) + return {} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None, {"on_progress": on_progress}) + assert received == [(0.25, None, None)] + + +@pytest.mark.anyio +async def test_handler_raising_validation_error_sends_invalid_params(): + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + Tool.model_validate({"name": 123}) # raises ValidationError + raise NotImplementedError + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("t", None) + assert exc.value.error.code == INVALID_PARAMS + + +@pytest.mark.anyio +async def test_send_raw_request_before_run_raises_runtimeerror(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + try: + with pytest.raises(RuntimeError, match="before run"): + await d.send_raw_request("ping", None) + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_transport_exception_in_read_stream_is_logged_and_dropped(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + await c2s_send.send(ValueError("transport hiccup")) + # Dispatcher must remain healthy after the dropped exception. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None))) + with anyio.fail_after(5): + resp = await s2c_recv.receive() + assert isinstance(resp, SessionMessage) + assert isinstance(resp.message, JSONRPCResponse) + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_progress_notification_for_unknown_token_falls_through_to_on_notify(): + async with running_pair(jsonrpc_pair) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/progress", {"progressToken": 999, "progress": 0.5}) + await srec.notified.wait() + assert srec.notifications == [("notifications/progress", {"progressToken": 999, "progress": 0.5})] + + +@pytest.mark.anyio +async def test_cancelled_notification_for_unknown_request_id_is_noop(): + async with running_pair(jsonrpc_pair) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/cancelled", {"requestId": 999}) + # No effect; dispatcher remains healthy. + assert await client.send_raw_request("t", None) == {"echoed": "t", "params": {}} + assert srec.notifications == [] # cancelled is fully consumed, never teed + + +_probe: contextvars.ContextVar[str] = contextvars.ContextVar("probe", default="unset") + + +@pytest.mark.anyio +async def test_handler_inherits_sender_contextvars_via_spawn(): + """The handler task sees contextvars set by the task that wrote into the read stream.""" + raw_send, raw_recv = anyio.create_memory_object_stream[tuple[contextvars.Context, SessionMessage | Exception]](4) + read_stream = ContextReceiveStream[SessionMessage | Exception](raw_recv) + write_send = ContextSendStream[SessionMessage | Exception](raw_send) + out_send, out_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(read_stream, out_send) + + seen: list[str] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + seen.append(_probe.get()) + return {} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + + async def sender() -> None: + _probe.set("from-sender") + await write_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None)) + ) + + tg.start_soon(sender) + with anyio.fail_after(5): + resp = await out_recv.receive() + assert isinstance(resp, SessionMessage) + tg.cancel_scope.cancel() + finally: + for s in (raw_send, raw_recv, out_send, out_recv): + s.close() + assert seen == ["from-sender"] + + +@pytest.mark.anyio +async def test_response_write_after_peer_drop_is_swallowed(): + """Handler completes after the write stream is closed; the dropped write doesn't crash run().""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + proceed = anyio.Event() + handlers_done = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + await proceed.wait() + if method == "raise": + handlers_done.set() + raise MCPError(code=INTERNAL_ERROR, message="x") + return {"ok": True} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="ok", params=None))) + await c2s_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=2, method="raise", params=None)) + ) + await anyio.sleep(0) + # Peer drops: close the receive end so the server's writes hit BrokenResourceError. + s2c_recv.close() + proceed.set() + with anyio.fail_after(5): + await handlers_done.wait() + # run() must still be healthy — close the read side to let it exit cleanly. + c2s_send.close() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_cancel_outbound_after_write_stream_closed_is_swallowed(): + """Courtesy-cancel write hits a closed stream; the error is swallowed and cancellation propagates.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + caller_done = anyio.Event() + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + caller_scope = anyio.CancelScope() + + async def caller() -> None: + with caller_scope: + await client.send_raw_request("slow", None) + caller_done.set() + + tg.start_soon(caller) + # Deterministic proof the request write completed: pull it off the wire. + with anyio.fail_after(5): + sent = await c2s_recv.receive() + assert isinstance(sent, SessionMessage) + assert isinstance(sent.message, JSONRPCRequest) + # Now safe: close the client's write end, then cancel the caller. The + # shielded `_cancel_outbound` write hits ClosedResourceError and is + # swallowed; cancellation propagates cleanly. + c2s_send.close() + caller_scope.cancel() + with anyio.fail_after(5): + await caller_done.wait() + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +def test_resolve_pending_drops_outcome_when_waiter_stream_already_closed(): + """White-box: a response for an id still in _pending but whose waiter has gone.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + send, recv = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + d._pending[1] = _Pending(send=send, receive=recv) # pyright: ignore[reportPrivateUsage] + recv.close() # waiter gone — send_nowait will raise BrokenResourceError + d._resolve_pending(1, {"late": True}) # pyright: ignore[reportPrivateUsage] + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv, send): + s.close() + + +def test_fan_out_closed_drops_signal_when_waiter_already_has_outcome(): + """White-box: the buffer=1 invariant — WouldBlock means waiter already has an outcome.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + send, recv = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + # Register a fake pending and pre-fill its single buffer slot. + d._pending[1] = _Pending(send=send, receive=recv) # pyright: ignore[reportPrivateUsage] + send.send_nowait({"real": "result"}) + d._fan_out_closed() # pyright: ignore[reportPrivateUsage] + # The real result is still there; the close signal was dropped. + assert recv.receive_nowait() == {"real": "result"} + assert d._pending == {} # pyright: ignore[reportPrivateUsage] + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv, send, recv): + s.close() + + +def test_outbound_metadata_with_resumption_token_returns_client_metadata(): + md = _outbound_metadata(None, {"resumption_token": "abc"}) + assert isinstance(md, ClientMessageMetadata) + assert md.resumption_token == "abc" + assert _outbound_metadata(None, None) is None + assert _outbound_metadata(None, {}) is None + + +@pytest.mark.anyio +async def test_jsonrpc_error_response_with_null_id_is_dropped(): + """Parse-error responses (id=null) have no waiter; they're logged and dropped.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + await s2c_send.send( + SessionMessage(message=JSONRPCError(jsonrpc="2.0", id=None, error=ErrorData(code=-32700, message="x"))) + ) + await anyio.sleep(0) + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() From 8edcf43e16262090d4606caac979484304f09adb Mon Sep 17 00:00:00 2001 From: Pawan Bhardwaj Date: Sat, 23 May 2026 11:01:20 +0000 Subject: [PATCH 3/4] PR 2460 --- src/mcp/server/_typed_request.py | 86 +++++++++++ src/mcp/server/connection.py | 146 +++++++++++++++++++ src/mcp/server/context.py | 60 ++++++++ src/mcp/shared/context.py | 82 +++++++++++ src/mcp/shared/peer.py | 216 ++++++++++++++++++++++++++++ tests/server/test_connection.py | 205 ++++++++++++++++++++++++++ tests/server/test_server_context.py | 156 ++++++++++++++++++++ tests/shared/test_context.py | 115 +++++++++++++++ tests/shared/test_peer.py | 164 +++++++++++++++++++++ 9 files changed, 1230 insertions(+) create mode 100644 src/mcp/server/_typed_request.py create mode 100644 src/mcp/server/connection.py create mode 100644 src/mcp/shared/context.py create mode 100644 src/mcp/shared/peer.py create mode 100644 tests/server/test_connection.py create mode 100644 tests/server/test_server_context.py create mode 100644 tests/shared/test_context.py create mode 100644 tests/shared/test_peer.py diff --git a/src/mcp/server/_typed_request.py b/src/mcp/server/_typed_request.py new file mode 100644 index 0000000000..4334b20a94 --- /dev/null +++ b/src/mcp/server/_typed_request.py @@ -0,0 +1,86 @@ +"""Typed ``send_request`` for server-to-client requests. + +`TypedServerRequestMixin` provides a typed `send_request(req) -> Result` over +the host's raw `Outbound.send_raw_request`. Spec server-to-client request types +have their result type inferred via per-type overloads; custom requests pass +``result_type=`` explicitly. + +If the spec's request set grows substantially, consider declaring the result +mapping on the request types themselves (a ``__mcp_result__`` ClassVar read via +a structural protocol) so this overload ladder doesn't need maintaining +per-host-class. +""" + +from typing import Any, TypeVar, overload + +from pydantic import BaseModel + +from mcp.shared.dispatcher import CallOptions, Outbound +from mcp.shared.peer import dump_params +from mcp.types import ( + CreateMessageRequest, + CreateMessageResult, + ElicitRequest, + ElicitResult, + EmptyResult, + ListRootsRequest, + ListRootsResult, + PingRequest, + Request, +) + +__all__ = ["TypedServerRequestMixin"] + +ResultT = TypeVar("ResultT", bound=BaseModel) + +_RESULT_FOR: dict[type[Request[Any, Any]], type[BaseModel]] = { + CreateMessageRequest: CreateMessageResult, + ElicitRequest: ElicitResult, + ListRootsRequest: ListRootsResult, + PingRequest: EmptyResult, +} + + +class TypedServerRequestMixin: + """Typed ``send_request`` for the server-to-client request set. + + Mixed into `Connection` and the server `Context`. Each method constrains + ``self`` to `Outbound` so any host with ``send_raw_request`` works. + """ + + @overload + async def send_request( + self: Outbound, req: CreateMessageRequest, *, opts: CallOptions | None = None + ) -> CreateMessageResult: ... + @overload + async def send_request(self: Outbound, req: ElicitRequest, *, opts: CallOptions | None = None) -> ElicitResult: ... + @overload + async def send_request( + self: Outbound, req: ListRootsRequest, *, opts: CallOptions | None = None + ) -> ListRootsResult: ... + @overload + async def send_request(self: Outbound, req: PingRequest, *, opts: CallOptions | None = None) -> EmptyResult: ... + @overload + async def send_request( + self: Outbound, req: Request[Any, Any], *, result_type: type[ResultT], opts: CallOptions | None = None + ) -> ResultT: ... + async def send_request( + self: Outbound, + req: Request[Any, Any], + *, + result_type: type[BaseModel] | None = None, + opts: CallOptions | None = None, + ) -> BaseModel: + """Send a typed server-to-client request and return its typed result. + + For spec request types the result type is inferred. For custom requests + pass ``result_type=`` explicitly. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + KeyError: ``result_type`` omitted for a non-spec request type. + """ + raw = await self.send_raw_request(req.method, dump_params(req.params), opts) + cls = result_type if result_type is not None else _RESULT_FOR[type(req)] + return cls.model_validate(raw) diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py new file mode 100644 index 0000000000..df3652ce0e --- /dev/null +++ b/src/mcp/server/connection.py @@ -0,0 +1,146 @@ +"""`Connection` — per-client connection state and the standalone outbound channel. + +Always present on `Context` (never ``None``), even in stateless deployments. +Holds peer info populated at ``initialize`` time, the per-connection lifespan +output, and an `Outbound` for the standalone stream (the SSE GET stream in +streamable HTTP, or the single duplex stream in stdio). + +`notify` is best-effort: it never raises. If there's no standalone channel +(stateless HTTP) or the stream has been dropped, the notification is +debug-logged and silently discarded — server-initiated notifications are +inherently advisory. `send_raw_request` *does* raise `NoBackChannelError` when +there's no channel; `ping` is the only spec-sanctioned standalone request. +""" + +import logging +from collections.abc import Mapping +from typing import Any + +import anyio + +from mcp.server._typed_request import TypedServerRequestMixin +from mcp.shared.dispatcher import CallOptions, Outbound +from mcp.shared.exceptions import NoBackChannelError +from mcp.shared.peer import Meta, dump_params +from mcp.types import ClientCapabilities, Implementation, LoggingLevel + +__all__ = ["Connection"] + +logger = logging.getLogger(__name__) + + +def _notification_params(payload: dict[str, Any] | None, meta: Meta | None) -> dict[str, Any] | None: + if not meta: + return payload + out = dict(payload or {}) + out["_meta"] = meta + return out + + +class Connection(TypedServerRequestMixin): + """Per-client connection state and standalone-stream `Outbound`. + + Constructed by `ServerRunner` once per connection. The peer-info fields are + ``None`` until ``initialize`` completes; ``initialized`` is set then. + """ + + def __init__(self, outbound: Outbound, *, has_standalone_channel: bool) -> None: + self._outbound = outbound + self.has_standalone_channel = has_standalone_channel + + self.client_info: Implementation | None = None + self.client_capabilities: ClientCapabilities | None = None + self.protocol_version: str | None = None + self.initialized: anyio.Event = anyio.Event() + # TODO: make this generic (Connection[StateT]) once connection_lifespan + # wiring lands in ServerRunner. + self.state: Any = None + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a raw request on the standalone stream. + + Low-level `Outbound` channel. Prefer the typed ``send_request`` (from + `TypedServerRequestMixin`) or the convenience methods below; use this + directly only for off-spec messages. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: ``has_standalone_channel`` is ``False``. + """ + if not self.has_standalone_channel: + raise NoBackChannelError(method) + return await self._outbound.send_raw_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a best-effort notification on the standalone stream. + + Never raises. If there's no standalone channel or the stream is broken, + the notification is dropped and debug-logged. + """ + if not self.has_standalone_channel: + logger.debug("dropped %s: no standalone channel", method) + return + try: + await self._outbound.notify(method, params) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped %s: standalone stream closed", method) + + async def ping(self, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None: + """Send a ``ping`` request on the standalone stream. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: ``has_standalone_channel`` is ``False``. + """ + await self.send_raw_request("ping", dump_params(None, meta), opts) + + async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None: + """Send a ``notifications/message`` log entry on the standalone stream. Best-effort.""" + params: dict[str, Any] = {"level": level, "data": data} + if logger is not None: + params["logger"] = logger + await self.notify("notifications/message", _notification_params(params, meta)) + + async def send_tool_list_changed(self, *, meta: Meta | None = None) -> None: + await self.notify("notifications/tools/list_changed", _notification_params(None, meta)) + + async def send_prompt_list_changed(self, *, meta: Meta | None = None) -> None: + await self.notify("notifications/prompts/list_changed", _notification_params(None, meta)) + + async def send_resource_list_changed(self, *, meta: Meta | None = None) -> None: + await self.notify("notifications/resources/list_changed", _notification_params(None, meta)) + + async def send_resource_updated(self, uri: str, *, meta: Meta | None = None) -> None: + await self.notify("notifications/resources/updated", _notification_params({"uri": uri}, meta)) + + def check_capability(self, capability: ClientCapabilities) -> bool: + """Return whether the connected client declared the given capability. + + Returns ``False`` if ``initialize`` hasn't completed yet. + """ + # TODO: redesign — mirrors v1 ServerSession.check_client_capability + # verbatim for parity. + if self.client_capabilities is None: + return False + have = self.client_capabilities + if capability.roots is not None: + if have.roots is None: + return False + if capability.roots.list_changed and not have.roots.list_changed: + return False + if capability.sampling is not None and have.sampling is None: + return False + if capability.elicitation is not None and have.elicitation is None: + return False + if capability.experimental is not None: + if have.experimental is None: + return False + for k in capability.experimental: + if k not in have.experimental: + return False + return True diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index d8e11d78b2..4d35f8a902 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -5,10 +5,17 @@ from typing_extensions import TypeVar +from mcp.server._typed_request import TypedServerRequestMixin +from mcp.server.connection import Connection from mcp.server.experimental.request_context import Experimental from mcp.server.session import ServerSession from mcp.shared._context import RequestContext +from mcp.shared.context import BaseContext +from mcp.shared.dispatcher import DispatchContext from mcp.shared.message import CloseSSEStreamCallback +from mcp.shared.peer import Meta, PeerMixin +from mcp.shared.transport_context import TransportContext +from mcp.types import LoggingLevel, RequestParamsMeta LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any]) RequestT = TypeVar("RequestT", default=Any) @@ -21,3 +28,56 @@ class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContex request: RequestT | None = None close_sse_stream: CloseSSEStreamCallback | None = None close_standalone_sse_stream: CloseSSEStreamCallback | None = None + + +LifespanT = TypeVar("LifespanT", default=Any, covariant=True) +TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True) + + +class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Generic[LifespanT, TransportT]): + """Server-side per-request context. + + Composes `BaseContext` (forwards to `DispatchContext`, satisfies `Outbound`), + `PeerMixin` (kwarg-style ``sample``/``elicit_*``/``list_roots``/``ping``), + and `TypedServerRequestMixin` (typed ``send_request(req) -> Result``). Adds + ``lifespan`` and ``connection``. + + Constructed by `ServerRunner` per inbound request and handed to the user's + handler. + """ + + def __init__( + self, + dctx: DispatchContext[TransportT], + *, + lifespan: LifespanT, + connection: Connection, + meta: RequestParamsMeta | None = None, + ) -> None: + super().__init__(dctx, meta=meta) + self._lifespan = lifespan + self._connection = connection + + @property + def lifespan(self) -> LifespanT: + """The server-wide lifespan output (what `Server(..., lifespan=...)` yielded).""" + return self._lifespan + + @property + def connection(self) -> Connection: + """The per-client `Connection` for this request's connection.""" + return self._connection + + async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None: + """Send a request-scoped ``notifications/message`` log entry. + + Uses this request's back-channel (so the entry rides the request's SSE + stream in streamable HTTP), not the standalone stream — use + ``ctx.connection.log(...)`` for that. + """ + params: dict[str, Any] = {"level": level, "data": data} + if logger is not None: + params["logger"] = logger + if meta: + params["_meta"] = meta + await self.notify("notifications/message", params) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py new file mode 100644 index 0000000000..ff69c48401 --- /dev/null +++ b/src/mcp/shared/context.py @@ -0,0 +1,82 @@ +"""`BaseContext` — the user-facing per-request context. + +Composition over a `DispatchContext`: forwards the transport metadata, the +back-channel (`send_raw_request`/`notify`), progress reporting, and the cancel +event. Adds `meta` (the inbound request's `_meta` field). + +Satisfies `Outbound`, so `PeerMixin` works on it (the server-side `Context` +mixes that in directly). Shared between client and server: the server's +`Context` extends this with `lifespan`/`connection`; `ClientContext` is just an +alias. +""" + +from collections.abc import Mapping +from typing import Any, Generic + +import anyio +from typing_extensions import TypeVar + +from mcp.shared.dispatcher import CallOptions, DispatchContext +from mcp.shared.transport_context import TransportContext +from mcp.types import RequestParamsMeta + +__all__ = ["BaseContext"] + +TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True) + + +class BaseContext(Generic[TransportT]): + """Per-request context wrapping a `DispatchContext`. + + `ServerRunner` constructs one per inbound request and passes it to the + user's handler. + """ + + def __init__(self, dctx: DispatchContext[TransportT], meta: RequestParamsMeta | None = None) -> None: + self._dctx = dctx + self._meta = meta + + @property + def transport(self) -> TransportT: + """Transport-specific metadata for this inbound request.""" + return self._dctx.transport + + @property + def cancel_requested(self) -> anyio.Event: + """Set when the peer sends ``notifications/cancelled`` for this request.""" + return self._dctx.cancel_requested + + @property + def can_send_request(self) -> bool: + """Whether the back-channel can deliver server-initiated requests.""" + return self._dctx.transport.can_send_request + + @property + def meta(self) -> RequestParamsMeta | None: + """The inbound request's ``_meta`` field, if present.""" + return self._meta + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a request to the peer on the back-channel. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: ``can_send_request`` is ``False``. + """ + return await self._dctx.send_raw_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a notification to the peer on the back-channel.""" + await self._dctx.notify(method, params) + + async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + """Report progress for this request, if the peer supplied a progress token. + + A no-op when no token was supplied. + """ + await self._dctx.progress(progress, total, message) diff --git a/src/mcp/shared/peer.py b/src/mcp/shared/peer.py new file mode 100644 index 0000000000..47b64c7769 --- /dev/null +++ b/src/mcp/shared/peer.py @@ -0,0 +1,216 @@ +"""Typed MCP request sugar over an `Outbound`. + +`PeerMixin` defines the server-to-client request methods (sampling, elicitation, +roots, ping) once. Any class that satisfies `Outbound` (i.e. has +``send_raw_request`` and ``notify``) can mix it in and get the typed methods for +free — `Context`, `Connection`, `Client`, or the bare `Peer` wrapper below. + +The mixin does no capability gating: it builds the params, calls +``self.send_raw_request(method, params)``, and parses the result into the typed +model. Gating (and `NoBackChannelError`) is the host's `send_raw_request`'s job. +""" + +from collections.abc import Mapping +from typing import Any, overload + +from pydantic import BaseModel + +from mcp.shared.dispatcher import CallOptions, Outbound +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + CreateMessageResultWithTools, + ElicitRequestedSchema, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + IncludeContext, + ListRootsResult, + ModelPreferences, + SamplingMessage, + Tool, + ToolChoice, +) + +__all__ = ["Meta", "Peer", "PeerMixin", "dump_params"] + +Meta = dict[str, Any] +"""Type alias for the ``_meta`` field carried on request/notification params.""" + + +def dump_params(model: BaseModel | None, meta: Meta | None = None) -> dict[str, Any] | None: + """Serialize a params model to a wire dict, merging ``meta`` into ``_meta``. + + Shared by `PeerMixin`, `Connection`, and `TypedServerRequestMixin` so every + typed convenience method gets the same `_meta` handling. ``meta`` keys take + precedence over any ``_meta`` already present on the model. + """ + out = model.model_dump(by_alias=True, mode="json", exclude_none=True) if model is not None else None + if meta: + out = dict(out or {}) + out["_meta"] = {**out.get("_meta", {}), **meta} + return out + + +class PeerMixin: + """Typed server-to-client request methods. + + Each method constrains ``self`` to `Outbound` so the mixin can be applied + to anything with ``send_raw_request``/``notify`` — pyright checks the host + class structurally at the call site. + """ + + @overload + async def sample( + self: Outbound, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: None = None, + tool_choice: ToolChoice | None = None, + meta: Meta | None = None, + opts: CallOptions | None = None, + ) -> CreateMessageResult: ... + @overload + async def sample( + self: Outbound, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: list[Tool], + tool_choice: ToolChoice | None = None, + meta: Meta | None = None, + opts: CallOptions | None = None, + ) -> CreateMessageResultWithTools: ... + async def sample( + self: Outbound, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: list[Tool] | None = None, + tool_choice: ToolChoice | None = None, + meta: Meta | None = None, + opts: CallOptions | None = None, + ) -> CreateMessageResult | CreateMessageResultWithTools: + """Send a ``sampling/createMessage`` request to the peer. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: The host's transport context has no + back-channel for server-initiated requests. + """ + params = CreateMessageRequestParams( + messages=messages, + system_prompt=system_prompt, + include_context=include_context, + temperature=temperature, + max_tokens=max_tokens, + stop_sequences=stop_sequences, + metadata=metadata, + model_preferences=model_preferences, + tools=tools, + tool_choice=tool_choice, + ) + result = await self.send_raw_request("sampling/createMessage", dump_params(params, meta), opts) + if tools is not None: + return CreateMessageResultWithTools.model_validate(result) + return CreateMessageResult.model_validate(result) + + async def elicit_form( + self: Outbound, + message: str, + requested_schema: ElicitRequestedSchema, + *, + meta: Meta | None = None, + opts: CallOptions | None = None, + ) -> ElicitResult: + """Send a form-mode ``elicitation/create`` request. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + params = ElicitRequestFormParams(message=message, requested_schema=requested_schema) + result = await self.send_raw_request("elicitation/create", dump_params(params, meta), opts) + return ElicitResult.model_validate(result) + + async def elicit_url( + self: Outbound, + message: str, + url: str, + elicitation_id: str, + *, + meta: Meta | None = None, + opts: CallOptions | None = None, + ) -> ElicitResult: + """Send a URL-mode ``elicitation/create`` request. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + params = ElicitRequestURLParams(message=message, url=url, elicitation_id=elicitation_id) + result = await self.send_raw_request("elicitation/create", dump_params(params, meta), opts) + return ElicitResult.model_validate(result) + + async def list_roots( + self: Outbound, *, meta: Meta | None = None, opts: CallOptions | None = None + ) -> ListRootsResult: + """Send a ``roots/list`` request. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + result = await self.send_raw_request("roots/list", dump_params(None, meta), opts) + return ListRootsResult.model_validate(result) + + async def ping(self: Outbound, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None: + """Send a ``ping`` request and ignore the result. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + await self.send_raw_request("ping", dump_params(None, meta), opts) + + +class Peer(PeerMixin): + """Standalone wrapper that gives any `Outbound` the `PeerMixin` sugar. + + `Context` and `Connection` mix `PeerMixin` in directly; use `Peer` when + you have a bare dispatcher (or any `Outbound`) and want the typed methods + without writing your own host class. + """ + + def __init__(self, outbound: Outbound) -> None: + self._outbound = outbound + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + return await self._outbound.send_raw_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._outbound.notify(method, params) diff --git a/tests/server/test_connection.py b/tests/server/test_connection.py new file mode 100644 index 0000000000..ded9dfd6ac --- /dev/null +++ b/tests/server/test_connection.py @@ -0,0 +1,205 @@ +"""Tests for `Connection`. + +`Connection` wraps an `Outbound` (the standalone stream). Its `notify` is +best-effort (never raises); `send_raw_request` is gated on +``has_standalone_channel``. Tested with a stub `Outbound` so we can assert wire +shape and inject failures. +""" + +import logging +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.server.connection import Connection +from mcp.shared.dispatcher import CallOptions +from mcp.shared.exceptions import NoBackChannelError +from mcp.types import ( + ClientCapabilities, + ElicitationCapability, + EmptyResult, + ListRootsRequest, + ListRootsResult, + PingRequest, + RootsCapability, + SamplingCapability, +) + + +class StubOutbound: + def __init__( + self, *, result: dict[str, Any] | None = None, raise_on_send: type[BaseException] | None = None + ) -> None: + self.requests: list[tuple[str, Mapping[str, Any] | None]] = [] + self.notifications: list[tuple[str, Mapping[str, Any] | None]] = [] + self._result = result if result is not None else {} + self._raise_on_send = raise_on_send + + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None + ) -> dict[str, Any]: + self.requests.append((method, params)) + return self._result + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + if self._raise_on_send is not None: + raise self._raise_on_send() + self.notifications.append((method, params)) + + +@pytest.mark.anyio +async def test_connection_notify_forwards_to_outbound(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.notify("notifications/message", {"level": "info", "data": "hi"}) + assert out.notifications == [("notifications/message", {"level": "info", "data": "hi"})] + + +@pytest.mark.anyio +async def test_connection_notify_swallows_broken_stream_and_debug_logs(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG, logger="mcp.server.connection") + out = StubOutbound(raise_on_send=anyio.BrokenResourceError) + conn = Connection(out, has_standalone_channel=True) + await conn.notify("notifications/message", {"data": "x"}) # must not raise + assert "stream closed" in caplog.text.lower() + + +@pytest.mark.anyio +async def test_connection_notify_drops_when_no_standalone_channel(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG, logger="mcp.server.connection") + out = StubOutbound() + conn = Connection(out, has_standalone_channel=False) + await conn.notify("notifications/message", {"data": "x"}) # must not raise + assert out.notifications == [] + assert "no standalone channel" in caplog.text.lower() + + +@pytest.mark.anyio +async def test_connection_send_raw_request_raises_nobackchannel_when_no_standalone_channel(): + conn = Connection(StubOutbound(), has_standalone_channel=False) + with pytest.raises(NoBackChannelError): + await conn.send_raw_request("ping", None) + + +@pytest.mark.anyio +async def test_connection_send_raw_request_forwards_when_standalone_channel_present(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + result = await conn.send_raw_request("ping", None) + assert out.requests == [("ping", None)] + assert result == {} + + +@pytest.mark.anyio +async def test_connection_send_request_with_spec_type_infers_result_type(): + out = StubOutbound(result={"roots": [{"uri": "file:///ws"}]}) + conn = Connection(out, has_standalone_channel=True) + result = await conn.send_request(ListRootsRequest()) + method, _ = out.requests[0] + assert method == "roots/list" + assert isinstance(result, ListRootsResult) + assert str(result.roots[0].uri) == "file:///ws" + + +@pytest.mark.anyio +async def test_connection_send_request_with_result_type_kwarg_validates_custom_type(): + out = StubOutbound(result={}) + conn = Connection(out, has_standalone_channel=True) + result = await conn.send_request(PingRequest(), result_type=EmptyResult) + assert isinstance(result, EmptyResult) + + +@pytest.mark.anyio +async def test_connection_ping_sends_ping_on_standalone(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.ping() + assert out.requests == [("ping", None)] + + +@pytest.mark.anyio +async def test_connection_log_sends_logging_message_notification(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.log("info", {"k": "v"}, logger="my.logger") + method, params = out.notifications[0] + assert method == "notifications/message" + assert params is not None + assert params["level"] == "info" + assert params["data"] == {"k": "v"} + assert params["logger"] == "my.logger" + + +@pytest.mark.anyio +async def test_connection_log_with_meta_includes_meta_in_params(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.log("info", "x", meta={"traceId": "abc"}) + _, params = out.notifications[0] + assert params is not None + assert params["_meta"] == {"traceId": "abc"} + + +@pytest.mark.anyio +async def test_connection_list_changed_notifications_send_correct_methods(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.send_tool_list_changed() + await conn.send_prompt_list_changed() + await conn.send_resource_list_changed() + await conn.send_resource_updated("file:///workspace/a.txt") + methods = [m for m, _ in out.notifications] + assert methods == [ + "notifications/tools/list_changed", + "notifications/prompts/list_changed", + "notifications/resources/list_changed", + "notifications/resources/updated", + ] + assert out.notifications[-1][1] == {"uri": "file:///workspace/a.txt"} + + +@pytest.mark.anyio +async def test_connection_send_tool_list_changed_with_meta_includes_meta_only_params(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.send_tool_list_changed(meta={"k": 1}) + assert out.notifications == [("notifications/tools/list_changed", {"_meta": {"k": 1}})] + + +def test_connection_check_capability_false_before_initialized(): + conn = Connection(StubOutbound(), has_standalone_channel=True) + assert conn.check_capability(ClientCapabilities(sampling=SamplingCapability())) is False + + +@pytest.mark.parametrize( + ("have", "want", "expected"), + [ + (ClientCapabilities(roots=None), ClientCapabilities(roots=RootsCapability()), False), + ( + ClientCapabilities(roots=RootsCapability(list_changed=False)), + ClientCapabilities(roots=RootsCapability(list_changed=True)), + False, + ), + (ClientCapabilities(sampling=None), ClientCapabilities(sampling=SamplingCapability()), False), + (ClientCapabilities(experimental=None), ClientCapabilities(experimental={"a": {}}), False), + (ClientCapabilities(experimental={"a": {}}), ClientCapabilities(experimental={"b": {}}), False), + (ClientCapabilities(experimental={"a": {}}), ClientCapabilities(experimental={"a": {}}), True), + ], +) +def test_check_capability_per_field_branches(have: ClientCapabilities, want: ClientCapabilities, expected: bool): + conn = Connection(StubOutbound(), has_standalone_channel=True) + conn.client_capabilities = have + assert conn.check_capability(want) is expected + + +def test_connection_check_capability_true_when_client_declares_it(): + conn = Connection(StubOutbound(), has_standalone_channel=True) + conn.client_capabilities = ClientCapabilities( + sampling=SamplingCapability(), roots=RootsCapability(list_changed=True) + ) + conn.initialized.set() + assert conn.check_capability(ClientCapabilities(sampling=SamplingCapability())) is True + assert conn.check_capability(ClientCapabilities(roots=RootsCapability(list_changed=True))) is True + assert conn.check_capability(ClientCapabilities(elicitation=ElicitationCapability())) is False diff --git a/tests/server/test_server_context.py b/tests/server/test_server_context.py new file mode 100644 index 0000000000..e01de34d33 --- /dev/null +++ b/tests/server/test_server_context.py @@ -0,0 +1,156 @@ +"""Tests for the server-side `Context`. + +`Context` composes `BaseContext` (forwarding to a `DispatchContext`) with +`PeerMixin` (typed sample/elicit/roots/ping) plus `lifespan` and `connection`. +End-to-end tested over `DirectDispatcher`. +""" + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +import anyio +import pytest + +from mcp.server.connection import Connection +from mcp.server.context import Context +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.transport_context import TransportContext +from mcp.types import CreateMessageResult, ListRootsRequest, ListRootsResult, SamplingMessage, TextContent + +from ..shared.conftest import direct_pair +from ..shared.test_dispatcher import Recorder, echo_handlers, running_pair + +DCtx = DispatchContext[TransportContext] + + +@dataclass +class _Lifespan: + name: str + + +@pytest.mark.anyio +async def test_context_exposes_lifespan_and_connection_and_forwards_base_context(): + captured: list[Context[_Lifespan, TransportContext]] = [] + conn = Connection.__new__(Connection) # placeholder until running_pair gives us the dispatcher + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan, TransportContext] = Context(dctx, lifespan=_Lifespan("app"), connection=conn) + captured.append(ctx) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, server, *_): + # Now we have the server dispatcher; build the real Connection bound to it. + conn.__init__(server, has_standalone_channel=True) + with anyio.fail_after(5): + await client.send_raw_request("t", None) + ctx = captured[0] + assert ctx.lifespan.name == "app" + assert ctx.connection is conn + assert ctx.transport.kind == "direct" + assert ctx.can_send_request is True + + +@pytest.mark.anyio +async def test_context_sample_round_trips_via_peer_mixin_on_base_context_outbound(): + crec = Recorder() + + async def client_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + crec.requests.append((method, params)) + return {"role": "assistant", "content": {"type": "text", "text": "ok"}, "model": "m"} + + results: list[CreateMessageResult] = [] + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan, TransportContext] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + results.append( + await ctx.sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="hi"))], + max_tokens=5, + ) + ) + return {} + + async with running_pair( + direct_pair, + server_on_request=server_on_request, + client_on_request=client_on_request, + ) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None) + assert crec.requests[0][0] == "sampling/createMessage" + assert isinstance(results[0], CreateMessageResult) + + +@pytest.mark.anyio +async def test_context_send_request_with_spec_type_infers_result_via_typed_mixin(): + async def client_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + return {"roots": []} + + results: list[ListRootsResult] = [] + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan, TransportContext] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + results.append(await ctx.send_request(ListRootsRequest())) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request, client_on_request=client_on_request) as ( + client, + *_, + ): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + assert isinstance(results[0], ListRootsResult) + + +@pytest.mark.anyio +async def test_context_log_sends_request_scoped_message_notification(): + crec = Recorder() + _, c_notify = echo_handlers(crec) + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan, TransportContext] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + await ctx.log("debug", "hello") + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request, client_on_notify=c_notify) as ( + client, + *_, + ): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + await crec.notified.wait() + method, params = crec.notifications[0] + assert method == "notifications/message" + assert params is not None and params["level"] == "debug" and params["data"] == "hello" + + +@pytest.mark.anyio +async def test_context_log_includes_logger_and_meta_when_supplied(): + crec = Recorder() + _, c_notify = echo_handlers(crec) + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan, TransportContext] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + await ctx.log("info", "x", logger="my.log", meta={"traceId": "t"}) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request, client_on_notify=c_notify) as ( + client, + *_, + ): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + await crec.notified.wait() + _, params = crec.notifications[0] + assert params is not None + assert params["logger"] == "my.log" + assert params["_meta"] == {"traceId": "t"} diff --git a/tests/shared/test_context.py b/tests/shared/test_context.py new file mode 100644 index 0000000000..882f90bfab --- /dev/null +++ b/tests/shared/test_context.py @@ -0,0 +1,115 @@ +"""Tests for `BaseContext`. + +`BaseContext` is composition over a `DispatchContext` — it forwards +``transport``/``cancel_requested``/``send_raw_request``/``notify``/``progress`` +and adds ``meta``. It must satisfy `Outbound` so `PeerMixin` works on it. +""" + +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.shared.context import BaseContext +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.peer import Peer +from mcp.shared.transport_context import TransportContext + +from .conftest import direct_pair +from .test_dispatcher import Recorder, echo_handlers, running_pair + +DCtx = DispatchContext[TransportContext] + + +@pytest.mark.anyio +async def test_base_context_forwards_transport_and_cancel_requested(): + captured: list[BaseContext[TransportContext]] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + captured.append(bctx) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + bctx = captured[0] + assert bctx.transport.kind == "direct" + assert isinstance(bctx.cancel_requested, anyio.Event) + assert bctx.can_send_request is True + assert bctx.meta is None + + +@pytest.mark.anyio +async def test_base_context_send_raw_request_and_notify_forward_to_dispatch_context(): + crec = Recorder() + c_req, c_notify = echo_handlers(crec) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + sample = await bctx.send_raw_request("sampling/createMessage", {"x": 1}) + await bctx.notify("notifications/message", {"level": "info"}) + return {"sample": sample} + + async with running_pair( + direct_pair, + server_on_request=server_on_request, + client_on_request=c_req, + client_on_notify=c_notify, + ) as (client, *_): + with anyio.fail_after(5): + result = await client.send_raw_request("tools/call", None) + await crec.notified.wait() + assert crec.requests == [("sampling/createMessage", {"x": 1})] + assert crec.notifications == [("notifications/message", {"level": "info"})] + assert result["sample"] == {"echoed": "sampling/createMessage", "params": {"x": 1}} + + +@pytest.mark.anyio +async def test_base_context_report_progress_invokes_caller_on_progress(): + received: list[tuple[float, float | None, str | None]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + await bctx.report_progress(0.5, total=1.0, message="halfway") + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None, {"on_progress": on_progress}) + assert received == [(0.5, 1.0, "halfway")] + + +@pytest.mark.anyio +async def test_base_context_satisfies_outbound_so_peer_mixin_works(): + """Wrapping a BaseContext in Peer proves it satisfies Outbound structurally.""" + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + await Peer(bctx).ping() + return {} + + crec = Recorder() + c_req, c_notify = echo_handlers(crec) + async with running_pair( + direct_pair, server_on_request=server_on_request, client_on_request=c_req, client_on_notify=c_notify + ) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + assert crec.requests == [("ping", None)] + + +@pytest.mark.anyio +async def test_base_context_meta_holds_supplied_request_params_meta(): + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx, meta={"progressToken": "abc"}) + assert bctx.meta is not None and bctx.meta.get("progressToken") == "abc" + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None) diff --git a/tests/shared/test_peer.py b/tests/shared/test_peer.py new file mode 100644 index 0000000000..0be4225818 --- /dev/null +++ b/tests/shared/test_peer.py @@ -0,0 +1,164 @@ +"""Tests for `PeerMixin` and `Peer`. + +Each PeerMixin method is tested by wrapping a `DirectDispatcher` in `Peer`, +calling the typed method, and asserting (a) the right method+params went out +and (b) the return value is the typed result model. +""" + +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.peer import Peer, dump_params +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + CreateMessageResult, + CreateMessageResultWithTools, + ElicitResult, + ListRootsResult, + SamplingMessage, + TextContent, + Tool, +) + +from .conftest import direct_pair +from .test_dispatcher import running_pair + +DCtx = DispatchContext[TransportContext] + + +class _Recorder: + def __init__(self, result: dict[str, Any]) -> None: + self.result = result + self.seen: list[tuple[str, Mapping[str, Any] | None]] = [] + + async def on_request(self, ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + self.seen.append((method, params)) + return self.result + + +@pytest.mark.anyio +async def test_peer_sample_sends_create_message_and_returns_typed_result(): + rec = _Recorder({"role": "assistant", "content": {"type": "text", "text": "hi"}, "model": "m"}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="hello"))], + max_tokens=10, + ) + method, params = rec.seen[0] + assert method == "sampling/createMessage" + assert params is not None and params["maxTokens"] == 10 + assert isinstance(result, CreateMessageResult) + assert result.model == "m" + + +@pytest.mark.anyio +async def test_peer_sample_with_tools_returns_with_tools_result(): + rec = _Recorder({"role": "assistant", "content": [{"type": "text", "text": "x"}], "model": "m"}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="q"))], + max_tokens=5, + tools=[Tool(name="t", input_schema={"type": "object"})], + ) + method, params = rec.seen[0] + assert method == "sampling/createMessage" + assert params is not None and params["tools"][0]["name"] == "t" + assert isinstance(result, CreateMessageResultWithTools) + + +@pytest.mark.anyio +async def test_peer_elicit_form_sends_elicitation_create_with_form_params(): + rec = _Recorder({"action": "accept", "content": {"name": "Max"}}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.elicit_form("Your name?", requested_schema={"type": "object", "properties": {}}) + method, params = rec.seen[0] + assert method == "elicitation/create" + assert params is not None and params["mode"] == "form" + assert params["message"] == "Your name?" + assert isinstance(result, ElicitResult) + + +@pytest.mark.anyio +async def test_peer_elicit_url_sends_elicitation_create_with_url_params(): + rec = _Recorder({"action": "accept"}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.elicit_url("Auth needed", url="https://example.com/auth", elicitation_id="e1") + method, params = rec.seen[0] + assert method == "elicitation/create" + assert params is not None and params["mode"] == "url" + assert params["url"] == "https://example.com/auth" + assert isinstance(result, ElicitResult) + + +@pytest.mark.anyio +async def test_peer_list_roots_sends_roots_list_and_returns_typed_result(): + rec = _Recorder({"roots": [{"uri": "file:///workspace"}]}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.list_roots() + method, _ = rec.seen[0] + assert method == "roots/list" + assert isinstance(result, ListRootsResult) + assert len(result.roots) == 1 + assert str(result.roots[0].uri) == "file:///workspace" + + +@pytest.mark.anyio +async def test_peer_list_roots_with_meta_sends_meta_in_params(): + rec = _Recorder({"roots": []}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + await peer.list_roots(meta={"traceId": "t1"}) + method, params = rec.seen[0] + assert method == "roots/list" + assert params == {"_meta": {"traceId": "t1"}} + + +def test_dump_params_merges_meta_over_model_meta(): + out = dump_params(None, None) + assert out is None + out = dump_params(None, {"k": 1}) + assert out == {"_meta": {"k": 1}} + + +@pytest.mark.anyio +async def test_peer_notify_forwards_to_wrapped_outbound(): + sent: list[tuple[str, Mapping[str, Any] | None]] = [] + + class _Out: + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: Any = None + ) -> dict[str, Any]: + raise NotImplementedError + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + sent.append((method, params)) + + await Peer(_Out()).notify("n", {"x": 1}) + assert sent == [("n", {"x": 1})] + + +@pytest.mark.anyio +async def test_peer_ping_sends_ping_and_returns_none(): + rec = _Recorder({}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.ping() + method, _ = rec.seen[0] + assert method == "ping" + assert result is None From 2dbc2ec89168344f06b78418c1530246273c171f Mon Sep 17 00:00:00 2001 From: Pawan Bhardwaj Date: Sat, 23 May 2026 11:01:32 +0000 Subject: [PATCH 4/4] PR 2491 --- pyproject.toml | 1 + src/mcp/server/context.py | 36 +++- src/mcp/server/lowlevel/server.py | 15 +- src/mcp/server/runner.py | 295 ++++++++++++++++++++++++++ src/mcp/shared/_otel.py | 11 +- tests/conftest.py | 11 + tests/server/conftest.py | 45 ++++ tests/server/test_runner.py | 340 ++++++++++++++++++++++++++++++ tests/shared/test_otel.py | 3 - uv.lock | 2 + 10 files changed, 753 insertions(+), 6 deletions(-) create mode 100644 src/mcp/server/runner.py create mode 100644 tests/server/conftest.py create mode 100644 tests/server/test_runner.py diff --git a/pyproject.toml b/pyproject.toml index d88869da1c..20a5d9362d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ dev = [ "pillow>=12.0", "strict-no-cover", "logfire>=3.0.0", + "opentelemetry-sdk>=1.39.1", ] docs = [ "mkdocs>=1.6.1", diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 4d35f8a902..1c855ae48a 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -1,8 +1,10 @@ from __future__ import annotations +from collections.abc import Awaitable, Callable from dataclasses import dataclass -from typing import Any, Generic +from typing import Any, Generic, Protocol +from pydantic import BaseModel from typing_extensions import TypeVar from mcp.server._typed_request import TypedServerRequestMixin @@ -81,3 +83,35 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, * if meta: params["_meta"] = meta await self.notify("notifications/message", params) + + +HandlerResult = BaseModel | dict[str, Any] | None +"""What a request handler (or middleware) may return. `ServerRunner` serializes +all three to a result dict.""" + +CallNext = Callable[[], Awaitable[HandlerResult]] + +_MwLifespanT = TypeVar("_MwLifespanT", contravariant=True) + + +class ContextMiddleware(Protocol[_MwLifespanT]): + """Context-tier middleware: ``(ctx, method, typed_params, call_next) -> result``. + + Runs *inside* `ServerRunner._on_request` after params validation and + `Context` construction. Wraps registered handlers (including ``ping``) but + not ``initialize``, ``METHOD_NOT_FOUND``, or validation failures. Listed + outermost-first on `Server.middleware`. + + `Server[L].middleware` holds `ContextMiddleware[L]`, so an app-specific + middleware sees `ctx.lifespan: L`. A reusable middleware (no app-specific + types) can be typed `ContextMiddleware[object]` — `Context` is covariant in + `LifespanT`, so it registers on any `Server[L]`. + """ + + async def __call__( + self, + ctx: Context[_MwLifespanT, TransportContext], + method: str, + params: BaseModel, + call_next: CallNext, + ) -> HandlerResult: ... diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 59de0ace45..9dc44708f0 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -58,7 +58,7 @@ async def main(): from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings -from mcp.server.context import ServerRequestContext +from mcp.server.context import ContextMiddleware, ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.models import InitializationOptions @@ -199,6 +199,9 @@ def __init__( ] = {} self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None self._session_manager: StreamableHTTPSessionManager | None = None + # Context-tier middleware consumed by `ServerRunner`. Additive; the + # existing `run()` path ignores it. + self.middleware: list[ContextMiddleware[LifespanResultT]] = [] logger.debug("Initializing server %r", name) # Populate internal handler dicts from on_* kwargs @@ -246,6 +249,16 @@ def _has_handler(self, method: str) -> bool: """Check if a handler is registered for the given method.""" return method in self._request_handlers or method in self._notification_handlers + # --- ServerRegistry protocol (consumed by ServerRunner) ------------------ + + def get_request_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None: + """Return the handler for a request method, or ``None``.""" + return self._request_handlers.get(method) + + def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None: + """Return the handler for a notification method, or ``None``.""" + return self._notification_handlers.get(method) + # TODO: Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py new file mode 100644 index 0000000000..bb3af04435 --- /dev/null +++ b/src/mcp/server/runner.py @@ -0,0 +1,295 @@ +"""`ServerRunner` — per-connection orchestrator over a `Dispatcher`. + +`ServerRunner` is the bridge between the dispatcher layer (`on_request` / +`on_notify`, untyped dicts) and the user's handler layer (typed `Context`, +typed params). One instance per client connection. It: + +* handles the ``initialize`` handshake and populates `Connection` +* gates requests until initialized (``ping`` exempt) +* looks up the handler in the server's registry, validates params, builds + `Context`, runs the middleware chain, returns the result dict +* drives ``dispatcher.run()`` and the per-connection lifespan + +`ServerRunner` consumes any `ServerRegistry` — the lowlevel `Server` satisfies +it via additive methods so the existing ``Server.run()`` path is unaffected. +""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable, Mapping, Sequence +from dataclasses import dataclass, field +from functools import partial, reduce +from typing import Any, Generic, Protocol, cast + +import anyio.abc +from opentelemetry.trace import SpanKind, StatusCode +from pydantic import BaseModel +from typing_extensions import TypeVar + +from mcp.server.connection import Connection +from mcp.server.context import CallNext, Context, ContextMiddleware +from mcp.server.lowlevel.server import NotificationOptions +from mcp.shared._otel import extract_trace_context, otel_span +from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest +from mcp.shared.exceptions import MCPError +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + INVALID_REQUEST, + LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, + CallToolRequestParams, + CompleteRequestParams, + GetPromptRequestParams, + Implementation, + InitializeRequestParams, + InitializeResult, + NotificationParams, + PaginatedRequestParams, + ProgressNotificationParams, + ReadResourceRequestParams, + RequestParams, + ServerCapabilities, + SetLevelRequestParams, + SubscribeRequestParams, + UnsubscribeRequestParams, +) + +__all__ = ["CallNext", "ContextMiddleware", "ServerRegistry", "ServerRunner", "otel_middleware"] + +logger = logging.getLogger(__name__) + +LifespanT = TypeVar("LifespanT", default=Any) +ServerTransportT = TypeVar("ServerTransportT", bound=TransportContext, default=TransportContext) + +Handler = Callable[..., Awaitable[Any]] +"""A request/notification handler: ``(ctx, params) -> result``. Typed loosely +so the existing `ServerRequestContext`-based handlers and the new +`Context`-based handlers both fit during the transition. +""" + + +_INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) + +# TODO: remove this lookup once `Server` stores (params_type, handler) in its +# registry directly. This is scaffolding so ServerRunner can validate params +# without changing the existing `_request_handlers` dict shape. +_PARAMS_FOR_METHOD: dict[str, type[BaseModel]] = { + "ping": RequestParams, + "tools/list": PaginatedRequestParams, + "tools/call": CallToolRequestParams, + "prompts/list": PaginatedRequestParams, + "prompts/get": GetPromptRequestParams, + "resources/list": PaginatedRequestParams, + "resources/templates/list": PaginatedRequestParams, + "resources/read": ReadResourceRequestParams, + "resources/subscribe": SubscribeRequestParams, + "resources/unsubscribe": UnsubscribeRequestParams, + "logging/setLevel": SetLevelRequestParams, + "completion/complete": CompleteRequestParams, +} +"""Spec method → params model. Scaffolding while the lowlevel `Server`'s +`_request_handlers` stores handler-only; the registry refactor should make this +the registry's responsibility (or store params types alongside handlers).""" + +_PARAMS_FOR_NOTIFICATION: dict[str, type[BaseModel]] = { + "notifications/initialized": NotificationParams, + "notifications/roots/list_changed": NotificationParams, + "notifications/progress": ProgressNotificationParams, +} + + +class ServerRegistry(Protocol): + """The handler registry `ServerRunner` consumes. + + The lowlevel `Server` satisfies this via additive methods. + """ + + @property + def name(self) -> str: ... + @property + def version(self) -> str | None: ... + + @property + def middleware(self) -> Sequence[ContextMiddleware[Any]]: ... + + def get_request_handler(self, method: str) -> Handler | None: ... + def get_notification_handler(self, method: str) -> Handler | None: ... + def get_capabilities( + self, notification_options: Any, experimental_capabilities: dict[str, dict[str, Any]] + ) -> ServerCapabilities: ... + + +def otel_middleware(next_on_request: OnRequest) -> OnRequest: + """Dispatch-tier middleware that wraps each request in an OpenTelemetry span. + + Mirrors the span shape of the existing `Server._handle_request`: span name + ``"MCP handle []"``, ``mcp.method.name`` attribute, W3C + trace context extracted from ``params._meta`` (SEP-414), and an ERROR + status if the handler raises. + """ + + async def wrapped( + dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + target: str | None + match params: + case {"name": str() as target}: + pass + case _: + target = None + parent: Any | None + match params: + case {"_meta": {**meta}}: + parent = extract_trace_context(meta) + case _: + parent = None + span_name = f"MCP handle {method}{f' {target}' if target else ''}" + with otel_span( + span_name, + kind=SpanKind.SERVER, + attributes={"mcp.method.name": method}, + context=parent, + record_exception=False, + set_status_on_exception=False, + ) as span: + try: + return await next_on_request(dctx, method, params) + except MCPError as e: + span.set_status(StatusCode.ERROR, e.error.message) + raise + except Exception as e: + span.record_exception(e) + span.set_status(StatusCode.ERROR, str(e)) + raise + + return wrapped + + +def _dump_result(result: Any) -> dict[str, Any]: + if result is None: + return {} + if isinstance(result, BaseModel): + return result.model_dump(by_alias=True, mode="json", exclude_none=True) + if isinstance(result, dict): + return cast(dict[str, Any], result) + raise TypeError(f"handler returned {type(result).__name__}; expected BaseModel, dict, or None") + + +@dataclass +class ServerRunner(Generic[LifespanT, ServerTransportT]): + """Per-connection orchestrator. One instance per client connection.""" + + server: ServerRegistry + dispatcher: Dispatcher[ServerTransportT] + lifespan_state: LifespanT + has_standalone_channel: bool + stateless: bool = False + dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware]) + + connection: Connection = field(init=False) + _initialized: bool = field(init=False) + + def __post_init__(self) -> None: + self._initialized = self.stateless + self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel) + + async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: + """Drive the dispatcher until the underlying channel closes. + + Composes `dispatch_middleware` over `_on_request` and hands the result + to `dispatcher.run()`. ``task_status.started()`` is forwarded so callers + can ``await tg.start(runner.run)`` and resume once the dispatcher is + ready to accept requests. + """ + await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status) + + def _compose_on_request(self) -> OnRequest: + """Wrap `_on_request` in `dispatch_middleware`, outermost-first. + + Dispatch-tier middleware sees raw ``(dctx, method, params) -> dict`` + and wraps everything — initialize, METHOD_NOT_FOUND, validation + failures included. `run()` calls this once and hands the result to + `dispatcher.run()`. + """ + return reduce(lambda h, mw: mw(h), reversed(self.dispatch_middleware), self._on_request) + + async def _on_request( + self, + dctx: DispatchContext[TransportContext], + method: str, + params: Mapping[str, Any] | None, + ) -> dict[str, Any]: + if method == "initialize": + return self._handle_initialize(params) + if not self._initialized and method not in _INIT_EXEMPT: + raise MCPError( + code=INVALID_REQUEST, + message=f"Received {method!r} before initialization was complete", + ) + handler = self.server.get_request_handler(method) + if handler is None: + raise MCPError(code=METHOD_NOT_FOUND, message=f"Method not found: {method}") + # TODO: scaffolding — params_type comes from a static lookup until the + # registry stores it alongside the handler. + params_type = _PARAMS_FOR_METHOD.get(method, RequestParams) + # ValidationError propagates; the dispatcher's exception boundary maps + # it to INVALID_PARAMS. + typed_params = params_type.model_validate(params or {}) + ctx = self._make_context(dctx, typed_params) + call: CallNext = partial(handler, ctx, typed_params) + for mw in reversed(self.server.middleware): + call = partial(mw, ctx, method, typed_params, call) + return _dump_result(await call()) + + async def _on_notify( + self, + dctx: DispatchContext[TransportContext], + method: str, + params: Mapping[str, Any] | None, + ) -> None: + if method == "notifications/initialized": + self._initialized = True + self.connection.initialized.set() + return + if not self._initialized: + logger.debug("dropped %s: received before initialization", method) + return + handler = self.server.get_notification_handler(method) + if handler is None: + logger.debug("no handler for notification %s", method) + return + params_type = _PARAMS_FOR_NOTIFICATION.get(method, NotificationParams) + typed_params = params_type.model_validate(params or {}) + ctx = self._make_context(dctx, typed_params) + await handler(ctx, typed_params) + + def _make_context( + self, dctx: DispatchContext[TransportContext], typed_params: BaseModel + ) -> Context[LifespanT, ServerTransportT]: + # `OnRequest` delivers `DispatchContext[TransportContext]`; this + # ServerRunner instance was constructed for a specific + # `ServerTransportT`, so the narrow is safe by construction. + narrowed = cast(DispatchContext[ServerTransportT], dctx) + meta = getattr(typed_params, "meta", None) + return Context(narrowed, lifespan=self.lifespan_state, connection=self.connection, meta=meta) + + def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]: + init = InitializeRequestParams.model_validate(params or {}) + self.connection.client_info = init.client_info + self.connection.client_capabilities = init.capabilities + # TODO: real version negotiation. This always responds with LATEST, + # which is wrong — the server should pick the highest version both + # sides support and compute a per-connection feature set from it. + # See FOLLOWUPS: "Consolidate per-connection mode/negotiation". + self.connection.protocol_version = ( + init.protocol_version if init.protocol_version in {LATEST_PROTOCOL_VERSION} else LATEST_PROTOCOL_VERSION + ) + self._initialized = True + self.connection.initialized.set() + result = InitializeResult( + protocol_version=self.connection.protocol_version, + capabilities=self.server.get_capabilities(NotificationOptions(), {}), + server_info=Implementation(name=self.server.name, version=self.server.version or "0.0.0"), + ) + return _dump_result(result) diff --git a/src/mcp/shared/_otel.py b/src/mcp/shared/_otel.py index 170e873a0f..553b8a0bce 100644 --- a/src/mcp/shared/_otel.py +++ b/src/mcp/shared/_otel.py @@ -20,9 +20,18 @@ def otel_span( kind: SpanKind, attributes: dict[str, Any] | None = None, context: Context | None = None, + record_exception: bool = True, + set_status_on_exception: bool = True, ) -> Iterator[Any]: """Create an OTel span.""" - with _tracer.start_as_current_span(name, kind=kind, attributes=attributes, context=context) as span: + with _tracer.start_as_current_span( + name, + kind=kind, + attributes=attributes, + context=context, + record_exception=record_exception, + set_status_on_exception=set_status_on_exception, + ) as span: yield span diff --git a/tests/conftest.py b/tests/conftest.py index af7e479932..b83c472135 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,16 @@ +import os + import pytest +# OpenTelemetry's `set_tracer_provider` is set-once per process, so the suite +# uses a single span-capture mechanism: logfire's `capfire` fixture (its +# `configure()` swaps span processors on repeat calls rather than re-setting +# the provider). Logfire's default `distributed_tracing=None` emits a +# RuntimeWarning + diagnostic span when incoming W3C trace context is +# extracted; several tests exercise that propagation deliberately, so opt in +# suite-wide. Set before logfire is imported anywhere. +os.environ.setdefault("LOGFIRE_DISTRIBUTED_TRACING", "true") + @pytest.fixture def anyio_backend(): diff --git a/tests/server/conftest.py b/tests/server/conftest.py new file mode 100644 index 0000000000..290ccc957a --- /dev/null +++ b/tests/server/conftest.py @@ -0,0 +1,45 @@ +"""Shared fixtures for server-side tests.""" + +from collections.abc import Iterator + +import pytest +from logfire.testing import CaptureLogfire, TestExporter +from opentelemetry.sdk.trace import ReadableSpan + + +class SpanCapture: + """Thin adapter over logfire's `TestExporter` for asserting on MCP spans. + + `finished()` returns the raw `ReadableSpan` objects emitted by the + ``mcp-python-sdk`` instrumentation scope, filtered to exclude logfire's + synthetic ``pending_span`` markers, so tests can assert directly on + `.name`, `.kind`, `.status`, `.attributes`, `.parent`, `.events`. + """ + + def __init__(self, exporter: TestExporter) -> None: + self._exporter = exporter + + def clear(self) -> None: + self._exporter.clear() + + def finished(self) -> list[ReadableSpan]: + return [ + s + for s in self._exporter.exported_spans + if s.instrumentation_scope is not None + and s.instrumentation_scope.name == "mcp-python-sdk" + and not (s.attributes and s.attributes.get("logfire.span_type") == "pending_span") + ] + + +@pytest.fixture +def spans(capfire: CaptureLogfire) -> Iterator[SpanCapture]: + """In-memory MCP span capture, cleared before and after each test. + + Backed by the project-level `capfire` override (see ``tests/conftest.py``) + so there is a single global tracer provider for the suite. + """ + capture = SpanCapture(capfire.exporter) + capture.clear() + yield capture + capture.clear() diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py new file mode 100644 index 0000000000..843b0ae8b9 --- /dev/null +++ b/tests/server/test_runner.py @@ -0,0 +1,340 @@ +"""Tests for `ServerRunner`. + +End-to-end over `DirectDispatcher` with a real lowlevel `Server` as the +registry. The `connected_runner` helper starts both sides and (by default) +performs the initialize handshake, so each test exercises only the behaviour +under test. +""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +import anyio +import anyio.lowlevel +import pytest +from opentelemetry.trace import SpanKind, StatusCode + +from mcp.server.connection import Connection +from mcp.server.context import Context +from mcp.server.lowlevel.server import Server +from mcp.server.runner import ServerRunner, otel_middleware +from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair +from mcp.shared.dispatcher import DispatchMiddleware +from mcp.shared.exceptions import MCPError +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + INTERNAL_ERROR, + INVALID_REQUEST, + LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, + ClientCapabilities, + Implementation, + InitializeRequestParams, + Tool, +) + +from ..shared.test_dispatcher import Recorder, echo_handlers +from .conftest import SpanCapture + + +def _initialize_params() -> dict[str, Any]: + return InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test-client", version="1.0"), + ).model_dump(by_alias=True, exclude_none=True) + + +_seen_ctx: list[Context[Any, TransportContext]] = [] +SrvT = Server[dict[str, Any]] + + +@pytest.fixture +def server() -> SrvT: + """A lowlevel Server with one tools/list handler registered.""" + _seen_ctx.clear() + + async def list_tools(ctx: Any, params: Any) -> Any: + # ctx is typed `Any` because Server's on_list_tools kwarg expects the + # legacy ServerRequestContext shape; ServerRunner passes the new + # `Context`. The transition is intentional — Handler is loosely typed. + _seen_ctx.append(ctx) + return {"tools": [Tool(name="t", input_schema={"type": "object"}).model_dump(by_alias=True)]} + + return Server(name="test-server", version="0.0.1", on_list_tools=list_tools) + + +@asynccontextmanager +async def connected_runner( + server: SrvT, + *, + initialized: bool = True, + stateless: bool = False, + has_standalone_channel: bool = True, + dispatch_middleware: list[DispatchMiddleware] | None = None, +) -> AsyncIterator[tuple[DirectDispatcher, ServerRunner[None, TransportContext]]]: + """Yield ``(client, runner)`` running over an in-memory dispatcher pair. + + Starts the client (echo handlers) and `runner.run()` in a task group, wraps + the body in ``anyio.fail_after(5)``, and cancels on exit. When + ``initialized`` is true the helper performs the real ``initialize`` request + before yielding, so tests start past the init-gate via the public path. + """ + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state=None, + has_standalone_channel=has_standalone_channel, + stateless=stateless, + dispatch_middleware=dispatch_middleware or [], + ) + c_req, c_notify = echo_handlers(Recorder()) + body_exc: BaseException | None = None + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(runner.run) + try: + with anyio.fail_after(5): + if initialized: + await client.send_raw_request("initialize", _initialize_params()) + yield client, runner + except BaseException as e: + # Capture and re-raise outside the task group so test failures + # surface as the original exception, not an ExceptionGroup wrapper. + body_exc = e + client.close() + server_d.close() + if body_exc is not None: + raise body_exc + + +@pytest.mark.anyio +async def test_connected_runner_propagates_body_exception_unwrapped(server: SrvT): + """The harness re-raises body exceptions as-is, not as ``ExceptionGroup``.""" + with pytest.raises(RuntimeError, match="boom"): + async with connected_runner(server): + raise RuntimeError("boom") + + +@pytest.mark.anyio +async def test_runner_handles_initialize_and_populates_connection(server: SrvT): + async with connected_runner(server, initialized=False) as (client, runner): + result = await client.send_raw_request("initialize", _initialize_params()) + assert result["serverInfo"]["name"] == "test-server" + assert "tools" in result["capabilities"] + assert runner.connection.client_info is not None + assert runner.connection.client_info.name == "test-client" + assert runner.connection.protocol_version == LATEST_PROTOCOL_VERSION + assert runner._initialized is True + + +@pytest.mark.anyio +async def test_runner_gates_requests_before_initialize(server: SrvT): + async with connected_runner(server, initialized=False) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INVALID_REQUEST + # ping is exempt from the gate + assert await client.send_raw_request("ping", None) == {} + + +@pytest.mark.anyio +async def test_runner_routes_to_handler_and_builds_context(server: SrvT): + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + ctx = _seen_ctx[0] + assert isinstance(ctx, Context) + assert ctx.lifespan is None + assert isinstance(ctx.connection, Connection) + assert ctx.transport.kind == "direct" + + +@pytest.mark.anyio +async def test_runner_unknown_method_raises_method_not_found(server: SrvT): + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("nonexistent/method", None) + assert exc.value.error.code == METHOD_NOT_FOUND + + +@pytest.mark.anyio +async def test_runner_on_notify_initialized_sets_flag_and_connection_event(server: SrvT): + async with connected_runner(server, initialized=False) as (client, runner): + await client.notify("notifications/initialized", None) + await runner.connection.initialized.wait() + assert runner._initialized is True + + +@pytest.mark.anyio +async def test_runner_on_notify_routes_to_registered_handler(server: SrvT): + seen: list[tuple[Any, Any]] = [] + + async def on_roots_changed(ctx: Any, params: Any) -> None: + seen.append((ctx, params)) + + server._notification_handlers["notifications/roots/list_changed"] = on_roots_changed + async with connected_runner(server) as (client, _): + await client.notify("notifications/roots/list_changed", None) + # DirectDispatcher delivers synchronously; one yield is enough. + await anyio.lowlevel.checkpoint() + assert len(seen) == 1 + assert isinstance(seen[0][0], Context) + + +@pytest.mark.anyio +async def test_runner_on_notify_drops_before_init_and_unknown_methods(server: SrvT): + async with connected_runner(server, initialized=False) as (client, _): + await client.notify("notifications/roots/list_changed", None) # before init: dropped + await client.notify("notifications/initialized", None) + await client.notify("notifications/unknown", None) # no handler: dropped + # No exception raised; both drops are silent. + + +@pytest.mark.anyio +async def test_runner_dispatch_middleware_wraps_everything_including_initialize(server: SrvT): + seen_methods: list[str] = [] + + def trace_mw(next_on_request: Any) -> Any: + async def wrapped(dctx: Any, method: str, params: Any) -> Any: + seen_methods.append(method) + return await next_on_request(dctx, method, params) + + return wrapped + + async with connected_runner(server, dispatch_middleware=[trace_mw]) as (client, _): + await client.send_raw_request("tools/list", None) + assert seen_methods == ["initialize", "tools/list"] + + +@pytest.mark.anyio +async def test_runner_server_middleware_wraps_handlers_but_not_initialize(server: SrvT): + seen_methods: list[str] = [] + + async def ctx_mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: + seen_methods.append(method) + return await call_next() + + server.middleware.append(ctx_mw) + async with connected_runner(server) as (client, _): + await client.send_raw_request("ping", None) + await client.send_raw_request("tools/list", None) + # initialize (sent by the helper) NOT wrapped; ping and tools/list ARE. + assert seen_methods == ["ping", "tools/list"] + + +@pytest.mark.anyio +async def test_runner_server_middleware_runs_outermost_first(server: SrvT): + order: list[str] = [] + + def make_mw(tag: str) -> Any: + async def mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: + order.append(f"{tag}-in") + result = await call_next() + order.append(f"{tag}-out") + return result + + return mw + + server.middleware.extend([make_mw("a"), make_mw("b")]) + async with connected_runner(server) as (client, _): + await client.send_raw_request("tools/list", None) + assert order == ["a-in", "b-in", "b-out", "a-out"] + + +@pytest.mark.anyio +async def test_runner_handler_returning_none_yields_empty_result(server: SrvT): + async def set_level(ctx: Any, params: Any) -> None: + return None + + server._request_handlers["logging/setLevel"] = set_level + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("logging/setLevel", {"level": "info"}) + assert result == {} + + +@pytest.mark.anyio +async def test_runner_handler_returning_unsupported_type_surfaces_as_internal_error(server: SrvT): + async def bad_return(ctx: Any, params: Any) -> int: + return 42 + + server._request_handlers["tools/list"] = bad_return + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INTERNAL_ERROR + assert "int" in exc.value.error.message + + +@pytest.mark.anyio +async def test_runner_stateless_skips_init_gate(server: SrvT): + async with connected_runner(server, initialized=False, stateless=True, has_standalone_channel=False) as (client, _): + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + + +@pytest.mark.anyio +async def test_otel_middleware_emits_server_span_with_method_and_target(server: SrvT, spans: SpanCapture): + async def call_tool(ctx: Any, params: Any) -> dict[str, Any]: + return {"content": [], "isError": False} + + server._request_handlers["tools/call"] = call_tool + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + result = await client.send_raw_request("tools/call", {"name": "mytool", "arguments": {}}) + assert result == {"content": [], "isError": False} + [span] = spans.finished() + assert span.name == "MCP handle tools/call mytool" + assert span.kind == SpanKind.SERVER + assert span.attributes is not None + assert span.attributes["mcp.method.name"] == "tools/call" + assert span.status.status_code == StatusCode.UNSET + + +@pytest.mark.anyio +async def test_otel_middleware_extracts_parent_context_from_meta(server: SrvT, spans: SpanCapture): + parent_span_id = "b7ad6b7169203331" + traceparent = f"00-0af7651916cd43dd8448eb211c80319c-{parent_span_id}-01" + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + await client.send_raw_request("tools/list", {"_meta": {"traceparent": traceparent}}) + [span] = spans.finished() + assert span.parent is not None + assert format(span.parent.span_id, "016x") == parent_span_id + assert span.context is not None + assert format(span.context.trace_id, "032x") == "0af7651916cd43dd8448eb211c80319c" + + +@pytest.mark.anyio +async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, spans: SpanCapture): + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + with pytest.raises(MCPError) as exc: + await client.send_raw_request("nonexistent/method", None) + assert exc.value.error.code == METHOD_NOT_FOUND + [span] = spans.finished() + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "Method not found: nonexistent/method" + # MCPError is a protocol-level response, not a crash — no traceback event. + assert not [e for e in span.events if e.name == "exception"] + + +@pytest.mark.anyio +async def test_otel_middleware_records_error_status_on_handler_exception(server: SrvT, spans: SpanCapture): + async def failing(ctx: Any, params: Any) -> Any: + raise ValueError("handler blew up") + + server._request_handlers["tools/list"] = failing + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INTERNAL_ERROR + [span] = spans.finished() + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "handler blew up" + [event] = [e for e in span.events if e.name == "exception"] + assert event.attributes is not None + assert event.attributes["exception.type"] == "ValueError" diff --git a/tests/shared/test_otel.py b/tests/shared/test_otel.py index ec7ff78cc1..a7df4c4294 100644 --- a/tests/shared/test_otel.py +++ b/tests/shared/test_otel.py @@ -10,9 +10,6 @@ pytestmark = pytest.mark.anyio -# Logfire warns about propagated trace context by default (distributed_tracing=None). -# This is expected here since we're testing cross-boundary context propagation. -@pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test_client_and_server_spans(capfire: CaptureLogfire): """Verify that calling a tool produces client and server spans with correct attributes.""" server = MCPServer("test") diff --git a/uv.lock b/uv.lock index b396898b66..f59be2d36e 100644 --- a/uv.lock +++ b/uv.lock @@ -885,6 +885,7 @@ dev = [ { name = "inline-snapshot" }, { name = "logfire" }, { name = "mcp", extra = ["cli", "ws"] }, + { name = "opentelemetry-sdk" }, { name = "pillow" }, { name = "pyright" }, { name = "pytest" }, @@ -937,6 +938,7 @@ dev = [ { name = "inline-snapshot", specifier = ">=0.23.0" }, { name = "logfire", specifier = ">=3.0.0" }, { name = "mcp", extras = ["cli", "ws"], editable = "." }, + { name = "opentelemetry-sdk", specifier = ">=1.39.1" }, { name = "pillow", specifier = ">=12.0" }, { name = "pyright", specifier = ">=1.1.400" }, { name = "pytest", specifier = ">=8.3.4" },