|
| 1 | +"""`ServerRunner` — per-connection orchestrator over a `Dispatcher`. |
| 2 | +
|
| 3 | +`ServerRunner` is the bridge between the dispatcher layer (`on_request` / |
| 4 | +`on_notify`, untyped dicts) and the user's handler layer (typed `Context`, |
| 5 | +typed params). One instance per client connection. It: |
| 6 | +
|
| 7 | +* handles the ``initialize`` handshake and populates `Connection` |
| 8 | +* gates requests until initialized (``ping`` exempt) |
| 9 | +* looks up the handler in the server's registry, validates params, builds |
| 10 | + `Context`, runs the middleware chain, returns the result dict |
| 11 | +* drives ``dispatcher.run()`` and the per-connection lifespan |
| 12 | +
|
| 13 | +`ServerRunner` consumes any `ServerRegistry` — the lowlevel `Server` satisfies |
| 14 | +it via additive methods so the existing ``Server.run()`` path is unaffected. |
| 15 | +""" |
| 16 | + |
| 17 | +from __future__ import annotations |
| 18 | + |
| 19 | +import logging |
| 20 | +from collections.abc import Awaitable, Callable, Mapping, Sequence |
| 21 | +from dataclasses import dataclass, field |
| 22 | +from functools import partial, reduce |
| 23 | +from typing import Any, Generic, Protocol, cast |
| 24 | + |
| 25 | +import anyio.abc |
| 26 | +from opentelemetry.trace import SpanKind, StatusCode |
| 27 | +from pydantic import BaseModel |
| 28 | +from typing_extensions import TypeVar |
| 29 | + |
| 30 | +from mcp.server.connection import Connection |
| 31 | +from mcp.server.context import CallNext, Context, ContextMiddleware |
| 32 | +from mcp.server.lowlevel.server import NotificationOptions |
| 33 | +from mcp.shared._otel import extract_trace_context, otel_span |
| 34 | +from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest |
| 35 | +from mcp.shared.exceptions import MCPError |
| 36 | +from mcp.shared.transport_context import TransportContext |
| 37 | +from mcp.types import ( |
| 38 | + INVALID_REQUEST, |
| 39 | + LATEST_PROTOCOL_VERSION, |
| 40 | + METHOD_NOT_FOUND, |
| 41 | + CallToolRequestParams, |
| 42 | + CompleteRequestParams, |
| 43 | + GetPromptRequestParams, |
| 44 | + Implementation, |
| 45 | + InitializeRequestParams, |
| 46 | + InitializeResult, |
| 47 | + NotificationParams, |
| 48 | + PaginatedRequestParams, |
| 49 | + ProgressNotificationParams, |
| 50 | + ReadResourceRequestParams, |
| 51 | + RequestParams, |
| 52 | + ServerCapabilities, |
| 53 | + SetLevelRequestParams, |
| 54 | + SubscribeRequestParams, |
| 55 | + UnsubscribeRequestParams, |
| 56 | +) |
| 57 | + |
| 58 | +__all__ = ["CallNext", "ContextMiddleware", "ServerRegistry", "ServerRunner", "otel_middleware"] |
| 59 | + |
| 60 | +logger = logging.getLogger(__name__) |
| 61 | + |
| 62 | +LifespanT = TypeVar("LifespanT", default=Any) |
| 63 | +ServerTransportT = TypeVar("ServerTransportT", bound=TransportContext, default=TransportContext) |
| 64 | + |
| 65 | +Handler = Callable[..., Awaitable[Any]] |
| 66 | +"""A request/notification handler: ``(ctx, params) -> result``. Typed loosely |
| 67 | +so the existing `ServerRequestContext`-based handlers and the new |
| 68 | +`Context`-based handlers both fit during the transition. |
| 69 | +""" |
| 70 | + |
| 71 | + |
| 72 | +_INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) |
| 73 | + |
| 74 | +# TODO: remove this lookup once `Server` stores (params_type, handler) in its |
| 75 | +# registry directly. This is scaffolding so ServerRunner can validate params |
| 76 | +# without changing the existing `_request_handlers` dict shape. |
| 77 | +_PARAMS_FOR_METHOD: dict[str, type[BaseModel]] = { |
| 78 | + "ping": RequestParams, |
| 79 | + "tools/list": PaginatedRequestParams, |
| 80 | + "tools/call": CallToolRequestParams, |
| 81 | + "prompts/list": PaginatedRequestParams, |
| 82 | + "prompts/get": GetPromptRequestParams, |
| 83 | + "resources/list": PaginatedRequestParams, |
| 84 | + "resources/templates/list": PaginatedRequestParams, |
| 85 | + "resources/read": ReadResourceRequestParams, |
| 86 | + "resources/subscribe": SubscribeRequestParams, |
| 87 | + "resources/unsubscribe": UnsubscribeRequestParams, |
| 88 | + "logging/setLevel": SetLevelRequestParams, |
| 89 | + "completion/complete": CompleteRequestParams, |
| 90 | +} |
| 91 | +"""Spec method → params model. Scaffolding while the lowlevel `Server`'s |
| 92 | +`_request_handlers` stores handler-only; the registry refactor should make this |
| 93 | +the registry's responsibility (or store params types alongside handlers).""" |
| 94 | + |
| 95 | +_PARAMS_FOR_NOTIFICATION: dict[str, type[BaseModel]] = { |
| 96 | + "notifications/initialized": NotificationParams, |
| 97 | + "notifications/roots/list_changed": NotificationParams, |
| 98 | + "notifications/progress": ProgressNotificationParams, |
| 99 | +} |
| 100 | + |
| 101 | + |
| 102 | +class ServerRegistry(Protocol): |
| 103 | + """The handler registry `ServerRunner` consumes. |
| 104 | +
|
| 105 | + The lowlevel `Server` satisfies this via additive methods. |
| 106 | + """ |
| 107 | + |
| 108 | + @property |
| 109 | + def name(self) -> str: ... |
| 110 | + @property |
| 111 | + def version(self) -> str | None: ... |
| 112 | + |
| 113 | + @property |
| 114 | + def middleware(self) -> Sequence[ContextMiddleware[Any]]: ... |
| 115 | + |
| 116 | + def get_request_handler(self, method: str) -> Handler | None: ... |
| 117 | + def get_notification_handler(self, method: str) -> Handler | None: ... |
| 118 | + def get_capabilities( |
| 119 | + self, notification_options: Any, experimental_capabilities: dict[str, dict[str, Any]] |
| 120 | + ) -> ServerCapabilities: ... |
| 121 | + |
| 122 | + |
| 123 | +def otel_middleware(next_on_request: OnRequest) -> OnRequest: |
| 124 | + """Dispatch-tier middleware that wraps each request in an OpenTelemetry span. |
| 125 | +
|
| 126 | + Mirrors the span shape of the existing `Server._handle_request`: span name |
| 127 | + ``"MCP handle <method> [<target>]"``, ``mcp.method.name`` attribute, W3C |
| 128 | + trace context extracted from ``params._meta`` (SEP-414), and an ERROR |
| 129 | + status if the handler raises. |
| 130 | + """ |
| 131 | + |
| 132 | + async def wrapped( |
| 133 | + dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None |
| 134 | + ) -> dict[str, Any]: |
| 135 | + target: str | None |
| 136 | + match params: |
| 137 | + case {"name": str() as target}: |
| 138 | + pass |
| 139 | + case _: |
| 140 | + target = None |
| 141 | + parent: Any | None |
| 142 | + match params: |
| 143 | + case {"_meta": {**meta}}: |
| 144 | + parent = extract_trace_context(meta) |
| 145 | + case _: |
| 146 | + parent = None |
| 147 | + span_name = f"MCP handle {method}{f' {target}' if target else ''}" |
| 148 | + with otel_span( |
| 149 | + span_name, |
| 150 | + kind=SpanKind.SERVER, |
| 151 | + attributes={"mcp.method.name": method}, |
| 152 | + context=parent, |
| 153 | + record_exception=False, |
| 154 | + set_status_on_exception=False, |
| 155 | + ) as span: |
| 156 | + try: |
| 157 | + return await next_on_request(dctx, method, params) |
| 158 | + except MCPError as e: |
| 159 | + span.set_status(StatusCode.ERROR, e.error.message) |
| 160 | + raise |
| 161 | + except Exception as e: |
| 162 | + span.record_exception(e) |
| 163 | + span.set_status(StatusCode.ERROR, str(e)) |
| 164 | + raise |
| 165 | + |
| 166 | + return wrapped |
| 167 | + |
| 168 | + |
| 169 | +def _dump_result(result: Any) -> dict[str, Any]: |
| 170 | + if result is None: |
| 171 | + return {} |
| 172 | + if isinstance(result, BaseModel): |
| 173 | + return result.model_dump(by_alias=True, mode="json", exclude_none=True) |
| 174 | + if isinstance(result, dict): |
| 175 | + return cast(dict[str, Any], result) |
| 176 | + raise TypeError(f"handler returned {type(result).__name__}; expected BaseModel, dict, or None") |
| 177 | + |
| 178 | + |
| 179 | +@dataclass |
| 180 | +class ServerRunner(Generic[LifespanT, ServerTransportT]): |
| 181 | + """Per-connection orchestrator. One instance per client connection.""" |
| 182 | + |
| 183 | + server: ServerRegistry |
| 184 | + dispatcher: Dispatcher[ServerTransportT] |
| 185 | + lifespan_state: LifespanT |
| 186 | + has_standalone_channel: bool |
| 187 | + stateless: bool = False |
| 188 | + dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware]) |
| 189 | + |
| 190 | + connection: Connection = field(init=False) |
| 191 | + _initialized: bool = field(init=False) |
| 192 | + |
| 193 | + def __post_init__(self) -> None: |
| 194 | + self._initialized = self.stateless |
| 195 | + self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel) |
| 196 | + |
| 197 | + async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: |
| 198 | + """Drive the dispatcher until the underlying channel closes. |
| 199 | +
|
| 200 | + Composes `dispatch_middleware` over `_on_request` and hands the result |
| 201 | + to `dispatcher.run()`. ``task_status.started()`` is forwarded so callers |
| 202 | + can ``await tg.start(runner.run)`` and resume once the dispatcher is |
| 203 | + ready to accept requests. |
| 204 | + """ |
| 205 | + await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status) |
| 206 | + |
| 207 | + def _compose_on_request(self) -> OnRequest: |
| 208 | + """Wrap `_on_request` in `dispatch_middleware`, outermost-first. |
| 209 | +
|
| 210 | + Dispatch-tier middleware sees raw ``(dctx, method, params) -> dict`` |
| 211 | + and wraps everything — initialize, METHOD_NOT_FOUND, validation |
| 212 | + failures included. `run()` calls this once and hands the result to |
| 213 | + `dispatcher.run()`. |
| 214 | + """ |
| 215 | + return reduce(lambda h, mw: mw(h), reversed(self.dispatch_middleware), self._on_request) |
| 216 | + |
| 217 | + async def _on_request( |
| 218 | + self, |
| 219 | + dctx: DispatchContext[TransportContext], |
| 220 | + method: str, |
| 221 | + params: Mapping[str, Any] | None, |
| 222 | + ) -> dict[str, Any]: |
| 223 | + if method == "initialize": |
| 224 | + return self._handle_initialize(params) |
| 225 | + if not self._initialized and method not in _INIT_EXEMPT: |
| 226 | + raise MCPError( |
| 227 | + code=INVALID_REQUEST, |
| 228 | + message=f"Received {method!r} before initialization was complete", |
| 229 | + ) |
| 230 | + handler = self.server.get_request_handler(method) |
| 231 | + if handler is None: |
| 232 | + raise MCPError(code=METHOD_NOT_FOUND, message=f"Method not found: {method}") |
| 233 | + # TODO: scaffolding — params_type comes from a static lookup until the |
| 234 | + # registry stores it alongside the handler. |
| 235 | + params_type = _PARAMS_FOR_METHOD.get(method, RequestParams) |
| 236 | + # ValidationError propagates; the dispatcher's exception boundary maps |
| 237 | + # it to INVALID_PARAMS. |
| 238 | + typed_params = params_type.model_validate(params or {}) |
| 239 | + ctx = self._make_context(dctx, typed_params) |
| 240 | + call: CallNext = partial(handler, ctx, typed_params) |
| 241 | + for mw in reversed(self.server.middleware): |
| 242 | + call = partial(mw, ctx, method, typed_params, call) |
| 243 | + return _dump_result(await call()) |
| 244 | + |
| 245 | + async def _on_notify( |
| 246 | + self, |
| 247 | + dctx: DispatchContext[TransportContext], |
| 248 | + method: str, |
| 249 | + params: Mapping[str, Any] | None, |
| 250 | + ) -> None: |
| 251 | + if method == "notifications/initialized": |
| 252 | + self._initialized = True |
| 253 | + self.connection.initialized.set() |
| 254 | + return |
| 255 | + if not self._initialized: |
| 256 | + logger.debug("dropped %s: received before initialization", method) |
| 257 | + return |
| 258 | + handler = self.server.get_notification_handler(method) |
| 259 | + if handler is None: |
| 260 | + logger.debug("no handler for notification %s", method) |
| 261 | + return |
| 262 | + params_type = _PARAMS_FOR_NOTIFICATION.get(method, NotificationParams) |
| 263 | + typed_params = params_type.model_validate(params or {}) |
| 264 | + ctx = self._make_context(dctx, typed_params) |
| 265 | + await handler(ctx, typed_params) |
| 266 | + |
| 267 | + def _make_context( |
| 268 | + self, dctx: DispatchContext[TransportContext], typed_params: BaseModel |
| 269 | + ) -> Context[LifespanT, ServerTransportT]: |
| 270 | + # `OnRequest` delivers `DispatchContext[TransportContext]`; this |
| 271 | + # ServerRunner instance was constructed for a specific |
| 272 | + # `ServerTransportT`, so the narrow is safe by construction. |
| 273 | + narrowed = cast(DispatchContext[ServerTransportT], dctx) |
| 274 | + meta = getattr(typed_params, "meta", None) |
| 275 | + return Context(narrowed, lifespan=self.lifespan_state, connection=self.connection, meta=meta) |
| 276 | + |
| 277 | + def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]: |
| 278 | + init = InitializeRequestParams.model_validate(params or {}) |
| 279 | + self.connection.client_info = init.client_info |
| 280 | + self.connection.client_capabilities = init.capabilities |
| 281 | + # TODO: real version negotiation. This always responds with LATEST, |
| 282 | + # which is wrong — the server should pick the highest version both |
| 283 | + # sides support and compute a per-connection feature set from it. |
| 284 | + # See FOLLOWUPS: "Consolidate per-connection mode/negotiation". |
| 285 | + self.connection.protocol_version = ( |
| 286 | + init.protocol_version if init.protocol_version in {LATEST_PROTOCOL_VERSION} else LATEST_PROTOCOL_VERSION |
| 287 | + ) |
| 288 | + self._initialized = True |
| 289 | + self.connection.initialized.set() |
| 290 | + result = InitializeResult( |
| 291 | + protocol_version=self.connection.protocol_version, |
| 292 | + capabilities=self.server.get_capabilities(NotificationOptions(), {}), |
| 293 | + server_info=Implementation(name=self.server.name, version=self.server.version or "0.0.0"), |
| 294 | + ) |
| 295 | + return _dump_result(result) |
0 commit comments