Skip to content
Closed

[Test] #2677

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ on:
branches: ["main", "v1.x"]
tags: ["v*.*.*"]
pull_request:
branches: ["main", "v1.x"]

permissions:
contents: read
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ dev = [
"pillow>=12.0",
"strict-no-cover",
"logfire>=3.0.0",
"opentelemetry-sdk>=1.39.1",
]
docs = [
"mkdocs>=1.6.1",
Expand Down
86 changes: 86 additions & 0 deletions src/mcp/server/_typed_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Typed ``send_request`` for server-to-client requests.

`TypedServerRequestMixin` provides a typed `send_request(req) -> Result` over
the host's raw `Outbound.send_raw_request`. Spec server-to-client request types
have their result type inferred via per-type overloads; custom requests pass
``result_type=`` explicitly.

If the spec's request set grows substantially, consider declaring the result
mapping on the request types themselves (a ``__mcp_result__`` ClassVar read via
a structural protocol) so this overload ladder doesn't need maintaining
per-host-class.
"""

from typing import Any, TypeVar, overload

from pydantic import BaseModel

from mcp.shared.dispatcher import CallOptions, Outbound
from mcp.shared.peer import dump_params
from mcp.types import (
CreateMessageRequest,
CreateMessageResult,
ElicitRequest,
ElicitResult,
EmptyResult,
ListRootsRequest,
ListRootsResult,
PingRequest,
Request,
)

__all__ = ["TypedServerRequestMixin"]

ResultT = TypeVar("ResultT", bound=BaseModel)

_RESULT_FOR: dict[type[Request[Any, Any]], type[BaseModel]] = {
CreateMessageRequest: CreateMessageResult,
ElicitRequest: ElicitResult,
ListRootsRequest: ListRootsResult,
PingRequest: EmptyResult,
}


class TypedServerRequestMixin:
"""Typed ``send_request`` for the server-to-client request set.

Mixed into `Connection` and the server `Context`. Each method constrains
``self`` to `Outbound` so any host with ``send_raw_request`` works.
"""

@overload
async def send_request(
self: Outbound, req: CreateMessageRequest, *, opts: CallOptions | None = None
) -> CreateMessageResult: ...
@overload
async def send_request(self: Outbound, req: ElicitRequest, *, opts: CallOptions | None = None) -> ElicitResult: ...
@overload
async def send_request(
self: Outbound, req: ListRootsRequest, *, opts: CallOptions | None = None
) -> ListRootsResult: ...
@overload
async def send_request(self: Outbound, req: PingRequest, *, opts: CallOptions | None = None) -> EmptyResult: ...
@overload
async def send_request(
self: Outbound, req: Request[Any, Any], *, result_type: type[ResultT], opts: CallOptions | None = None
) -> ResultT: ...
async def send_request(
self: Outbound,
req: Request[Any, Any],
*,
result_type: type[BaseModel] | None = None,
opts: CallOptions | None = None,
) -> BaseModel:
"""Send a typed server-to-client request and return its typed result.

For spec request types the result type is inferred. For custom requests
pass ``result_type=`` explicitly.

Raises:
MCPError: The peer responded with an error.
NoBackChannelError: No back-channel for server-initiated requests.
KeyError: ``result_type`` omitted for a non-spec request type.
"""
raw = await self.send_raw_request(req.method, dump_params(req.params), opts)
cls = result_type if result_type is not None else _RESULT_FOR[type(req)]
return cls.model_validate(raw)
146 changes: 146 additions & 0 deletions src/mcp/server/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""`Connection` — per-client connection state and the standalone outbound channel.

Always present on `Context` (never ``None``), even in stateless deployments.
Holds peer info populated at ``initialize`` time, the per-connection lifespan
output, and an `Outbound` for the standalone stream (the SSE GET stream in
streamable HTTP, or the single duplex stream in stdio).

`notify` is best-effort: it never raises. If there's no standalone channel
(stateless HTTP) or the stream has been dropped, the notification is
debug-logged and silently discarded — server-initiated notifications are
inherently advisory. `send_raw_request` *does* raise `NoBackChannelError` when
there's no channel; `ping` is the only spec-sanctioned standalone request.
"""

import logging
from collections.abc import Mapping
from typing import Any

import anyio

from mcp.server._typed_request import TypedServerRequestMixin
from mcp.shared.dispatcher import CallOptions, Outbound
from mcp.shared.exceptions import NoBackChannelError
from mcp.shared.peer import Meta, dump_params
from mcp.types import ClientCapabilities, Implementation, LoggingLevel

__all__ = ["Connection"]

logger = logging.getLogger(__name__)


def _notification_params(payload: dict[str, Any] | None, meta: Meta | None) -> dict[str, Any] | None:
if not meta:
return payload
out = dict(payload or {})
out["_meta"] = meta
return out


class Connection(TypedServerRequestMixin):
"""Per-client connection state and standalone-stream `Outbound`.

Constructed by `ServerRunner` once per connection. The peer-info fields are
``None`` until ``initialize`` completes; ``initialized`` is set then.
"""

def __init__(self, outbound: Outbound, *, has_standalone_channel: bool) -> None:
self._outbound = outbound
self.has_standalone_channel = has_standalone_channel

self.client_info: Implementation | None = None
self.client_capabilities: ClientCapabilities | None = None
self.protocol_version: str | None = None
self.initialized: anyio.Event = anyio.Event()
# TODO: make this generic (Connection[StateT]) once connection_lifespan
# wiring lands in ServerRunner.
self.state: Any = None

async def send_raw_request(
self,
method: str,
params: Mapping[str, Any] | None,
opts: CallOptions | None = None,
) -> dict[str, Any]:
"""Send a raw request on the standalone stream.

Low-level `Outbound` channel. Prefer the typed ``send_request`` (from
`TypedServerRequestMixin`) or the convenience methods below; use this
directly only for off-spec messages.

Raises:
MCPError: The peer responded with an error.
NoBackChannelError: ``has_standalone_channel`` is ``False``.
"""
if not self.has_standalone_channel:
raise NoBackChannelError(method)
return await self._outbound.send_raw_request(method, params, opts)

async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
"""Send a best-effort notification on the standalone stream.

Never raises. If there's no standalone channel or the stream is broken,
the notification is dropped and debug-logged.
"""
if not self.has_standalone_channel:
logger.debug("dropped %s: no standalone channel", method)
return
try:
await self._outbound.notify(method, params)
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
logger.debug("dropped %s: standalone stream closed", method)

async def ping(self, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None:
"""Send a ``ping`` request on the standalone stream.

Raises:
MCPError: The peer responded with an error.
NoBackChannelError: ``has_standalone_channel`` is ``False``.
"""
await self.send_raw_request("ping", dump_params(None, meta), opts)

async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None:
"""Send a ``notifications/message`` log entry on the standalone stream. Best-effort."""
params: dict[str, Any] = {"level": level, "data": data}
if logger is not None:
params["logger"] = logger
await self.notify("notifications/message", _notification_params(params, meta))

async def send_tool_list_changed(self, *, meta: Meta | None = None) -> None:
await self.notify("notifications/tools/list_changed", _notification_params(None, meta))

async def send_prompt_list_changed(self, *, meta: Meta | None = None) -> None:
await self.notify("notifications/prompts/list_changed", _notification_params(None, meta))

async def send_resource_list_changed(self, *, meta: Meta | None = None) -> None:
await self.notify("notifications/resources/list_changed", _notification_params(None, meta))

async def send_resource_updated(self, uri: str, *, meta: Meta | None = None) -> None:
await self.notify("notifications/resources/updated", _notification_params({"uri": uri}, meta))

def check_capability(self, capability: ClientCapabilities) -> bool:
"""Return whether the connected client declared the given capability.

Returns ``False`` if ``initialize`` hasn't completed yet.
"""
# TODO: redesign — mirrors v1 ServerSession.check_client_capability
# verbatim for parity.
if self.client_capabilities is None:
return False
have = self.client_capabilities
if capability.roots is not None:
if have.roots is None:
return False
if capability.roots.list_changed and not have.roots.list_changed:
return False
if capability.sampling is not None and have.sampling is None:
return False
if capability.elicitation is not None and have.elicitation is None:
return False
if capability.experimental is not None:
if have.experimental is None:
return False
for k in capability.experimental:
if k not in have.experimental:
return False
return True
96 changes: 95 additions & 1 deletion src/mcp/server/context.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from __future__ import annotations

from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any, Generic
from typing import Any, Generic, Protocol

from pydantic import BaseModel
from typing_extensions import TypeVar

from mcp.server._typed_request import TypedServerRequestMixin
from mcp.server.connection import Connection
from mcp.server.experimental.request_context import Experimental
from mcp.server.session import ServerSession
from mcp.shared._context import RequestContext
from mcp.shared.context import BaseContext
from mcp.shared.dispatcher import DispatchContext
from mcp.shared.message import CloseSSEStreamCallback
from mcp.shared.peer import Meta, PeerMixin
from mcp.shared.transport_context import TransportContext
from mcp.types import LoggingLevel, RequestParamsMeta

LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any])
RequestT = TypeVar("RequestT", default=Any)
Expand All @@ -21,3 +30,88 @@ class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContex
request: RequestT | None = None
close_sse_stream: CloseSSEStreamCallback | None = None
close_standalone_sse_stream: CloseSSEStreamCallback | None = None


