Skip to content

Commit bf03f32

Browse files
refactor: address PR review comments for Client class
- Make InMemoryTransport private by moving to _memory.py module - Improve exception handling in __aenter__: use AsyncExitStack.pop_all() pattern instead of manual try/except for cleaner resource cleanup - Add reentry check: raise RuntimeError if Client is entered twice - Add test for reentry behavior - Remove low-level transport documentation (InMemoryTransport is now private)
1 parent 1f66cd3 commit bf03f32

File tree

12 files changed

+30
-44
lines changed

12 files changed

+30
-44
lines changed

docs/testing.md

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -69,26 +69,3 @@ async def test_call_add_tool():
6969
1. If you are using `trio`, you should set `"trio"` as the `anyio_backend`. Check more information in the [anyio documentation](https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on).
7070

7171
There you go! You can now extend your tests to cover more scenarios.
72-
73-
## Advanced: Low-Level Transport Access
74-
75-
For tests that need direct access to the underlying `ClientSession`, use `InMemoryTransport`:
76-
77-
```python
78-
from mcp.client.session import ClientSession
79-
from mcp.client.transports import InMemoryTransport
80-
81-
from server import app
82-
83-
84-
async def test_with_low_level_access():
85-
transport = InMemoryTransport(app, raise_exceptions=True)
86-
async with transport.connect() as (read_stream, write_stream):
87-
async with ClientSession(read_stream, write_stream) as session:
88-
await session.initialize()
89-
# Direct access to the full ClientSession API
90-
result = await session.call_tool("add", {"a": 1, "b": 2})
91-
```
92-
93-
The `Client` class is built on top of `InMemoryTransport` and handles initialization automatically.
94-
Use `InMemoryTransport` directly only when you need fine-grained control over the session lifecycle.

src/mcp/client/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22

33
from mcp.client.client import Client
44
from mcp.client.session import ClientSession
5-
from mcp.client.transports.memory import InMemoryTransport
65

76
__all__ = [
87
"Client",
98
"ClientSession",
10-
"InMemoryTransport",
119
]

src/mcp/client/client.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pydantic import AnyUrl
1010

1111
import mcp.types as types
12+
from mcp.client._memory import InMemoryTransport
1213
from mcp.client.session import (
1314
ClientSession,
1415
ElicitationFnT,
@@ -17,7 +18,6 @@
1718
MessageHandlerFnT,
1819
SamplingFnT,
1920
)
20-
from mcp.client.transports.memory import InMemoryTransport
2121
from mcp.server import Server
2222
from mcp.server.fastmcp import FastMCP
2323
from mcp.shared.session import ProgressFnT
@@ -92,16 +92,16 @@ def __init__(
9292

9393
async def __aenter__(self) -> Client:
9494
"""Enter the async context manager."""
95-
self._exit_stack = AsyncExitStack()
96-
await self._exit_stack.__aenter__()
95+
if self._session is not None:
96+
raise RuntimeError("Client is already entered; cannot reenter")
9797

98-
try:
98+
async with AsyncExitStack() as exit_stack:
9999
# Create transport and connect
100100
transport = InMemoryTransport(self._server, raise_exceptions=self._raise_exceptions)
101-
read_stream, write_stream = await self._exit_stack.enter_async_context(transport.connect())
101+
read_stream, write_stream = await exit_stack.enter_async_context(transport.connect())
102102

103103
# Create session
104-
self._session = await self._exit_stack.enter_async_context(
104+
self._session = await exit_stack.enter_async_context(
105105
ClientSession(
106106
read_stream=read_stream,
107107
write_stream=write_stream,
@@ -118,10 +118,9 @@ async def __aenter__(self) -> Client:
118118
# Initialize the session
119119
await self._session.initialize()
120120

121+
# Transfer ownership to self for __aexit__ to handle
122+
self._exit_stack = exit_stack.pop_all()
121123
return self
122-
except Exception:
123-
await self._exit_stack.__aexit__(None, None, None)
124-
raise
125124

126125
async def __aexit__(
127126
self,
Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
11
"""Transport implementations for MCP clients."""
22

3-
from mcp.client.transports.memory import InMemoryTransport
4-
5-
__all__ = ["InMemoryTransport"]
3+
__all__: list[str] = []

tests/client/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ async def patched_create_streams():
125125
# Apply the patch for the duration of the test
126126
# Patch both locations since InMemoryTransport imports it directly
127127
with patch("mcp.shared.memory.create_client_server_memory_streams", patched_create_streams):
128-
with patch("mcp.client.transports.memory.create_client_server_memory_streams", patched_create_streams):
128+
with patch("mcp.client._memory.create_client_server_memory_streams", patched_create_streams):
129129
# Return a collection with helper methods
130130
def get_spy_collection() -> StreamSpyCollection:
131131
assert client_spy is not None, "client_spy was not initialized"

tests/client/test_client.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,13 @@ def test_session_property_before_enter(app: FastMCP):
220220
_ = client.session
221221

222222

223+
async def test_reentry_raises_runtime_error(app: FastMCP):
224+
"""Test that reentering a client raises RuntimeError."""
225+
async with Client(app) as client:
226+
with pytest.raises(RuntimeError, match="Client is already entered"):
227+
await client.__aenter__()
228+
229+
223230
async def test_cleanup_on_init_failure(app: FastMCP):
224231
"""Test that resources are cleaned up if initialization fails."""
225232
with patch("mcp.client.client.ClientSession") as mock_session_class:
@@ -230,9 +237,16 @@ async def test_cleanup_on_init_failure(app: FastMCP):
230237
mock_session_class.return_value = mock_session
231238

232239
client = Client(app)
233-
with pytest.raises(RuntimeError, match="Session init failed"):
240+
with pytest.raises(BaseException) as exc_info:
234241
await client.__aenter__()
235242

243+
# The error should contain our message (may be wrapped in ExceptionGroup)
244+
exc = exc_info.value
245+
if hasattr(exc, "exceptions"): # ExceptionGroup
246+
assert any("Session init failed" in str(e) for e in exc.exceptions)
247+
else:
248+
assert "Session init failed" in str(exc)
249+
236250
# Verify the client is in a clean state (session should be None)
237251
assert client._session is None
238252

tests/client/test_list_methods_cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import pytest
44

55
import mcp.types as types
6+
from mcp.client._memory import InMemoryTransport
67
from mcp.client.session import ClientSession
7-
from mcp.client.transports import InMemoryTransport
88
from mcp.server import Server
99
from mcp.server.fastmcp import FastMCP
1010
from mcp.types import ListToolsRequest, ListToolsResult

tests/client/transports/test_memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from mcp.client.transports.memory import InMemoryTransport
5+
from mcp.client._memory import InMemoryTransport
66
from mcp.server import Server
77
from mcp.server.fastmcp import FastMCP
88
from mcp.types import Resource

tests/server/test_cancel_handling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import pytest
77

88
import mcp.types as types
9+
from mcp.client._memory import InMemoryTransport
910
from mcp.client.session import ClientSession
10-
from mcp.client.transports import InMemoryTransport
1111
from mcp.server.lowlevel.server import Server
1212
from mcp.shared.exceptions import McpError
1313
from mcp.types import (

0 commit comments

Comments
 (0)