Skip to content

Commit f02a5d5

Browse files
feat: add handle_tool_error and handle_validation_error to load_mcp_tools
1 parent cec7a56 commit f02a5d5

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

langchain_mcp_adapters/client.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55
"""
66

77
import asyncio
8-
from collections.abc import AsyncIterator
8+
from collections.abc import AsyncIterator, Callable
99
from contextlib import asynccontextmanager
1010
from types import TracebackType
1111
from typing import Any
1212

1313
from langchain_core.documents.base import Blob
1414
from langchain_core.messages import AIMessage, HumanMessage
15-
from langchain_core.tools import BaseTool
15+
from langchain_core.tools import BaseTool, ToolException
1616
from mcp import ClientSession
17+
from pydantic import ValidationError
1718

1819
from langchain_mcp_adapters.callbacks import CallbackContext, Callbacks
1920
from langchain_mcp_adapters.hooks import Hooks
@@ -142,12 +143,22 @@ async def session(
142143
await session.initialize()
143144
yield session
144145

145-
async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
146+
async def get_tools(
147+
self,
148+
*,
149+
server_name: str | None = None,
150+
handle_tool_error: bool | str | Callable[[ToolException], str] | None = False,
151+
handle_validation_error: (
152+
bool | str | Callable[[ValidationError], str] | None
153+
) = False,
154+
) -> list[BaseTool]:
146155
"""Get a list of all tools from all connected servers.
147156
148157
Args:
149158
server_name: Optional name of the server to get tools from.
150159
If None, all tools from all servers will be returned (default).
160+
handle_tool_error: Optional error handler for tool execution errors.
161+
handle_validation_error: Optional error handler for validation errors.
151162
152163
NOTE: a new session will be created for each tool call
153164
@@ -168,6 +179,8 @@ async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
168179
callbacks=self.callbacks,
169180
server_name=server_name,
170181
hooks=self.hooks,
182+
handle_tool_error=handle_tool_error,
183+
handle_validation_error=handle_validation_error,
171184
)
172185

173186
all_tools: list[BaseTool] = []
@@ -180,6 +193,8 @@ async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
180193
callbacks=self.callbacks,
181194
server_name=name,
182195
hooks=self.hooks,
196+
handle_tool_error=handle_tool_error,
197+
handle_validation_error=handle_validation_error,
183198
)
184199
)
185200
load_mcp_tool_tasks.append(load_mcp_tool_task)

langchain_mcp_adapters/tools.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
tools, handle tool execution, and manage tool conversion between the two formats.
55
"""
66

7+
from collections.abc import Callable
78
from typing import Any, cast, get_args
89

910
from langchain_core.tools import (
@@ -25,7 +26,7 @@
2526
TextContent,
2627
)
2728
from mcp.types import Tool as MCPTool
28-
from pydantic import BaseModel, create_model
29+
from pydantic import BaseModel, ValidationError, create_model
2930

3031
from langchain_mcp_adapters.callbacks import CallbackContext, Callbacks, _MCPCallbacks
3132
from langchain_mcp_adapters.hooks import CallToolRequestSpec, Hooks, ToolHookContext
@@ -128,6 +129,10 @@ def convert_mcp_tool_to_langchain_tool(
128129
callbacks: Callbacks | None = None,
129130
hooks: Hooks | None = None,
130131
server_name: str | None = None,
132+
handle_tool_error: bool | str | Callable[[ToolException], str] | None = False,
133+
handle_validation_error: (
134+
bool | str | Callable[[ValidationError], str] | None
135+
) = False,
131136
) -> BaseTool:
132137
"""Convert an MCP tool to a LangChain tool.
133138
@@ -141,6 +146,8 @@ def convert_mcp_tool_to_langchain_tool(
141146
callbacks: Optional callbacks for handling notifications and events
142147
hooks: Optional hooks for before/after tool call processing
143148
server_name: Name of the server this tool belongs to
149+
handle_tool_error: Optional error handler for tool execution errors.
150+
handle_validation_error: Optional error handler for validation errors.
144151
145152
Returns:
146153
a LangChain tool
@@ -259,6 +266,8 @@ async def call_tool(
259266
coroutine=call_tool,
260267
response_format="content_and_artifact",
261268
metadata=metadata,
269+
handle_tool_error=handle_tool_error,
270+
handle_validation_error=handle_validation_error,
262271
)
263272

264273

@@ -269,6 +278,10 @@ async def load_mcp_tools(
269278
callbacks: Callbacks | None = None,
270279
hooks: Hooks | None = None,
271280
server_name: str | None = None,
281+
handle_tool_error: bool | str | Callable[[ToolException], str] | None = False,
282+
handle_validation_error: (
283+
bool | str | Callable[[ValidationError], str] | None
284+
) = False,
272285
) -> list[BaseTool]:
273286
"""Load all available MCP tools and convert them to LangChain tools.
274287
@@ -278,6 +291,8 @@ async def load_mcp_tools(
278291
callbacks: Optional callbacks for handling notifications and events.
279292
hooks: Optional hooks for before/after tool call processing.
280293
server_name: Name of the server these tools belong to.
294+
handle_tool_error: Optional error handler for tool execution errors.
295+
handle_validation_error: Optional error handler for validation errors.
281296
282297
Returns:
283298
List of LangChain tools. Tool annotations are returned as part
@@ -317,6 +332,8 @@ async def load_mcp_tools(
317332
callbacks=callbacks,
318333
hooks=hooks,
319334
server_name=server_name,
335+
handle_tool_error=handle_tool_error,
336+
handle_validation_error=handle_validation_error,
320337
)
321338
for tool in tools
322339
]

0 commit comments

Comments
 (0)