LifespanT = TypeVar("LifespanT", default=Any, covariant=True)
TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True)


class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Generic[LifespanT, TransportT]):
"""Server-side per-request context.

Composes `BaseContext` (forwards to `DispatchContext`, satisfies `Outbound`),
`PeerMixin` (kwarg-style ``sample``/``elicit_*``/``list_roots``/``ping``),
and `TypedServerRequestMixin` (typed ``send_request(req) -> Result``). Adds
``lifespan`` and ``connection``.

Constructed by `ServerRunner` per inbound request and handed to the user's
handler.
"""

def __init__(
self,
dctx: DispatchContext[TransportT],
*,
lifespan: LifespanT,
connection: Connection,
meta: RequestParamsMeta | None = None,
) -> None:
super().__init__(dctx, meta=meta)
self._lifespan = lifespan
self._connection = connection

@property
def lifespan(self) -> LifespanT:
"""The server-wide lifespan output (what `Server(..., lifespan=...)` yielded)."""
return self._lifespan

@property
def connection(self) -> Connection:
"""The per-client `Connection` for this request's connection."""
return self._connection

async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None:
"""Send a request-scoped ``notifications/message`` log entry.

Uses this request's back-channel (so the entry rides the request's SSE
stream in streamable HTTP), not the standalone stream — use
``ctx.connection.log(...)`` for that.
"""
params: dict[str, Any] = {"level": level, "data": data}
if logger is not None:
params["logger"] = logger
if meta:
params["_meta"] = meta
await self.notify("notifications/message", params)


