Skip to content

feat(toolbox-langchain): Support per-invocation auth via RunnableConfig #291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: anubhav-state-li
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions packages/toolbox-langchain/src/toolbox_langchain/async_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any, Callable, Union

from deprecated import deprecated
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from toolbox_core.tool import ToolboxTool as ToolboxCoreTool
from toolbox_core.utils import params_to_pydantic_model
Expand Down Expand Up @@ -52,7 +53,11 @@ def __init__(
def _run(self, **kwargs: Any) -> str:
raise NotImplementedError("Synchronous methods not supported by async tools.")

async def _arun(self, **kwargs: Any) -> str:
async def _arun(
self,
config: RunnableConfig,
**kwargs: Any,
) -> str:
"""
The coroutine that invokes the tool with the given arguments.

Expand All @@ -63,7 +68,33 @@ async def _arun(self, **kwargs: Any) -> str:
A dictionary containing the parsed JSON response from the tool
invocation.
"""
return await self.__core_tool(**kwargs)
tool_to_run = self.__core_tool
if (
config
and "configurable" in config
and "auth_token_getters" in config["configurable"]
):
auth_token_getters = config["configurable"]["auth_token_getters"]
if auth_token_getters:

# The `add_auth_token_getters` method requires that all provided
# getters are used by the tool. To prevent validation errors,
# filter the incoming getters to include only those that this
# specific tool requires.
required_auth_keys = set(self.__core_tool._required_authz_tokens)
for auth_list in self.__core_tool._required_authn_params.values():
required_auth_keys.update(auth_list)
filtered_getters = {
k: v
for k, v in auth_token_getters.items()
if k in required_auth_keys
}
if filtered_getters:
tool_to_run = self.__core_tool.add_auth_token_getters(
filtered_getters
)

return await tool_to_run(**kwargs)

def add_auth_token_getters(
self, auth_token_getters: dict[str, Callable[[], str]]
Expand Down
69 changes: 63 additions & 6 deletions packages/toolbox-langchain/src/toolbox_langchain/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

from asyncio import to_thread
from typing import Any, Awaitable, Callable, Mapping, Sequence, Union
from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union

from deprecated import deprecated
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import BaseTool
from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool
from toolbox_core.utils import params_to_pydantic_model
Expand Down Expand Up @@ -73,11 +74,67 @@ def _client_headers(
) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]:
return self.__core_tool._client_headers

def _run(self, **kwargs: Any) -> str:
return self.__core_tool(**kwargs)

async def _arun(self, **kwargs: Any) -> str:
return await to_thread(self.__core_tool, **kwargs)
def _run(
self,
config: RunnableConfig,
**kwargs: Any,
) -> str:
tool_to_run = self.__core_tool
if (
config
and "configurable" in config
and "auth_token_getters" in config["configurable"]
):
auth_token_getters = config["configurable"]["auth_token_getters"]
if auth_token_getters:

# The `add_auth_token_getters` method requires that all provided
# getters are used by the tool. To prevent validation errors,
# filter the incoming getters to include only those that this
# specific tool requires.
required_auth_keys = set(self.__core_tool._required_authz_tokens)
for auth_list in self.__core_tool._required_authn_params.values():
required_auth_keys.update(auth_list)
filtered_getters = {
k: v
for k, v in auth_token_getters.items()
if k in required_auth_keys
}
if filtered_getters:
tool_to_run = self.__core_tool.add_auth_token_getters(
filtered_getters
)

return tool_to_run(**kwargs)

async def _arun(self, config: RunnableConfig, **kwargs: Any) -> str:
tool_to_run = self.__core_tool
if (
config
and "configurable" in config
and "auth_token_getters" in config["configurable"]
):
auth_token_getters = config["configurable"]["auth_token_getters"]
if auth_token_getters:

# The `add_auth_token_getters` method requires that all provided
# getters are used by the tool. To prevent validation errors,
# filter the incoming getters to include only those that this
# specific tool requires.
required_auth_keys = set(self.__core_tool._required_authz_tokens)
for auth_list in self.__core_tool._required_authn_params.values():
required_auth_keys.update(auth_list)
filtered_getters = {
k: v
for k, v in auth_token_getters.items()
if k in required_auth_keys
}
if filtered_getters:
tool_to_run = self.__core_tool.add_auth_token_getters(
filtered_getters
)

return await to_thread(tool_to_run, **kwargs)

def add_auth_token_getters(
self, auth_token_getters: dict[str, Callable[[], str]]
Expand Down
4 changes: 2 additions & 2 deletions packages/toolbox-langchain/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def test_toolbox_tool_run(self, toolbox_tool, mock_core_tool):
expected_result = "sync_run_output"
mock_core_tool.return_value = expected_result

result = toolbox_tool._run(**kwargs_to_run)
result = toolbox_tool._run(**kwargs_to_run, config={})

assert result == expected_result
assert mock_core_tool.call_count == 1
Expand All @@ -294,7 +294,7 @@ async def to_thread_side_effect(func, *args, **kwargs_for_func):

mock_to_thread_in_tools.side_effect = to_thread_side_effect

result = await toolbox_tool._arun(**kwargs_to_run)
result = await toolbox_tool._arun(**kwargs_to_run, config={})

assert result == expected_result
mock_to_thread_in_tools.assert_awaited_once_with(
Expand Down