Skip to content

Commit 0f50eb0

Browse files
committed
feat(toolbox-langchain): Implement self-authenticated tools
1 parent a304b28 commit 0f50eb0

File tree

2 files changed

+96
-8
lines changed

2 files changed

+96
-8
lines changed

packages/toolbox-langchain/src/toolbox_langchain/async_tools.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Any, Callable, Union
1616

1717
from deprecated import deprecated
18+
from langchain_core.runnables import RunnableConfig
1819
from langchain_core.tools import BaseTool
1920
from toolbox_core.tool import ToolboxTool as ToolboxCoreTool
2021
from toolbox_core.utils import params_to_pydantic_model
@@ -52,7 +53,11 @@ def __init__(
5253
def _run(self, **kwargs: Any) -> str:
5354
raise NotImplementedError("Synchronous methods not supported by async tools.")
5455

55-
async def _arun(self, **kwargs: Any) -> str:
56+
async def _arun(
57+
self,
58+
config: RunnableConfig,
59+
**kwargs: Any,
60+
) -> str:
5661
"""
5762
The coroutine that invokes the tool with the given arguments.
5863
@@ -63,7 +68,33 @@ async def _arun(self, **kwargs: Any) -> str:
6368
A dictionary containing the parsed JSON response from the tool
6469
invocation.
6570
"""
66-
return await self.__core_tool(**kwargs)
71+
tool_to_run = self.__core_tool
72+
if (
73+
config
74+
and "configurable" in config
75+
and "auth_token_getters" in config["configurable"]
76+
):
77+
auth_token_getters = config["configurable"]["auth_token_getters"]
78+
if auth_token_getters:
79+
80+
# The `add_auth_token_getters` method requires that all provided
81+
# getters are used by the tool. To prevent validation errors,
82+
# filter the incoming getters to include only those that this
83+
# specific tool requires.
84+
required_auth_keys = set(self.__core_tool._required_authz_tokens)
85+
for auth_list in self.__core_tool._required_authn_params.values():
86+
required_auth_keys.update(auth_list)
87+
filtered_getters = {
88+
k: v
89+
for k, v in auth_token_getters.items()
90+
if k in required_auth_keys
91+
}
92+
if filtered_getters:
93+
tool_to_run = self.__core_tool.add_auth_token_getters(
94+
filtered_getters
95+
)
96+
97+
return await tool_to_run(**kwargs)
6798

6899
def add_auth_token_getters(
69100
self, auth_token_getters: dict[str, Callable[[], str]]

packages/toolbox-langchain/src/toolbox_langchain/tools.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
# limitations under the License.
1414

1515
from asyncio import to_thread
16-
from typing import Any, Awaitable, Callable, Mapping, Sequence, Union
16+
from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union
1717

1818
from deprecated import deprecated
19+
from langchain_core.runnables import RunnableConfig
1920
from langchain_core.tools import BaseTool
2021
from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool
2122
from toolbox_core.utils import params_to_pydantic_model
@@ -73,11 +74,67 @@ def _client_headers(
7374
) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]:
7475
return self.__core_tool._client_headers
7576

76-
def _run(self, **kwargs: Any) -> str:
77-
return self.__core_tool(**kwargs)
78-
79-
async def _arun(self, **kwargs: Any) -> str:
80-
return await to_thread(self.__core_tool, **kwargs)
77+
def _run(
78+
self,
79+
config: RunnableConfig,
80+
**kwargs: Any,
81+
) -> str:
82+
tool_to_run = self.__core_tool
83+
if (
84+
config
85+
and "configurable" in config
86+
and "auth_token_getters" in config["configurable"]
87+
):
88+
auth_token_getters = config["configurable"]["auth_token_getters"]
89+
if auth_token_getters:
90+
91+
# The `add_auth_token_getters` method requires that all provided
92+
# getters are used by the tool. To prevent validation errors,
93+
# filter the incoming getters to include only those that this
94+
# specific tool requires.
95+
required_auth_keys = set(self.__core_tool._required_authz_tokens)
96+
for auth_list in self.__core_tool._required_authn_params.values():
97+
required_auth_keys.update(auth_list)
98+
filtered_getters = {
99+
k: v
100+
for k, v in auth_token_getters.items()
101+
if k in required_auth_keys
102+
}
103+
if filtered_getters:
104+
tool_to_run = self.__core_tool.add_auth_token_getters(
105+
filtered_getters
106+
)
107+
108+
return tool_to_run(**kwargs)
109+
110+
async def _arun(self, config: RunnableConfig, **kwargs: Any) -> str:
111+
tool_to_run = self.__core_tool
112+
if (
113+
config
114+
and "configurable" in config
115+
and "auth_token_getters" in config["configurable"]
116+
):
117+
auth_token_getters = config["configurable"]["auth_token_getters"]
118+
if auth_token_getters:
119+
120+
# The `add_auth_token_getters` method requires that all provided
121+
# getters are used by the tool. To prevent validation errors,
122+
# filter the incoming getters to include only those that this
123+
# specific tool requires.
124+
required_auth_keys = set(self.__core_tool._required_authz_tokens)
125+
for auth_list in self.__core_tool._required_authn_params.values():
126+
required_auth_keys.update(auth_list)
127+
filtered_getters = {
128+
k: v
129+
for k, v in auth_token_getters.items()
130+
if k in required_auth_keys
131+
}
132+
if filtered_getters:
133+
tool_to_run = self.__core_tool.add_auth_token_getters(
134+
filtered_getters
135+
)
136+
137+
return await to_thread(tool_to_run, **kwargs)
81138

82139
def add_auth_token_getters(
83140
self, auth_token_getters: dict[str, Callable[[], str]]

0 commit comments

Comments
 (0)