Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions langchain_mcp_adapters/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
"""

import asyncio
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager
from types import TracebackType
from typing import Any

from langchain_core.documents.base import Blob
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import BaseTool
from langchain_core.tools import BaseTool, ToolException
from mcp import ClientSession
from pydantic import ValidationError

from langchain_mcp_adapters.callbacks import CallbackContext, Callbacks
from langchain_mcp_adapters.hooks import Hooks
Expand Down Expand Up @@ -142,12 +143,22 @@ async def session(
await session.initialize()
yield session

async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
async def get_tools(
self,
*,
server_name: str | None = None,
handle_tool_error: bool | str | Callable[[ToolException], str] | None = False,
handle_validation_error: (
bool | str | Callable[[ValidationError], str] | None
) = False,
) -> list[BaseTool]:
"""Get a list of all tools from all connected servers.

Args:
server_name: Optional name of the server to get tools from.
If None, all tools from all servers will be returned (default).
handle_tool_error: Optional error handler for tool execution errors.
handle_validation_error: Optional error handler for validation errors.

NOTE: a new session will be created for each tool call

Expand All @@ -168,6 +179,8 @@ async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
callbacks=self.callbacks,
server_name=server_name,
hooks=self.hooks,
handle_tool_error=handle_tool_error,
handle_validation_error=handle_validation_error,
)

all_tools: list[BaseTool] = []
Expand All @@ -180,6 +193,8 @@ async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
callbacks=self.callbacks,
server_name=name,
hooks=self.hooks,
handle_tool_error=handle_tool_error,
handle_validation_error=handle_validation_error,
)
)
load_mcp_tool_tasks.append(load_mcp_tool_task)
Expand Down
19 changes: 18 additions & 1 deletion langchain_mcp_adapters/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
tools, handle tool execution, and manage tool conversion between the two formats.
"""

from collections.abc import Callable
from typing import Any, cast, get_args

from langchain_core.tools import (
Expand All @@ -25,7 +26,7 @@
TextContent,
)
from mcp.types import Tool as MCPTool
from pydantic import BaseModel, create_model
from pydantic import BaseModel, ValidationError, create_model

from langchain_mcp_adapters.callbacks import CallbackContext, Callbacks, _MCPCallbacks
from langchain_mcp_adapters.hooks import CallToolRequestSpec, Hooks, ToolHookContext
Expand Down Expand Up @@ -128,6 +129,10 @@ def convert_mcp_tool_to_langchain_tool(
callbacks: Callbacks | None = None,
hooks: Hooks | None = None,
server_name: str | None = None,
handle_tool_error: bool | str | Callable[[ToolException], str] | None = False,
handle_validation_error: (
bool | str | Callable[[ValidationError], str] | None
) = False,
) -> BaseTool:
"""Convert an MCP tool to a LangChain tool.

Expand All @@ -141,6 +146,8 @@ def convert_mcp_tool_to_langchain_tool(
callbacks: Optional callbacks for handling notifications and events
hooks: Optional hooks for before/after tool call processing
server_name: Name of the server this tool belongs to
handle_tool_error: Optional error handler for tool execution errors.
handle_validation_error: Optional error handler for validation errors.

Returns:
a LangChain tool
Expand Down Expand Up @@ -259,6 +266,8 @@ async def call_tool(
coroutine=call_tool,
response_format="content_and_artifact",
metadata=metadata,
handle_tool_error=handle_tool_error,
handle_validation_error=handle_validation_error,
)


Expand All @@ -269,6 +278,10 @@ async def load_mcp_tools(
callbacks: Callbacks | None = None,
hooks: Hooks | None = None,
server_name: str | None = None,
handle_tool_error: bool | str | Callable[[ToolException], str] | None = False,
handle_validation_error: (
bool | str | Callable[[ValidationError], str] | None
) = False,
) -> list[BaseTool]:
"""Load all available MCP tools and convert them to LangChain tools.

Expand All @@ -278,6 +291,8 @@ async def load_mcp_tools(
callbacks: Optional callbacks for handling notifications and events.
hooks: Optional hooks for before/after tool call processing.
server_name: Name of the server these tools belong to.
handle_tool_error: Optional error handler for tool execution errors.
handle_validation_error: Optional error handler for validation errors.

Returns:
List of LangChain tools. Tool annotations are returned as part
Expand Down Expand Up @@ -317,6 +332,8 @@ async def load_mcp_tools(
callbacks=callbacks,
hooks=hooks,
server_name=server_name,
handle_tool_error=handle_tool_error,
handle_validation_error=handle_validation_error,
)
for tool in tools
]
Expand Down