Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions src/google/adk/tools/mcp_tool/mcp_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from typing import Union

import anyio
from pydantic import BaseModel
import httpx
from pydantic import BaseModel, ConfigDict

try:
from mcp import ClientSession
Expand Down Expand Up @@ -99,13 +100,16 @@ class StreamableHTTPConnectionParams(BaseModel):
Streamable HTTP server.
terminate_on_close: Whether to terminate the MCP Streamable HTTP server
when the connection is closed.
httpx_client: httpx.AsyncClient to use for the connection.
"""

url: str
headers: dict[str, Any] | None = None
timeout: float = 5.0
sse_read_timeout: float = 60 * 5.0
terminate_on_close: bool = True
httpx_client: httpx.AsyncClient | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)


def retry_on_closed_resource(func):
Expand Down Expand Up @@ -277,15 +281,28 @@ def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
sse_read_timeout=self._connection_params.sse_read_timeout,
)
elif isinstance(self._connection_params, StreamableHTTPConnectionParams):
client = streamablehttp_client(
url=self._connection_params.url,
headers=merged_headers,
timeout=timedelta(seconds=self._connection_params.timeout),
sse_read_timeout=timedelta(
seconds=self._connection_params.sse_read_timeout
),
terminate_on_close=self._connection_params.terminate_on_close,
)
if self._connection_params.httpx_client:
client = streamablehttp_client(
url=self._connection_params.url,
headers=merged_headers,
timeout=timedelta(seconds=self._connection_params.timeout),
sse_read_timeout=timedelta(
seconds=self._connection_params.sse_read_timeout
),
terminate_on_close=self._connection_params.terminate_on_close,
httpx_client_factory=self._connection_params.httpx_client,
)
else:
client = streamablehttp_client(
url=self._connection_params.url,
headers=merged_headers,
timeout=timedelta(seconds=self._connection_params.timeout),
sse_read_timeout=timedelta(
seconds=self._connection_params.sse_read_timeout
),
terminate_on_close=self._connection_params.terminate_on_close,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's significant code duplication between the if and else blocks. All arguments to streamablehttp_client are the same except for httpx_client_factory. You can make the code more concise and easier to maintain by creating a dictionary of common arguments and conditionally adding httpx_client_factory before calling the function with keyword argument unpacking.

      kwargs = {
          "url": self._connection_params.url,
          "headers": merged_headers,
          "timeout": timedelta(seconds=self._connection_params.timeout),
          "sse_read_timeout": timedelta(
              seconds=self._connection_params.sse_read_timeout
          ),
          "terminate_on_close": self._connection_params.terminate_on_close,
      }
      if self._connection_params.httpx_client:
        kwargs["httpx_client_factory"] = self._connection_params.httpx_client
      client = streamablehttp_client(**kwargs)


else:
raise ValueError(
'Unable to initialize connection. Connection should be'
Expand Down
13 changes: 13 additions & 0 deletions tests/unittests/tools/mcp_tool/test_mcp_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from unittest.mock import Mock
from unittest.mock import patch

import httpx
import pytest

# Skip all tests in this module if Python version is less than 3.10
Expand Down Expand Up @@ -143,6 +144,18 @@ def test_init_with_streamable_http_params(self):
manager = MCPSessionManager(http_params)

assert manager._connection_params == http_params
assert manager._connection_params.httpx_client is None

def test_init_with_streamable_http_params_with_httpx_client(self):
"""Test initialization with StreamableHTTPConnectionParams."""
client = httpx.AsyncClient()
http_params = StreamableHTTPConnectionParams(
url="https://example.com/mcp", timeout=15.0, httpx_client=client
)
manager = MCPSessionManager(http_params)

assert manager._connection_params == http_params
assert manager._connection_params.httpx_client == client

def test_generate_session_key_stdio(self):
"""Test session key generation for stdio connections."""
Expand Down
Loading