Skip to content

Commit da18ca7

Browse files
authored
fix: pass connected client to proxy to reuse session (#88)
Check the docs on <https://gofastmcp.com/servers/proxy#session-isolation-&-concurrency>
1 parent 2d5267e commit da18ca7

File tree

2 files changed

+52
-16
lines changed

2 files changed

+52
-16
lines changed

mcp_proxy_for_aws/server.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import asyncio
2626
import httpx
2727
import logging
28+
from fastmcp import Client
2829
from fastmcp.server.middleware.error_handling import RetryMiddleware
2930
from fastmcp.server.middleware.logging import LoggingMiddleware
3031
from fastmcp.server.server import FastMCP
@@ -83,16 +84,16 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:
8384
transport = create_transport_with_sigv4(
8485
args.endpoint, service, region, metadata, timeout, profile
8586
)
87+
async with Client(transport=transport) as client:
88+
# Create proxy with the transport
89+
proxy = FastMCP.as_proxy(client)
90+
add_logging_middleware(proxy, args.log_level)
91+
add_tool_filtering_middleware(proxy, args.read_only)
8692

87-
# Create proxy with the transport
88-
proxy = FastMCP.as_proxy(transport)
89-
add_logging_middleware(proxy, args.log_level)
90-
add_tool_filtering_middleware(proxy, args.read_only)
93+
if args.retries:
94+
add_retry_middleware(proxy, args.retries)
9195

92-
if args.retries:
93-
add_retry_middleware(proxy, args.retries)
94-
95-
await proxy.run_async()
96+
await proxy.run_async()
9697

9798

9899
def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None:

tests/unit/test_server.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Tests for the mcp-proxy-for-aws Server."""
1616

1717
import pytest
18+
from fastmcp.client.transports import ClientTransport
1819
from fastmcp.server.server import FastMCP
1920
from mcp_proxy_for_aws.server import (
2021
add_retry_middleware,
@@ -31,6 +32,7 @@
3132
class TestServer:
3233
"""Tests for the server module."""
3334

35+
@patch('mcp_proxy_for_aws.server.Client')
3436
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
3537
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
3638
@patch('mcp_proxy_for_aws.server.determine_aws_region')
@@ -45,6 +47,7 @@ async def test_setup_mcp_mode(
4547
mock_determine_region,
4648
mock_as_proxy,
4749
mock_create_transport,
50+
mock_client_class,
4851
):
4952
"""Test that MCP mode is set up correctly."""
5053
# Arrange
@@ -68,9 +71,15 @@ async def test_setup_mcp_mode(
6871
mock_determine_service.return_value = 'test-service'
6972
mock_determine_region.return_value = 'us-east-1'
7073

71-
# Mock the transport and proxy
72-
mock_transport = Mock()
74+
# Mock the transport and client
75+
mock_transport = Mock(spec=ClientTransport)
7376
mock_create_transport.return_value = mock_transport
77+
78+
mock_client = Mock()
79+
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
80+
mock_client.__aexit__ = AsyncMock(return_value=None)
81+
mock_client_class.return_value = mock_client
82+
7483
mock_proxy = Mock()
7584
mock_proxy.run_async = AsyncMock()
7685
mock_as_proxy.return_value = mock_proxy
@@ -90,11 +99,13 @@ async def test_setup_mcp_mode(
9099
assert call_args[0][3] == {'AWS_REGION': 'us-east-1'} # metadata
91100
# call_args[0][4] is the Timeout object
92101
assert call_args[0][5] is None # profile
93-
mock_as_proxy.assert_called_once_with(mock_transport)
102+
mock_client_class.assert_called_once_with(transport=mock_transport)
103+
mock_as_proxy.assert_called_once_with(mock_client)
94104
mock_add_filtering.assert_called_once_with(mock_proxy, True)
95105
mock_add_retry.assert_called_once_with(mock_proxy, 1)
96106
mock_proxy.run_async.assert_called_once()
97107

108+
@patch('mcp_proxy_for_aws.server.Client')
98109
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
99110
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
100111
@patch('mcp_proxy_for_aws.server.determine_aws_region')
@@ -107,6 +118,7 @@ async def test_setup_mcp_mode_no_retries(
107118
mock_determine_region,
108119
mock_as_proxy,
109120
mock_create_transport,
121+
mock_client_class,
110122
):
111123
"""Test that MCP mode setup without retries doesn't add retry middleware."""
112124
# Arrange
@@ -130,9 +142,15 @@ async def test_setup_mcp_mode_no_retries(
130142
mock_determine_service.return_value = 'test-service'
131143
mock_determine_region.return_value = 'us-east-1'
132144

133-
# Mock the transport and proxy
134-
mock_transport = Mock()
145+
# Mock the transport and client
146+
mock_transport = Mock(spec=ClientTransport)
135147
mock_create_transport.return_value = mock_transport
148+
149+
mock_client = Mock()
150+
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
151+
mock_client.__aexit__ = AsyncMock(return_value=None)
152+
mock_client_class.return_value = mock_client
153+
136154
mock_proxy = Mock()
137155
mock_proxy.run_async = AsyncMock()
138156
mock_as_proxy.return_value = mock_proxy
@@ -155,10 +173,12 @@ async def test_setup_mcp_mode_no_retries(
155173
} # metadata
156174
# call_args[0][4] is the Timeout object
157175
assert call_args[0][5] == 'test-profile' # profile
158-
mock_as_proxy.assert_called_once_with(mock_transport)
176+
mock_client_class.assert_called_once_with(transport=mock_transport)
177+
mock_as_proxy.assert_called_once_with(mock_client)
159178
mock_add_filtering.assert_called_once_with(mock_proxy, False)
160179
mock_proxy.run_async.assert_called_once()
161180

181+
@patch('mcp_proxy_for_aws.server.Client')
162182
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
163183
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
164184
@patch('mcp_proxy_for_aws.server.determine_aws_region')
@@ -171,6 +191,7 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
171191
mock_determine_region,
172192
mock_as_proxy,
173193
mock_create_transport,
194+
mock_client_class,
174195
):
175196
"""Test that AWS_REGION is automatically injected when no metadata is provided."""
176197
# Arrange
@@ -192,8 +213,14 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
192213
mock_determine_service.return_value = 'test-service'
193214
mock_determine_region.return_value = 'ap-southeast-1'
194215

195-
mock_transport = Mock()
216+
mock_transport = Mock(spec=ClientTransport)
196217
mock_create_transport.return_value = mock_transport
218+
219+
mock_client = Mock()
220+
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
221+
mock_client.__aexit__ = AsyncMock(return_value=None)
222+
mock_client_class.return_value = mock_client
223+
197224
mock_proxy = Mock()
198225
mock_proxy.run_async = AsyncMock()
199226
mock_as_proxy.return_value = mock_proxy
@@ -207,6 +234,7 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
207234
metadata = call_args[0][3]
208235
assert metadata == {'AWS_REGION': 'ap-southeast-1'}
209236

237+
@patch('mcp_proxy_for_aws.server.Client')
210238
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
211239
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
212240
@patch('mcp_proxy_for_aws.server.determine_aws_region')
@@ -219,6 +247,7 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it(
219247
mock_determine_region,
220248
mock_as_proxy,
221249
mock_create_transport,
250+
mock_client_class,
222251
):
223252
"""Test that AWS_REGION is injected even when other metadata is provided."""
224253
# Arrange
@@ -240,8 +269,14 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it(
240269
mock_determine_service.return_value = 'test-service'
241270
mock_determine_region.return_value = 'us-west-1'
242271

243-
mock_transport = Mock()
272+
mock_transport = Mock(spec=ClientTransport)
244273
mock_create_transport.return_value = mock_transport
274+
275+
mock_client = Mock()
276+
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
277+
mock_client.__aexit__ = AsyncMock(return_value=None)
278+
mock_client_class.return_value = mock_client
279+
245280
mock_proxy = Mock()
246281
mock_proxy.run_async = AsyncMock()
247282
mock_as_proxy.return_value = mock_proxy

0 commit comments

Comments
 (0)