Skip to content

fix: Support API-Key for MCP Tool authentication #1673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions src/google/adk/tools/mcp_tool/mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
from __future__ import annotations

import base64
import json
import logging
from typing import Optional

from fastapi.openapi.models import APIKeyIn
from google.genai.types import FunctionDeclaration
from google.oauth2.credentials import Credentials
from typing_extensions import override

from .._gemini_schema_util import _to_gemini_schema
Expand Down Expand Up @@ -58,6 +57,9 @@ class MCPTool(BaseAuthenticatedTool):

Internally, the tool initializes from a MCP Tool, and uses the MCP Session to
call the tool.

Note: For API key authentication, only header-based API keys are supported.
Query and cookie-based API keys will result in authentication errors.
"""

def __init__(
Expand Down Expand Up @@ -134,7 +136,19 @@ async def _run_async_impl(
async def _get_headers(
self, tool_context: ToolContext, credential: AuthCredential
) -> Optional[dict[str, str]]:
headers = None
"""Extracts authentication headers from credentials.

Args:
tool_context: The tool context of the current invocation.
credential: The authentication credential to process.

Returns:
Dictionary of headers to add to the request, or None if no auth.

Raises:
ValueError: If API key authentication is configured for non-header location.
"""
headers: Optional[dict[str, str]] = None
if credential:
if credential.oauth2:
headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"}
Expand Down Expand Up @@ -167,10 +181,33 @@ async def _get_headers(
)
}
elif credential.api_key:
# For API keys, we'll add them as headers since MCP typically uses header-based auth
# The specific header name would depend on the API, using a common default
# TODO Allow user to specify the header name for API keys.
headers = {"X-API-Key": credential.api_key}
if (
not self._credentials_manager
or not self._credentials_manager._auth_config
):
error_msg = (
"Cannot find corresponding auth scheme for API key credential"
f" {credential}"
)
logger.error(error_msg)
raise ValueError(error_msg)
elif (
self._credentials_manager._auth_config.auth_scheme.in_
!= APIKeyIn.header
):
error_msg = (
"MCPTool only supports header-based API key authentication."
" Configured location:"
f" {self._credentials_manager._auth_config.auth_scheme.in_}"
)
logger.error(error_msg)
raise ValueError(error_msg)
else:
headers = {
self._credentials_manager._auth_config.auth_scheme.name: (
credential.api_key
)
}
elif credential.service_account:
# Service accounts should be exchanged for access tokens before reaching this point
logger.warning(
Expand Down
211 changes: 205 additions & 6 deletions tests/unittests/tools/mcp_tool/test_mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# limitations under the License.

import sys
from typing import Any
from typing import Dict
from unittest.mock import AsyncMock
from unittest.mock import Mock
from unittest.mock import patch
Expand Down Expand Up @@ -268,8 +266,102 @@ async def test_get_headers_http_basic(self):
assert headers == {"Authorization": f"Basic {expected_encoded}"}

@pytest.mark.asyncio
async def test_get_headers_api_key(self):
"""Test header generation for API Key credentials."""
async def test_get_headers_api_key_with_valid_header_scheme(self):
"""Test header generation for API Key credentials with header-based auth scheme."""
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from google.adk.auth.auth_schemes import AuthSchemeType

# Create auth scheme for header-based API key
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": APIKeyIn.header,
"name": "X-Custom-API-Key",
})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)

tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)

tool_context = Mock(spec=ToolContext)
headers = await tool._get_headers(tool_context, auth_credential)

assert headers == {"X-Custom-API-Key": "my_api_key"}

@pytest.mark.asyncio
async def test_get_headers_api_key_with_query_scheme_raises_error(self):
"""Test that API Key with query-based auth scheme raises ValueError."""
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from google.adk.auth.auth_schemes import AuthSchemeType

# Create auth scheme for query-based API key (not supported)
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": APIKeyIn.query,
"name": "api_key",
})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)

tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)

tool_context = Mock(spec=ToolContext)

with pytest.raises(
ValueError,
match="MCPTool only supports header-based API key authentication",
):
await tool._get_headers(tool_context, auth_credential)

@pytest.mark.asyncio
async def test_get_headers_api_key_with_cookie_scheme_raises_error(self):
"""Test that API Key with cookie-based auth scheme raises ValueError."""
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from google.adk.auth.auth_schemes import AuthSchemeType

# Create auth scheme for cookie-based API key (not supported)
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": APIKeyIn.cookie,
"name": "session_id",
})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)

tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)

tool_context = Mock(spec=ToolContext)

with pytest.raises(
ValueError,
match="MCPTool only supports header-based API key authentication",
):
await tool._get_headers(tool_context, auth_credential)

@pytest.mark.asyncio
async def test_get_headers_api_key_without_auth_config_raises_error(self):
"""Test that API Key without auth config raises ValueError."""
# Create tool without auth scheme/config
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
Expand All @@ -278,11 +370,37 @@ async def test_get_headers_api_key(self):
credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
tool_context = Mock(spec=ToolContext)

with pytest.raises(
ValueError,
match="Cannot find corresponding auth scheme for API key credential",
):
await tool._get_headers(tool_context, credential)

@pytest.mark.asyncio
async def test_get_headers_api_key_without_credentials_manager_raises_error(
self,
):
"""Test that API Key without credentials manager raises ValueError."""
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)

# Manually set credentials manager to None to simulate error condition
tool._credentials_manager = None

credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
tool_context = Mock(spec=ToolContext)
headers = await tool._get_headers(tool_context, credential)

assert headers == {"X-API-Key": "my_api_key"}
with pytest.raises(
ValueError,
match="Cannot find corresponding auth scheme for API key credential",
):
await tool._get_headers(tool_context, credential)

@pytest.mark.asyncio
async def test_get_headers_no_credential(self):
Expand Down Expand Up @@ -318,6 +436,48 @@ async def test_get_headers_service_account(self):
# Should return None as service account credentials are not supported for direct header generation
assert headers is None

@pytest.mark.asyncio
async def test_run_async_impl_with_api_key_header_auth(self):
"""Test running tool with API key header authentication end-to-end."""
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from google.adk.auth.auth_schemes import AuthSchemeType

# Create auth scheme for header-based API key
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": APIKeyIn.header,
"name": "X-Service-API-Key",
})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="test_service_key"
)

tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)

# Mock the session response
expected_response = {"result": "authenticated_success"}
self.mock_session.call_tool = AsyncMock(return_value=expected_response)

tool_context = Mock(spec=ToolContext)
args = {"param1": "test_value"}

result = await tool._run_async_impl(
args=args, tool_context=tool_context, credential=auth_credential
)

assert result == expected_response
# Check that headers were passed correctly with custom API key header
self.mock_session_manager.create_session.assert_called_once()
call_args = self.mock_session_manager.create_session.call_args
headers = call_args[1]["headers"]
assert headers == {"X-Service-API-Key": "test_service_key"}

@pytest.mark.asyncio
async def test_run_async_impl_retry_decorator(self):
"""Test that the retry decorator is applied correctly."""
Expand Down Expand Up @@ -350,6 +510,45 @@ async def test_get_headers_http_custom_scheme(self):

assert headers == {"Authorization": "custom custom_token"}

@pytest.mark.asyncio
async def test_get_headers_api_key_error_logging(self):
"""Test that API key errors are logged correctly."""
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from google.adk.auth.auth_schemes import AuthSchemeType

# Create auth scheme for query-based API key (not supported)
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": APIKeyIn.query,
"name": "api_key",
})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)

tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)

tool_context = Mock(spec=ToolContext)

# Test with logging
with patch("google.adk.tools.mcp_tool.mcp_tool.logger") as mock_logger:
with pytest.raises(ValueError):
await tool._get_headers(tool_context, auth_credential)

# Verify error was logged
mock_logger.error.assert_called_once()
logged_message = mock_logger.error.call_args[0][0]
assert (
"MCPTool only supports header-based API key authentication"
in logged_message
)

def test_init_validation(self):
"""Test that initialization validates required parameters."""
# This test ensures that the MCPTool properly handles its dependencies
Expand Down