Skip to content

fix: Add validation to ensure added auth token getters are used by the tool #220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: anubhav-fix-authz-required
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,9 @@ def add_auth_token_getters(
new_getters = types.MappingProxyType(
dict(self.__auth_service_token_getters, **auth_token_getters)
)
# create a read-only updated for params that are still required

# find the updated required authn params, authz tokens and the auth
# token getters used
new_req_authn_params, new_req_authz_tokens, used_auth_token_getters = (
identify_required_authn_params(
self.__required_authn_params,
Expand All @@ -292,10 +294,16 @@ def add_auth_token_getters(
)
)

# TODO: Add validation for used_auth_token_getters
# ensure no auth token getter provided remains unused
unused_auth = set(incoming_services) - used_auth_token_getters
if unused_auth:
raise ValueError(
f"Authentication source(s) `{', '.join(unused_auth)}` unused by tool `{self.__name__}`."
)

return self.__copy(
auth_service_token_getters=new_getters,
# create a read-only map for params that are still required
required_authn_params=types.MappingProxyType(new_req_authn_params),
required_authz_tokens=new_req_authz_tokens,
)
Expand Down
28 changes: 28 additions & 0 deletions packages/toolbox-core/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,34 @@ async def test_add_auth_token_getters_duplicate_fail(self, tool_name, client):
):
authed_tool.add_auth_token_getters({AUTH_SERVICE: {}})

@pytest.mark.asyncio
async def test_add_auth_token_getters_missing_fail(self, tool_name, client):
"""
Tests that adding a missing auth token getter raises ValueError.
"""
AUTH_SERVICE = "xmy-auth-service"

tool = await client.load_tool(tool_name)

with pytest.raises(
ValueError,
match=f"Authentication source\(s\) \`{AUTH_SERVICE}\` unused by tool \`{tool_name}\`.",
):
tool.add_auth_token_getters({AUTH_SERVICE: {}})

@pytest.mark.asyncio
async def test_constructor_getters_missing_fail(self, tool_name, client):
"""
Tests that adding a missing auth token getter raises ValueError.
"""
AUTH_SERVICE = "xmy-auth-service"

with pytest.raises(
ValueError,
match=f"Validation failed for tool '{tool_name}': unused auth tokens: {AUTH_SERVICE}.",
):
await client.load_tool(tool_name, auth_token_getters={AUTH_SERVICE: {}})


class TestBoundParameter:

Expand Down
39 changes: 38 additions & 1 deletion packages/toolbox-core/tests/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import inspect
from typing import AsyncGenerator, Callable
from typing import AsyncGenerator, Callable, Mapping
from unittest.mock import AsyncMock, Mock

import pytest
Expand Down Expand Up @@ -92,6 +94,12 @@ def auth_header_key() -> str:
return "test-auth_token"


@pytest.fixture
def unused_auth_getters() -> dict[str, Callable[[], str]]:
"""Provides an auth getter for a service not required by sample_tool."""
return {"unused-auth-service": lambda: "unused-token-value"}


def test_create_func_docstring_one_param_real_schema():
"""
Tests create_func_docstring with one real ParameterSchema instance.
Expand Down Expand Up @@ -432,3 +440,32 @@ def test_tool_add_auth_token_getters_conflict_with_existing_client_header(

with pytest.raises(ValueError, match=expected_error_message):
tool_instance.add_auth_token_getters(new_auth_getters_causing_conflict)


def test_add_auth_token_getters_unused_token(
http_session: ClientSession,
sample_tool_params: list[ParameterSchema],
sample_tool_description: str,
unused_auth_getters: Mapping[str, Callable[[], str]],
):
"""
Tests ValueError when add_auth_token_getters is called with a getter for
an unused authentication service.
"""
tool_instance = ToolboxTool(
session=http_session,
base_url=TEST_BASE_URL,
name=TEST_TOOL_NAME,
description=sample_tool_description,
params=sample_tool_params,
required_authn_params={},
required_authz_tokens=[],
auth_service_token_getters={},
bound_params={},
client_headers={},
)

expected_error_message = "Authentication source\(s\) \`unused-auth-service\` unused by tool \`sample_tool\`."

with pytest.raises(ValueError, match=expected_error_message):
tool_instance.add_auth_token_getters(unused_auth_getters)