Skip to content

Commit 0244bf9

Browse files
committed
review: Server.middleware wraps every request and notification (t04/t18); t25 TODO
ServerMiddleware now wraps the entire _on_request and _on_notify body so it observes initialize, notifications/initialized, METHOD_NOT_FOUND, validation failures, and pre-init drops - everything the old incoming_messages stream saw. params is the raw Mapping[str, Any] | None (not the typed model); ctx.request_id is None distinguishes a notification; call_next() raises MCPError for request-side failures so middleware can observe them, returns None for notifications. _on_notify keeps its swallow-at-boundary behavior but middleware sees the raise first. ServerRunner._on_request/_on_notify restructured: build ctx from raw inputs first (_extract_meta validates only _meta via RequestParams so ctx.meta keeps the alias-converted shape), then compose Server.middleware around an _inner() containing validation/init/gate/lookup/handler. _compose_server_middleware helper shared by both paths. _handle_initialize returns InitializeResult (outer _dump_result serializes) so middleware can transform it. _make_context now takes meta directly. t25: TODO at the duplicate-inbound-request-id site (parity with v1/TS; spec puts uniqueness on the sender).
1 parent 2a2cacd commit 0244bf9

6 files changed

Lines changed: 214 additions & 91 deletions

File tree

docs/migration.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,7 @@ In practice, replace direct `ServerSession` use with `Server.run(read_stream, wr
11131113

11141114
- `InitializationState` enum and `ServerSession._initialization_state` — initialization tracking is now on `Connection` (`connection.initialized` is an `anyio.Event`, `connection.client_params` holds the init params).
11151115
- `ServerRequestResponder` type alias.
1116-
- `ServerSession.incoming_messages` stream — there is no longer a public stream of inbound messages to iterate. Register handlers via the `on_*` constructor params (or `add_request_handler`) and use `Server.middleware` to observe every request.
1116+
- `ServerSession.incoming_messages` stream — there is no longer a public stream of inbound messages to iterate. Register handlers via the `on_*` constructor params (or `add_request_handler`) and use `Server.middleware` to observe every inbound request and notification (`initialize`, unknown methods, validation failures, and `notifications/initialized` included).
11171117
- `ServerSession.__aenter__` / `__aexit__``ServerSession` is no longer an async context manager.
11181118
- The private `_receive_loop`, `_received_request`, `_received_notification`, and `_handle_incoming` overrides — there is nothing to override on `ServerSession` anymore. To intercept inbound messages, use `Server.middleware` or `DispatchMiddleware` (see the `_handle_*` removal section above).
11191119

src/mcp/server/context.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,25 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *
117117

118118

119119
class ServerMiddleware(Protocol[_MwLifespanT]):
120-
"""Context-tier middleware: `(ctx, method, typed_params, call_next) -> result`.
121-
122-
Runs *inside* `ServerRunner._on_request` after params validation and
123-
context construction. Wraps registered handlers (including `ping`) but
124-
not `initialize`, `METHOD_NOT_FOUND`, or validation failures. Listed
125-
outermost-first on `Server.middleware`.
120+
"""Context-tier middleware: `(ctx, method, params, call_next) -> result`.
121+
122+
Runs at the top of `ServerRunner._on_request` / `_on_notify` after `ctx`
123+
is built but before any validation, lookup, or handshake. Wraps every
124+
inbound request and notification: `initialize`, the pre-init gate,
125+
`METHOD_NOT_FOUND`, params validation, the handler call, and
126+
`notifications/initialized` all run inside `call_next()`. A request-side
127+
failure reaches the middleware as a raised `MCPError` (or
128+
`ValidationError` for malformed params) so observation/logging middleware
129+
can record it. Listed outermost-first on `Server.middleware`.
130+
131+
`ctx.request_id is None` distinguishes a notification from a request. For
132+
notifications `call_next()` returns `None` (a dropped or unhandled
133+
notification also returns `None`) and the middleware's own return value is
134+
discarded.
135+
136+
`params` is the raw inbound mapping (no model validation has happened
137+
yet). For typed inspection, validate against the model the middleware
138+
expects.
126139
127140
`Server[L].middleware` holds `ServerMiddleware[L]`, so an app-specific
128141
middleware sees `ctx.lifespan_context: L`. While the context is the
@@ -140,6 +153,6 @@ async def __call__(
140153
self,
141154
ctx: ServerRequestContext[_MwLifespanT, Any],
142155
method: str,
143-
params: BaseModel | None,
156+
params: Mapping[str, Any] | None,
144157
call_next: CallNext,
145158
) -> HandlerResult: ...

src/mcp/server/lowlevel/server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@ def __init__(
215215
self._request_handlers: dict[str, HandlerEntry[LifespanResultT]] = {}
216216
self._notification_handlers: dict[str, HandlerEntry[LifespanResultT]] = {}
217217
self._session_manager: StreamableHTTPSessionManager | None = None
218-
# Context-tier middleware: wraps each request handler with
218+
# Context-tier middleware: wraps every inbound request (including
219+
# `initialize`, lookup, validation, handler) with
219220
# `(ctx, method, params, call_next)`. Applied in `ServerRunner._on_request`.
220221
# TODO(maxisbey): provisional - signature and semantics change with the
221222
# Context/middleware rework (covariant `Context[L]`, outbound seam) before

src/mcp/server/runner.py

Lines changed: 122 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from typing_extensions import TypeVar
2828

2929
from mcp.server.connection import Connection
30-
from mcp.server.context import CallNext, ServerMiddleware, ServerRequestContext
30+
from mcp.server.context import CallNext, HandlerResult, ServerMiddleware, ServerRequestContext
3131
from mcp.server.models import InitializationOptions
3232
from mcp.server.session import ServerSession
3333
from mcp.shared._otel import extract_trace_context, otel_span
@@ -47,6 +47,7 @@
4747
InitializeRequestParams,
4848
InitializeResult,
4949
RequestParams,
50+
RequestParamsMeta,
5051
client_request_adapter,
5152
)
5253

@@ -62,6 +63,24 @@
6263

6364
_INIT_EXEMPT: frozenset[str] = frozenset({"ping"})
6465

66+
67+
def _extract_meta(params: Mapping[str, Any] | None) -> RequestParamsMeta | None:
68+
"""Lift `_meta` from raw params with the same key-aliasing pydantic applies.
69+
70+
`RequestParams` only declares `meta` (alias `_meta`) and `MCPModel` does
71+
not forbid extras, so this validate ignores everything else and never
72+
rejects on the caller's other fields. Returns `None` for absent or
73+
malformed `_meta` so context construction is independent of params
74+
validity (which `_on_request` checks separately).
75+
"""
76+
if not params or "_meta" not in params:
77+
return None
78+
try:
79+
return RequestParams.model_validate(params, by_name=False).meta
80+
except ValidationError:
81+
return None
82+
83+
6584
_SPEC_CLIENT_METHODS: frozenset[str] = frozenset(
6685
cast(type[BaseModel], arm).model_fields["method"].default for arm in get_args(ClientRequest)
6786
)
@@ -206,45 +225,50 @@ async def _on_request(
206225
method: str,
207226
params: Mapping[str, Any] | None,
208227
) -> dict[str, Any]:
209-
# TODO(maxisbey): pinned compat. `BaseSession._receive_loop` validates
210-
# every inbound request against the spec `ClientRequest` discriminated
211-
# union *before* handler lookup, so a spec method with malformed params
212-
# surfaces as INVALID_PARAMS via the dispatcher's ValidationError
213-
# boundary even when no handler is registered. v2 wanted to decouple
214-
# the runner from the spec union; revisit once the suite's divergence
215-
# entry is resolved. Gated on spec methods so custom methods registered
216-
# via `add_request_handler` still route (the existing server rejects
217-
# those too, but nothing pins that and routing them is strictly better).
218-
if method in _SPEC_CLIENT_METHODS:
219-
payload: dict[str, Any] = {"method": method}
220-
if params is not None:
221-
payload["params"] = dict(params)
222-
client_request_adapter.validate_python(payload, by_name=False)
223-
if method == "initialize":
224-
return self._handle_initialize(params)
225-
if not self._initialized and method not in _INIT_EXEMPT:
226-
# TODO(maxisbey): pinned compat. The existing server has no
227-
# dedicated pre-init check; the request dies in ClientRequest
228-
# validation, so the client sees the generic invalid-params shape.
229-
raise MCPError(code=INVALID_PARAMS, message="Invalid request parameters", data="")
230-
entry = self.server.get_request_handler(method)
231-
if entry is None:
232-
raise MCPError(code=METHOD_NOT_FOUND, message="Method not found")
233-
# ValidationError propagates; the dispatcher's exception boundary maps
234-
# it to INVALID_PARAMS. Absent wire params reach the handler as None
235-
# (matches the existing `Server._handle_request`, where `req.params`
236-
# is None for optional-params requests like tools/list); the empty-dict
237-
# validate is a required-field check so a required-params model still
238-
# surfaces as INVALID_PARAMS rather than reaching the handler as None.
239-
if params is None:
240-
entry.params_type.model_validate({}, by_name=False)
241-
typed_params = None
242-
else:
243-
typed_params = entry.params_type.model_validate(params, by_name=False)
244-
ctx = self._make_context(dctx, typed_params)
245-
call: CallNext = partial(entry.handler, ctx, typed_params)
246-
for mw in reversed(self.server.middleware):
247-
call = partial(mw, ctx, method, typed_params, call)
228+
ctx = self._make_context(dctx, _extract_meta(params))
229+
230+
async def _inner() -> HandlerResult:
231+
# TODO(maxisbey): pinned compat. `BaseSession._receive_loop`
232+
# validates every inbound request against the spec `ClientRequest`
233+
# discriminated union *before* handler lookup, so a spec method
234+
# with malformed params surfaces as INVALID_PARAMS via the
235+
# dispatcher's ValidationError boundary even when no handler is
236+
# registered. v2 wanted to decouple the runner from the spec union;
237+
# revisit once the suite's divergence entry is resolved. Gated on
238+
# spec methods so custom methods registered via
239+
# `add_request_handler` still route (the existing server rejects
240+
# those too, but nothing pins that and routing is strictly better).
241+
if method in _SPEC_CLIENT_METHODS:
242+
payload: dict[str, Any] = {"method": method}
243+
if params is not None:
244+
payload["params"] = dict(params)
245+
client_request_adapter.validate_python(payload, by_name=False)
246+
if method == "initialize":
247+
return self._handle_initialize(params)
248+
if not self._initialized and method not in _INIT_EXEMPT:
249+
# TODO(maxisbey): pinned compat. The existing server has no
250+
# dedicated pre-init check; the request dies in ClientRequest
251+
# validation, so the client sees the generic invalid-params
252+
# shape.
253+
raise MCPError(code=INVALID_PARAMS, message="Invalid request parameters", data="")
254+
entry = self.server.get_request_handler(method)
255+
if entry is None:
256+
raise MCPError(code=METHOD_NOT_FOUND, message="Method not found")
257+
# ValidationError propagates; the dispatcher's exception boundary
258+
# maps it to INVALID_PARAMS. Absent wire params reach the handler
259+
# as None (matches the existing `Server._handle_request`, where
260+
# `req.params` is None for optional-params requests like
261+
# tools/list); the empty-dict validate is a required-field check
262+
# so a required-params model still surfaces as INVALID_PARAMS
263+
# rather than reaching the handler as None.
264+
if params is None:
265+
entry.params_type.model_validate({}, by_name=False)
266+
typed_params = None
267+
else:
268+
typed_params = entry.params_type.model_validate(params, by_name=False)
269+
return await entry.handler(ctx, typed_params)
270+
271+
call = self._compose_server_middleware(ctx, method, params, _inner)
248272
return _dump_result(await call())
249273

250274
async def _on_notify(
@@ -253,45 +277,68 @@ async def _on_notify(
253277
method: str,
254278
params: Mapping[str, Any] | None,
255279
) -> None:
256-
if method == "notifications/initialized":
257-
self._initialized = True
258-
self.connection.initialized.set()
259-
return
260-
if not self._initialized:
261-
logger.debug("dropped %s: received before initialization", method)
262-
return
263-
entry = self.server.get_notification_handler(method)
264-
if entry is None:
265-
logger.debug("no handler for notification %s", method)
266-
return
267-
# Absent wire params reach the handler as None, not an empty model
268-
# (matches the existing `Server._handle_notification`). The empty-dict
269-
# validate is a required-field check: a required-params model (e.g.
270-
# ProgressNotificationParams) takes the malformed-params drop path
271-
# instead of reaching a non-Optional handler as None.
272-
try:
273-
if params is None:
274-
entry.params_type.model_validate({}, by_name=False)
275-
typed_params = None
276-
else:
277-
typed_params = entry.params_type.model_validate(params, by_name=False)
278-
except ValidationError:
279-
logger.warning("dropped %r: malformed params", method)
280-
return
281-
ctx = self._make_context(dctx, typed_params)
282-
try:
280+
ctx = self._make_context(dctx, _extract_meta(params))
281+
282+
async def _inner() -> None:
283+
if method == "notifications/initialized":
284+
self._initialized = True
285+
self.connection.initialized.set()
286+
return
287+
if not self._initialized:
288+
logger.debug("dropped %s: received before initialization", method)
289+
return
290+
entry = self.server.get_notification_handler(method)
291+
if entry is None:
292+
logger.debug("no handler for notification %s", method)
293+
return
294+
# Absent wire params reach the handler as None, not an empty model
295+
# (matches the existing `Server._handle_notification`). The
296+
# empty-dict validate is a required-field check: a required-params
297+
# model (e.g. ProgressNotificationParams) takes the
298+
# malformed-params drop path instead of reaching a non-Optional
299+
# handler as None.
300+
try:
301+
if params is None:
302+
entry.params_type.model_validate({}, by_name=False)
303+
typed_params = None
304+
else:
305+
typed_params = entry.params_type.model_validate(params, by_name=False)
306+
except ValidationError:
307+
logger.warning("dropped %r: malformed params", method)
308+
return
283309
await entry.handler(ctx, typed_params)
310+
311+
call = self._compose_server_middleware(ctx, method, params, _inner)
312+
try:
313+
await call()
284314
except Exception:
285-
# Top-level boundary: a notification handler crashing must not
286-
# tear down the connection (it runs as a bare task in the
287-
# dispatcher's task group; an uncaught exception would cancel
288-
# every sibling, including the read loop and in-flight requests).
315+
# Top-level boundary: a notification handler (or middleware)
316+
# crashing must not tear down the connection (it runs as a bare
317+
# task in the dispatcher's task group; an uncaught exception would
318+
# cancel every sibling, including the read loop and in-flight
319+
# requests). Middleware sees the raise out of `call_next()` first.
289320
logger.exception("notification handler for %r raised", method)
290321

322+
def _compose_server_middleware(
323+
self,
324+
ctx: ServerRequestContext[LifespanT, Any],
325+
method: str,
326+
params: Mapping[str, Any] | None,
327+
inner: CallNext,
328+
) -> CallNext:
329+
"""Wrap `inner` in `Server.middleware`, outermost-first.
330+
331+
Shared by `_on_request` and `_on_notify` so the same middleware chain
332+
observes every inbound message.
333+
"""
334+
call = inner
335+
for mw in reversed(self.server.middleware):
336+
call = partial(mw, ctx, method, params, call)
337+
return call
338+
291339
def _make_context(
292-
self, dctx: DispatchContext[TransportContext], typed_params: BaseModel | None
340+
self, dctx: DispatchContext[TransportContext], meta: RequestParamsMeta | None
293341
) -> ServerRequestContext[LifespanT, Any]:
294-
meta = typed_params.meta if isinstance(typed_params, RequestParams) else None
295342
# TODO(maxisbey): remove for Context rework. Reads the SHTTP per-request
296343
# data off the raw `dctx.message_metadata` carrier; replace with the
297344
# per-transport context once that lands.
@@ -312,7 +359,7 @@ def _make_context(
312359
close_standalone_sse_stream=close_standalone_sse_stream,
313360
)
314361

315-
def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]:
362+
def _handle_initialize(self, params: Mapping[str, Any] | None) -> InitializeResult:
316363
init = InitializeRequestParams.model_validate(params or {}, by_name=False)
317364
self.connection.client_params = init
318365
requested = init.protocol_version
@@ -335,4 +382,4 @@ def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]
335382
),
336383
instructions=opts.instructions,
337384
)
338-
return _dump_result(result)
385+
return result

src/mcp/shared/jsonrpc_dispatcher.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,9 @@ async def _dispatch_request(
477477
_progress_token=progress_token,
478478
)
479479
scope = anyio.CancelScope()
480+
# TODO(maxisbey): the spec puts request-id uniqueness on the sender;
481+
# neither v1 nor the TS SDK guards a duplicate id here, so for now we
482+
# blind-overwrite (parity). Revisit rejecting with INVALID_REQUEST.
480483
self._in_flight[req.id] = _InFlight(scope=scope, dctx=dctx)
481484
if req.method in self._inline_methods:
482485
# Spawn (so `sender_ctx` applies, matching the concurrent path) but

0 commit comments

Comments
 (0)