-
Notifications
You must be signed in to change notification settings - Fork 25
Init client and show errors in all clients #108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9ae03db
7cb3809
0f5f2b4
347b2e6
c55e126
b20924f
08e3ef5
d6a04b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| import logging | ||
| import mcp.types as mt | ||
| from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext | ||
| from mcp_proxy_for_aws.proxy import AWSMCPProxyClientFactory | ||
| from typing_extensions import override | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class InitializeMiddleware(Middleware): | ||
| """Intecept MCP initialize request and initialize the proxy client.""" | ||
|
|
||
| def __init__(self, client_factory: AWSMCPProxyClientFactory) -> None: | ||
| """Create a middleware with client factory.""" | ||
| super().__init__() | ||
| self._client_factory = client_factory | ||
|
|
||
| @override | ||
| async def on_initialize( | ||
| self, | ||
| context: MiddlewareContext[mt.InitializeRequest], | ||
| call_next: CallNext[mt.InitializeRequest, None], | ||
| ) -> None: | ||
| try: | ||
| logger.debug('Received initialize request %s.', context.message) | ||
| self._client_factory.set_init_params(context.message) | ||
| return await call_next(context) | ||
| except Exception: | ||
| logger.exception('Initialize failed in middleware.') | ||
| raise | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,167 @@ | ||
| import httpx | ||
| import logging | ||
| from fastmcp import Client | ||
| from fastmcp.client.transports import ClientTransport | ||
| from fastmcp.exceptions import NotFoundError | ||
| from fastmcp.server.proxy import ClientFactoryT | ||
| from fastmcp.server.proxy import FastMCPProxy as _FastMCPProxy | ||
| from fastmcp.server.proxy import ProxyClient as _ProxyClient | ||
| from fastmcp.server.proxy import ProxyToolManager as _ProxyToolManager | ||
| from fastmcp.tools import Tool | ||
| from mcp import McpError | ||
| from mcp.types import InitializeRequest, JSONRPCError, JSONRPCMessage | ||
| from typing import Any | ||
| from typing_extensions import override | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class AWSProxyToolManager(_ProxyToolManager): | ||
| """Customized proxy tool manager that better suites our needs.""" | ||
|
|
||
| def __init__(self, client_factory: ClientFactoryT, **kwargs: Any): | ||
| """Initialize a proxy tool manager. | ||
|
|
||
| Cached tools are set to None. | ||
| """ | ||
| super().__init__(client_factory, **kwargs) | ||
| self._cached_tools: dict[str, Tool] | None = None | ||
|
|
||
| @override | ||
| async def get_tool(self, key: str) -> Tool: | ||
| """Return the tool from cached tools. | ||
|
|
||
| This method is invoked when the client tries to call a tool. | ||
|
|
||
| tool = self.get_tool(key) | ||
| tool.invoke(...) | ||
|
|
||
| The parent class implementation always make a mcp call to list the tools. | ||
| Since the client already knows the name of the tools, list_tool is not necessary. | ||
| We are wasting a network call just to get the tools which were already listed. | ||
|
|
||
| In case the server supports notifications/tools/listChanged, the `get_tools` method | ||
| will be called explicity , hence, we are not missing the change to the tool list. | ||
| """ | ||
| if self._cached_tools is None: | ||
| logger.debug('cached_tools not found, calling get_tools') | ||
| self._cached_tools = await self.get_tools() | ||
| if key in self._cached_tools: | ||
| return self._cached_tools[key] | ||
| raise NotFoundError(f'Tool {key!r} not found') | ||
|
|
||
| @override | ||
| async def get_tools(self) -> dict[str, Tool]: | ||
| """Return list tools.""" | ||
| self._cached_tools = await super(AWSProxyToolManager, self).get_tools() | ||
| return self._cached_tools | ||
|
|
||
|
|
||
| class AWSMCPProxy(_FastMCPProxy): | ||
| """Customized MCP Proxy to better suite our needs.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| client_factory: ClientFactoryT | None = None, | ||
| **kwargs, | ||
| ): | ||
| """Initialize a client.""" | ||
| super().__init__(client_factory=client_factory, **kwargs) | ||
| self._tool_manager = AWSProxyToolManager( | ||
| client_factory=self.client_factory, | ||
| transformations=self._tool_manager.transformations, | ||
| ) | ||
|
|
||
|
|
||
| class AWSMCPProxyClient(_ProxyClient): | ||
| """Proxy client that handles HTTP errors when connection fails.""" | ||
|
|
||
| def __init__(self, transport: ClientTransport, **kwargs): | ||
| """Constructor of AutoRefreshProxyCilent.""" | ||
| super().__init__(transport, **kwargs) | ||
|
|
||
| @override | ||
| async def _connect(self): | ||
| """Enter as normal && initialize only once.""" | ||
| logger.debug('Connecting %s', self) | ||
| try: | ||
| result = await super(AWSMCPProxyClient, self)._connect() | ||
| logger.debug('Connected %s', self) | ||
| return result | ||
| except httpx.HTTPStatusError as http_error: | ||
| logger.exception('Connection failed') | ||
| response = http_error.response | ||
| try: | ||
| body = await response.aread() | ||
| jsonrpc_msg = JSONRPCMessage.model_validate_json(body).root | ||
| except Exception: | ||
| logger.debug('HTTP error is not a valid MCP message.') | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to log the exception here? |
||
| raise http_error | ||
|
|
||
| if isinstance(jsonrpc_msg, JSONRPCError): | ||
| logger.debug('Converting HTTP error to MCP error %s', http_error) | ||
| # raising McpError so that the sdk can handle the exception properly | ||
| raise McpError(error=jsonrpc_msg.error) from http_error | ||
| else: | ||
| raise http_error | ||
| except RuntimeError: | ||
| try: | ||
| logger.warning('encountered runtime error, try force disconnect.') | ||
| await self._disconnect(force=True) | ||
| except Exception: | ||
| # _disconnect awaits on the session_task, | ||
| # which raises the timeout error that caused the client session to be terminated. | ||
| # the error is ignored as long as the counter is force set to 0. | ||
| # TODO: investigate how timeout error is handled by fastmcp and httpx | ||
| logger.exception('encountered another error, ignoring.') | ||
| return await self._connect() | ||
|
|
||
| async def __aexit__(self, exc_type, exc_val, exc_tb): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. override?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not necessary imo, |
||
| """The MCP Proxy for AWS project is a proxy from stdio to http (sigv4). | ||
|
|
||
| We want the client to remain connected until the stdio connection is closed. | ||
|
|
||
| https://modelcontextprotocol.io/specification/2024-11-05/basic/transports#stdio | ||
|
|
||
| 1. close stdin | ||
| 2. terminate subprocess | ||
|
|
||
| There is no equivalent of the streamble-http DELETE concept in stdio to terminate a session. | ||
| Hence the connection will be terminated only at program exit. | ||
| """ | ||
| pass | ||
|
|
||
|
|
||
| class AWSMCPProxyClientFactory: | ||
| """Client factory that returns a connected client.""" | ||
|
|
||
| def __init__(self, transport: ClientTransport) -> None: | ||
| """Initialize a client factory with transport.""" | ||
| self._transport = transport | ||
| self._client: AWSMCPProxyClient | None = None | ||
| self._initialize_request: InitializeRequest | None = None | ||
|
|
||
| def set_init_params(self, initialize_request: InitializeRequest): | ||
| """Set client init parameters.""" | ||
| self._initialize_request = initialize_request | ||
|
|
||
| async def get_client(self) -> Client: | ||
| """Get client.""" | ||
| if self._client is None: | ||
| self._client = AWSMCPProxyClient(self._transport) | ||
|
|
||
| return self._client | ||
|
|
||
| async def __call__(self) -> Client: | ||
| """Implement the callable factory interface.""" | ||
| return await self.get_client() | ||
|
|
||
| async def disconnect(self): | ||
| """Disconnect all the clients (no throw).""" | ||
| try: | ||
| if self._client: | ||
| await self._client._disconnect(force=True) | ||
| except Exception: | ||
| logger.exception('Failed to disconnect client.') | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,26 +23,16 @@ | |
| """ | ||
|
|
||
| import asyncio | ||
| import contextlib | ||
| import httpx | ||
| import logging | ||
| import sys | ||
| from fastmcp.client import ClientTransport | ||
| from fastmcp.server.middleware.error_handling import RetryMiddleware | ||
| from fastmcp.server.middleware.logging import LoggingMiddleware | ||
| from fastmcp.server.proxy import FastMCPProxy, ProxyClient | ||
| from fastmcp.server.server import FastMCP | ||
| from mcp import McpError | ||
| from mcp.types import ( | ||
| CONNECTION_CLOSED, | ||
| ErrorData, | ||
| JSONRPCError, | ||
| JSONRPCMessage, | ||
| JSONRPCResponse, | ||
| ) | ||
| from mcp_proxy_for_aws.cli import parse_args | ||
| from mcp_proxy_for_aws.logging_config import configure_logging | ||
| from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware | ||
| from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware | ||
| from mcp_proxy_for_aws.proxy import AWSMCPProxy, AWSMCPProxyClientFactory | ||
| from mcp_proxy_for_aws.utils import ( | ||
| create_transport_with_sigv4, | ||
| determine_aws_region, | ||
|
|
@@ -53,62 +43,9 @@ | |
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @contextlib.asynccontextmanager | ||
| async def _initialize_client(transport: ClientTransport): | ||
| """Handle the exceptions for during client initialize.""" | ||
| async with contextlib.AsyncExitStack() as stack: | ||
| try: | ||
| client = await stack.enter_async_context(ProxyClient(transport)) | ||
| except httpx.HTTPStatusError as http_error: | ||
| logger.error('HTTP Error during initialize %s', http_error) | ||
| response = http_error.response | ||
| try: | ||
| body = await response.aread() | ||
| jsonrpc_msg = JSONRPCMessage.model_validate_json(body).root | ||
| if isinstance(jsonrpc_msg, (JSONRPCError, JSONRPCResponse)): | ||
| line = jsonrpc_msg.model_dump_json( | ||
| by_alias=True, | ||
| exclude_none=True, | ||
| ) | ||
| logger.debug('Writing the unhandled http error to stdout %s', http_error) | ||
| print(line, file=sys.stdout) | ||
| else: | ||
| logger.debug('Ignoring jsonrpc message type=%s', type(jsonrpc_msg)) | ||
| except Exception as _: | ||
| logger.debug('Cannot read HTTP response body') | ||
| raise http_error | ||
| except Exception as e: | ||
| cause = e.__cause__ | ||
| if isinstance(cause, McpError): | ||
| logger.error('MCP Error during initialize %s', cause.error) | ||
| jsonrpc_error = JSONRPCError(jsonrpc='2.0', id=0, error=cause.error) | ||
| line = jsonrpc_error.model_dump_json( | ||
| by_alias=True, | ||
| exclude_none=True, | ||
| ) | ||
| else: | ||
| logger.error('Error during initialize %s', e) | ||
| jsonrpc_error = JSONRPCError( | ||
| jsonrpc='2.0', | ||
| id=0, | ||
| error=ErrorData( | ||
| code=CONNECTION_CLOSED, | ||
| message=str(e), | ||
| ), | ||
| ) | ||
| line = jsonrpc_error.model_dump_json( | ||
| by_alias=True, | ||
| exclude_none=True, | ||
| ) | ||
| print(line, file=sys.stdout) | ||
| raise e | ||
| logger.debug('Initialized MCP client') | ||
| yield client | ||
|
|
||
|
|
||
| async def run_proxy(args) -> None: | ||
| """Set up the server in MCP mode.""" | ||
| logger.info('Setting up server in MCP mode') | ||
| logger.info('Setting up mcp proxy server to %s', args.endpoint) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you are already editing this, can you also log the current version of the proxy, should help with debugging when logs are shared. |
||
|
|
||
| # Validate and determine service | ||
| service = determine_service_name(args.endpoint, args.service) | ||
|
|
@@ -134,7 +71,6 @@ async def run_proxy(args) -> None: | |
| metadata, | ||
| profile, | ||
| ) | ||
| logger.info('Running in MCP mode') | ||
|
|
||
| timeout = httpx.Timeout( | ||
| args.timeout, | ||
|
|
@@ -147,35 +83,29 @@ async def run_proxy(args) -> None: | |
| transport = create_transport_with_sigv4( | ||
| args.endpoint, service, region, metadata, timeout, profile | ||
| ) | ||
| client_factory = AWSMCPProxyClientFactory(transport) | ||
|
|
||
| async with _initialize_client(transport) as client: | ||
|
|
||
| async def client_factory(): | ||
| nonlocal client | ||
| if not client.is_connected(): | ||
| logger.debug('Reinitialize client') | ||
| client = ProxyClient(transport) | ||
| await client._connect() | ||
| return client | ||
|
|
||
| try: | ||
| proxy = FastMCPProxy( | ||
| client_factory=client_factory, | ||
| name='MCP Proxy for AWS', | ||
| instructions=( | ||
| 'MCP Proxy for AWS provides access to SigV4 protected MCP servers through a single interface. ' | ||
| 'This proxy handles authentication and request routing to the appropriate backend services.' | ||
| ), | ||
| ) | ||
| add_logging_middleware(proxy, args.log_level) | ||
| add_tool_filtering_middleware(proxy, args.read_only) | ||
|
|
||
| if args.retries: | ||
| add_retry_middleware(proxy, args.retries) | ||
| await proxy.run_async(transport='stdio', show_banner=False, log_level=args.log_level) | ||
| except Exception as e: | ||
| logger.error('Cannot start proxy server: %s', e) | ||
| raise e | ||
| try: | ||
| proxy = AWSMCPProxy( | ||
| client_factory=client_factory, | ||
| name='MCP Proxy for AWS', | ||
| instructions=( | ||
| 'MCP Proxy for AWS provides access to SigV4 protected MCP servers through a single interface. ' | ||
| 'This proxy handles authentication and request routing to the appropriate backend services.' | ||
| ), | ||
| ) | ||
| proxy.add_middleware(InitializeMiddleware(client_factory)) | ||
| add_logging_middleware(proxy, args.log_level) | ||
| add_tool_filtering_middleware(proxy, args.read_only) | ||
|
|
||
| if args.retries: | ||
| add_retry_middleware(proxy, args.retries) | ||
| await proxy.run_async(transport='stdio', show_banner=False, log_level=args.log_level) | ||
| except Exception as e: | ||
| logger.error('Cannot start proxy server: %s', e) | ||
| raise e | ||
| finally: | ||
| await client_factory.disconnect() | ||
|
|
||
|
|
||
| def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want to log the exact error?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exceptionlogs the latest error automatically