diff --git a/mcp_proxy_for_aws/middleware/initialize_middleware.py b/mcp_proxy_for_aws/middleware/initialize_middleware.py new file mode 100644 index 0000000..06fa277 --- /dev/null +++ b/mcp_proxy_for_aws/middleware/initialize_middleware.py @@ -0,0 +1,31 @@ +import logging +import mcp.types as mt +from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext +from mcp_proxy_for_aws.proxy import AWSMCPProxyClientFactory +from typing_extensions import override + + +logger = logging.getLogger(__name__) + + +class InitializeMiddleware(Middleware): + """Intecept MCP initialize request and initialize the proxy client.""" + + def __init__(self, client_factory: AWSMCPProxyClientFactory) -> None: + """Create a middleware with client factory.""" + super().__init__() + self._client_factory = client_factory + + @override + async def on_initialize( + self, + context: MiddlewareContext[mt.InitializeRequest], + call_next: CallNext[mt.InitializeRequest, None], + ) -> None: + try: + logger.debug('Received initialize request %s.', context.message) + self._client_factory.set_init_params(context.message) + return await call_next(context) + except Exception: + logger.exception('Initialize failed in middleware.') + raise diff --git a/mcp_proxy_for_aws/proxy.py b/mcp_proxy_for_aws/proxy.py new file mode 100644 index 0000000..b0b3bec --- /dev/null +++ b/mcp_proxy_for_aws/proxy.py @@ -0,0 +1,167 @@ +import httpx +import logging +from fastmcp import Client +from fastmcp.client.transports import ClientTransport +from fastmcp.exceptions import NotFoundError +from fastmcp.server.proxy import ClientFactoryT +from fastmcp.server.proxy import FastMCPProxy as _FastMCPProxy +from fastmcp.server.proxy import ProxyClient as _ProxyClient +from fastmcp.server.proxy import ProxyToolManager as _ProxyToolManager +from fastmcp.tools import Tool +from mcp import McpError +from mcp.types import InitializeRequest, JSONRPCError, JSONRPCMessage +from typing import Any +from typing_extensions import override + + +logger = logging.getLogger(__name__) + + +class AWSProxyToolManager(_ProxyToolManager): + """Customized proxy tool manager that better suites our needs.""" + + def __init__(self, client_factory: ClientFactoryT, **kwargs: Any): + """Initialize a proxy tool manager. + + Cached tools are set to None. + """ + super().__init__(client_factory, **kwargs) + self._cached_tools: dict[str, Tool] | None = None + + @override + async def get_tool(self, key: str) -> Tool: + """Return the tool from cached tools. + + This method is invoked when the client tries to call a tool. + + tool = self.get_tool(key) + tool.invoke(...) + + The parent class implementation always make a mcp call to list the tools. + Since the client already knows the name of the tools, list_tool is not necessary. + We are wasting a network call just to get the tools which were already listed. + + In case the server supports notifications/tools/listChanged, the `get_tools` method + will be called explicity , hence, we are not missing the change to the tool list. + """ + if self._cached_tools is None: + logger.debug('cached_tools not found, calling get_tools') + self._cached_tools = await self.get_tools() + if key in self._cached_tools: + return self._cached_tools[key] + raise NotFoundError(f'Tool {key!r} not found') + + @override + async def get_tools(self) -> dict[str, Tool]: + """Return list tools.""" + self._cached_tools = await super(AWSProxyToolManager, self).get_tools() + return self._cached_tools + + +class AWSMCPProxy(_FastMCPProxy): + """Customized MCP Proxy to better suite our needs.""" + + def __init__( + self, + *, + client_factory: ClientFactoryT | None = None, + **kwargs, + ): + """Initialize a client.""" + super().__init__(client_factory=client_factory, **kwargs) + self._tool_manager = AWSProxyToolManager( + client_factory=self.client_factory, + transformations=self._tool_manager.transformations, + ) + + +class AWSMCPProxyClient(_ProxyClient): + """Proxy client that handles HTTP errors when connection fails.""" + + def __init__(self, transport: ClientTransport, **kwargs): + """Constructor of AutoRefreshProxyCilent.""" + super().__init__(transport, **kwargs) + + @override + async def _connect(self): + """Enter as normal && initialize only once.""" + logger.debug('Connecting %s', self) + try: + result = await super(AWSMCPProxyClient, self)._connect() + logger.debug('Connected %s', self) + return result + except httpx.HTTPStatusError as http_error: + logger.exception('Connection failed') + response = http_error.response + try: + body = await response.aread() + jsonrpc_msg = JSONRPCMessage.model_validate_json(body).root + except Exception: + logger.debug('HTTP error is not a valid MCP message.') + raise http_error + + if isinstance(jsonrpc_msg, JSONRPCError): + logger.debug('Converting HTTP error to MCP error %s', http_error) + # raising McpError so that the sdk can handle the exception properly + raise McpError(error=jsonrpc_msg.error) from http_error + else: + raise http_error + except RuntimeError: + try: + logger.warning('encountered runtime error, try force disconnect.') + await self._disconnect(force=True) + except Exception: + # _disconnect awaits on the session_task, + # which raises the timeout error that caused the client session to be terminated. + # the error is ignored as long as the counter is force set to 0. + # TODO: investigate how timeout error is handled by fastmcp and httpx + logger.exception('encountered another error, ignoring.') + return await self._connect() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """The MCP Proxy for AWS project is a proxy from stdio to http (sigv4). + + We want the client to remain connected until the stdio connection is closed. + + https://modelcontextprotocol.io/specification/2024-11-05/basic/transports#stdio + + 1. close stdin + 2. terminate subprocess + + There is no equivalent of the streamble-http DELETE concept in stdio to terminate a session. + Hence the connection will be terminated only at program exit. + """ + pass + + +class AWSMCPProxyClientFactory: + """Client factory that returns a connected client.""" + + def __init__(self, transport: ClientTransport) -> None: + """Initialize a client factory with transport.""" + self._transport = transport + self._client: AWSMCPProxyClient | None = None + self._initialize_request: InitializeRequest | None = None + + def set_init_params(self, initialize_request: InitializeRequest): + """Set client init parameters.""" + self._initialize_request = initialize_request + + async def get_client(self) -> Client: + """Get client.""" + if self._client is None: + self._client = AWSMCPProxyClient(self._transport) + + return self._client + + async def __call__(self) -> Client: + """Implement the callable factory interface.""" + return await self.get_client() + + async def disconnect(self): + """Disconnect all the clients (no throw).""" + try: + if self._client: + await self._client._disconnect(force=True) + except Exception: + logger.exception('Failed to disconnect client.') diff --git a/mcp_proxy_for_aws/server.py b/mcp_proxy_for_aws/server.py index 0056d4c..d49f0e4 100644 --- a/mcp_proxy_for_aws/server.py +++ b/mcp_proxy_for_aws/server.py @@ -23,26 +23,16 @@ """ import asyncio -import contextlib import httpx import logging -import sys -from fastmcp.client import ClientTransport from fastmcp.server.middleware.error_handling import RetryMiddleware from fastmcp.server.middleware.logging import LoggingMiddleware -from fastmcp.server.proxy import FastMCPProxy, ProxyClient from fastmcp.server.server import FastMCP -from mcp import McpError -from mcp.types import ( - CONNECTION_CLOSED, - ErrorData, - JSONRPCError, - JSONRPCMessage, - JSONRPCResponse, -) from mcp_proxy_for_aws.cli import parse_args from mcp_proxy_for_aws.logging_config import configure_logging +from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware +from mcp_proxy_for_aws.proxy import AWSMCPProxy, AWSMCPProxyClientFactory from mcp_proxy_for_aws.utils import ( create_transport_with_sigv4, determine_aws_region, @@ -53,62 +43,9 @@ logger = logging.getLogger(__name__) -@contextlib.asynccontextmanager -async def _initialize_client(transport: ClientTransport): - """Handle the exceptions for during client initialize.""" - async with contextlib.AsyncExitStack() as stack: - try: - client = await stack.enter_async_context(ProxyClient(transport)) - except httpx.HTTPStatusError as http_error: - logger.error('HTTP Error during initialize %s', http_error) - response = http_error.response - try: - body = await response.aread() - jsonrpc_msg = JSONRPCMessage.model_validate_json(body).root - if isinstance(jsonrpc_msg, (JSONRPCError, JSONRPCResponse)): - line = jsonrpc_msg.model_dump_json( - by_alias=True, - exclude_none=True, - ) - logger.debug('Writing the unhandled http error to stdout %s', http_error) - print(line, file=sys.stdout) - else: - logger.debug('Ignoring jsonrpc message type=%s', type(jsonrpc_msg)) - except Exception as _: - logger.debug('Cannot read HTTP response body') - raise http_error - except Exception as e: - cause = e.__cause__ - if isinstance(cause, McpError): - logger.error('MCP Error during initialize %s', cause.error) - jsonrpc_error = JSONRPCError(jsonrpc='2.0', id=0, error=cause.error) - line = jsonrpc_error.model_dump_json( - by_alias=True, - exclude_none=True, - ) - else: - logger.error('Error during initialize %s', e) - jsonrpc_error = JSONRPCError( - jsonrpc='2.0', - id=0, - error=ErrorData( - code=CONNECTION_CLOSED, - message=str(e), - ), - ) - line = jsonrpc_error.model_dump_json( - by_alias=True, - exclude_none=True, - ) - print(line, file=sys.stdout) - raise e - logger.debug('Initialized MCP client') - yield client - - async def run_proxy(args) -> None: """Set up the server in MCP mode.""" - logger.info('Setting up server in MCP mode') + logger.info('Setting up mcp proxy server to %s', args.endpoint) # Validate and determine service service = determine_service_name(args.endpoint, args.service) @@ -134,7 +71,6 @@ async def run_proxy(args) -> None: metadata, profile, ) - logger.info('Running in MCP mode') timeout = httpx.Timeout( args.timeout, @@ -147,35 +83,29 @@ async def run_proxy(args) -> None: transport = create_transport_with_sigv4( args.endpoint, service, region, metadata, timeout, profile ) + client_factory = AWSMCPProxyClientFactory(transport) - async with _initialize_client(transport) as client: - - async def client_factory(): - nonlocal client - if not client.is_connected(): - logger.debug('Reinitialize client') - client = ProxyClient(transport) - await client._connect() - return client - - try: - proxy = FastMCPProxy( - client_factory=client_factory, - name='MCP Proxy for AWS', - instructions=( - 'MCP Proxy for AWS provides access to SigV4 protected MCP servers through a single interface. ' - 'This proxy handles authentication and request routing to the appropriate backend services.' - ), - ) - add_logging_middleware(proxy, args.log_level) - add_tool_filtering_middleware(proxy, args.read_only) - - if args.retries: - add_retry_middleware(proxy, args.retries) - await proxy.run_async(transport='stdio', show_banner=False, log_level=args.log_level) - except Exception as e: - logger.error('Cannot start proxy server: %s', e) - raise e + try: + proxy = AWSMCPProxy( + client_factory=client_factory, + name='MCP Proxy for AWS', + instructions=( + 'MCP Proxy for AWS provides access to SigV4 protected MCP servers through a single interface. ' + 'This proxy handles authentication and request routing to the appropriate backend services.' + ), + ) + proxy.add_middleware(InitializeMiddleware(client_factory)) + add_logging_middleware(proxy, args.log_level) + add_tool_filtering_middleware(proxy, args.read_only) + + if args.retries: + add_retry_middleware(proxy, args.retries) + await proxy.run_async(transport='stdio', show_banner=False, log_level=args.log_level) + except Exception as e: + logger.error('Cannot start proxy server: %s', e) + raise e + finally: + await client_factory.disconnect() def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None: diff --git a/tests/unit/test_initialize_client.py b/tests/unit/test_initialize_client.py deleted file mode 100644 index a9568c6..0000000 --- a/tests/unit/test_initialize_client.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for _initialize_client error handling.""" - -import httpx -import pytest -from mcp import McpError -from mcp.types import ErrorData, JSONRPCError, JSONRPCResponse -from mcp_proxy_for_aws.server import _initialize_client -from unittest.mock import AsyncMock, Mock, patch - - -@pytest.mark.asyncio -async def test_successful_initialization(): - """Test successful client initialization.""" - mock_transport = Mock() - mock_client = Mock() - - with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class: - mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) - mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None) - - async with _initialize_client(mock_transport) as client: - assert client == mock_client - - -@pytest.mark.asyncio -async def test_http_error_with_jsonrpc_error(capsys): - """Test HTTPStatusError with JSONRPCError response.""" - mock_transport = Mock() - error_data = ErrorData(code=-32600, message='Invalid Request') - jsonrpc_error = JSONRPCError(jsonrpc='2.0', id=1, error=error_data) - - mock_response = Mock() - mock_response.aread = AsyncMock(return_value=jsonrpc_error.model_dump_json().encode()) - - http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response) - - with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class: - mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error) - - with pytest.raises(httpx.HTTPStatusError): - async with _initialize_client(mock_transport): - pass - - captured = capsys.readouterr() - assert 'Invalid Request' in captured.out - - -@pytest.mark.asyncio -async def test_http_error_with_jsonrpc_response(capsys): - """Test HTTPStatusError with JSONRPCResponse.""" - mock_transport = Mock() - jsonrpc_response = JSONRPCResponse(jsonrpc='2.0', id=1, result={'status': 'error'}) - - mock_response = Mock() - mock_response.aread = AsyncMock(return_value=jsonrpc_response.model_dump_json().encode()) - - http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response) - - with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class: - mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error) - - with pytest.raises(httpx.HTTPStatusError): - async with _initialize_client(mock_transport): - pass - - captured = capsys.readouterr() - assert '"result":{"status":"error"}' in captured.out - - -@pytest.mark.asyncio -async def test_http_error_with_invalid_json(): - """Test HTTPStatusError with invalid JSON response.""" - mock_transport = Mock() - - mock_response = Mock() - mock_response.aread = AsyncMock(return_value=b'invalid json') - - http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response) - - with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class: - mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error) - - with pytest.raises(httpx.HTTPStatusError): - async with _initialize_client(mock_transport): - pass - - -@pytest.mark.asyncio -async def test_http_error_with_non_jsonrpc_message(): - """Test HTTPStatusError with non-JSONRPCError/Response message.""" - mock_transport = Mock() - - mock_response = Mock() - mock_response.aread = AsyncMock(return_value=b'{"jsonrpc":"2.0","method":"test"}') - - http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response) - - with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class: - mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error) - - with pytest.raises(httpx.HTTPStatusError): - async with _initialize_client(mock_transport): - pass - - -@pytest.mark.asyncio -async def test_http_error_response_read_failure(): - """Test HTTPStatusError when response.aread() fails.""" - mock_transport = Mock() - - mock_response = Mock() - mock_response.aread = AsyncMock(side_effect=Exception('Read failed')) - - http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response) - - with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class: - mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error) - - with pytest.raises(httpx.HTTPStatusError): - async with _initialize_client(mock_transport): - pass - - -@pytest.mark.asyncio -async def test_generic_error_with_mcp_error_cause(capsys): - """Test generic exception with McpError as cause.""" - mock_transport = Mock() - error_data = ErrorData(code=-32601, message='Method not found') - mcp_error = McpError(error_data) - generic_error = Exception('Wrapper error') - generic_error.__cause__ = mcp_error - - with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class: - mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=generic_error) - - with pytest.raises(Exception): - async with _initialize_client(mock_transport): - pass - - captured = capsys.readouterr() - assert 'Method not found' in captured.out - assert '"code":-32601' in captured.out - - -@pytest.mark.asyncio -async def test_generic_error_without_mcp_error_cause(capsys): - """Test generic exception without McpError cause.""" - mock_transport = Mock() - generic_error = Exception('Generic error') - - with patch('mcp_proxy_for_aws.server.ProxyClient') as mock_client_class: - mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=generic_error) - - with pytest.raises(Exception): - async with _initialize_client(mock_transport): - pass - - captured = capsys.readouterr() - assert 'Generic error' in captured.out - assert '"code":-32000' in captured.out diff --git a/tests/unit/test_proxy.py b/tests/unit/test_proxy.py new file mode 100644 index 0000000..eaba425 --- /dev/null +++ b/tests/unit/test_proxy.py @@ -0,0 +1,229 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for proxy module.""" + +import httpx +import pytest +from fastmcp.client.transports import ClientTransport +from fastmcp.exceptions import NotFoundError +from fastmcp.tools import Tool +from mcp import McpError +from mcp.types import ErrorData, InitializeRequest, JSONRPCError +from mcp_proxy_for_aws.proxy import ( + AWSMCPProxy, + AWSMCPProxyClient, + AWSMCPProxyClientFactory, + AWSProxyToolManager, +) +from unittest.mock import AsyncMock, Mock, patch + + +@pytest.mark.asyncio +async def test_tool_manager_get_tool_with_cache(): + """Test get_tool returns from cache when available.""" + mock_factory = Mock() + manager = AWSProxyToolManager(mock_factory) + mock_tool = Mock(spec=Tool) + manager._cached_tools = {'test_tool': mock_tool} + + result = await manager.get_tool('test_tool') + assert result == mock_tool + + +@pytest.mark.asyncio +async def test_tool_manager_get_tool_without_cache(): + """Test get_tool fetches tools when cache is empty.""" + mock_factory = Mock() + manager = AWSProxyToolManager(mock_factory) + mock_tool = Mock(spec=Tool) + + with patch.object(manager, 'get_tools', return_value={'test_tool': mock_tool}): + result = await manager.get_tool('test_tool') + assert result == mock_tool + assert manager._cached_tools == {'test_tool': mock_tool} + + +@pytest.mark.asyncio +async def test_tool_manager_get_tool_not_found(): + """Test get_tool raises NotFoundError when tool doesn't exist.""" + mock_factory = Mock() + manager = AWSProxyToolManager(mock_factory) + manager._cached_tools = {} + + with pytest.raises(NotFoundError, match="Tool 'missing_tool' not found"): + await manager.get_tool('missing_tool') + + +@pytest.mark.asyncio +async def test_tool_manager_get_tools_updates_cache(): + """Test get_tools updates the cache.""" + mock_factory = Mock() + manager = AWSProxyToolManager(mock_factory) + mock_tools = {'tool1': Mock(spec=Tool), 'tool2': Mock(spec=Tool)} + + with patch('mcp_proxy_for_aws.proxy._ProxyToolManager.get_tools', return_value=mock_tools): + result = await manager.get_tools() + assert result == mock_tools + assert manager._cached_tools == mock_tools + + +def test_proxy_initialization(): + """Test AWSMCPProxy initializes with custom tool manager.""" + mock_factory = Mock() + proxy = AWSMCPProxy(client_factory=mock_factory, name='test') + assert isinstance(proxy._tool_manager, AWSProxyToolManager) + + +@pytest.mark.asyncio +async def test_proxy_client_connect_success(): + """Test successful connection.""" + mock_transport = Mock(spec=ClientTransport) + client = AWSMCPProxyClient(mock_transport) + + with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', return_value='connected'): + result = await client._connect() + assert result == 'connected' + + +@pytest.mark.asyncio +async def test_proxy_client_connect_http_error_with_mcp_error(): + """Test connection failure with MCP error response.""" + mock_transport = Mock(spec=ClientTransport) + client = AWSMCPProxyClient(mock_transport) + + error_data = ErrorData(code=-32600, message='Invalid Request') + jsonrpc_error = JSONRPCError(jsonrpc='2.0', id=1, error=error_data) + + mock_response = Mock() + mock_response.aread = AsyncMock(return_value=jsonrpc_error.model_dump_json().encode()) + + http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response) + + with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', side_effect=http_error): + with pytest.raises(McpError) as exc_info: + await client._connect() + assert exc_info.value.error.code == -32600 + assert exc_info.value.error.message == 'Invalid Request' + + +@pytest.mark.asyncio +async def test_proxy_client_connect_http_error_non_mcp(): + """Test connection failure with non-MCP HTTP error.""" + mock_transport = Mock(spec=ClientTransport) + client = AWSMCPProxyClient(mock_transport) + + mock_response = Mock() + mock_response.aread = AsyncMock(return_value=b'Not a JSON-RPC message') + + http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response) + + with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', side_effect=http_error): + with pytest.raises(httpx.HTTPStatusError): + await client._connect() + + +@pytest.mark.asyncio +async def test_proxy_client_aexit_does_not_disconnect(): + """Test __aexit__ does not disconnect the client.""" + mock_transport = Mock(spec=ClientTransport) + client = AWSMCPProxyClient(mock_transport) + + result = await client.__aexit__(None, None, None) + assert result is None + + +def test_client_factory_initialization(): + """Test factory initialization.""" + mock_transport = Mock(spec=ClientTransport) + factory = AWSMCPProxyClientFactory(mock_transport) + + assert factory._transport == mock_transport + assert factory._client is None + assert factory._initialize_request is None + + +def test_client_factory_set_init_params(): + """Test setting initialization parameters.""" + mock_transport = Mock(spec=ClientTransport) + factory = AWSMCPProxyClientFactory(mock_transport) + + mock_request = Mock(spec=InitializeRequest) + factory.set_init_params(mock_request) + + assert factory._initialize_request == mock_request + + +@pytest.mark.asyncio +async def test_client_factory_get_client_when_connected(): + """Test get_client returns existing client when connected.""" + mock_transport = Mock(spec=ClientTransport) + factory = AWSMCPProxyClientFactory(mock_transport) + + mock_client = Mock(spec=AWSMCPProxyClient) + factory._client = mock_client + + client = await factory.get_client() + assert client == mock_client + + +@pytest.mark.asyncio +async def test_client_factory_get_client_when_disconnected(): + """Test get_client creates new client when disconnected.""" + mock_transport = Mock(spec=ClientTransport) + factory = AWSMCPProxyClientFactory(mock_transport) + + client = await factory.get_client() + assert isinstance(client, AWSMCPProxyClient) + assert factory._client == client + + +@pytest.mark.asyncio +async def test_client_factory_callable_interface(): + """Test factory callable interface.""" + mock_transport = Mock(spec=ClientTransport) + factory = AWSMCPProxyClientFactory(mock_transport) + + client = await factory() + assert isinstance(client, AWSMCPProxyClient) + + +@pytest.mark.asyncio +async def test_client_factory_disconnect_all(): + """Test disconnect disconnects the client.""" + mock_transport = Mock(spec=ClientTransport) + factory = AWSMCPProxyClientFactory(mock_transport) + + mock_client = Mock() + mock_client._disconnect = AsyncMock() + factory._client = mock_client + + await factory.disconnect() + + mock_client._disconnect.assert_called_once_with(force=True) + + +@pytest.mark.asyncio +async def test_client_factory_disconnect_all_handles_exceptions(): + """Test disconnect handles exceptions gracefully.""" + mock_transport = Mock(spec=ClientTransport) + factory = AWSMCPProxyClientFactory(mock_transport) + + mock_client = Mock() + mock_client._disconnect = AsyncMock(side_effect=Exception('Disconnect failed')) + factory._client = mock_client + + await factory.disconnect() + + mock_client._disconnect.assert_called_once_with(force=True) diff --git a/tests/unit/test_server.py b/tests/unit/test_server.py index 4432c4f..4269afd 100644 --- a/tests/unit/test_server.py +++ b/tests/unit/test_server.py @@ -30,9 +30,9 @@ class TestServer: """Tests for the server module.""" - @patch('mcp_proxy_for_aws.server.ProxyClient') + @patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory') @patch('mcp_proxy_for_aws.server.create_transport_with_sigv4') - @patch('mcp_proxy_for_aws.server.FastMCPProxy') + @patch('mcp_proxy_for_aws.server.AWSMCPProxy') @patch('mcp_proxy_for_aws.server.determine_aws_region') @patch('mcp_proxy_for_aws.server.determine_service_name') @patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware') @@ -43,9 +43,9 @@ async def test_setup_mcp_mode( mock_add_filtering, mock_determine_service, mock_determine_region, - mock_fastmcp_proxy, + mock_aws_proxy, mock_create_transport, - mock_client_class, + mock_client_factory_class, ): """Test that MCP mode is set up correctly.""" # Arrange @@ -68,20 +68,18 @@ async def test_setup_mcp_mode( mock_determine_service.return_value = 'test-service' mock_determine_region.return_value = 'us-east-1' - # Mock the transport and client + # Mock the transport and client factory mock_transport = Mock(spec=ClientTransport) mock_create_transport.return_value = mock_transport - mock_client = Mock() - mock_client.initialize_result = None - mock_client.is_connected = Mock(return_value=True) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client_class.return_value = mock_client + mock_client_factory = Mock() + mock_client_factory.disconnect = AsyncMock() + mock_client_factory_class.return_value = mock_client_factory mock_proxy = Mock() mock_proxy.run_async = AsyncMock() - mock_fastmcp_proxy.return_value = mock_proxy + mock_proxy.add_middleware = Mock() + mock_aws_proxy.return_value = mock_proxy # Act await run_proxy(mock_args) @@ -98,17 +96,17 @@ async def test_setup_mcp_mode( assert call_args[0][3] == {'AWS_REGION': 'us-east-1'} # metadata # call_args[0][4] is the Timeout object assert call_args[0][5] is None # profile - mock_client_class.assert_called_once_with(mock_transport) - mock_fastmcp_proxy.assert_called_once() + mock_client_factory_class.assert_called_once_with(mock_transport) + mock_aws_proxy.assert_called_once() mock_add_filtering.assert_called_once_with(mock_proxy, True) mock_add_retry.assert_called_once_with(mock_proxy, 1) mock_proxy.run_async.assert_called_once_with( transport='stdio', show_banner=False, log_level='INFO' ) - @patch('mcp_proxy_for_aws.server.ProxyClient') + @patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory') @patch('mcp_proxy_for_aws.server.create_transport_with_sigv4') - @patch('mcp_proxy_for_aws.server.FastMCPProxy') + @patch('mcp_proxy_for_aws.server.AWSMCPProxy') @patch('mcp_proxy_for_aws.server.determine_aws_region') @patch('mcp_proxy_for_aws.server.determine_service_name') @patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware') @@ -117,9 +115,9 @@ async def test_setup_mcp_mode_no_retries( mock_add_filtering, mock_determine_service, mock_determine_region, - mock_fastmcp_proxy, + mock_aws_proxy, mock_create_transport, - mock_client_class, + mock_client_factory_class, ): """Test that MCP mode setup without retries doesn't add retry middleware.""" # Arrange @@ -142,20 +140,18 @@ async def test_setup_mcp_mode_no_retries( mock_determine_service.return_value = 'test-service' mock_determine_region.return_value = 'us-east-1' - # Mock the transport and client + # Mock the transport and client factory mock_transport = Mock(spec=ClientTransport) mock_create_transport.return_value = mock_transport - mock_client = Mock() - mock_client.initialize_result = None - mock_client.is_connected = Mock(return_value=True) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client_class.return_value = mock_client + mock_client_factory = Mock() + mock_client_factory.disconnect = AsyncMock() + mock_client_factory_class.return_value = mock_client_factory mock_proxy = Mock() mock_proxy.run_async = AsyncMock() - mock_fastmcp_proxy.return_value = mock_proxy + mock_proxy.add_middleware = Mock() + mock_aws_proxy.return_value = mock_proxy # Act await run_proxy(mock_args) @@ -175,16 +171,16 @@ async def test_setup_mcp_mode_no_retries( } # metadata # call_args[0][4] is the Timeout object assert call_args[0][5] == 'test-profile' # profile - mock_client_class.assert_called_once_with(mock_transport) - mock_fastmcp_proxy.assert_called_once() + mock_client_factory_class.assert_called_once_with(mock_transport) + mock_aws_proxy.assert_called_once() mock_add_filtering.assert_called_once_with(mock_proxy, False) mock_proxy.run_async.assert_called_once_with( transport='stdio', show_banner=False, log_level='INFO' ) - @patch('mcp_proxy_for_aws.server.ProxyClient') + @patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory') @patch('mcp_proxy_for_aws.server.create_transport_with_sigv4') - @patch('mcp_proxy_for_aws.server.FastMCPProxy') + @patch('mcp_proxy_for_aws.server.AWSMCPProxy') @patch('mcp_proxy_for_aws.server.determine_aws_region') @patch('mcp_proxy_for_aws.server.determine_service_name') @patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware') @@ -193,9 +189,9 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region( mock_add_filtering, mock_determine_service, mock_determine_region, - mock_fastmcp_proxy, + mock_aws_proxy, mock_create_transport, - mock_client_class, + mock_client_factory_class, ): """Test that AWS_REGION is automatically injected when no metadata is provided.""" # Arrange @@ -219,16 +215,14 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region( mock_transport = Mock(spec=ClientTransport) mock_create_transport.return_value = mock_transport - mock_client = Mock() - mock_client.initialize_result = None - mock_client.is_connected = Mock(return_value=True) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client_class.return_value = mock_client + mock_client_factory = Mock() + mock_client_factory.disconnect = AsyncMock() + mock_client_factory_class.return_value = mock_client_factory mock_proxy = Mock() mock_proxy.run_async = AsyncMock() - mock_fastmcp_proxy.return_value = mock_proxy + mock_proxy.add_middleware = Mock() + mock_aws_proxy.return_value = mock_proxy # Act await run_proxy(mock_args) @@ -239,9 +233,9 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region( metadata = call_args[0][3] assert metadata == {'AWS_REGION': 'ap-southeast-1'} - @patch('mcp_proxy_for_aws.server.ProxyClient') + @patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory') @patch('mcp_proxy_for_aws.server.create_transport_with_sigv4') - @patch('mcp_proxy_for_aws.server.FastMCPProxy') + @patch('mcp_proxy_for_aws.server.AWSMCPProxy') @patch('mcp_proxy_for_aws.server.determine_aws_region') @patch('mcp_proxy_for_aws.server.determine_service_name') @patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware') @@ -250,9 +244,9 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it( mock_add_filtering, mock_determine_service, mock_determine_region, - mock_fastmcp_proxy, + mock_aws_proxy, mock_create_transport, - mock_client_class, + mock_client_factory_class, ): """Test that AWS_REGION is injected even when other metadata is provided.""" # Arrange @@ -276,16 +270,14 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it( mock_transport = Mock(spec=ClientTransport) mock_create_transport.return_value = mock_transport - mock_client = Mock() - mock_client.initialize_result = None - mock_client.is_connected = Mock(return_value=True) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - mock_client_class.return_value = mock_client + mock_client_factory = Mock() + mock_client_factory.disconnect = AsyncMock() + mock_client_factory_class.return_value = mock_client_factory mock_proxy = Mock() mock_proxy.run_async = AsyncMock() - mock_fastmcp_proxy.return_value = mock_proxy + mock_proxy.add_middleware = Mock() + mock_aws_proxy.return_value = mock_proxy # Act await run_proxy(mock_args)