diff --git a/src/google/adk/tools/mcp_tool/session_context.py b/src/google/adk/tools/mcp_tool/session_context.py index ca637d0489..9be634218d 100644 --- a/src/google/adk/tools/mcp_tool/session_context.py +++ b/src/google/adk/tools/mcp_tool/session_context.py @@ -1,4 +1,4 @@ -# Copyright 2025 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -9,7 +9,7 @@ # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and +# See the License for the specific language governing permissions and: # limitations under the License. from __future__ import annotations @@ -54,6 +54,9 @@ def __init__( timeout: Optional[float], sse_read_timeout: Optional[float], is_stdio: bool = False, + *, + sampling_callback: Optional[object] = None, + sampling_capabilities: Optional[object] = None, ): """ Args: @@ -63,6 +66,9 @@ def __init__( sse_read_timeout: Timeout in seconds for reading data from the MCP SSE server. is_stdio: Whether this is a stdio connection (affects read timeout). + sampling_callback: Optional callback to handle sampling requests from the + MCP server. + sampling_capabilities: Optional capabilities for sampling. """ self._client = client self._timeout = timeout @@ -140,7 +146,7 @@ async def close(self): except Exception as e: logger.warning(f'Failed to close MCP session: {e}') - async def __aenter__(self) -> ClientSession: + async def __a_enter__(self) -> ClientSession: return await self.start() async def __aexit__(self, exc_type, exc_val, exc_tb): @@ -150,10 +156,7 @@ async def _run(self): """Run the complete session context within a single task.""" try: async with AsyncExitStack() as exit_stack: - transports = await asyncio.wait_for( - exit_stack.enter_async_context(self._client), - timeout=self._timeout, - ) + transports = await exit_stack.enter_async_context(self._client) # The streamable http client returns a GetSessionCallback in addition # to the read/write MemoryObjectStreams needed to build the # ClientSession. We limit to the first two values to be compatible