Skip to content

Commit 5ddf1dd

Browse files
committed
Extract JSON-RPC wrapping into a Dispatcher component
BaseSession previously did two jobs: MCP protocol semantics (progress tokens, cancellation, in-flight tracking, result validation) and JSON-RPC wire encoding (wrap/unwrap envelopes, ID correlation, the receive loop). This split those into a composition: BaseSession owns the MCP semantics, a Dispatcher owns the wire protocol. The default JSONRPCDispatcher is the old _receive_loop + send_request + send_notification + _send_response logic extracted verbatim. BaseSession constructs one from the supplied streams, so every existing transport (stdio, SHTTP, WebSocket, in-memory) works unchanged. The Dispatcher Protocol deals in {"method": str, "params": dict} — the same dict request.model_dump() already produces. A custom dispatcher (gRPC stub, message broker, anything) implements five methods and passes itself as ClientSession(dispatcher=...). All of initialize(), list_tools(), call_tool() work unchanged on top; no parallel Client hierarchy, no fat Protocol. This addresses the composition-not-inheritance point raised in review of the previous pluggable-transport PRs. Closes the SessionMessage/JSON-RPC coupling that was the remaining blocker. Github-Issue: #1690
1 parent 883d893 commit 5ddf1dd

File tree

7 files changed

+503
-223
lines changed

7 files changed

+503
-223
lines changed

src/mcp/client/session.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from mcp.client.experimental import ExperimentalClientFeatures
1212
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
1313
from mcp.shared._context import RequestContext
14+
from mcp.shared.dispatcher import Dispatcher
1415
from mcp.shared.message import SessionMessage
1516
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
1617
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -109,8 +110,8 @@ class ClientSession(
109110
):
110111
def __init__(
111112
self,
112-
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
113-
write_stream: MemoryObjectSendStream[SessionMessage],
113+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None,
114+
write_stream: MemoryObjectSendStream[SessionMessage] | None = None,
114115
read_timeout_seconds: float | None = None,
115116
sampling_callback: SamplingFnT | None = None,
116117
elicitation_callback: ElicitationFnT | None = None,
@@ -121,8 +122,9 @@ def __init__(
121122
*,
122123
sampling_capabilities: types.SamplingCapability | None = None,
123124
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
125+
dispatcher: Dispatcher | None = None,
124126
) -> None:
125-
super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds)
127+
super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds, dispatcher=dispatcher)
126128
self._client_info = client_info or DEFAULT_CLIENT_INFO
127129
self._sampling_callback = sampling_callback or _default_sampling_callback
128130
self._sampling_capabilities = sampling_capabilities

src/mcp/server/session.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
4040
from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures
4141
from mcp.server.models import InitializationOptions
4242
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
43+
from mcp.shared.dispatcher import JSONRPCDispatcher
4344
from mcp.shared.exceptions import StatelessModeNotSupported
4445
from mcp.shared.experimental.tasks.capabilities import check_tasks_capability
4546
from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY
@@ -157,9 +158,9 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
157158

158159
return True
159160

160-
async def _receive_loop(self) -> None:
161+
async def _run(self) -> None:
161162
async with self._incoming_message_stream_writer:
162-
await super()._receive_loop()
163+
await super()._run()
163164

