From 3465ca5fb8de6e46d67f49f083cfc51626d4280e Mon Sep 17 00:00:00 2001 From: Dor Lugasi-Gal Date: Mon, 3 Apr 2023 05:41:33 +0000 Subject: [PATCH] added role_requirements.Any\All behaviour to the authorize decorator (#7) * added role_requirements.Any\All behavior to the authorized decorator * Added some docstrings * formatting * merge tests with parametrize * readme update --- README.md | 17 +++++++ aad_fastapi/aad_decorators.py | 23 +++++---- aad_fastapi/roles/all_role_validator.py | 12 +++++ aad_fastapi/roles/any_role_validator.py | 12 +++++ aad_fastapi/roles/role_requirement.py | 8 +++ aad_fastapi/roles/role_validator.py | 28 +++++++++++ aad_fastapi/roles/role_validator_interface.py | 8 +++ tests/conftest.py | 23 +++++++-- tests/test_authorize_decorator.py | 48 ++++++++++++++++++ tests/test_role_validator.py | 49 +++++++++++++++++++ 10 files changed, 214 insertions(+), 14 deletions(-) create mode 100644 aad_fastapi/roles/all_role_validator.py create mode 100644 aad_fastapi/roles/any_role_validator.py create mode 100644 aad_fastapi/roles/role_requirement.py create mode 100644 aad_fastapi/roles/role_validator.py create mode 100644 aad_fastapi/roles/role_validator_interface.py create mode 100644 tests/test_authorize_decorator.py create mode 100644 tests/test_role_validator.py diff --git a/README.md b/README.md index 3d1f89e..f402aea 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,23 @@ async def user_with_scope_and_roles( ``` +### Role Requirements and Multiple Roles +The `@authorize` decorator is capable of specifying multiple roles as a list of string, which will be checked against the user's roles. +If **all** the roles are found, the user is authorized. + +To require the user to have **at least one** of the specified roles, you can use the `role_requirement` parameter with the +value `RoleRequirement.ANY`: + +``` python +@app.get("/user_with_scope_and_roles_any") +@authorize("user_impersonation", roles=["admin","superuser"], role_requirement=RoleRequirement.ANY) +async def user_with_scope_and_roles_any( + request: Request, token=Depends(oauth2_scheme(options=api_options)) +): + # code here +``` +> **NOTE:** `RoleRequirement.ALL` is the default behavior and does not need to be specified. + ## Register your application within your Azure AD tenant There are two applications to register: diff --git a/aad_fastapi/aad_decorators.py b/aad_fastapi/aad_decorators.py index 99ee9d7..c9d5036 100644 --- a/aad_fastapi/aad_decorators.py +++ b/aad_fastapi/aad_decorators.py @@ -11,24 +11,26 @@ from .aad_options import AzureAdSettings from .aad_user import AadUser +from .roles.role_requirement import RoleRequirement +from .roles.role_validator import RoleValidator def authorize( scopes: typing.Union[str, typing.Sequence[str]] = None, roles: typing.Union[str, typing.Sequence[str]] = None, + role_requirement: RoleRequirement = RoleRequirement.ALL, ): - """authorize decorator. you can specify scopes and (or) roles""" + """ + Decorator to authorize a route + :param scopes: list of scopes + :param roles: list of roles to validate + :param role_requirement: role requirement (RoleRequirement.ALL or RoleRequirement.ANY) + """ def wrapper(endpoint): @wraps(endpoint) @requires(scopes) async def require_auth_endpoint(request: Request, *args, **kwargs): - def has_required_roles(user_roles: typing.Sequence[str]) -> bool: - for mandatory_role in mandatory_roles_list: - if mandatory_role not in user_roles: - return False - return True - # Check args mandatory_roles_list = [] if roles is not None: @@ -39,7 +41,10 @@ def has_required_roles(user_roles: typing.Sequence[str]) -> bool: user: AadUser = request.user user_roles_list = user.roles_id or [] - if len(mandatory_roles_list) > 0 and not has_required_roles(user_roles_list): + role_validator = RoleValidator(mandatory_roles_list, role_requirement) + if len(mandatory_roles_list) > 0 and not role_validator.validate_roles( + user_roles_list + ): raise HTTPException(status_code=403, detail="Unauthorized role") if inspect.iscoroutinefunction(endpoint): @@ -57,7 +62,7 @@ def has_required_roles(user_roles: typing.Sequence[str]) -> bool: def oauth2_scheme( options: AzureAdSettings = None, env_path: Optional[str] = None, **kwargs ): - """get the OAUTH2 schema used for API Authentication""" + """get the OAuth2 schema used for API Authentication""" if options is None: options = AzureAdSettings(_env_file=env_path) diff --git a/aad_fastapi/roles/all_role_validator.py b/aad_fastapi/roles/all_role_validator.py new file mode 100644 index 0000000..a624d7c --- /dev/null +++ b/aad_fastapi/roles/all_role_validator.py @@ -0,0 +1,12 @@ +import typing + +from aad_fastapi.roles.role_validator_interface import RoleValidatorInterface + + +class AllRoleValidator(RoleValidatorInterface): + """Validate that all mandatory roles are present in the user roles""" + + def validate_roles( + self, user_roles: typing.Sequence[str], mandatory_roles: typing.Sequence[str] + ) -> bool: + return all(mandatory_role in user_roles for mandatory_role in mandatory_roles) diff --git a/aad_fastapi/roles/any_role_validator.py b/aad_fastapi/roles/any_role_validator.py new file mode 100644 index 0000000..d8f3319 --- /dev/null +++ b/aad_fastapi/roles/any_role_validator.py @@ -0,0 +1,12 @@ +import typing + +from aad_fastapi.roles.role_validator_interface import RoleValidatorInterface + + +class AnyRoleValidator(RoleValidatorInterface): + """Validate that at least one mandatory role is present in the user roles""" + + def validate_roles( + self, user_roles: typing.Sequence[str], mandatory_roles: typing.Sequence[str] + ) -> bool: + return any(mandatory_role in user_roles for mandatory_role in mandatory_roles) diff --git a/aad_fastapi/roles/role_requirement.py b/aad_fastapi/roles/role_requirement.py new file mode 100644 index 0000000..f9af4e2 --- /dev/null +++ b/aad_fastapi/roles/role_requirement.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class RoleRequirement(Enum): + """Role requirement enum for authorization.""" + + ALL = "all" + ANY = "any" diff --git a/aad_fastapi/roles/role_validator.py b/aad_fastapi/roles/role_validator.py new file mode 100644 index 0000000..279e60b --- /dev/null +++ b/aad_fastapi/roles/role_validator.py @@ -0,0 +1,28 @@ +import typing + +from aad_fastapi.roles.all_role_validator import AllRoleValidator +from aad_fastapi.roles.any_role_validator import AnyRoleValidator +from aad_fastapi.roles.role_requirement import RoleRequirement + + +class RoleValidator: + """Role validator class""" + + _validators = { + RoleRequirement.ALL: AllRoleValidator, + RoleRequirement.ANY: AnyRoleValidator, + } + + def __init__( + self, mandatory_roles: typing.List[str], role_requirement: RoleRequirement + ): + self.mandatory_roles = mandatory_roles + self.role_requirement = role_requirement + self.validator_class = self._validators.get(role_requirement) + if self.validator_class is None: + raise ValueError(f"Invalid role requirement: {role_requirement}") + + def validate_roles(self, user_roles: typing.List[str]) -> bool: + """validate the user roles against the mandatory roles""" + validator = self.validator_class() + return validator.validate_roles(user_roles, self.mandatory_roles) diff --git a/aad_fastapi/roles/role_validator_interface.py b/aad_fastapi/roles/role_validator_interface.py new file mode 100644 index 0000000..dab111d --- /dev/null +++ b/aad_fastapi/roles/role_validator_interface.py @@ -0,0 +1,8 @@ +import typing + + +class RoleValidatorInterface: + def validate_roles( + self, user_roles: typing.Sequence[str], mandatory_roles: typing.Sequence[str] + ) -> bool: + """validate the user roles against the mandatory roles""" diff --git a/tests/conftest.py b/tests/conftest.py index 0a7dec8..8916fea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,15 +2,17 @@ import sys import pytest -from aad_fastapi import AadBearerBackend, authorize, oauth2_scheme -from async_asgi_testclient import TestClient from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import rsa from fastapi import FastAPI from fastapi.param_functions import Depends +from fastapi.testclient import TestClient from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import Request +from aad_fastapi import AadBearerBackend, authorize, oauth2_scheme +from aad_fastapi.roles.role_requirement import RoleRequirement + os.environ["CLIENT_ID"] = "01010101-aaaa-bbbb-acdf-020202020202" os.environ["TENANT_ID"] = "02020202-aaaa-erty-olki-020202020202" os.environ["DOMAIN"] = "contoso.onmicrosoft.com" @@ -49,7 +51,7 @@ def cert(): @pytest.fixture(scope="module") -def client(public_key): +def mock_test_client(public_key): # pre fill client id swagger_ui_init_oauth = { "clientId": os.environ.get("CLIENT_ID"), @@ -77,9 +79,20 @@ async def get_isauth_with_impersonation( ): return request.user - @app.get("/isauth_impersonation_roles") + @app.get("/isauth_impersonation_all_roles") @authorize("user_impersonation", ["Admin", "Contributor"]) - async def get_isauth_with_impersonation_and_roles( + async def get_isauth_with_impersonation_and_all_roles( + request: Request, token=Depends(oauth2_scheme()) + ): + return request.user + + @app.get("/isauth_impersonation_any_roles") + @authorize( + "user_impersonation", + ["Admin", "Contributor"], + role_requirement=RoleRequirement.ANY, + ) + async def get_isauth_with_impersonation_and_any_roles( request: Request, token=Depends(oauth2_scheme()) ): return request.user diff --git a/tests/test_authorize_decorator.py b/tests/test_authorize_decorator.py new file mode 100644 index 0000000..09e1506 --- /dev/null +++ b/tests/test_authorize_decorator.py @@ -0,0 +1,48 @@ +import json + +import pytest + +from aad_fastapi import AzureAdSettings +from tests.helpers import gen_payload + + +@pytest.mark.parametrize( + "roles, expected_status_code", + [ + (["Admin", "Contributor"], 200), + (["Admin"], 403), + (["Contributor"], 403), + ([], 403), + ], +) +def test_isauth_with_impersonation_and_all_roles( + mock_test_client, private_key, roles, expected_status_code +): + options = AzureAdSettings() + payload = gen_payload(options, private_key, roles=roles, scp=["user_impersonation"]) + token = json.loads(payload)["access_token"] + response = mock_test_client.get( + "/isauth_impersonation_all_roles", headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == expected_status_code + + +@pytest.mark.parametrize( + "roles,expected_status_code", + [ + (["Admin", "Contributor"], 200), + (["Admin"], 200), + (["Contributor"], 200), + ([], 403), + ], +) +def test_valid_access_token_with_any_roles( + mock_test_client, private_key, roles, expected_status_code +): + options = AzureAdSettings() + payload = gen_payload(options, private_key, roles=roles, scp=["user_impersonation"]) + token = json.loads(payload)["access_token"] + response = mock_test_client.get( + "/isauth_impersonation_any_roles", headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == expected_status_code diff --git a/tests/test_role_validator.py b/tests/test_role_validator.py new file mode 100644 index 0000000..b00c5a5 --- /dev/null +++ b/tests/test_role_validator.py @@ -0,0 +1,49 @@ +import pytest + +from aad_fastapi.roles.all_role_validator import AllRoleValidator +from aad_fastapi.roles.any_role_validator import AnyRoleValidator +from aad_fastapi.roles.role_requirement import RoleRequirement +from aad_fastapi.roles.role_validator import RoleValidator + + +@pytest.mark.parametrize( + "roles, requirement, user_roles, expected", + [ + (["admin", "editor"], RoleRequirement.ALL, ["admin", "editor"], True), + (["admin", "editor"], RoleRequirement.ALL, ["admin"], False), + (["admin", "editor"], RoleRequirement.ANY, ["admin"], True), + (["admin", "editor"], RoleRequirement.ANY, ["guest"], False), + ], +) +def test_role_validator(roles, requirement, user_roles, expected): + validator = RoleValidator(roles, requirement) + assert validator.validate_roles(user_roles) == expected + + +def test_role_validator_invalid_role_requirement(): + with pytest.raises(ValueError): + RoleValidator([], "invalid") + + +@pytest.mark.parametrize( + "mandatory_roles, user_roles, expected", + [ + (["admin", "editor"], ["admin", "editor"], True), + (["admin", "editor"], ["admin"], False), + ], +) +def test_all_role_validator(mandatory_roles, user_roles, expected): + validator = AllRoleValidator() + assert validator.validate_roles(user_roles, mandatory_roles) == expected + + +@pytest.mark.parametrize( + "mandatory_roles, user_roles, expected", + [ + (["admin", "editor"], ["admin"], True), + (["admin", "editor"], ["guest"], False), + ], +) +def test_any_role_validator(mandatory_roles, user_roles, expected): + validator = AnyRoleValidator() + assert validator.validate_roles(user_roles, mandatory_roles) == expected