Skip to content

Commit 25fb05f

Browse files
committed
feat(mrtr): add IncompleteResult types and client retry loop (SEP-2322)
Lowlevel plumbing for Multi Round-Trip Requests: Types: - IncompleteResult with result_type discriminator, input_requests, request_state - InputRequest/InputResponse unions (elicitation, sampling, roots) - input_responses + request_state fields on RequestParams Server (lowlevel): - on_call_tool return widened to include IncompleteResult Session: - send_request accepts TypeAdapter (overload) for union result parsing - call_tool_mrtr() returns CallToolResult | IncompleteResult - call_tool() stays narrow, raises on IncompleteResult with migration hint Client: - call_tool() drives MRTR retry loop internally — dispatches embedded input requests to elicitation/sampling/list_roots callbacks, retries with collected responses + echoed request_state - max_mrtr_rounds bound (default 8) The client-side delta from today's code is zero: elicitation_callback is the same function whether it fires from SSE push or MRTR retry.
1 parent 92c693b commit 25fb05f

File tree

7 files changed

+425
-14
lines changed

7 files changed

+425
-14
lines changed

src/mcp/client/client.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,24 @@
1212
from mcp.client.streamable_http import streamable_http_client
1313
from mcp.server import Server
1414
from mcp.server.mcpserver import MCPServer
15+
from mcp.shared._context import RequestContext
1516
from mcp.shared.session import ProgressFnT
1617
from mcp.types import (
1718
CallToolResult,
1819
CompleteResult,
20+
CreateMessageRequest,
21+
ElicitRequest,
1922
EmptyResult,
23+
ErrorData,
2024
GetPromptResult,
2125
Implementation,
26+
IncompleteResult,
2227
InitializeResult,
28+
InputResponses,
2329
ListPromptsResult,
2430
ListResourcesResult,
2531
ListResourceTemplatesResult,
32+
ListRootsRequest,
2633
ListToolsResult,
2734
LoggingLevel,
2835
PaginatedRequestParams,
@@ -32,6 +39,9 @@
3239
ResourceTemplateReference,
3340
)
3441

42+
MRTR_MAX_ROUNDS = 8
43+
"""Bound on MRTR retry rounds. A well-formed handler converges; an unbounded loop is a bug."""
44+
3545

3646
@dataclass
3747
class Client:
@@ -95,6 +105,9 @@ async def main():
95105
elicitation_callback: ElicitationFnT | None = None
96106
"""Callback for handling elicitation requests."""
97107

