diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py index fee763c3..971ebd24 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py @@ -15,6 +15,7 @@ from typing import Any, Callable, Union from deprecated import deprecated +from langchain_core.runnables import RunnableConfig from langchain_core.tools import BaseTool from toolbox_core.tool import ToolboxTool as ToolboxCoreTool from toolbox_core.utils import params_to_pydantic_model @@ -52,7 +53,11 @@ def __init__( def _run(self, **kwargs: Any) -> str: raise NotImplementedError("Synchronous methods not supported by async tools.") - async def _arun(self, **kwargs: Any) -> str: + async def _arun( + self, + config: RunnableConfig, + **kwargs: Any, + ) -> str: """ The coroutine that invokes the tool with the given arguments. @@ -63,7 +68,33 @@ async def _arun(self, **kwargs: Any) -> str: A dictionary containing the parsed JSON response from the tool invocation. """ - return await self.__core_tool(**kwargs) + tool_to_run = self.__core_tool + if ( + config + and "configurable" in config + and "auth_token_getters" in config["configurable"] + ): + auth_token_getters = config["configurable"]["auth_token_getters"] + if auth_token_getters: + + # The `add_auth_token_getters` method requires that all provided + # getters are used by the tool. To prevent validation errors, + # filter the incoming getters to include only those that this + # specific tool requires. + required_auth_keys = set(self.__core_tool._required_authz_tokens) + for auth_list in self.__core_tool._required_authn_params.values(): + required_auth_keys.update(auth_list) + filtered_getters = { + k: v + for k, v in auth_token_getters.items() + if k in required_auth_keys + } + if filtered_getters: + tool_to_run = self.__core_tool.add_auth_token_getters( + filtered_getters + ) + + return await tool_to_run(**kwargs) def add_auth_token_getters( self, auth_token_getters: dict[str, Callable[[], str]] diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index e03b37f8..34654882 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -13,9 +13,10 @@ # limitations under the License. from asyncio import to_thread -from typing import Any, Awaitable, Callable, Mapping, Sequence, Union +from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union from deprecated import deprecated +from langchain_core.runnables import RunnableConfig from langchain_core.tools import BaseTool from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool from toolbox_core.utils import params_to_pydantic_model @@ -73,11 +74,67 @@ def _client_headers( ) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]: return self.__core_tool._client_headers - def _run(self, **kwargs: Any) -> str: - return self.__core_tool(**kwargs) - - async def _arun(self, **kwargs: Any) -> str: - return await to_thread(self.__core_tool, **kwargs) + def _run( + self, + config: RunnableConfig, + **kwargs: Any, + ) -> str: + tool_to_run = self.__core_tool + if ( + config + and "configurable" in config + and "auth_token_getters" in config["configurable"] + ): + auth_token_getters = config["configurable"]["auth_token_getters"] + if auth_token_getters: + + # The `add_auth_token_getters` method requires that all provided + # getters are used by the tool. To prevent validation errors, + # filter the incoming getters to include only those that this + # specific tool requires. + required_auth_keys = set(self.__core_tool._required_authz_tokens) + for auth_list in self.__core_tool._required_authn_params.values(): + required_auth_keys.update(auth_list) + filtered_getters = { + k: v + for k, v in auth_token_getters.items() + if k in required_auth_keys + } + if filtered_getters: + tool_to_run = self.__core_tool.add_auth_token_getters( + filtered_getters + ) + + return tool_to_run(**kwargs) + + async def _arun(self, config: RunnableConfig, **kwargs: Any) -> str: + tool_to_run = self.__core_tool + if ( + config + and "configurable" in config + and "auth_token_getters" in config["configurable"] + ): + auth_token_getters = config["configurable"]["auth_token_getters"] + if auth_token_getters: + + # The `add_auth_token_getters` method requires that all provided + # getters are used by the tool. To prevent validation errors, + # filter the incoming getters to include only those that this + # specific tool requires. + required_auth_keys = set(self.__core_tool._required_authz_tokens) + for auth_list in self.__core_tool._required_authn_params.values(): + required_auth_keys.update(auth_list) + filtered_getters = { + k: v + for k, v in auth_token_getters.items() + if k in required_auth_keys + } + if filtered_getters: + tool_to_run = self.__core_tool.add_auth_token_getters( + filtered_getters + ) + + return await to_thread(tool_to_run, **kwargs) def add_auth_token_getters( self, auth_token_getters: dict[str, Callable[[], str]] diff --git a/packages/toolbox-langchain/tests/test_tools.py b/packages/toolbox-langchain/tests/test_tools.py index 90fddf4b..6a6b6fdb 100644 --- a/packages/toolbox-langchain/tests/test_tools.py +++ b/packages/toolbox-langchain/tests/test_tools.py @@ -273,7 +273,7 @@ def test_toolbox_tool_run(self, toolbox_tool, mock_core_tool): expected_result = "sync_run_output" mock_core_tool.return_value = expected_result - result = toolbox_tool._run(**kwargs_to_run) + result = toolbox_tool._run(**kwargs_to_run, config={}) assert result == expected_result assert mock_core_tool.call_count == 1 @@ -294,7 +294,7 @@ async def to_thread_side_effect(func, *args, **kwargs_for_func): mock_to_thread_in_tools.side_effect = to_thread_side_effect - result = await toolbox_tool._arun(**kwargs_to_run) + result = await toolbox_tool._arun(**kwargs_to_run, config={}) assert result == expected_result mock_to_thread_in_tools.assert_awaited_once_with(