Skip to content

fastmcp: allow passing Tool directly to .add_tool #699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from mcp.server.fastmcp.exceptions import ResourceError
from mcp.server.fastmcp.prompts import Prompt, PromptManager
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
from mcp.server.fastmcp.tools import ToolManager
from mcp.server.fastmcp.tools import Tool, ToolManager
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
from mcp.server.fastmcp.utilities.types import Image
from mcp.server.lowlevel.helper_types import ReadResourceContents
Expand Down Expand Up @@ -315,6 +315,10 @@ async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContent
logger.error(f"Error reading resource {uri}: {e}")
raise ResourceError(str(e))

def add_tool_instance(self, tool: Tool) -> None:
"""Add a Tool instance to the server."""
self._tool_manager.add_tool_instance(tool)

def add_tool(
self,
fn: AnyFunction,
Expand Down
18 changes: 11 additions & 7 deletions src/mcp/server/fastmcp/tools/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ def list_tools(self) -> list[Tool]:
"""List all registered tools."""
return list(self._tools.values())

def add_tool_instance(self, tool: Tool) -> Tool:
"""Add a Tool instance to the server."""
existing = self._tools.get(tool.name)
if existing:
if self.warn_on_duplicate_tools:
logger.warning(f"Tool already exists: {tool.name}")
return existing
self._tools[tool.name] = tool
return tool

def add_tool(
self,
fn: Callable[..., Any],
Expand All @@ -42,13 +52,7 @@ def add_tool(
tool = Tool.from_function(
fn, name=name, description=description, annotations=annotations
)
existing = self._tools.get(tool.name)
if existing:
if self.warn_on_duplicate_tools:
logger.warning(f"Tool already exists: {tool.name}")
return existing
self._tools[tool.name] = tool
return tool
return self.add_tool_instance(tool)

async def call_tool(
self,
Expand Down
29 changes: 28 additions & 1 deletion tests/server/fastmcp/test_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from mcp.server.fastmcp import Context, FastMCP
from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.tools import ToolManager
from mcp.server.fastmcp.tools import Tool, ToolManager
from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT
from mcp.types import ToolAnnotations
Expand All @@ -31,6 +32,32 @@ def add(a: int, b: int) -> int:
assert tool.parameters["properties"]["a"]["type"] == "integer"
assert tool.parameters["properties"]["b"]["type"] == "integer"

def test_add_tool_instance(self):
manager = ToolManager()

def add(a: int, b: int) -> int:
return a + b

class AddArguments(ArgModelBase):
a: int
b: int

fn_metadata = FuncMetadata(arg_model=AddArguments)

original_tool = Tool(
name="add",
description="Add two numbers.",
fn=add,
fn_metadata=fn_metadata,
is_async=False,
parameters=AddArguments.model_json_schema(),
context_kwarg=None,
annotations=None,
)
manager.add_tool_instance(original_tool)
saved_tool = manager.get_tool("add")
assert saved_tool == original_tool

@pytest.mark.anyio
async def test_async_function(self):
"""Test registering and running an async function."""
Expand Down
Loading