Skip to content

Commit

Permalink
fix azure and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
dpgaspar committed Oct 6, 2023
1 parent 9b73caf commit b3a49f2
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 53 deletions.
11 changes: 9 additions & 2 deletions docs/security.rst
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,8 @@ Specify a list of OAUTH_PROVIDERS in **config.py** that you want to allow for yo
"client_kwargs": {
"scope": "User.read name preferred_username email profile upn",
"resource": "AZURE_APPLICATION_ID",
# Optionally enforce signature JWT verification
"verify_signature": False
},
"request_token_url": None,
"access_token_url": "https://login.microsoftonline.com/AZURE_TENANT_ID/oauth2/token",
Expand Down Expand Up @@ -346,8 +348,14 @@ You can give FlaskAppBuilder roles based on Oauth groups::

To customize the userinfo retrieval, you can create your own method like this::

from flask_appbuilder.security.manager import UserInfo

@appbuilder.sm.oauth_user_info_getter
def my_user_info_getter(sm, provider, response=None):
def my_user_info_getter(
sm: SecurityManager,
provider: str,
response: Dict[str, Any]
) -> UserInfo:
if provider == "okta":
me = sm.oauth_remotes[provider].get("userinfo")
return {
Expand All @@ -364,7 +372,6 @@ To customize the userinfo retrieval, you can create your own method like this::
"email": me.json().get("email"),
"first_name": me.json().get("given_name", ""),
"last_name": me.json().get("family_name", ""),
"id": me.json().get("sub", ""),
"role_keys": ["User"], # set AUTH_ROLES_SYNC_AT_LOGIN = False
}
return {}
Expand Down
96 changes: 59 additions & 37 deletions flask_appbuilder/security/manager.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import datetime
import json
import logging
import re
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypedDict, Union

from authlib.jose import JsonWebKey, jwt
from flask import Flask, g, session, url_for
from flask_babel import lazy_gettext as _
from flask_jwt_extended import current_user as current_user_jwt
from flask_jwt_extended import JWTManager
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from flask_login import current_user, LoginManager
import requests
import jwt
from werkzeug.security import check_password_hash, generate_password_hash

from .api import SecurityApi
Expand Down Expand Up @@ -62,6 +60,14 @@
log = logging.getLogger(__name__)


class UserInfo(TypedDict, total=False):
username: str
first_name: str
last_name: str
email: str
role_keys: List[str]


