Skip to content

Commit 3643b62

Browse files
committed
Reimplement test_client_tool_call_with_meta to goes through all the protocol phases
1 parent 701611d commit 3643b62

File tree

1 file changed

+72
-17
lines changed

1 file changed

+72
-17
lines changed

tests/client/test_session.py

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -507,38 +507,93 @@ async def mock_server():
507507
@pytest.mark.anyio
508508
@pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}])
509509
async def test_client_tool_call_with_meta(meta: dict[str, Any] | None):
510-
"""Test that client tool call requests can include metadata."""
510+
"""Test that client tool call requests can include metadata"""
511511
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
512512
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
513513

514+
mocked_tool = types.Tool(name="sample_tool", inputSchema={})
515+
514516
async def mock_server():
517+
# Receive initialization request from client
515518
session_message = await client_to_server_receive.receive()
516519
jsonrpc_request = session_message.message
517520
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
521+
request = ClientRequest.model_validate(
522+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
523+
)
524+
assert isinstance(request.root, InitializeRequest)
518525

519-
assert jsonrpc_request.root.method == "tools/call"
526+
result = ServerResult(
527+
InitializeResult(
528+
protocolVersion=LATEST_PROTOCOL_VERSION,
529+
capabilities=ServerCapabilities(),
530+
serverInfo=Implementation(name="mock-server", version="0.1.0"),
531+
)
532+
)
533+
534+
# Answer initialization request
535+
await server_to_client_send.send(
536+
SessionMessage(
537+
JSONRPCMessage(
538+
JSONRPCResponse(
539+
jsonrpc="2.0",
540+
id=jsonrpc_request.root.id,
541+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
542+
)
543+
)
544+
)
545+
)
520546

521-
if meta is not None:
522-
assert jsonrpc_request.root.params
523-
assert "_meta" in jsonrpc_request.root.params
524-
assert jsonrpc_request.root.params["_meta"] == meta
547+
# Receive initialized notification
548+
await client_to_server_receive.receive()
549+
550+
# Wait for the client to send a 'tools/call' request
551+
session_message = await client_to_server_receive.receive()
552+
jsonrpc_request = session_message.message
553+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
554+
555+
assert jsonrpc_request.root.method == "tools/call"
525556

526557
result = ServerResult(
527558
CallToolResult(content=[TextContent(type="text", text="Called successfully")], isError=False)
528559
)
529560

530-
async with server_to_client_send:
531-
await server_to_client_send.send(
532-
SessionMessage(
533-
JSONRPCMessage(
534-
JSONRPCResponse(
535-
jsonrpc="2.0",
536-
id=jsonrpc_request.root.id,
537-
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
538-
)
561+
# Send the tools/call result
562+
await server_to_client_send.send(
563+
SessionMessage(
564+
JSONRPCMessage(
565+
JSONRPCResponse(
566+
jsonrpc="2.0",
567+
id=jsonrpc_request.root.id,
568+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
539569
)
540570
)
541571
)
572+
)
573+
574+
# Wait for the tools/list request from the client
575+
# The client requires this step to validate the tool output schema
576+
session_message = await client_to_server_receive.receive()
577+
jsonrpc_request = session_message.message
578+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
579+
580+
assert jsonrpc_request.root.method == "tools/list"
581+
582+
result = types.ListToolsResult(tools=[mocked_tool])
583+
584+
await server_to_client_send.send(
585+
SessionMessage(
586+
JSONRPCMessage(
587+
JSONRPCResponse(
588+
jsonrpc="2.0",
589+
id=jsonrpc_request.root.id,
590+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
591+
)
592+
)
593+
)
594+
)
595+
596+
server_to_client_send.close()
542597

543598
async with (
544599
ClientSession(
@@ -553,6 +608,6 @@ async def mock_server():
553608
):
554609
tg.start_soon(mock_server)
555610

556-
session._tool_output_schemas["sample_tool"] = None
611+
await session.initialize()
557612

558-
await session.call_tool(name="sample_tool", arguments={"foo": "bar"}, meta=meta)
613+
await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta)

0 commit comments

Comments
 (0)