Skip to content

Commit 8deee3f

Browse files
feat(client): Add capability_extensions parameter to ClientSession
Add a new `capability_extensions` parameter to `ClientSession.__init__()` that allows clients to include additional capability fields in the initialize request. This enables clients to advertise protocol extensions (like `io.modelcontextprotocol/ui`) without having to override the `initialize()` method. Example usage: ```python session = ClientSession( read_stream, write_stream, capability_extensions={ "extensions": { "io.modelcontextprotocol/ui": { "mimeTypes": ["text/html;profile=mcp-app"] } } } ) ``` The extensions are merged into `ClientCapabilities` using Pydantic's extra fields feature (`model_config = {'extra': 'allow'}`).
1 parent dcc9b4f commit 8deee3f

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

src/mcp/client/session.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def __init__(
121121
*,
122122
sampling_capabilities: types.SamplingCapability | None = None,
123123
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
124+
capability_extensions: dict[str, Any] | None = None,
124125
) -> None:
125126
super().__init__(
126127
read_stream,
@@ -143,6 +144,10 @@ def __init__(
143144
# Experimental: Task handlers (use defaults if not provided)
144145
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()
145146

147+
# Capability extensions to include in initialize request
148+
# These are merged into ClientCapabilities using Pydantic's extra fields
149+
self._capability_extensions = capability_extensions or {}
150+
146151
async def initialize(self) -> types.InitializeResult:
147152
sampling = (
148153
(self._sampling_capabilities or types.SamplingCapability())
@@ -177,6 +182,7 @@ async def initialize(self) -> types.InitializeResult:
177182
experimental=None,
178183
roots=roots,
179184
tasks=self._task_handlers.build_capability(),
185+
**self._capability_extensions,
180186
),
181187
client_info=self._client_info,
182188
),

tests/client/test_session.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,3 +768,81 @@ async def mock_server():
768768
await session.initialize()
769769

770770
await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta)
771+
772+
773+
@pytest.mark.anyio
774+
async def test_client_session_capability_extensions():
775+
"""Test that capability_extensions are included in the initialize request."""
776+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
777+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
778+
779+
received_capabilities = None
780+
781+
# Define capability extensions (e.g., UI extension)
782+
capability_extensions = {
783+
"extensions": {
784+
"io.modelcontextprotocol/ui": {
785+
"mimeTypes": ["text/html;profile=mcp-app"]
786+
}
787+
}
788+
}
789+
790+
async def mock_server():
791+
nonlocal received_capabilities
792+
793+
session_message = await client_to_server_receive.receive()
794+
jsonrpc_request = session_message.message
795+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
796+
request = ClientRequest.model_validate(
797+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
798+
)
799+
assert isinstance(request.root, InitializeRequest)
800+
received_capabilities = request.root.params.capabilities
801+
802+
result = ServerResult(
803+
InitializeResult(
804+
protocol_version=LATEST_PROTOCOL_VERSION,
805+
capabilities=ServerCapabilities(),
806+
server_info=Implementation(name="mock-server", version="0.1.0"),
807+
)
808+
)
809+
810+
async with server_to_client_send:
811+
await server_to_client_send.send(
812+
SessionMessage(
813+
JSONRPCMessage(
814+
JSONRPCResponse(
815+
jsonrpc="2.0",
816+
id=jsonrpc_request.root.id,
817+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
818+
)
819+
)
820+
)
821+
)
822+
# Receive initialized notification
823+
await client_to_server_receive.receive()
824+
825+
async with (
826+
ClientSession(
827+
server_to_client_receive,
828+
client_to_server_send,
829+
capability_extensions=capability_extensions,
830+
) as session,
831+
anyio.create_task_group() as tg,
832+
client_to_server_send,
833+
client_to_server_receive,
834+
server_to_client_send,
835+
server_to_client_receive,
836+
):
837+
tg.start_soon(mock_server)
838+
await session.initialize()
839+
840+
# Assert that the capability extensions were included in the request
841+
assert received_capabilities is not None
842+
# The extensions should be present via Pydantic's extra fields
843+
caps_dict = received_capabilities.model_dump()
844+
assert "extensions" in caps_dict
845+
assert "io.modelcontextprotocol/ui" in caps_dict["extensions"]
846+
assert caps_dict["extensions"]["io.modelcontextprotocol/ui"]["mimeTypes"] == [
847+
"text/html;profile=mcp-app"
848+
]

0 commit comments

Comments
 (0)