Skip to content

Commit 2dbc2ec

Browse files
committed
PR 2491
1 parent 8edcf43 commit 2dbc2ec

10 files changed

Lines changed: 753 additions & 6 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ dev = [
9191
"pillow>=12.0",
9292
"strict-no-cover",
9393
"logfire>=3.0.0",
94+
"opentelemetry-sdk>=1.39.1",
9495
]
9596
docs = [
9697
"mkdocs>=1.6.1",

src/mcp/server/context.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

3+
from collections.abc import Awaitable, Callable
34
from dataclasses import dataclass
4-
from typing import Any, Generic
5+
from typing import Any, Generic, Protocol
56

7+
from pydantic import BaseModel
68
from typing_extensions import TypeVar
79

810
from mcp.server._typed_request import TypedServerRequestMixin
@@ -81,3 +83,35 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *
8183
if meta:
8284
params["_meta"] = meta
8385
await self.notify("notifications/message", params)
86+
87+
88+
HandlerResult = BaseModel | dict[str, Any] | None
89+
"""What a request handler (or middleware) may return. `ServerRunner` serializes
90+
all three to a result dict."""
91+
92+
CallNext = Callable[[], Awaitable[HandlerResult]]
93+
94+
_MwLifespanT = TypeVar("_MwLifespanT", contravariant=True)
95+
96+
97+
class ContextMiddleware(Protocol[_MwLifespanT]):
98+
"""Context-tier middleware: ``(ctx, method, typed_params, call_next) -> result``.
99+
100+
Runs *inside* `ServerRunner._on_request` after params validation and
101+
`Context` construction. Wraps registered handlers (including ``ping``) but
102+
not ``initialize``, ``METHOD_NOT_FOUND``, or validation failures. Listed
103+
outermost-first on `Server.middleware`.
104+
105+
`Server[L].middleware` holds `ContextMiddleware[L]`, so an app-specific
106+
middleware sees `ctx.lifespan: L`. A reusable middleware (no app-specific
107+
types) can be typed `ContextMiddleware[object]` — `Context` is covariant in
108+
`LifespanT`, so it registers on any `Server[L]`.
109+
"""
110+
111+
async def __call__(
112+
self,
113+
ctx: Context[_MwLifespanT, TransportContext],
114+
method: str,
115+
params: BaseModel,
116+
call_next: CallNext,
117+
) -> HandlerResult: ...

src/mcp/server/lowlevel/server.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ async def main():
5858
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier
5959
from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes
6060
from mcp.server.auth.settings import AuthSettings
61-
from mcp.server.context import ServerRequestContext
61+
from mcp.server.context import ContextMiddleware, ServerRequestContext
6262
from mcp.server.experimental.request_context import Experimental
6363
from mcp.server.lowlevel.experimental import ExperimentalHandlers
6464
from mcp.server.models import InitializationOptions
@@ -199,6 +199,9 @@ def __init__(
199199
] = {}
200200
self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None
201201
self._session_manager: StreamableHTTPSessionManager | None = None
202+
# Context-tier middleware consumed by `ServerRunner`. Additive; the
203+
# existing `run()` path ignores it.
204+
self.middleware: list[ContextMiddleware[LifespanResultT]] = []
202205
logger.debug("Initializing server %r", name)
203206

204207
# Populate internal handler dicts from on_* kwargs
@@ -246,6 +249,16 @@ def _has_handler(self, method: str) -> bool:
246249
"""Check if a handler is registered for the given method."""
247250
return method in self._request_handlers or method in self._notification_handlers
248251

252+
# --- ServerRegistry protocol (consumed by ServerRunner) ------------------
253+
254+
def get_request_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None:
255+
"""Return the handler for a request method, or ``None``."""
256+
return self._request_handlers.get(method)
257+
258+
def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None:
259+
"""Return the handler for a notification method, or ``None``."""
260+
return self._notification_handlers.get(method)
261+
249262
# TODO: Rethink capabilities API. Currently capabilities are derived from registered
250263
# handlers but require NotificationOptions to be passed externally for list_changed
251264
# flags, and experimental_capabilities as a separate dict. Consider deriving capabilities

src/mcp/server/runner.py

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
"""`ServerRunner` — per-connection orchestrator over a `Dispatcher`.
2+
3+
`ServerRunner` is the bridge between the dispatcher layer (`on_request` /
4+
`on_notify`, untyped dicts) and the user's handler layer (typed `Context`,
5+
typed params). One instance per client connection. It:
6+
7+
* handles the ``initialize`` handshake and populates `Connection`
8+
* gates requests until initialized (``ping`` exempt)
9+
* looks up the handler in the server's registry, validates params, builds
10+
`Context`, runs the middleware chain, returns the result dict
11+
* drives ``dispatcher.run()`` and the per-connection lifespan
12+
13+
`ServerRunner` consumes any `ServerRegistry` — the lowlevel `Server` satisfies
14+
it via additive methods so the existing ``Server.run()`` path is unaffected.
15+
"""
16+
17+
from __future__ import annotations
18+
19+
import logging
20+
from collections.abc import Awaitable, Callable, Mapping, Sequence
21+
from dataclasses import dataclass, field
22+
from functools import partial, reduce
23+
from typing import Any, Generic, Protocol, cast
24+
25+
import anyio.abc
26+
from opentelemetry.trace import SpanKind, StatusCode
27+
from pydantic import BaseModel
28+
from typing_extensions import TypeVar
29+
30+
from mcp.server.connection import Connection
31+
from mcp.server.context import CallNext, Context, ContextMiddleware
32+
from mcp.server.lowlevel.server import NotificationOptions
33+
from mcp.shared._otel import extract_trace_context, otel_span
34+
from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest
35+
from mcp.shared.exceptions import MCPError
36+
from mcp.shared.transport_context import TransportContext
37+
from mcp.types import (
38+
INVALID_REQUEST,
39+
LATEST_PROTOCOL_VERSION,
40+
METHOD_NOT_FOUND,
41+
CallToolRequestParams,
42+
CompleteRequestParams,
43+
GetPromptRequestParams,
44+
Implementation,
45+
InitializeRequestParams,
46+
InitializeResult,
47+
NotificationParams,
48+
PaginatedRequestParams,
49+
ProgressNotificationParams,
50+
ReadResourceRequestParams,
51+
RequestParams,
52+
ServerCapabilities,
53+
SetLevelRequestParams,
54+
SubscribeRequestParams,
55+
UnsubscribeRequestParams,
56+
)
57+
58+
__all__ = ["CallNext", "ContextMiddleware", "ServerRegistry", "ServerRunner", "otel_middleware"]
59+
60+
logger = logging.getLogger(__name__)
61+
62+
LifespanT = TypeVar("LifespanT", default=Any)
63+
ServerTransportT = TypeVar("ServerTransportT", bound=TransportContext, default=TransportContext)
64+
65+
Handler = Callable[..., Awaitable[Any]]
66+
"""A request/notification handler: ``(ctx, params) -> result``. Typed loosely
67+
so the existing `ServerRequestContext`-based handlers and the new
68+
`Context`-based handlers both fit during the transition.
69+
"""
70+
71+
72+
_INIT_EXEMPT: frozenset[str] = frozenset({"ping"})
73+
74+
# TODO: remove this lookup once `Server` stores (params_type, handler) in its
75+
# registry directly. This is scaffolding so ServerRunner can validate params
76+
# without changing the existing `_request_handlers` dict shape.
77+
_PARAMS_FOR_METHOD: dict[str, type[BaseModel]] = {
78+
"ping": RequestParams,
79+
"tools/list": PaginatedRequestParams,
80+
"tools/call": CallToolRequestParams,
81+
"prompts/list": PaginatedRequestParams,
82+
"prompts/get": GetPromptRequestParams,
83+
"resources/list": PaginatedRequestParams,
84+
"resources/templates/list": PaginatedRequestParams,
85+
"resources/read": ReadResourceRequestParams,
86+
"resources/subscribe": SubscribeRequestParams,
87+
"resources/unsubscribe": UnsubscribeRequestParams,
88+
"logging/setLevel": SetLevelRequestParams,
89+
"completion/complete": CompleteRequestParams,
90+
}
91+
"""Spec method → params model. Scaffolding while the lowlevel `Server`'s
92+
`_request_handlers` stores handler-only; the registry refactor should make this
93+
the registry's responsibility (or store params types alongside handlers)."""
94+
95+
_PARAMS_FOR_NOTIFICATION: dict[str, type[BaseModel]] = {
96+
"notifications/initialized": NotificationParams,
97+
"notifications/roots/list_changed": NotificationParams,
98+
"notifications/progress": ProgressNotificationParams,
99+
}
100+
101+
102+
class ServerRegistry(Protocol):
103+
"""The handler registry `ServerRunner` consumes.
104+
105+
The lowlevel `Server` satisfies this via additive methods.
106+
"""
107+
108+
@property
109+
def name(self) -> str: ...
110+
@property
111+
def version(self) -> str | None: ...
112+
113+
@property
114+
def middleware(self) -> Sequence[ContextMiddleware[Any]]: ...
115+
116+
def get_request_handler(self, method: str) -> Handler | None: ...
117+
def get_notification_handler(self, method: str) -> Handler | None: ...
118+
def get_capabilities(
119+
self, notification_options: Any, experimental_capabilities: dict[str, dict[str, Any]]
120+
) -> ServerCapabilities: ...
121+
122+
123+
def otel_middleware(next_on_request: OnRequest) -> OnRequest:
124+
"""Dispatch-tier middleware that wraps each request in an OpenTelemetry span.
125+
126+
Mirrors the span shape of the existing `Server._handle_request`: span name
127+
``"MCP handle <method> [<target>]"``, ``mcp.method.name`` attribute, W3C
128+
trace context extracted from ``params._meta`` (SEP-414), and an ERROR
129+
status if the handler raises.
130+
"""
131+
132+
async def wrapped(
133+
dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None
134+
) -> dict[str, Any]:
135+
target: str | None
136+
match params:
137+
case {"name": str() as target}:
138+
pass
139+
case _:
140+
target = None
141+
parent: Any | None
142+
match params:
143+
case {"_meta": {**meta}}:
144+
parent = extract_trace_context(meta)
145+
case _:
146+
parent = None
147+
span_name = f"MCP handle {method}{f' {target}' if target else ''}"
148+
with otel_span(
149+
span_name,
150+
kind=SpanKind.SERVER,
151+
attributes={"mcp.method.name": method},
152+
context=parent,
153+
record_exception=False,
154+
set_status_on_exception=False,
155+
) as span:
156+
try:
157+
return await next_on_request(dctx, method, params)
158+
except MCPError as e:
159+
span.set_status(StatusCode.ERROR, e.error.message)
160+
raise
161+
except Exception as e:
162+
span.record_exception(e)
163+
span.set_status(StatusCode.ERROR, str(e))
164+
raise
165+
166+
return wrapped
167+
168+
169+
def _dump_result(result: Any) -> dict[str, Any]:
170+
if result is None:
171+
return {}
172+
if isinstance(result, BaseModel):
173+
return result.model_dump(by_alias=True, mode="json", exclude_none=True)
174+
if isinstance(result, dict):
175+
return cast(dict[str, Any], result)
176+
raise TypeError(f"handler returned {type(result).__name__}; expected BaseModel, dict, or None")
177+
178+
179+
@dataclass
180+
class ServerRunner(Generic[LifespanT, ServerTransportT]):
181+
"""Per-connection orchestrator. One instance per client connection."""
182+
183+
server: ServerRegistry
184+
dispatcher: Dispatcher[ServerTransportT]
185+
lifespan_state: LifespanT
186+
has_standalone_channel: bool
187+
stateless: bool = False
188+
dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware])
189+
190+
connection: Connection = field(init=False)
191+
_initialized: bool = field(init=False)
192+
193+
def __post_init__(self) -> None:
194+
self._initialized = self.stateless
195+
self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel)
196+
197+
async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None:
198+
"""Drive the dispatcher until the underlying channel closes.
199+
200+
Composes `dispatch_middleware` over `_on_request` and hands the result
201+
to `dispatcher.run()`. ``task_status.started()`` is forwarded so callers
202+
can ``await tg.start(runner.run)`` and resume once the dispatcher is
203+
ready to accept requests.
204+
"""
205+
await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status)
206+
207+
def _compose_on_request(self) -> OnRequest:
208+
"""Wrap `_on_request` in `dispatch_middleware`, outermost-first.
209+
210+
Dispatch-tier middleware sees raw ``(dctx, method, params) -> dict``
211+
and wraps everything — initialize, METHOD_NOT_FOUND, validation
212+
failures included. `run()` calls this once and hands the result to
213+
`dispatcher.run()`.
214+
"""
215+
return reduce(lambda h, mw: mw(h), reversed(self.dispatch_middleware), self._on_request)
216+
217+
async def _on_request(
218+
self,
219+
dctx: DispatchContext[TransportContext],
220+
method: str,
221+
params: Mapping[str, Any] | None,
222+
) -> dict[str, Any]:
223+
if method == "initialize":
224+
return self._handle_initialize(params)
225+
if not self._initialized and method not in _INIT_EXEMPT:
226+
raise MCPError(
227+
code=INVALID_REQUEST,
228+
message=f"Received {method!r} before initialization was complete",
229+
)
230+
handler = self.server.get_request_handler(method)
231+
if handler is None:
232+
raise MCPError(code=METHOD_NOT_FOUND, message=f"Method not found: {method}")
233+
# TODO: scaffolding — params_type comes from a static lookup until the
234+
# registry stores it alongside the handler.
235+
params_type = _PARAMS_FOR_METHOD.get(method, RequestParams)
236+
# ValidationError propagates; the dispatcher's exception boundary maps
237+
# it to INVALID_PARAMS.
238+
typed_params = params_type.model_validate(params or {})
239+
ctx = self._make_context(dctx, typed_params)
240+
call: CallNext = partial(handler, ctx, typed_params)
241+
for mw in reversed(self.server.middleware):
242+
call = partial(mw, ctx, method, typed_params, call)
243+
return _dump_result(await call())
244+
245+
async def _on_notify(
246+
self,
247+
dctx: DispatchContext[TransportContext],
248+
method: str,
249+
params: Mapping[str, Any] | None,
250+
) -> None:
251+
if method == "notifications/initialized":
252+
self._initialized = True
253+
self.connection.initialized.set()
254+
return
255+
if not self._initialized:
256+
logger.debug("dropped %s: received before initialization", method)
257+
return
258+
handler = self.server.get_notification_handler(method)
259+
if handler is None:
260+
logger.debug("no handler for notification %s", method)
261+
return
262+
params_type = _PARAMS_FOR_NOTIFICATION.get(method, NotificationParams)
263+
typed_params = params_type.model_validate(params or {})
264+
ctx = self._make_context(dctx, typed_params)
265+
await handler(ctx, typed_params)
266+
267+
def _make_context(
268+
self, dctx: DispatchContext[TransportContext], typed_params: BaseModel
269+
) -> Context[LifespanT, ServerTransportT]:
270+
# `OnRequest` delivers `DispatchContext[TransportContext]`; this
271+
# ServerRunner instance was constructed for a specific
272+
# `ServerTransportT`, so the narrow is safe by construction.
273+
narrowed = cast(DispatchContext[ServerTransportT], dctx)
274+
meta = getattr(typed_params, "meta", None)
275+
return Context(narrowed, lifespan=self.lifespan_state, connection=self.connection, meta=meta)
276+
277+
def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]:
278+
init = InitializeRequestParams.model_validate(params or {})
279+
self.connection.client_info = init.client_info
280+
self.connection.client_capabilities = init.capabilities
281+
# TODO: real version negotiation. This always responds with LATEST,
282+
# which is wrong — the server should pick the highest version both
283+
# sides support and compute a per-connection feature set from it.
284+
# See FOLLOWUPS: "Consolidate per-connection mode/negotiation".
285+
self.connection.protocol_version = (
286+
init.protocol_version if init.protocol_version in {LATEST_PROTOCOL_VERSION} else LATEST_PROTOCOL_VERSION
287+
)
288+
self._initialized = True
289+
self.connection.initialized.set()
290+
result = InitializeResult(
291+
protocol_version=self.connection.protocol_version,
292+
capabilities=self.server.get_capabilities(NotificationOptions(), {}),
293+
server_info=Implementation(name=self.server.name, version=self.server.version or "0.0.0"),
294+
)
295+
return _dump_result(result)

0 commit comments

Comments
 (0)