class AbstractSecurityManager(BaseManager):
"""
Abstract SecurityManager class, declares all methods used by the
Expand Down Expand Up @@ -271,7 +277,7 @@ def __init__(self, appbuilder):
from authlib.integrations.flask_client import OAuth

self.oauth = OAuth(app)
self.oauth_remotes = dict()
self.oauth_remotes = {}
for _provider in self.oauth_providers:
provider_name = _provider["name"]
log.debug("OAuth providers init %s", provider_name)
Expand Down Expand Up @@ -519,7 +525,9 @@ def current_user(self):
elif current_user_jwt:
return current_user_jwt

def oauth_user_info_getter(self, f):
def oauth_user_info_getter(
self, func: Callable[["BaseSecurityManager", str, Dict[str, Any]], UserInfo]
):
"""
Decorator function to be the OAuth user info getter
for all the providers, receives provider and response
Expand All @@ -534,21 +542,11 @@ def my_oauth_user_info(sm, provider, response=None):
if provider == 'github':
me = sm.oauth_remotes[provider].get('user')
return {'username': me.data.get('login')}
else:
return {}
return {}
"""

def wraps(provider, response=None):
ret = f(self, provider, response=response)
# Checks if decorator is well behaved and returns a dict as supposed.
if not type(ret) == dict:
log.error(
"OAuth user info decorated function "
"did not returned a dict, but: %s",
type(ret),
)
return {}
return ret
def wraps(provider: str, response: Dict[str, Any] = None) -> UserInfo:
return func(self, provider, response)

Check warning on line 549 in flask_appbuilder/security/manager.py

View check run for this annotation

Codecov / codecov/patch

flask_appbuilder/security/manager.py#L548-L549

Added lines #L548 - L549 were not covered by tests

self.oauth_user_info = wraps
return wraps
Expand Down Expand Up @@ -587,9 +585,9 @@ def set_oauth_session(self, provider, oauth_response):
)
session["oauth_provider"] = provider

def get_oauth_user_info(self, provider, resp):
def get_oauth_user_info(self, provider: str, resp: Dict[str, Any]) -> UserInfo:
"""
Since there are different OAuth API's with different ways to
Since there are different OAuth APIs with different ways to
retrieve user info
"""
# for GITHUB
Expand Down Expand Up @@ -629,9 +627,8 @@ def get_oauth_user_info(self, provider, resp):
"email": data.get("email", ""),
}
if provider == "azure":
log.debug("Azure response received:\n%s", json.dumps(resp, indent=4))
me = self._decode_and_validate_azure_jwt(resp["id_token"])
log.debug("Decoded JWT:\n%s", json.dumps(me, indent=4))
log.debug("User info from Azure: %s", me)
# https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference#payload-claims
return {
"email": me["email"],
Expand Down Expand Up @@ -662,7 +659,9 @@ def get_oauth_user_info(self, provider, resp):
}
# for Keycloak
if provider in ["keycloak", "keycloak_before_17"]:
me = self.appbuilder.sm.oauth_remotes[provider].get("openid-connect/userinfo")
me = self.appbuilder.sm.oauth_remotes[provider].get(
"openid-connect/userinfo"
)
me.raise_for_status()
data = me.json()
log.debug("User info from Keycloak: %s", data)
Expand All @@ -674,13 +673,22 @@ def get_oauth_user_info(self, provider, resp):
}
return {}

Check warning on line 674 in flask_appbuilder/security/manager.py

View check run for this annotation

Codecov / codecov/patch

flask_appbuilder/security/manager.py#L674

Added line #L674 was not covered by tests

def _decode_and_validate_azure_jwt(self, id_token):
keyset = JsonWebKey.import_key_set(requests.get(MICROSOFT_KEY_SET_URL).json())
claims = jwt.decode(id_token, keyset)
claims.validate()
log.debug("Decoded JWT:\n%s", json.dumps(claims, indent=4))
def _decode_and_validate_azure_jwt(self, id_token: str) -> Dict[str, str]:
verify_signature = self.oauth_remotes["azure"].client_kwargs.get(
"verify_signature", False
)
if verify_signature:
from authlib.jose import JsonWebKey, jwt as authlib_jwt
import requests

Check warning on line 682 in flask_appbuilder/security/manager.py

View check run for this annotation

Codecov / codecov/patch

flask_appbuilder/security/manager.py#L681-L682

Added lines #L681 - L682 were not covered by tests

keyset = JsonWebKey.import_key_set(

Check warning on line 684 in flask_appbuilder/security/manager.py

View check run for this annotation

Codecov / codecov/patch

flask_appbuilder/security/manager.py#L684

Added line #L684 was not covered by tests
requests.get(MICROSOFT_KEY_SET_URL).json()
)
claims = authlib_jwt.decode(id_token, keyset)
claims.validate()
return claims

Check warning on line 689 in flask_appbuilder/security/manager.py

View check run for this annotation

Codecov / codecov/patch

flask_appbuilder/security/manager.py#L687-L689

Added lines #L687 - L689 were not covered by tests

return claims
return jwt.decode(id_token, options={"verify_signature": False})

def register_views(self):
if not self.appbuilder.app.config.get("FAB_ADD_SECURITY_VIEWS", True):
Expand Down Expand Up @@ -785,7 +793,9 @@ def register_views(self):
label=_("Views/Menus"),
category="Security",
)
if self.appbuilder.app.config.get("FAB_ADD_SECURITY_PERMISSION_VIEWS_VIEW", True):
if self.appbuilder.app.config.get(
"FAB_ADD_SECURITY_PERMISSION_VIEWS_VIEW", True
):
self.appbuilder.add_view(
self.permissionviewmodelview,
"Permission on Views/Menus",
Expand Down Expand Up @@ -958,7 +968,9 @@ def _search_ldap(self, ldap, con, username):
except (IndexError, NameError):
return None, None

def _ldap_calculate_user_roles(self, user_attributes: Dict[str, bytes]) -> List[str]:
def _ldap_calculate_user_roles(
self, user_attributes: Dict[str, bytes]
) -> List[str]:
user_role_objects = set()

# apply AUTH_ROLES_MAPPING
Expand Down Expand Up @@ -1023,7 +1035,9 @@ def _ldap_bind(ldap, con, dn: str, password: str) -> bool:
return False

@staticmethod
def ldap_extract(ldap_dict: Dict[str, bytes], field_name: str, fallback: str) -> str:
def ldap_extract(
ldap_dict: Dict[str, bytes], field_name: str, fallback: str
) -> str:
raw_value = ldap_dict.get(field_name, [bytes()])
# decode - if empty string, default to fallback, otherwise take first element
return raw_value[0].decode("utf-8") or fallback
Expand Down Expand Up @@ -1070,7 +1084,9 @@ def auth_user_ldap(self, username, password):
if self.auth_ldap_tls_cacertdir:
ldap.set_option(ldap.OPT_X_TLS_CACERTDIR, self.auth_ldap_tls_cacertdir)
if self.auth_ldap_tls_cacertfile:
ldap.set_option(ldap.OPT_X_TLS_CACERTFILE, self.auth_ldap_tls_cacertfile)
ldap.set_option(
ldap.OPT_X_TLS_CACERTFILE, self.auth_ldap_tls_cacertfile
)
if self.auth_ldap_tls_certfile:
ldap.set_option(ldap.OPT_X_TLS_CERTFILE, self.auth_ldap_tls_certfile)
if self.auth_ldap_tls_keyfile:
Expand Down Expand Up @@ -1489,7 +1505,9 @@ def _get_user_permission_view_menus(
# Then check against database-stored roles
pvms_names = [
pvm.view_menu.name
for pvm in self.find_roles_permission_view_menus(permission_name, db_role_ids)
for pvm in self.find_roles_permission_view_menus(
permission_name, db_role_ids
)
]
result.update(pvms_names)
return result
Expand Down Expand Up @@ -1636,7 +1654,9 @@ def _get_new_old_permissions(baseview) -> Dict:
method_name
)
# Actions do not get prefix when normally defined
if hasattr(baseview, "actions") and baseview.actions.get(old_permission_name):
if hasattr(baseview, "actions") and baseview.actions.get(
old_permission_name
):
permission_prefix = ""
else:
permission_prefix = PERMISSION_PREFIX
Expand Down Expand Up @@ -1912,7 +1932,9 @@ def find_permission(self, name):
"""
raise NotImplementedError

