Skip to content

feat: Track utilized auth services #203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: anubhav-re-bind-error
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/toolbox-core/src/toolbox_core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __parse_tool(

authn_params = identify_required_authn_params(
authn_params, auth_token_getters.keys()
)
)[0]

tool = ToolboxTool(
session=self.__session,
Expand Down
2 changes: 1 addition & 1 deletion packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def add_auth_token_getters(
new_req_authn_params = types.MappingProxyType(
identify_required_authn_params(
self.__required_authn_params, auth_token_getters.keys()
)
)[0]
)

return self.__copy(
Expand Down
29 changes: 20 additions & 9 deletions packages/toolbox-core/src/toolbox_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,29 +46,40 @@ def create_func_docstring(description: str, params: Sequence[ParameterSchema]) -

def identify_required_authn_params(
req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str]
) -> dict[str, list[str]]:
) -> tuple[dict[str, list[str]], set[str]]:
"""
Identifies authentication parameters that are still required; because they
are not covered by the provided `auth_service_names`.
are not covered by the provided `auth_service_names`, and also returns a
set of all authentication services that were found to be matching.

Args:
req_authn_params: A mapping of parameter names to sets of required
req_authn_params: A mapping of parameter names to lists of required
authentication services.
auth_service_names: An iterable of authentication service names for which
token getters are available.

Returns:
A new dictionary representing the subset of required authentication parameters
that are not covered by the provided `auth_service_names`.
A tuple containing:
- A new dictionary representing the subset of required
authentication parameters that are not covered by the provided
`auth_service_names`.
- A list of authentication service names from `auth_service_names`
that were found to satisfy at least one parameter's requirements.
"""
required_params = {} # params that are still required with provided auth_services
required_params: dict[str, list[str]] = {}
used_services: set[str] = set()

for param, services in req_authn_params.items():
# if we don't have a token_getter for any of the services required by the param,
# the param is still required
required = not any(s in services for s in auth_service_names)
if required:
matched_services = [s for s in services if s in auth_service_names]

if matched_services:
used_services.update(matched_services)
else:
required_params[param] = services
return required_params

return required_params, used_services


def params_to_pydantic_model(
Expand Down
36 changes: 24 additions & 12 deletions packages/toolbox-core/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ def test_identify_required_authn_params_none_required():
req_authn_params = {}
auth_service_names = ["service_a", "service_b"]
expected = {}
assert (
identify_required_authn_params(req_authn_params, auth_service_names) == expected
expected_used = set()
assert identify_required_authn_params(req_authn_params, auth_service_names) == (
expected,
expected_used,
)


Expand All @@ -100,8 +102,10 @@ def test_identify_required_authn_params_all_covered():
}
auth_service_names = ["service_a", "service_b"]
expected = {}
assert (
identify_required_authn_params(req_authn_params, auth_service_names) == expected
expected_used = set(auth_service_names)
assert identify_required_authn_params(req_authn_params, auth_service_names) == (
expected,
expected_used,
)


Expand All @@ -118,8 +122,10 @@ def test_identify_required_authn_params_some_covered():
"token_d": ["service_d"],
"token_e": ["service_e", "service_f"],
}
assert (
identify_required_authn_params(req_authn_params, auth_service_names) == expected
expected_used = set(auth_service_names)
assert identify_required_authn_params(req_authn_params, auth_service_names) == (
expected,
expected_used,
)


Expand All @@ -134,8 +140,10 @@ def test_identify_required_authn_params_none_covered():
"token_d": ["service_d"],
"token_e": ["service_e", "service_f"],
}
assert (
identify_required_authn_params(req_authn_params, auth_service_names) == expected
expected_used = set()
assert identify_required_authn_params(req_authn_params, auth_service_names) == (
expected,
expected_used,
)


Expand All @@ -150,8 +158,10 @@ def test_identify_required_authn_params_no_available_services():
"token_a": ["service_a"],
"token_b": ["service_b", "service_c"],
}
assert (
identify_required_authn_params(req_authn_params, auth_service_names) == expected
expected_used = set()
assert identify_required_authn_params(req_authn_params, auth_service_names) == (
expected,
expected_used,
)


Expand All @@ -164,8 +174,10 @@ def test_identify_required_authn_params_empty_services_for_param():
expected = {
"token_x": [],
}
assert (
identify_required_authn_params(req_authn_params, auth_service_names) == expected
expected_used = set()
assert identify_required_authn_params(req_authn_params, auth_service_names) == (
expected,
expected_used,
)


Expand Down