|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | from asyncio import to_thread
|
16 |
| -from typing import Any, Awaitable, Callable, Mapping, Sequence, Union |
| 16 | +from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union |
17 | 17 |
|
18 | 18 | from deprecated import deprecated
|
| 19 | +from langchain_core.runnables import RunnableConfig |
19 | 20 | from langchain_core.tools import BaseTool
|
20 | 21 | from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool
|
21 | 22 | from toolbox_core.utils import params_to_pydantic_model
|
@@ -73,11 +74,67 @@ def _client_headers(
|
73 | 74 | ) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]:
|
74 | 75 | return self.__core_tool._client_headers
|
75 | 76 |
|
76 |
| - def _run(self, **kwargs: Any) -> str: |
77 |
| - return self.__core_tool(**kwargs) |
78 |
| - |
79 |
| - async def _arun(self, **kwargs: Any) -> str: |
80 |
| - return await to_thread(self.__core_tool, **kwargs) |
| 77 | + def _run( |
| 78 | + self, |
| 79 | + config: RunnableConfig, |
| 80 | + **kwargs: Any, |
| 81 | + ) -> str: |
| 82 | + tool_to_run = self.__core_tool |
| 83 | + if ( |
| 84 | + config |
| 85 | + and "configurable" in config |
| 86 | + and "auth_token_getters" in config["configurable"] |
| 87 | + ): |
| 88 | + auth_token_getters = config["configurable"]["auth_token_getters"] |
| 89 | + if auth_token_getters: |
| 90 | + |
| 91 | + # The `add_auth_token_getters` method requires that all provided |
| 92 | + # getters are used by the tool. To prevent validation errors, |
| 93 | + # filter the incoming getters to include only those that this |
| 94 | + # specific tool requires. |
| 95 | + required_auth_keys = set(self.__core_tool._required_authz_tokens) |
| 96 | + for auth_list in self.__core_tool._required_authn_params.values(): |
| 97 | + required_auth_keys.update(auth_list) |
| 98 | + filtered_getters = { |
| 99 | + k: v |
| 100 | + for k, v in auth_token_getters.items() |
| 101 | + if k in required_auth_keys |
| 102 | + } |
| 103 | + if filtered_getters: |
| 104 | + tool_to_run = self.__core_tool.add_auth_token_getters( |
| 105 | + filtered_getters |
| 106 | + ) |
| 107 | + |
| 108 | + return tool_to_run(**kwargs) |
| 109 | + |
| 110 | + async def _arun(self, config: RunnableConfig, **kwargs: Any) -> str: |
| 111 | + tool_to_run = self.__core_tool |
| 112 | + if ( |
| 113 | + config |
| 114 | + and "configurable" in config |
| 115 | + and "auth_token_getters" in config["configurable"] |
| 116 | + ): |
| 117 | + auth_token_getters = config["configurable"]["auth_token_getters"] |
| 118 | + if auth_token_getters: |
| 119 | + |
| 120 | + # The `add_auth_token_getters` method requires that all provided |
| 121 | + # getters are used by the tool. To prevent validation errors, |
| 122 | + # filter the incoming getters to include only those that this |
| 123 | + # specific tool requires. |
| 124 | + required_auth_keys = set(self.__core_tool._required_authz_tokens) |
| 125 | + for auth_list in self.__core_tool._required_authn_params.values(): |
| 126 | + required_auth_keys.update(auth_list) |
| 127 | + filtered_getters = { |
| 128 | + k: v |
| 129 | + for k, v in auth_token_getters.items() |
| 130 | + if k in required_auth_keys |
| 131 | + } |
| 132 | + if filtered_getters: |
| 133 | + tool_to_run = self.__core_tool.add_auth_token_getters( |
| 134 | + filtered_getters |
| 135 | + ) |
| 136 | + |
| 137 | + return await to_thread(tool_to_run, **kwargs) |
81 | 138 |
|
82 | 139 | def add_auth_token_getters(
|
83 | 140 | self, auth_token_getters: dict[str, Callable[[], str]]
|
|
0 commit comments