Skip to content
17 changes: 10 additions & 7 deletions src/google/adk/tools/mcp_tool/session_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Google LLC
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -9,7 +9,7 @@
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and:
# limitations under the License.

from __future__ import annotations
Expand Down Expand Up @@ -54,6 +54,9 @@ def __init__(
timeout: Optional[float],
sse_read_timeout: Optional[float],
is_stdio: bool = False,
*,
sampling_callback: Optional[object] = None,
sampling_capabilities: Optional[object] = None,
):
"""
Args:
Expand All @@ -63,6 +66,9 @@ def __init__(
sse_read_timeout: Timeout in seconds for reading data from the MCP SSE
server.
is_stdio: Whether this is a stdio connection (affects read timeout).
sampling_callback: Optional callback to handle sampling requests from the
MCP server.
sampling_capabilities: Optional capabilities for sampling.
"""
self._client = client
self._timeout = timeout
Expand Down Expand Up @@ -140,7 +146,7 @@ async def close(self):
except Exception as e:
logger.warning(f'Failed to close MCP session: {e}')

async def __aenter__(self) -> ClientSession:
async def __a_enter__(self) -> ClientSession:
return await self.start()

async def __aexit__(self, exc_type, exc_val, exc_tb):
Expand All @@ -150,10 +156,7 @@ async def _run(self):
"""Run the complete session context within a single task."""
try:
async with AsyncExitStack() as exit_stack:
transports = await asyncio.wait_for(
exit_stack.enter_async_context(self._client),
timeout=self._timeout,
)
transports = await exit_stack.enter_async_context(self._client)
# The streamable http client returns a GetSessionCallback in addition
# to the read/write MemoryObjectStreams needed to build the
# ClientSession. We limit to the first two values to be compatible
Expand Down
Loading