108+
max_mrtr_rounds: int = MRTR_MAX_ROUNDS
109+
"""Maximum MRTR retry rounds before raising (SEP-2322). A server that never converges is a bug."""
110+
98111
_session: ClientSession | None = field(init=False, default=None)
99112
_exit_stack: AsyncExitStack | None = field(init=False, default=None)
100113
_transport: Transport = field(init=False)
@@ -238,6 +251,12 @@ async def call_tool(
238251
) -> CallToolResult:
239252
"""Call a tool on the server.
240253
254+
If the server returns an ``IncompleteResult`` (SEP-2322 MRTR), this
255+
method drives the retry loop internally: each embedded input request
256+
is dispatched to the matching callback (``elicitation_callback``,
257+
``sampling_callback``, or ``list_roots_callback``) and the tool is
258+
re-called with the collected responses plus echoed ``request_state``.
259+
241260
Args:
242261
name: The name of the tool to call
243262
arguments: Arguments to pass to the tool
@@ -248,14 +267,65 @@ async def call_tool(
248267
Returns:
249268
The tool result.
250269
"""
251-
return await self.session.call_tool(
252-
name=name,
253-
arguments=arguments,
254-
read_timeout_seconds=read_timeout_seconds,
255-
progress_callback=progress_callback,
256-
meta=meta,
270+
input_responses: InputResponses | None = None
271+
request_state: str | None = None
272+
273+
for _round in range(self.max_mrtr_rounds):
274+
result = await self.session.call_tool_mrtr(
275+
name=name,
276+
arguments=arguments,
277+
read_timeout_seconds=read_timeout_seconds,
278+
progress_callback=progress_callback,
279+
meta=meta,
280+
input_responses=input_responses,
281+
request_state=request_state,
282+
)
283+
284+
if isinstance(result, CallToolResult):
285+
return result
286+
287+
input_responses = await self._fulfil_input_requests(result)
288+
request_state = result.request_state
289+
290+
raise RuntimeError(
291+
f"MRTR retry loop for tool {name!r} exceeded {self.max_mrtr_rounds} rounds without converging"
257292
)
258293

294+
async def _fulfil_input_requests(self, incomplete: IncompleteResult) -> InputResponses | None:
295+
"""Dispatch each embedded input request to the matching callback."""
296+
if not incomplete.input_requests:
297+
return None
298+
299+
ctx = RequestContext[ClientSession](session=self.session)
300+
responses: InputResponses = {}
301+
302+
for key, req in incomplete.input_requests.items():
303+
match req:
304+
case ElicitRequest(params=params):
305+
if self.elicitation_callback is None:
306+
raise RuntimeError(
307+
f"Server sent elicitation input request {key!r} but no elicitation_callback is configured"
308+
)
309+
result = await self.elicitation_callback(ctx, params)
310+
case CreateMessageRequest(params=params):
311+
if self.sampling_callback is None: # pragma: no cover
312+
raise RuntimeError(
313+
f"Server sent sampling input request {key!r} but no sampling_callback is configured"
314+
)
315+
result = await self.sampling_callback(ctx, params)
316+
case ListRootsRequest():
317+
if self.list_roots_callback is None: # pragma: no cover
318+
raise RuntimeError(
319+
f"Server sent roots input request {key!r} but no list_roots_callback is configured"
320+
)
321+
result = await self.list_roots_callback(ctx)
322+
323+
if isinstance(result, ErrorData):
324+
raise RuntimeError(f"Input request {key!r} failed: {result.message}")
325+
responses[key] = result
326+
327+
return responses
328+
259329
async def list_prompts(
260330
self,
261331
*,

src/mcp/client/session.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ async def _default_logging_callback(
9797

9898
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
9999

100+
_call_tool_result_adapter: TypeAdapter[types.IncompleteResult | types.CallToolResult] = TypeAdapter(
101+
types.IncompleteResult | types.CallToolResult
102+
)
103+
100104

101105
class ClientSession(
102106
BaseSession[
@@ -305,18 +309,60 @@ async def call_tool(
305309
*,
306310
meta: RequestParamsMeta | None = None,
307311
) -> types.CallToolResult:
308-
"""Send a tools/call request with optional progress callback support."""
312+
"""Send a tools/call request with optional progress callback support.
309313
310-
result = await self.send_request(
314+
Raises:
315+
RuntimeError: If the server returns an IncompleteResult. Use
316+
``Client.call_tool`` or ``call_tool_mrtr`` to handle MRTR flows.
317+
"""
318+
result = await self.call_tool_mrtr(
319+
name,
320+
arguments,
321+
read_timeout_seconds,
322+
progress_callback,
323+
meta=meta,
324+
)
325+
if isinstance(result, types.IncompleteResult):
326+
raise RuntimeError(
327+
f"Server returned IncompleteResult for tool {name!r}. "
328+
"Use Client.call_tool or ClientSession.call_tool_mrtr to handle MRTR flows."
329+
)
330+
return result
331+
332+
async def call_tool_mrtr(
333+
self,
334+
name: str,
335+
arguments: dict[str, Any] | None = None,
336+
read_timeout_seconds: float | None = None,
337+
progress_callback: ProgressFnT | None = None,
338+
*,
339+
meta: RequestParamsMeta | None = None,
340+
input_responses: types.InputResponses | None = None,
341+
request_state: str | None = None,
342+
) -> types.CallToolResult | types.IncompleteResult:
343+
"""Send a single tools/call request; returns IncompleteResult if server needs input.
344+
345+
This is the MRTR-aware variant (SEP-2322). One request → one response.
346+
Higher-level ``mcp.client.Client.call_tool`` drives the retry loop; this
347+
method just surfaces whatever the server sent.
348+
"""
349+
350+
result: types.CallToolResult | types.IncompleteResult = await self.send_request(
311351
types.CallToolRequest(
312-
params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=meta),
352+
params=types.CallToolRequestParams(
353+
name=name,
354+
arguments=arguments,
355+
input_responses=input_responses,
356+
request_state=request_state,
357+
_meta=meta,
358+
),
313359
),
314-
types.CallToolResult,
360+
_call_tool_result_adapter,
315361
request_read_timeout_seconds=read_timeout_seconds,
316362
progress_callback=progress_callback,
317363
)
318364

319-
if not result.is_error:
365+
if isinstance(result, types.CallToolResult) and not result.is_error:
320366
await self._validate_tool_result(name, result)
321367

322368
return result

src/mcp/server/lowlevel/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def __init__(
118118
| None = None,
119119
on_call_tool: Callable[
120120
[ServerRequestContext[LifespanResultT], types.CallToolRequestParams],
121-
Awaitable[types.CallToolResult | types.CreateTaskResult],
121+
Awaitable[types.CallToolResult | types.IncompleteResult | types.CreateTaskResult],
122122
]
123123
| None = None,
124124
on_list_resources: Callable[

src/mcp/shared/session.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Callable
55
from contextlib import AsyncExitStack
66
from types import TracebackType
7-
from typing import Any, Generic, Protocol, TypeVar
7+
from typing import Any, Generic, Protocol, TypeVar, overload
88

99
import anyio
1010
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -230,14 +230,34 @@ async def __aexit__(
230230
self._task_group.cancel_scope.cancel()
231231
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
232232

233+
@overload
233234
async def send_request(
234235
self,
235236
request: SendRequestT,
236237
result_type: type[ReceiveResultT],
237238
request_read_timeout_seconds: float | None = None,
238239
metadata: MessageMetadata = None,
239240
progress_callback: ProgressFnT | None = None,
240-
) -> ReceiveResultT:
241+
) -> ReceiveResultT: ...
242+
243+
@overload
244+
async def send_request(
245+
self,
246+
request: SendRequestT,
247+
result_type: TypeAdapter[Any],
248+
request_read_timeout_seconds: float | None = None,
249+
metadata: MessageMetadata = None,
250+
progress_callback: ProgressFnT | None = None,
251+
) -> Any: ...
252+
253+
async def send_request(
254+
self,
255+
request: SendRequestT,
256+
result_type: type[ReceiveResultT] | TypeAdapter[Any],
257+
request_read_timeout_seconds: float | None = None,
258+
metadata: MessageMetadata = None,
259+
progress_callback: ProgressFnT | None = None,
260+
) -> ReceiveResultT | Any:
241261
"""Sends a request and waits for a response.
242262
243263
Raises an MCPError if the response contains an error. If a request read timeout is provided, it will take
@@ -280,6 +300,8 @@ async def send_request(
280300

281301
if isinstance(response_or_error, JSONRPCError):
282302
raise MCPError.from_jsonrpc_error(response_or_error)
303+
elif isinstance(result_type, TypeAdapter):
304+
return result_type.validate_python(response_or_error.result)
283305
else:
284306
return result_type.model_validate(response_or_error.result, by_name=False)
285307

src/mcp/types/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,15 @@
7474
ImageContent,
7575
Implementation,
7676
IncludeContext,
77+
IncompleteResult,
7778
InitializedNotification,
7879
InitializeRequest,
7980
InitializeRequestParams,
8081
InitializeResult,
82+
InputRequest,
83+
InputRequests,
84+
InputResponse,
85+
InputResponses,
8186
ListPromptsRequest,
8287
ListPromptsResult,
8388
ListResourcesRequest,
@@ -179,6 +184,7 @@
179184
client_notification_adapter,
180185
client_request_adapter,
181186
client_result_adapter,
187+
input_request_adapter,
182188
server_notification_adapter,
183189
server_request_adapter,
184190
server_result_adapter,
@@ -342,6 +348,13 @@
342348
"SubscribeRequestParams",
343349
"UnsubscribeRequest",
344350
"UnsubscribeRequestParams",
351+
# MRTR (SEP-2322)
352+
"IncompleteResult",
353+
"InputRequest",
354+
"InputRequests",
355+
"InputResponse",
356+
"InputResponses",
357+
"input_request_adapter",
345358
# Results
346359
"CallToolResult",
347360
"CancelTaskResult",

src/mcp/types/_types.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,19 @@ class RequestParams(MCPModel):
7676
for task augmentation of specific request types in their capabilities.
7777
"""
7878

79+
input_responses: dict[str, Any] | None = None
80+
"""Responses to input requests from a prior IncompleteResult (SEP-2322 MRTR).
81+
82+
Keys mirror the server's inputRequests keys; values are the corresponding
83+
ElicitResult, CreateMessageResult, or ListRootsResult payloads.
84+
"""
85+
86+
request_state: str | None = None
87+
"""Opaque state echoed from a prior IncompleteResult (SEP-2322 MRTR).
88+
89+
Clients MUST NOT inspect or modify this value.
90+
"""
91+
7992
meta: RequestParamsMeta | None = Field(alias="_meta", default=None)
8093

8194

@@ -1716,6 +1729,44 @@ class ElicitationRequiredErrorData(MCPModel):
17161729
"""List of URL mode elicitations that must be completed."""
17171730

17181731

1732+
# ─── Multi Round-Trip Requests (SEP-2322) ───────────────────────────────────
1733+
1734+
InputRequest: TypeAlias = CreateMessageRequest | ElicitRequest | ListRootsRequest
1735+
"""A server-initiated request embedded in an IncompleteResult."""
1736+
1737+
InputResponse: TypeAlias = CreateMessageResult | CreateMessageResultWithTools | ElicitResult | ListRootsResult
1738+
"""A client's response to an InputRequest, sent on the retry."""
1739+
1740+
InputRequests: TypeAlias = dict[str, InputRequest]
1741+
"""Keyed map of input requests. Keys are server-assigned and opaque to the client."""
1742+
1743+
InputResponses: TypeAlias = dict[str, InputResponse]
1744+
"""Keyed map of input responses. Keys mirror the server's InputRequests keys."""
1745+
1746+
1747+
class IncompleteResult(Result):
1748+
"""A result indicating the server needs more input before completing (SEP-2322).
1749+
1750+
The client MUST retry the original request with ``input_responses`` populated
1751+
for each key in ``input_requests``, and ``request_state`` echoed verbatim.
1752+
1753+
At least one of ``input_requests`` or ``request_state`` must be present.
1754+
"""
1755+
1756+
result_type: Literal["incomplete"] = "incomplete"
1757+
"""Discriminator marking this as an incomplete result."""
1758+
1759+
input_requests: InputRequests | None = None
1760+
"""Server-initiated requests the client must fulfil before retrying."""
1761+
1762+
request_state: str | None = None
1763+
"""Opaque state the client must echo back on retry. Not inspected by the client."""
1764+
1765+
1766+
input_request_adapter: TypeAdapter[InputRequest] = TypeAdapter(InputRequest)
1767+
"""Type adapter for validating embedded InputRequest payloads."""
1768+
1769+
17191770
ClientResult = (
17201771
EmptyResult
17211772
| CreateMessageResult
@@ -1774,5 +1825,6 @@ class ElicitationRequiredErrorData(MCPModel):
17741825
| ListTasksResult
17751826
| CancelTaskResult
17761827
| CreateTaskResult
1828+
| IncompleteResult
17771829
)
17781830
server_result_adapter = TypeAdapter[ServerResult](ServerResult)

0 commit comments

Comments
 (0)