Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions mcp_proxy_for_aws/middleware/initialize_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import logging
import mcp.types as mt
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext
from mcp_proxy_for_aws.proxy import AWSMCPProxyClientFactory
from typing_extensions import override


logger = logging.getLogger(__name__)


class InitializeMiddleware(Middleware):
"""Intecept MCP initialize request and initialize the proxy client."""

def __init__(self, client_factory: AWSMCPProxyClientFactory) -> None:
"""Create a middleware with client factory."""
super().__init__()
self._client_factory = client_factory

@override
async def on_initialize(
self,
context: MiddlewareContext[mt.InitializeRequest],
call_next: CallNext[mt.InitializeRequest, None],
) -> None:
try:
logger.debug('Received initialize request %s.', context.message)
self._client_factory.set_init_params(context.message)
return await call_next(context)
except Exception:
logger.exception('Initialize failed in middleware.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to log the exact error?

Copy link
Contributor Author

@wzxxing wzxxing Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exception logs the latest error automatically

raise
167 changes: 167 additions & 0 deletions mcp_proxy_for_aws/proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import httpx
import logging
from fastmcp import Client
from fastmcp.client.transports import ClientTransport
from fastmcp.exceptions import NotFoundError
from fastmcp.server.proxy import ClientFactoryT
from fastmcp.server.proxy import FastMCPProxy as _FastMCPProxy
from fastmcp.server.proxy import ProxyClient as _ProxyClient
from fastmcp.server.proxy import ProxyToolManager as _ProxyToolManager
from fastmcp.tools import Tool
from mcp import McpError
from mcp.types import InitializeRequest, JSONRPCError, JSONRPCMessage
from typing import Any
from typing_extensions import override


logger = logging.getLogger(__name__)


class AWSProxyToolManager(_ProxyToolManager):
"""Customized proxy tool manager that better suites our needs."""

def __init__(self, client_factory: ClientFactoryT, **kwargs: Any):
"""Initialize a proxy tool manager.

Cached tools are set to None.
"""
super().__init__(client_factory, **kwargs)
self._cached_tools: dict[str, Tool] | None = None

@override
async def get_tool(self, key: str) -> Tool:
"""Return the tool from cached tools.

This method is invoked when the client tries to call a tool.

tool = self.get_tool(key)
tool.invoke(...)

The parent class implementation always make a mcp call to list the tools.
Since the client already knows the name of the tools, list_tool is not necessary.
We are wasting a network call just to get the tools which were already listed.

In case the server supports notifications/tools/listChanged, the `get_tools` method
will be called explicity , hence, we are not missing the change to the tool list.
"""
if self._cached_tools is None:
logger.debug('cached_tools not found, calling get_tools')
self._cached_tools = await self.get_tools()
if key in self._cached_tools:
return self._cached_tools[key]
raise NotFoundError(f'Tool {key!r} not found')

@override
async def get_tools(self) -> dict[str, Tool]:
"""Return list tools."""
self._cached_tools = await super(AWSProxyToolManager, self).get_tools()
return self._cached_tools


class AWSMCPProxy(_FastMCPProxy):
"""Customized MCP Proxy to better suite our needs."""

def __init__(
self,
*,
client_factory: ClientFactoryT | None = None,
**kwargs,
):
"""Initialize a client."""
super().__init__(client_factory=client_factory, **kwargs)
self._tool_manager = AWSProxyToolManager(
client_factory=self.client_factory,
transformations=self._tool_manager.transformations,
)


class AWSMCPProxyClient(_ProxyClient):
"""Proxy client that handles HTTP errors when connection fails."""

def __init__(self, transport: ClientTransport, **kwargs):
"""Constructor of AutoRefreshProxyCilent."""
super().__init__(transport, **kwargs)

@override
async def _connect(self):
"""Enter as normal && initialize only once."""
logger.debug('Connecting %s', self)
try:
result = await super(AWSMCPProxyClient, self)._connect()
logger.debug('Connected %s', self)
return result
except httpx.HTTPStatusError as http_error:
logger.exception('Connection failed')
response = http_error.response
try:
body = await response.aread()
jsonrpc_msg = JSONRPCMessage.model_validate_json(body).root
except Exception:
logger.debug('HTTP error is not a valid MCP message.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to log the exception here?

raise http_error

if isinstance(jsonrpc_msg, JSONRPCError):
logger.debug('Converting HTTP error to MCP error %s', http_error)
# raising McpError so that the sdk can handle the exception properly
raise McpError(error=jsonrpc_msg.error) from http_error
else:
raise http_error
except RuntimeError:
try:
logger.warning('encountered runtime error, try force disconnect.')
await self._disconnect(force=True)
except Exception:
# _disconnect awaits on the session_task,
# which raises the timeout error that caused the client session to be terminated.
# the error is ignored as long as the counter is force set to 0.
# TODO: investigate how timeout error is handled by fastmcp and httpx
logger.exception('encountered another error, ignoring.')
return await self._connect()

async def __aexit__(self, exc_type, exc_val, exc_tb):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessary imo, __aexit__ is async context manager protocol method, not a method specific to the parent class.

"""The MCP Proxy for AWS project is a proxy from stdio to http (sigv4).

We want the client to remain connected until the stdio connection is closed.

https://modelcontextprotocol.io/specification/2024-11-05/basic/transports#stdio

1. close stdin
2. terminate subprocess

There is no equivalent of the streamble-http DELETE concept in stdio to terminate a session.
Hence the connection will be terminated only at program exit.
"""
pass


class AWSMCPProxyClientFactory:
"""Client factory that returns a connected client."""

def __init__(self, transport: ClientTransport) -> None:
"""Initialize a client factory with transport."""
self._transport = transport
self._client: AWSMCPProxyClient | None = None
self._initialize_request: InitializeRequest | None = None

def set_init_params(self, initialize_request: InitializeRequest):
"""Set client init parameters."""
self._initialize_request = initialize_request

async def get_client(self) -> Client:
"""Get client."""
if self._client is None:
self._client = AWSMCPProxyClient(self._transport)

return self._client

async def __call__(self) -> Client:
"""Implement the callable factory interface."""
return await self.get_client()

async def disconnect(self):
"""Disconnect all the clients (no throw)."""
try:
if self._client:
await self._client._disconnect(force=True)
except Exception:
logger.exception('Failed to disconnect client.')
120 changes: 25 additions & 95 deletions mcp_proxy_for_aws/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,16 @@
"""

import asyncio
import contextlib
import httpx
import logging
import sys
from fastmcp.client import ClientTransport
from fastmcp.server.middleware.error_handling import RetryMiddleware
from fastmcp.server.middleware.logging import LoggingMiddleware
from fastmcp.server.proxy import FastMCPProxy, ProxyClient
from fastmcp.server.server import FastMCP
from mcp import McpError
from mcp.types import (
CONNECTION_CLOSED,
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCResponse,
)
from mcp_proxy_for_aws.cli import parse_args
from mcp_proxy_for_aws.logging_config import configure_logging
from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware
from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware
from mcp_proxy_for_aws.proxy import AWSMCPProxy, AWSMCPProxyClientFactory
from mcp_proxy_for_aws.utils import (
create_transport_with_sigv4,
determine_aws_region,
Expand All @@ -53,62 +43,9 @@
logger = logging.getLogger(__name__)


@contextlib.asynccontextmanager
async def _initialize_client(transport: ClientTransport):
"""Handle the exceptions for during client initialize."""
async with contextlib.AsyncExitStack() as stack:
try:
client = await stack.enter_async_context(ProxyClient(transport))
except httpx.HTTPStatusError as http_error:
logger.error('HTTP Error during initialize %s', http_error)
response = http_error.response
try:
body = await response.aread()
jsonrpc_msg = JSONRPCMessage.model_validate_json(body).root
if isinstance(jsonrpc_msg, (JSONRPCError, JSONRPCResponse)):
line = jsonrpc_msg.model_dump_json(
by_alias=True,
exclude_none=True,
)
logger.debug('Writing the unhandled http error to stdout %s', http_error)
print(line, file=sys.stdout)
else:
logger.debug('Ignoring jsonrpc message type=%s', type(jsonrpc_msg))
except Exception as _:
logger.debug('Cannot read HTTP response body')
raise http_error
except Exception as e:
cause = e.__cause__
if isinstance(cause, McpError):
logger.error('MCP Error during initialize %s', cause.error)
jsonrpc_error = JSONRPCError(jsonrpc='2.0', id=0, error=cause.error)
line = jsonrpc_error.model_dump_json(
by_alias=True,
exclude_none=True,
)
else:
logger.error('Error during initialize %s', e)
jsonrpc_error = JSONRPCError(
jsonrpc='2.0',
id=0,
error=ErrorData(
code=CONNECTION_CLOSED,
message=str(e),
),
)
line = jsonrpc_error.model_dump_json(
by_alias=True,
exclude_none=True,
)
print(line, file=sys.stdout)
raise e
logger.debug('Initialized MCP client')
yield client


async def run_proxy(args) -> None:
"""Set up the server in MCP mode."""
logger.info('Setting up server in MCP mode')
logger.info('Setting up mcp proxy server to %s', args.endpoint)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are already editing this, can you also log the current version of the proxy, should help with debugging when logs are shared.


# Validate and determine service
service = determine_service_name(args.endpoint, args.service)
Expand All @@ -134,7 +71,6 @@ async def run_proxy(args) -> None:
metadata,
profile,
)
logger.info('Running in MCP mode')

timeout = httpx.Timeout(
args.timeout,
Expand All @@ -147,35 +83,29 @@ async def run_proxy(args) -> None:
transport = create_transport_with_sigv4(
args.endpoint, service, region, metadata, timeout, profile
)
client_factory = AWSMCPProxyClientFactory(transport)

async with _initialize_client(transport) as client:

async def client_factory():
nonlocal client
if not client.is_connected():
logger.debug('Reinitialize client')
client = ProxyClient(transport)
await client._connect()
return client

try:
proxy = FastMCPProxy(
client_factory=client_factory,
name='MCP Proxy for AWS',
instructions=(
'MCP Proxy for AWS provides access to SigV4 protected MCP servers through a single interface. '
'This proxy handles authentication and request routing to the appropriate backend services.'
),
)
add_logging_middleware(proxy, args.log_level)
add_tool_filtering_middleware(proxy, args.read_only)

if args.retries:
add_retry_middleware(proxy, args.retries)
await proxy.run_async(transport='stdio', show_banner=False, log_level=args.log_level)
except Exception as e:
logger.error('Cannot start proxy server: %s', e)
raise e
try:
proxy = AWSMCPProxy(
client_factory=client_factory,
name='MCP Proxy for AWS',
instructions=(
'MCP Proxy for AWS provides access to SigV4 protected MCP servers through a single interface. '
'This proxy handles authentication and request routing to the appropriate backend services.'
),
)
proxy.add_middleware(InitializeMiddleware(client_factory))
add_logging_middleware(proxy, args.log_level)
add_tool_filtering_middleware(proxy, args.read_only)

if args.retries:
add_retry_middleware(proxy, args.retries)
await proxy.run_async(transport='stdio', show_banner=False, log_level=args.log_level)
except Exception as e:
logger.error('Cannot start proxy server: %s', e)
raise e
finally:
await client_factory.disconnect()


def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None:
Expand Down
Loading
Loading