33import abc
44import asyncio
55import inspect
6+ from collections .abc import Awaitable
67from contextlib import AbstractAsyncContextManager , AsyncExitStack
78from datetime import timedelta
89from pathlib import Path
9- from typing import TYPE_CHECKING , Any , Literal
10+ from typing import TYPE_CHECKING , Any , Callable , Literal , TypeVar
1011
1112from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
1213from mcp import ClientSession , StdioServerParameters , Tool as MCPTool , stdio_client
2122from ..run_context import RunContextWrapper
2223from .util import ToolFilter , ToolFilterContext , ToolFilterStatic
2324
25+ T = TypeVar ("T" )
26+
2427if TYPE_CHECKING :
2528 from ..agent import AgentBase
2629
@@ -98,6 +101,8 @@ def __init__(
98101 client_session_timeout_seconds : float | None ,
99102 tool_filter : ToolFilter = None ,
100103 use_structured_content : bool = False ,
104+ max_retry_attempts : int = 0 ,
105+ retry_backoff_seconds_base : float = 1.0 ,
101106 ):
102107 """
103108 Args:
@@ -115,6 +120,10 @@ def __init__(
115120 include the structured content in the `tool_result.content`, and using it by
116121 default will cause duplicate content. You can set this to True if you know the
117122 server will not duplicate the structured content in the `tool_result.content`.
123+ max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
124+ Defaults to no retries.
125+ retry_backoff_seconds_base: The base delay, in seconds, used for exponential
126+ backoff between retries.
118127 """
119128 super ().__init__ (use_structured_content = use_structured_content )
120129 self .session : ClientSession | None = None
@@ -124,6 +133,8 @@ def __init__(
124133 self .server_initialize_result : InitializeResult | None = None
125134
126135 self .client_session_timeout_seconds = client_session_timeout_seconds
136+ self .max_retry_attempts = max_retry_attempts
137+ self .retry_backoff_seconds_base = retry_backoff_seconds_base
127138
128139 # The cache is always dirty at startup, so that we fetch tools at least once
129140 self ._cache_dirty = True
@@ -233,6 +244,18 @@ def invalidate_tools_cache(self):
233244 """Invalidate the tools cache."""
234245 self ._cache_dirty = True
235246
247+ async def _run_with_retries (self , func : Callable [[], Awaitable [T ]]) -> T :
248+ attempts = 0
249+ while True :
250+ try :
251+ return await func ()
252+ except Exception :
253+ attempts += 1
254+ if self .max_retry_attempts != - 1 and attempts > self .max_retry_attempts :
255+ raise
256+ backoff = self .retry_backoff_seconds_base * (2 ** (attempts - 1 ))
257+ await asyncio .sleep (backoff )
258+
236259 async def connect (self ):
237260 """Connect to the server."""
238261 try :
@@ -267,15 +290,17 @@ async def list_tools(
267290 """List the tools available on the server."""
268291 if not self .session :
269292 raise UserError ("Server not initialized. Make sure you call `connect()` first." )
293+ session = self .session
294+ assert session is not None
270295
271296 # Return from cache if caching is enabled, we have tools, and the cache is not dirty
272297 if self .cache_tools_list and not self ._cache_dirty and self ._tools_list :
273298 tools = self ._tools_list
274299 else :
275- # Reset the cache dirty to False
276- self ._cache_dirty = False
277300 # Fetch the tools from the server
278- self ._tools_list = (await self .session .list_tools ()).tools
301+ result = await self ._run_with_retries (lambda : session .list_tools ())
302+ self ._tools_list = result .tools
303+ self ._cache_dirty = False
279304 tools = self ._tools_list
280305
281306 # Filter tools based on tool_filter
@@ -290,8 +315,10 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C
290315 """Invoke a tool on the server."""
291316 if not self .session :
292317 raise UserError ("Server not initialized. Make sure you call `connect()` first." )
318+ session = self .session
319+ assert session is not None
293320
294- return await self .session .call_tool (tool_name , arguments )
321+ return await self ._run_with_retries ( lambda : session .call_tool (tool_name , arguments ) )
295322
296323 async def list_prompts (
297324 self ,
@@ -365,6 +392,8 @@ def __init__(
365392 client_session_timeout_seconds : float | None = 5 ,
366393 tool_filter : ToolFilter = None ,
367394 use_structured_content : bool = False ,
395+ max_retry_attempts : int = 0 ,
396+ retry_backoff_seconds_base : float = 1.0 ,
368397 ):
369398 """Create a new MCP server based on the stdio transport.
370399
@@ -388,12 +417,18 @@ def __init__(
388417 include the structured content in the `tool_result.content`, and using it by
389418 default will cause duplicate content. You can set this to True if you know the
390419 server will not duplicate the structured content in the `tool_result.content`.
420+ max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
421+ Defaults to no retries.
422+ retry_backoff_seconds_base: The base delay, in seconds, for exponential
423+ backoff between retries.
391424 """
392425 super ().__init__ (
393426 cache_tools_list ,
394427 client_session_timeout_seconds ,
395428 tool_filter ,
396429 use_structured_content ,
430+ max_retry_attempts ,
431+ retry_backoff_seconds_base ,
397432 )
398433
399434 self .params = StdioServerParameters (
@@ -455,6 +490,8 @@ def __init__(
455490 client_session_timeout_seconds : float | None = 5 ,
456491 tool_filter : ToolFilter = None ,
457492 use_structured_content : bool = False ,
493+ max_retry_attempts : int = 0 ,
494+ retry_backoff_seconds_base : float = 1.0 ,
458495 ):
459496 """Create a new MCP server based on the HTTP with SSE transport.
460497
@@ -480,12 +517,18 @@ def __init__(
480517 include the structured content in the `tool_result.content`, and using it by
481518 default will cause duplicate content. You can set this to True if you know the
482519 server will not duplicate the structured content in the `tool_result.content`.
520+ max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
521+ Defaults to no retries.
522+ retry_backoff_seconds_base: The base delay, in seconds, for exponential
523+ backoff between retries.
483524 """
484525 super ().__init__ (
485526 cache_tools_list ,
486527 client_session_timeout_seconds ,
487528 tool_filter ,
488529 use_structured_content ,
530+ max_retry_attempts ,
531+ retry_backoff_seconds_base ,
489532 )
490533
491534 self .params = params
@@ -547,6 +590,8 @@ def __init__(
547590 client_session_timeout_seconds : float | None = 5 ,
548591 tool_filter : ToolFilter = None ,
549592 use_structured_content : bool = False ,
593+ max_retry_attempts : int = 0 ,
594+ retry_backoff_seconds_base : float = 1.0 ,
550595 ):
551596 """Create a new MCP server based on the Streamable HTTP transport.
552597
@@ -573,12 +618,18 @@ def __init__(
573618 include the structured content in the `tool_result.content`, and using it by
574619 default will cause duplicate content. You can set this to True if you know the
575620 server will not duplicate the structured content in the `tool_result.content`.
621+ max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
622+ Defaults to no retries.
623+ retry_backoff_seconds_base: The base delay, in seconds, for exponential
624+ backoff between retries.
576625 """
577626 super ().__init__ (
578627 cache_tools_list ,
579628 client_session_timeout_seconds ,
580629 tool_filter ,
581630 use_structured_content ,
631+ max_retry_attempts ,
632+ retry_backoff_seconds_base ,
582633 )
583634
584635 self .params = params
0 commit comments