def find_roles_permission_view_menus(self, permission_name: str, role_ids: List[int]):
def find_roles_permission_view_menus(
self, permission_name: str, role_ids: List[int]
):
raise NotImplementedError

def exist_permission_on_roles(
Expand Down
63 changes: 60 additions & 3 deletions tests/security/test_auth_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flask_appbuilder import AppBuilder, SQLA
from flask_appbuilder.const import AUTH_OAUTH
import jinja2
import jwt
from tests.const import USERNAME_ADMIN, USERNAME_READONLY
from tests.fixtures.users import create_default_users

Expand All @@ -24,9 +25,29 @@ def setUp(self):
)
self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
self.app.config["AUTH_TYPE"] = AUTH_OAUTH
self.app.config[
"OAUTH_PROVIDERS"
] = [] # can be empty, because we dont use the external providers in tests
self.app.config["OAUTH_PROVIDERS"] = [
{
"name": "azure",
"icon": "fa-windows",
"token_key": "access_token",
"remote_app": {
"client_id": "CLIENT_ID",
"client_secret": "SECRET",
"api_base_url": "https://login.microsoftonline.com/TENANT_ID/oauth2",
"client_kwargs": {
"scope": "User.Read name email profile",
"resource": "AZURE_APPLICATION_ID",
},
"request_token_url": None,
"access_token_url": "https://login.microsoftonline.com/"
"AZURE_APPLICATION_ID/"
"oauth2/token",
"authorize_url": "https://login.microsoftonline.com/"
"AZURE_APPLICATION_ID/"
"oauth2/authorize",
},
}
]

# start Database
self.db = SQLA(self.app)
Expand Down Expand Up @@ -437,3 +458,39 @@ def test__registered__jmespath_role__with_role_sync(self):

# validate - user was given the correct roles
self.assertListEqual(user.roles, [sm.find_role("User")])

def test_oauth_user_info_azure(self):

self.appbuilder = AppBuilder(self.app, self.db.session)
claims = {
"aud": "test-aud",
"iss": "https://sts.windows.net/test/",
"iat": 7282182129,
"nbf": 7282182129,
"exp": 1000000000,
"amr": ["pwd"],
"email": "[email protected]",
"family_name": "user",
"given_name": "test",
"idp": "live.com",
"name": "Test user",
"oid": "b1a54a40-8dfa-4a6d-a2b8-f90b84d4b1df",
"unique_name": "live.com#[email protected]",
"ver": "1.0",
}

# Create an unsigned JWT
unsigned_jwt = jwt.encode(claims, key=None, algorithm="none")
user_info = self.appbuilder.sm.get_oauth_user_info(
"azure", {"access_token": "", "id_token": unsigned_jwt}
)
self.assertEqual(
user_info,
{
"email": "[email protected]",
"first_name": "test",
"last_name": "user",
"role_keys": [],
"username": "b1a54a40-8dfa-4a6d-a2b8-f90b84d4b1df",
},
)
11 changes: 0 additions & 11 deletions tests/security/test_base_security_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import datetime
import json
import unittest
from unittest.mock import MagicMock, patch

from flask_appbuilder.security.manager import BaseSecurityManager
from flask_appbuilder.security.manager import JsonWebKey, jwt

JWTClaimsMock = MagicMock()

Expand Down Expand Up @@ -71,12 +69,3 @@ def test_subsequent_unsuccessful_auth(self, mock1, mock2):
self.assertEqual(user_mock.fail_login_count, 10)
self.assertEqual(user_mock.last_login, None)
self.assertTrue(bsm.update_user.called_once)

@patch.object(JsonWebKey, "import_key_set", MagicMock())
@patch.object(jwt, "decode", MagicMock(return_value=JWTClaimsMock))
@patch.object(json, "dumps", MagicMock(return_value="DecodedExampleAzureJWT"))
def test_azure_jwt_validated(self, mock1, mock2):
bsm = BaseSecurityManager()

bsm._decode_and_validate_azure_jwt("ExampleAzureJWT")
JWTClaimsMock.validate.assert_called()

0 comments on commit b3a49f2

Please sign in to comment.