Skip to content

Commit

Permalink
added role_requirements.Any\All behaviour to the authorize decorator (#7
Browse files Browse the repository at this point in the history
)

* added role_requirements.Any\All behavior to the authorized decorator
* Added some docstrings
* formatting
* merge tests with parametrize
* readme update
  • Loading branch information
dorlugasigal authored Apr 3, 2023
1 parent 563bb4f commit 3465ca5
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 14 deletions.
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,23 @@ async def user_with_scope_and_roles(
```


### Role Requirements and Multiple Roles
The `@authorize` decorator is capable of specifying multiple roles as a list of string, which will be checked against the user's roles.
If **all** the roles are found, the user is authorized.

To require the user to have **at least one** of the specified roles, you can use the `role_requirement` parameter with the
value `RoleRequirement.ANY`:

``` python
@app.get("/user_with_scope_and_roles_any")
@authorize("user_impersonation", roles=["admin","superuser"], role_requirement=RoleRequirement.ANY)
async def user_with_scope_and_roles_any(
request: Request, token=Depends(oauth2_scheme(options=api_options))
):
# code here
```
> **NOTE:** `RoleRequirement.ALL` is the default behavior and does not need to be specified.
## Register your application within your Azure AD tenant

There are two applications to register:
Expand Down
23 changes: 14 additions & 9 deletions aad_fastapi/aad_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,26 @@

from .aad_options import AzureAdSettings
from .aad_user import AadUser
from .roles.role_requirement import RoleRequirement
from .roles.role_validator import RoleValidator


def authorize(
scopes: typing.Union[str, typing.Sequence[str]] = None,
roles: typing.Union[str, typing.Sequence[str]] = None,
role_requirement: RoleRequirement = RoleRequirement.ALL,
):
"""authorize decorator. you can specify scopes and (or) roles"""
"""
Decorator to authorize a route
:param scopes: list of scopes
:param roles: list of roles to validate
:param role_requirement: role requirement (RoleRequirement.ALL or RoleRequirement.ANY)
"""

def wrapper(endpoint):
@wraps(endpoint)
@requires(scopes)
async def require_auth_endpoint(request: Request, *args, **kwargs):
def has_required_roles(user_roles: typing.Sequence[str]) -> bool:
for mandatory_role in mandatory_roles_list:
if mandatory_role not in user_roles:
return False
return True

# Check args
mandatory_roles_list = []
if roles is not None:
Expand All @@ -39,7 +41,10 @@ def has_required_roles(user_roles: typing.Sequence[str]) -> bool:
user: AadUser = request.user
user_roles_list = user.roles_id or []

if len(mandatory_roles_list) > 0 and not has_required_roles(user_roles_list):
role_validator = RoleValidator(mandatory_roles_list, role_requirement)
if len(mandatory_roles_list) > 0 and not role_validator.validate_roles(
user_roles_list
):
raise HTTPException(status_code=403, detail="Unauthorized role")

if inspect.iscoroutinefunction(endpoint):
Expand All @@ -57,7 +62,7 @@ def has_required_roles(user_roles: typing.Sequence[str]) -> bool:
def oauth2_scheme(
options: AzureAdSettings = None, env_path: Optional[str] = None, **kwargs
):
"""get the OAUTH2 schema used for API Authentication"""
"""get the OAuth2 schema used for API Authentication"""

if options is None:
options = AzureAdSettings(_env_file=env_path)
Expand Down
12 changes: 12 additions & 0 deletions aad_fastapi/roles/all_role_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import typing

from aad_fastapi.roles.role_validator_interface import RoleValidatorInterface


class AllRoleValidator(RoleValidatorInterface):
"""Validate that all mandatory roles are present in the user roles"""

def validate_roles(
self, user_roles: typing.Sequence[str], mandatory_roles: typing.Sequence[str]
) -> bool:
return all(mandatory_role in user_roles for mandatory_role in mandatory_roles)
12 changes: 12 additions & 0 deletions aad_fastapi/roles/any_role_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import typing

from aad_fastapi.roles.role_validator_interface import RoleValidatorInterface


class AnyRoleValidator(RoleValidatorInterface):
"""Validate that at least one mandatory role is present in the user roles"""

def validate_roles(
self, user_roles: typing.Sequence[str], mandatory_roles: typing.Sequence[str]
) -> bool:
return any(mandatory_role in user_roles for mandatory_role in mandatory_roles)
8 changes: 8 additions & 0 deletions aad_fastapi/roles/role_requirement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from enum import Enum


class RoleRequirement(Enum):
"""Role requirement enum for authorization."""

ALL = "all"
ANY = "any"
28 changes: 28 additions & 0 deletions aad_fastapi/roles/role_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import typing

from aad_fastapi.roles.all_role_validator import AllRoleValidator
from aad_fastapi.roles.any_role_validator import AnyRoleValidator
from aad_fastapi.roles.role_requirement import RoleRequirement


class RoleValidator:
"""Role validator class"""

_validators = {
RoleRequirement.ALL: AllRoleValidator,
RoleRequirement.ANY: AnyRoleValidator,
}

def __init__(
self, mandatory_roles: typing.List[str], role_requirement: RoleRequirement
):
self.mandatory_roles = mandatory_roles
self.role_requirement = role_requirement
self.validator_class = self._validators.get(role_requirement)
if self.validator_class is None:
raise ValueError(f"Invalid role requirement: {role_requirement}")

def validate_roles(self, user_roles: typing.List[str]) -> bool:
"""validate the user roles against the mandatory roles"""
validator = self.validator_class()
return validator.validate_roles(user_roles, self.mandatory_roles)
8 changes: 8 additions & 0 deletions aad_fastapi/roles/role_validator_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import typing


class RoleValidatorInterface:
def validate_roles(
self, user_roles: typing.Sequence[str], mandatory_roles: typing.Sequence[str]
) -> bool:
"""validate the user roles against the mandatory roles"""
23 changes: 18 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
import sys

import pytest
from aad_fastapi import AadBearerBackend, authorize, oauth2_scheme
from async_asgi_testclient import TestClient
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
from fastapi import FastAPI
from fastapi.param_functions import Depends
from fastapi.testclient import TestClient
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import Request

from aad_fastapi import AadBearerBackend, authorize, oauth2_scheme
from aad_fastapi.roles.role_requirement import RoleRequirement

os.environ["CLIENT_ID"] = "01010101-aaaa-bbbb-acdf-020202020202"
os.environ["TENANT_ID"] = "02020202-aaaa-erty-olki-020202020202"
os.environ["DOMAIN"] = "contoso.onmicrosoft.com"
Expand Down Expand Up @@ -49,7 +51,7 @@ def cert():


@pytest.fixture(scope="module")
def client(public_key):
def mock_test_client(public_key):
# pre fill client id
swagger_ui_init_oauth = {
"clientId": os.environ.get("CLIENT_ID"),
Expand Down Expand Up @@ -77,9 +79,20 @@ async def get_isauth_with_impersonation(
):
return request.user

@app.get("/isauth_impersonation_roles")
@app.get("/isauth_impersonation_all_roles")
@authorize("user_impersonation", ["Admin", "Contributor"])
async def get_isauth_with_impersonation_and_roles(
async def get_isauth_with_impersonation_and_all_roles(
request: Request, token=Depends(oauth2_scheme())
):
return request.user

@app.get("/isauth_impersonation_any_roles")
@authorize(
"user_impersonation",
["Admin", "Contributor"],
role_requirement=RoleRequirement.ANY,
)
async def get_isauth_with_impersonation_and_any_roles(
request: Request, token=Depends(oauth2_scheme())
):
return request.user
Expand Down
48 changes: 48 additions & 0 deletions tests/test_authorize_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import json

import pytest

from aad_fastapi import AzureAdSettings
from tests.helpers import gen_payload


@pytest.mark.parametrize(
"roles, expected_status_code",
[
(["Admin", "Contributor"], 200),
(["Admin"], 403),
(["Contributor"], 403),
([], 403),
],
)
def test_isauth_with_impersonation_and_all_roles(
mock_test_client, private_key, roles, expected_status_code
):
options = AzureAdSettings()
payload = gen_payload(options, private_key, roles=roles, scp=["user_impersonation"])
token = json.loads(payload)["access_token"]
response = mock_test_client.get(
"/isauth_impersonation_all_roles", headers={"Authorization": f"Bearer {token}"}
)
assert response.status_code == expected_status_code


@pytest.mark.parametrize(
"roles,expected_status_code",
[
(["Admin", "Contributor"], 200),
(["Admin"], 200),
(["Contributor"], 200),
([], 403),
],
)
def test_valid_access_token_with_any_roles(
mock_test_client, private_key, roles, expected_status_code
):
options = AzureAdSettings()
payload = gen_payload(options, private_key, roles=roles, scp=["user_impersonation"])
token = json.loads(payload)["access_token"]
response = mock_test_client.get(
"/isauth_impersonation_any_roles", headers={"Authorization": f"Bearer {token}"}
)
assert response.status_code == expected_status_code
49 changes: 49 additions & 0 deletions tests/test_role_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest

from aad_fastapi.roles.all_role_validator import AllRoleValidator
from aad_fastapi.roles.any_role_validator import AnyRoleValidator
from aad_fastapi.roles.role_requirement import RoleRequirement
from aad_fastapi.roles.role_validator import RoleValidator


@pytest.mark.parametrize(
"roles, requirement, user_roles, expected",
[
(["admin", "editor"], RoleRequirement.ALL, ["admin", "editor"], True),
(["admin", "editor"], RoleRequirement.ALL, ["admin"], False),
(["admin", "editor"], RoleRequirement.ANY, ["admin"], True),
(["admin", "editor"], RoleRequirement.ANY, ["guest"], False),
],
)
def test_role_validator(roles, requirement, user_roles, expected):
validator = RoleValidator(roles, requirement)
assert validator.validate_roles(user_roles) == expected


def test_role_validator_invalid_role_requirement():
with pytest.raises(ValueError):
RoleValidator([], "invalid")


@pytest.mark.parametrize(
"mandatory_roles, user_roles, expected",
[
(["admin", "editor"], ["admin", "editor"], True),
(["admin", "editor"], ["admin"], False),
],
)
def test_all_role_validator(mandatory_roles, user_roles, expected):
validator = AllRoleValidator()
assert validator.validate_roles(user_roles, mandatory_roles) == expected


@pytest.mark.parametrize(
"mandatory_roles, user_roles, expected",
[
(["admin", "editor"], ["admin"], True),
(["admin", "editor"], ["guest"], False),
],
)
def test_any_role_validator(mandatory_roles, user_roles, expected):
validator = AnyRoleValidator()
assert validator.validate_roles(user_roles, mandatory_roles) == expected

0 comments on commit 3465ca5

Please sign in to comment.