HandlerResult = BaseModel | dict[str, Any] | None
"""What a request handler (or middleware) may return. `ServerRunner` serializes
all three to a result dict."""

CallNext = Callable[[], Awaitable[HandlerResult]]

_MwLifespanT = TypeVar("_MwLifespanT", contravariant=True)


class ContextMiddleware(Protocol[_MwLifespanT]):
"""Context-tier middleware: ``(ctx, method, typed_params, call_next) -> result``.

Runs *inside* `ServerRunner._on_request` after params validation and
`Context` construction. Wraps registered handlers (including ``ping``) but
not ``initialize``, ``METHOD_NOT_FOUND``, or validation failures. Listed
outermost-first on `Server.middleware`.

`Server[L].middleware` holds `ContextMiddleware[L]`, so an app-specific
middleware sees `ctx.lifespan: L`. A reusable middleware (no app-specific
types) can be typed `ContextMiddleware[object]` — `Context` is covariant in
`LifespanT`, so it registers on any `Server[L]`.
"""

async def __call__(
self,
ctx: Context[_MwLifespanT, TransportContext],
method: str,
params: BaseModel,
call_next: CallNext,
) -> HandlerResult: ...
15 changes: 14 additions & 1 deletion src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def main():
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier
from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes
from mcp.server.auth.settings import AuthSettings
from mcp.server.context import ServerRequestContext
from mcp.server.context import ContextMiddleware, ServerRequestContext
from mcp.server.experimental.request_context import Experimental
from mcp.server.lowlevel.experimental import ExperimentalHandlers
from mcp.server.models import InitializationOptions
Expand Down Expand Up @@ -199,6 +199,9 @@ def __init__(
] = {}
self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None
self._session_manager: StreamableHTTPSessionManager | None = None
# Context-tier middleware consumed by `ServerRunner`. Additive; the
# existing `run()` path ignores it.
self.middleware: list[ContextMiddleware[LifespanResultT]] = []
logger.debug("Initializing server %r", name)

# Populate internal handler dicts from on_* kwargs
Expand Down Expand Up @@ -246,6 +249,16 @@ def _has_handler(self, method: str) -> bool:
"""Check if a handler is registered for the given method."""
return method in self._request_handlers or method in self._notification_handlers

# --- ServerRegistry protocol (consumed by ServerRunner) ------------------

def get_request_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None:
"""Return the handler for a request method, or ``None``."""
return self._request_handlers.get(method)

def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None:
"""Return the handler for a notification method, or ``None``."""
return self._notification_handlers.get(method)

# TODO: Rethink capabilities API. Currently capabilities are derived from registered
# handlers but require NotificationOptions to be passed externally for list_changed
# flags, and experimental_capabilities as a separate dict. Consider deriving capabilities
Expand Down
Loading
Loading