diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 310fc48f1..af4616cae 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -15,12 +15,11 @@ from __future__ import annotations import base64 -import json import logging from typing import Optional +from fastapi.openapi.models import APIKeyIn from google.genai.types import FunctionDeclaration -from google.oauth2.credentials import Credentials from typing_extensions import override from .._gemini_schema_util import _to_gemini_schema @@ -58,6 +57,9 @@ class MCPTool(BaseAuthenticatedTool): Internally, the tool initializes from a MCP Tool, and uses the MCP Session to call the tool. + + Note: For API key authentication, only header-based API keys are supported. + Query and cookie-based API keys will result in authentication errors. """ def __init__( @@ -134,7 +136,19 @@ async def _run_async_impl( async def _get_headers( self, tool_context: ToolContext, credential: AuthCredential ) -> Optional[dict[str, str]]: - headers = None + """Extracts authentication headers from credentials. + + Args: + tool_context: The tool context of the current invocation. + credential: The authentication credential to process. + + Returns: + Dictionary of headers to add to the request, or None if no auth. + + Raises: + ValueError: If API key authentication is configured for non-header location. + """ + headers: Optional[dict[str, str]] = None if credential: if credential.oauth2: headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"} @@ -167,10 +181,33 @@ async def _get_headers( ) } elif credential.api_key: - # For API keys, we'll add them as headers since MCP typically uses header-based auth - # The specific header name would depend on the API, using a common default - # TODO Allow user to specify the header name for API keys. - headers = {"X-API-Key": credential.api_key} + if ( + not self._credentials_manager + or not self._credentials_manager._auth_config + ): + error_msg = ( + "Cannot find corresponding auth scheme for API key credential" + f" {credential}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + elif ( + self._credentials_manager._auth_config.auth_scheme.in_ + != APIKeyIn.header + ): + error_msg = ( + "MCPTool only supports header-based API key authentication." + " Configured location:" + f" {self._credentials_manager._auth_config.auth_scheme.in_}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + else: + headers = { + self._credentials_manager._auth_config.auth_scheme.name: ( + credential.api_key + ) + } elif credential.service_account: # Service accounts should be exchanged for access tokens before reaching this point logger.warning( diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index 82e3f2234..d32593749 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -13,8 +13,6 @@ # limitations under the License. import sys -from typing import Any -from typing import Dict from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch @@ -268,8 +266,102 @@ async def test_get_headers_http_basic(self): assert headers == {"Authorization": f"Basic {expected_encoded}"} @pytest.mark.asyncio - async def test_get_headers_api_key(self): - """Test header generation for API Key credentials.""" + async def test_get_headers_api_key_with_valid_header_scheme(self): + """Test header generation for API Key credentials with header-based auth scheme.""" + from fastapi.openapi.models import APIKey + from fastapi.openapi.models import APIKeyIn + from google.adk.auth.auth_schemes import AuthSchemeType + + # Create auth scheme for header-based API key + auth_scheme = APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": APIKeyIn.header, + "name": "X-Custom-API-Key", + }) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + tool_context = Mock(spec=ToolContext) + headers = await tool._get_headers(tool_context, auth_credential) + + assert headers == {"X-Custom-API-Key": "my_api_key"} + + @pytest.mark.asyncio + async def test_get_headers_api_key_with_query_scheme_raises_error(self): + """Test that API Key with query-based auth scheme raises ValueError.""" + from fastapi.openapi.models import APIKey + from fastapi.openapi.models import APIKeyIn + from google.adk.auth.auth_schemes import AuthSchemeType + + # Create auth scheme for query-based API key (not supported) + auth_scheme = APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": APIKeyIn.query, + "name": "api_key", + }) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + tool_context = Mock(spec=ToolContext) + + with pytest.raises( + ValueError, + match="MCPTool only supports header-based API key authentication", + ): + await tool._get_headers(tool_context, auth_credential) + + @pytest.mark.asyncio + async def test_get_headers_api_key_with_cookie_scheme_raises_error(self): + """Test that API Key with cookie-based auth scheme raises ValueError.""" + from fastapi.openapi.models import APIKey + from fastapi.openapi.models import APIKeyIn + from google.adk.auth.auth_schemes import AuthSchemeType + + # Create auth scheme for cookie-based API key (not supported) + auth_scheme = APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": APIKeyIn.cookie, + "name": "session_id", + }) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + tool_context = Mock(spec=ToolContext) + + with pytest.raises( + ValueError, + match="MCPTool only supports header-based API key authentication", + ): + await tool._get_headers(tool_context, auth_credential) + + @pytest.mark.asyncio + async def test_get_headers_api_key_without_auth_config_raises_error(self): + """Test that API Key without auth config raises ValueError.""" + # Create tool without auth scheme/config tool = MCPTool( mcp_tool=self.mock_mcp_tool, mcp_session_manager=self.mock_session_manager, @@ -278,11 +370,37 @@ async def test_get_headers_api_key(self): credential = AuthCredential( auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" ) + tool_context = Mock(spec=ToolContext) + + with pytest.raises( + ValueError, + match="Cannot find corresponding auth scheme for API key credential", + ): + await tool._get_headers(tool_context, credential) + + @pytest.mark.asyncio + async def test_get_headers_api_key_without_credentials_manager_raises_error( + self, + ): + """Test that API Key without credentials manager raises ValueError.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + # Manually set credentials manager to None to simulate error condition + tool._credentials_manager = None + + credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) tool_context = Mock(spec=ToolContext) - headers = await tool._get_headers(tool_context, credential) - assert headers == {"X-API-Key": "my_api_key"} + with pytest.raises( + ValueError, + match="Cannot find corresponding auth scheme for API key credential", + ): + await tool._get_headers(tool_context, credential) @pytest.mark.asyncio async def test_get_headers_no_credential(self): @@ -318,6 +436,48 @@ async def test_get_headers_service_account(self): # Should return None as service account credentials are not supported for direct header generation assert headers is None + @pytest.mark.asyncio + async def test_run_async_impl_with_api_key_header_auth(self): + """Test running tool with API key header authentication end-to-end.""" + from fastapi.openapi.models import APIKey + from fastapi.openapi.models import APIKeyIn + from google.adk.auth.auth_schemes import AuthSchemeType + + # Create auth scheme for header-based API key + auth_scheme = APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": APIKeyIn.header, + "name": "X-Service-API-Key", + }) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="test_service_key" + ) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + # Mock the session response + expected_response = {"result": "authenticated_success"} + self.mock_session.call_tool = AsyncMock(return_value=expected_response) + + tool_context = Mock(spec=ToolContext) + args = {"param1": "test_value"} + + result = await tool._run_async_impl( + args=args, tool_context=tool_context, credential=auth_credential + ) + + assert result == expected_response + # Check that headers were passed correctly with custom API key header + self.mock_session_manager.create_session.assert_called_once() + call_args = self.mock_session_manager.create_session.call_args + headers = call_args[1]["headers"] + assert headers == {"X-Service-API-Key": "test_service_key"} + @pytest.mark.asyncio async def test_run_async_impl_retry_decorator(self): """Test that the retry decorator is applied correctly.""" @@ -350,6 +510,45 @@ async def test_get_headers_http_custom_scheme(self): assert headers == {"Authorization": "custom custom_token"} + @pytest.mark.asyncio + async def test_get_headers_api_key_error_logging(self): + """Test that API key errors are logged correctly.""" + from fastapi.openapi.models import APIKey + from fastapi.openapi.models import APIKeyIn + from google.adk.auth.auth_schemes import AuthSchemeType + + # Create auth scheme for query-based API key (not supported) + auth_scheme = APIKey(**{ + "type": AuthSchemeType.apiKey, + "in": APIKeyIn.query, + "name": "api_key", + }) + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key" + ) + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + auth_scheme=auth_scheme, + auth_credential=auth_credential, + ) + + tool_context = Mock(spec=ToolContext) + + # Test with logging + with patch("google.adk.tools.mcp_tool.mcp_tool.logger") as mock_logger: + with pytest.raises(ValueError): + await tool._get_headers(tool_context, auth_credential) + + # Verify error was logged + mock_logger.error.assert_called_once() + logged_message = mock_logger.error.call_args[0][0] + assert ( + "MCPTool only supports header-based API key authentication" + in logged_message + ) + def test_init_validation(self): """Test that initialization validates required parameters.""" # This test ensures that the MCPTool properly handles its dependencies