22
33import io
44from io import TextIOWrapper
5+ from typing import cast
56
67import anyio
78import pytest
@@ -43,6 +44,16 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ
4344 url_adapter .validate_python (arguments ["url" ])
4445 return types .CallToolResult (content = [types .TextContent (type = "text" , text = "ok" )])
4546
47+ ctx = cast (ServerRequestContext , None )
48+ list_tools_result = await handle_list_tools (ctx , None )
49+ assert list_tools_result .tools [0 ].name == "fetch"
50+
51+ valid_tool_call_result = await handle_call_tool (
52+ ctx ,
53+ types .CallToolRequestParams (name = "fetch" , arguments = {"url" : "https://example.com" }),
54+ )
55+ assert valid_tool_call_result .content == [types .TextContent (type = "text" , text = "ok" )]
56+
4657 server = Server ("test-server" , on_list_tools = handle_list_tools , on_call_tool = handle_call_tool )
4758
4859 raw_stdin = io .BytesIO (
@@ -61,8 +72,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ
6172
6273 stdout .flush ()
6374 responses = [
64- jsonrpc_message_adapter .validate_json (line )
65- for line in raw_stdout .getvalue ().decode ("utf-8" ).splitlines ()
75+ jsonrpc_message_adapter .validate_json (line ) for line in raw_stdout .getvalue ().decode ("utf-8" ).splitlines ()
6676 ]
6777
6878 assert len (responses ) == 2
@@ -86,6 +96,8 @@ async def fetch(url: str) -> str:
8696 await anyio .sleep (0.1 )
8797 return str (TypeAdapter (AnyHttpUrl ).validate_python (url ))
8898
99+ assert await fetch ("https://example.com" ) == "https://example.com/"
100+
89101 raw_stdin = io .BytesIO (
90102 b'{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}\n '
91103 b'{"jsonrpc":"2.0","method":"notifications/initialized"}\n '
@@ -107,8 +119,7 @@ async def fetch(url: str) -> str:
107119
108120 stdout .flush ()
109121 responses = [
110- jsonrpc_message_adapter .validate_json (line )
111- for line in raw_stdout .getvalue ().decode ("utf-8" ).splitlines ()
122+ jsonrpc_message_adapter .validate_json (line ) for line in raw_stdout .getvalue ().decode ("utf-8" ).splitlines ()
112123 ]
113124
114125 assert responses
0 commit comments