164165
async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]):
165166
match responder.request:
@@ -676,12 +677,15 @@ async def send_message(self, message: SessionMessage) -> None:
676677
677678
WARNING: This is a low-level experimental method that may change without
678679
notice. Prefer using higher-level methods like send_notification() or
679-
send_request() for normal operations.
680+
send_request() for normal operations. Only works with the default
681+
JSON-RPC dispatcher.
680682
681683
Args:
682684
message: The session message to send
683685
"""
684-
await self._write_stream.send(message)
686+
if not isinstance(self._dispatcher, JSONRPCDispatcher): # pragma: no cover
687+
raise TypeError("send_message requires the default JSON-RPC dispatcher")
688+
await self._dispatcher._write_stream.send(message) # type: ignore[reportPrivateUsage]
685689

686690
async def _handle_incoming(self, req: ServerRequestResponder) -> None:
687691
await self._incoming_message_stream_writer.send(req)

src/mcp/shared/dispatcher.py

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
"""Dispatcher abstraction: the wire-protocol layer beneath a session.
2+
3+
A ``Dispatcher`` is responsible for encoding MCP messages for the wire,
4+
correlating request/response pairs, and delivering incoming messages to
5+
session-provided handlers. The session itself deals only in MCP-level
6+
dicts (``{"method": ..., "params": ...}`` for requests/notifications, result
7+
dicts for responses) and never sees the wire encoding.
8+
9+
The default ``JSONRPCDispatcher`` wraps messages in JSON-RPC 2.0 envelopes
10+
and exchanges them over anyio memory streams — this is what every built-in
11+
transport (stdio, Streamable HTTP, WebSocket) feeds into. Custom dispatchers
12+
may use other encodings and request/response models as long as MCP's
13+
request/notification/response semantics are preserved.
14+
15+
!!! warning
16+
The ``Dispatcher`` Protocol is experimental. Custom transports that
17+
carry JSON-RPC should implement the ``Transport`` Protocol from
18+
``mcp.client._transport`` instead — that path is stable.
19+
"""
20+
21+
from __future__ import annotations
22+
23+
import logging
24+
from collections.abc import Awaitable, Callable
25+
from typing import Any, Protocol
26+
27+
import anyio
28+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
29+
30+
from mcp.shared.exceptions import MCPError
31+
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
32+
from mcp.shared.response_router import ResponseRouter
33+
from mcp.types import (
34+
CONNECTION_CLOSED,
35+
ErrorData,
36+
JSONRPCError,
37+
JSONRPCNotification,
38+
JSONRPCRequest,
39+
JSONRPCResponse,
40+
RequestId,
41+
)
42+
43+
OnRequestFn = Callable[[RequestId, dict[str, Any], MessageMetadata], Awaitable[None]]
44+
"""Called when the peer sends us a request. Receives ``(request_id, {"method", "params"}, metadata)``."""
45+
46+
OnNotificationFn = Callable[[dict[str, Any]], Awaitable[None]]
47+
"""Called when the peer sends us a notification. Receives ``{"method", "params"}``."""
48+
49+
OnErrorFn = Callable[[Exception], Awaitable[None]]
50+
"""Called for transport-level errors and orphaned error responses."""
51+
52+
53+
class Dispatcher(Protocol):
54+
"""Wire-protocol layer beneath ``BaseSession``.
55+
56+
Session generates request IDs (they double as progress tokens); the dispatcher
57+
uses them for correlation if its protocol needs that. ``send_request`` blocks
58+
until the correlated response arrives and returns the raw result dict, which
59+
the session then validates into an MCP result type.
60+
61+
Implementations must be cancellation-safe: if ``send_request`` is cancelled
62+
(e.g. by the session's timeout), any correlation state for that request must
63+
be cleaned up.
64+
"""
65+
66+
def set_handlers(
67+
self,
68+
on_request: OnRequestFn,
69+
on_notification: OnNotificationFn,
70+
on_error: OnErrorFn,
71+
) -> None:
72+
"""Wire incoming-message callbacks. Called once, before ``run``."""
73+
...
74+
75+
async def run(self) -> None:
76+
"""Run the receive loop. Returns when the connection closes.
77+
78+
Started in the session's task group; cancelled on session exit.
79+
"""
80+
...
81+
82+
async def send_request(
83+
self,
84+
request_id: RequestId,
85+
request: dict[str, Any],
86+
metadata: MessageMetadata = None,
87+
timeout: float | None = None,
88+
) -> dict[str, Any]:
89+
"""Send a request and wait for its response.
90+
91+
``request`` is ``{"method": str, "params": dict | None}``. Returns the raw
92+
result dict. Raises ``MCPError`` if the peer returns an error response.
93+
Raises ``TimeoutError`` if no response arrives within ``timeout``.
94+
95+
The send itself must not be subject to the timeout — only the wait for
96+
the response — so that requests are reliably delivered even when the
97+
caller sets an aggressive deadline.
98+
"""
99+
...
100+
101+
async def send_notification(
102+
self,
103+
notification: dict[str, Any],
104+
related_request_id: RequestId | None = None,
105+
) -> None:
106+
"""Send a fire-and-forget notification. ``notification`` is ``{"method", "params"}``."""
107+
...
108+
109+
async def send_response(
110+
self,
111+
request_id: RequestId,
112+
response: dict[str, Any] | ErrorData,
113+
) -> None:
114+
"""Send a response to a request we previously received via ``on_request``."""
115+
...
116+
117+
118+
class JSONRPCDispatcher:
119+
"""Default dispatcher: JSON-RPC 2.0 over anyio memory streams.
120+
121+
This is the behaviour ``BaseSession`` had before the dispatcher extraction —
122+
every built-in transport produces a pair of streams that feed into here.
123+
"""
124+
125+
def __init__(
126+
self,
127+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
128+
write_stream: MemoryObjectSendStream[SessionMessage],
129+
response_routers: list[ResponseRouter],
130+
) -> None:
131+
self._read_stream = read_stream
132+
self._write_stream = write_stream
133+
self._response_routers = response_routers
134+
self._response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]] = {}
135+
self._on_request: OnRequestFn | None = None
136+
self._on_notification: OnNotificationFn | None = None
137+
self._on_error: OnErrorFn | None = None
138+
139+
def set_handlers(
140+
self,
141+
on_request: OnRequestFn,
142+
on_notification: OnNotificationFn,
143+
on_error: OnErrorFn,
144+
) -> None:
145+
self._on_request = on_request
146+
self._on_notification = on_notification
147+
self._on_error = on_error
148+
149+
async def send_request(
150+
self,
151+
request_id: RequestId,
152+
request: dict[str, Any],
153+
metadata: MessageMetadata = None,
154+
timeout: float | None = None,
155+
) -> dict[str, Any]:
156+
response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1)
157+
self._response_streams[request_id] = response_stream
158+
try:
159+
jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request)
160+
await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata))
161+
with anyio.fail_after(timeout):
162+
response_or_error = await response_stream_reader.receive()
163+
if isinstance(response_or_error, JSONRPCError):
164+
raise MCPError.from_jsonrpc_error(response_or_error)
165+
return response_or_error.result
166+
finally:
167+
self._response_streams.pop(request_id, None)
168+
await response_stream.aclose()
169+
await response_stream_reader.aclose()
170+
171+
async def send_notification(
172+
self,
173+
notification: dict[str, Any],
174+
related_request_id: RequestId | None = None,
175+
) -> None:
176+
jsonrpc_notification = JSONRPCNotification(jsonrpc="2.0", **notification)
177+
session_message = SessionMessage(
178+
message=jsonrpc_notification,
179+
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
180+
)
181+
await self._write_stream.send(session_message)
182+
183+
async def send_response(
184+
self,
185+
request_id: RequestId,
186+
response: dict[str, Any] | ErrorData,
187+
) -> None:
188+
if isinstance(response, ErrorData):
189+
message: JSONRPCResponse | JSONRPCError = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
190+
else:
191+
message = JSONRPCResponse(jsonrpc="2.0", id=request_id, result=response)
192+
await self._write_stream.send(SessionMessage(message=message))
193+
194+
async def run(self) -> None:
195+
assert self._on_request is not None
196+
assert self._on_notification is not None
197+
assert self._on_error is not None
198+
199+
async with self._read_stream, self._write_stream:
200+
try:
201+
async for message in self._read_stream:
202+
if isinstance(message, Exception):
203+
await self._on_error(message)
204+
elif isinstance(message.message, JSONRPCRequest):
205+
await self._on_request(
206+
message.message.id,
207+
message.message.model_dump(by_alias=True, mode="json", exclude_none=True),
208+
message.metadata,
209+
)
210+
elif isinstance(message.message, JSONRPCNotification):
211+
await self._on_notification(
212+
message.message.model_dump(by_alias=True, mode="json", exclude_none=True)
213+
)
214+
else:
215+
await self._route_response(message)
216+
except anyio.ClosedResourceError:
217+
# Expected when the peer disconnects abruptly.
218+
logging.debug("Read stream closed by client")
219+
except Exception as e:
220+
logging.exception(f"Unhandled exception in receive loop: {e}") # pragma: no cover
221+
finally:
222+
# Deliver CONNECTION_CLOSED to every request still awaiting a response.
223+
for id, stream in self._response_streams.items():
224+
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
225+
try:
226+
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
227+
await stream.aclose()
228+
except Exception: # pragma: no cover
229+
pass
230+
self._response_streams.clear()
231+
# Handlers are bound methods of the session; the session holds us. Break
232+
# the cycle so refcount GC can free both promptly.
233+
self._on_request = None
234+
self._on_notification = None
235+
self._on_error = None
236+
237+
async def _route_response(self, message: SessionMessage) -> None:
238+
# Runtime-true (run() only calls us in the response/error branch) but the
239+
# type checker can't see that, hence the guard.
240+
if not isinstance(message.message, JSONRPCResponse | JSONRPCError):
241+
return # pragma: no cover
242+
243+
assert self._on_error is not None
244+
245+
if message.message.id is None:
246+
error = message.message.error
247+
logging.warning(f"Received error with null ID: {error.message}")
248+
await self._on_error(MCPError(error.code, error.message, error.data))
249+
return
250+
251+
response_id = self._normalize_request_id(message.message.id)
252+
253+
# Response routers (experimental task support) get first look.
254+
if isinstance(message.message, JSONRPCError):
255+
for router in self._response_routers:
256+
if router.route_error(response_id, message.message.error):
257+
return
258+
else:
259+
response_data: dict[str, Any] = message.message.result or {}
260+
for router in self._response_routers:
261+
if router.route_response(response_id, response_data):
262+
return
263+
264+
stream = self._response_streams.pop(response_id, None)
265+
if stream:
266+
await stream.send(message.message)
267+
else:
268+
await self._on_error(RuntimeError(f"Received response with an unknown request ID: {message}"))
269+
270+
@staticmethod
271+
def _normalize_request_id(response_id: RequestId) -> RequestId:
272+
# We send integer IDs; some peers echo them back as strings.
273+
if isinstance(response_id, str):
274+
try:
275+
return int(response_id)
276+
except ValueError:
277+
logging.warning(f"Response ID {response_id!r} cannot be normalized to match pending requests")
278+
return response_id

0 commit comments

Comments
 (0)