From b3a49f2e9fa3bdfe09809f51088fdfabe563901f Mon Sep 17 00:00:00 2001 From: Daniel Gaspar Date: Fri, 6 Oct 2023 13:43:00 +0100 Subject: [PATCH] fix azure and add test --- docs/security.rst | 11 ++- flask_appbuilder/security/manager.py | 96 ++++++++++++-------- tests/security/test_auth_oauth.py | 63 ++++++++++++- tests/security/test_base_security_manager.py | 11 --- 4 files changed, 128 insertions(+), 53 deletions(-) diff --git a/docs/security.rst b/docs/security.rst index 47d8c3395a..9545c21c11 100644 --- a/docs/security.rst +++ b/docs/security.rst @@ -305,6 +305,8 @@ Specify a list of OAUTH_PROVIDERS in **config.py** that you want to allow for yo "client_kwargs": { "scope": "User.read name preferred_username email profile upn", "resource": "AZURE_APPLICATION_ID", + # Optionally enforce signature JWT verification + "verify_signature": False }, "request_token_url": None, "access_token_url": "https://login.microsoftonline.com/AZURE_TENANT_ID/oauth2/token", @@ -346,8 +348,14 @@ You can give FlaskAppBuilder roles based on Oauth groups:: To customize the userinfo retrieval, you can create your own method like this:: + from flask_appbuilder.security.manager import UserInfo + @appbuilder.sm.oauth_user_info_getter - def my_user_info_getter(sm, provider, response=None): + def my_user_info_getter( + sm: SecurityManager, + provider: str, + response: Dict[str, Any] + ) -> UserInfo: if provider == "okta": me = sm.oauth_remotes[provider].get("userinfo") return { @@ -364,7 +372,6 @@ To customize the userinfo retrieval, you can create your own method like this:: "email": me.json().get("email"), "first_name": me.json().get("given_name", ""), "last_name": me.json().get("family_name", ""), - "id": me.json().get("sub", ""), "role_keys": ["User"], # set AUTH_ROLES_SYNC_AT_LOGIN = False } return {} diff --git a/flask_appbuilder/security/manager.py b/flask_appbuilder/security/manager.py index 304554c16c..560c8d8fc1 100644 --- a/flask_appbuilder/security/manager.py +++ b/flask_appbuilder/security/manager.py @@ -1,10 +1,8 @@ import datetime -import json import logging import re -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypedDict, Union -from authlib.jose import JsonWebKey, jwt from flask import Flask, g, session, url_for from flask_babel import lazy_gettext as _ from flask_jwt_extended import current_user as current_user_jwt @@ -12,7 +10,7 @@ from flask_limiter import Limiter from flask_limiter.util import get_remote_address from flask_login import current_user, LoginManager -import requests +import jwt from werkzeug.security import check_password_hash, generate_password_hash from .api import SecurityApi @@ -62,6 +60,14 @@ log = logging.getLogger(__name__) +class UserInfo(TypedDict, total=False): + username: str + first_name: str + last_name: str + email: str + role_keys: List[str] + + class AbstractSecurityManager(BaseManager): """ Abstract SecurityManager class, declares all methods used by the @@ -271,7 +277,7 @@ def __init__(self, appbuilder): from authlib.integrations.flask_client import OAuth self.oauth = OAuth(app) - self.oauth_remotes = dict() + self.oauth_remotes = {} for _provider in self.oauth_providers: provider_name = _provider["name"] log.debug("OAuth providers init %s", provider_name) @@ -519,7 +525,9 @@ def current_user(self): elif current_user_jwt: return current_user_jwt - def oauth_user_info_getter(self, f): + def oauth_user_info_getter( + self, func: Callable[["BaseSecurityManager", str, Dict[str, Any]], UserInfo] + ): """ Decorator function to be the OAuth user info getter for all the providers, receives provider and response @@ -534,21 +542,11 @@ def my_oauth_user_info(sm, provider, response=None): if provider == 'github': me = sm.oauth_remotes[provider].get('user') return {'username': me.data.get('login')} - else: - return {} + return {} """ - def wraps(provider, response=None): - ret = f(self, provider, response=response) - # Checks if decorator is well behaved and returns a dict as supposed. - if not type(ret) == dict: - log.error( - "OAuth user info decorated function " - "did not returned a dict, but: %s", - type(ret), - ) - return {} - return ret + def wraps(provider: str, response: Dict[str, Any] = None) -> UserInfo: + return func(self, provider, response) self.oauth_user_info = wraps return wraps @@ -587,9 +585,9 @@ def set_oauth_session(self, provider, oauth_response): ) session["oauth_provider"] = provider - def get_oauth_user_info(self, provider, resp): + def get_oauth_user_info(self, provider: str, resp: Dict[str, Any]) -> UserInfo: """ - Since there are different OAuth API's with different ways to + Since there are different OAuth APIs with different ways to retrieve user info """ # for GITHUB @@ -629,9 +627,8 @@ def get_oauth_user_info(self, provider, resp): "email": data.get("email", ""), } if provider == "azure": - log.debug("Azure response received:\n%s", json.dumps(resp, indent=4)) me = self._decode_and_validate_azure_jwt(resp["id_token"]) - log.debug("Decoded JWT:\n%s", json.dumps(me, indent=4)) + log.debug("User info from Azure: %s", me) # https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference#payload-claims return { "email": me["email"], @@ -662,7 +659,9 @@ def get_oauth_user_info(self, provider, resp): } # for Keycloak if provider in ["keycloak", "keycloak_before_17"]: - me = self.appbuilder.sm.oauth_remotes[provider].get("openid-connect/userinfo") + me = self.appbuilder.sm.oauth_remotes[provider].get( + "openid-connect/userinfo" + ) me.raise_for_status() data = me.json() log.debug("User info from Keycloak: %s", data) @@ -674,13 +673,22 @@ def get_oauth_user_info(self, provider, resp): } return {} - def _decode_and_validate_azure_jwt(self, id_token): - keyset = JsonWebKey.import_key_set(requests.get(MICROSOFT_KEY_SET_URL).json()) - claims = jwt.decode(id_token, keyset) - claims.validate() - log.debug("Decoded JWT:\n%s", json.dumps(claims, indent=4)) + def _decode_and_validate_azure_jwt(self, id_token: str) -> Dict[str, str]: + verify_signature = self.oauth_remotes["azure"].client_kwargs.get( + "verify_signature", False + ) + if verify_signature: + from authlib.jose import JsonWebKey, jwt as authlib_jwt + import requests + + keyset = JsonWebKey.import_key_set( + requests.get(MICROSOFT_KEY_SET_URL).json() + ) + claims = authlib_jwt.decode(id_token, keyset) + claims.validate() + return claims - return claims + return jwt.decode(id_token, options={"verify_signature": False}) def register_views(self): if not self.appbuilder.app.config.get("FAB_ADD_SECURITY_VIEWS", True): @@ -785,7 +793,9 @@ def register_views(self): label=_("Views/Menus"), category="Security", ) - if self.appbuilder.app.config.get("FAB_ADD_SECURITY_PERMISSION_VIEWS_VIEW", True): + if self.appbuilder.app.config.get( + "FAB_ADD_SECURITY_PERMISSION_VIEWS_VIEW", True + ): self.appbuilder.add_view( self.permissionviewmodelview, "Permission on Views/Menus", @@ -958,7 +968,9 @@ def _search_ldap(self, ldap, con, username): except (IndexError, NameError): return None, None - def _ldap_calculate_user_roles(self, user_attributes: Dict[str, bytes]) -> List[str]: + def _ldap_calculate_user_roles( + self, user_attributes: Dict[str, bytes] + ) -> List[str]: user_role_objects = set() # apply AUTH_ROLES_MAPPING @@ -1023,7 +1035,9 @@ def _ldap_bind(ldap, con, dn: str, password: str) -> bool: return False @staticmethod - def ldap_extract(ldap_dict: Dict[str, bytes], field_name: str, fallback: str) -> str: + def ldap_extract( + ldap_dict: Dict[str, bytes], field_name: str, fallback: str + ) -> str: raw_value = ldap_dict.get(field_name, [bytes()]) # decode - if empty string, default to fallback, otherwise take first element return raw_value[0].decode("utf-8") or fallback @@ -1070,7 +1084,9 @@ def auth_user_ldap(self, username, password): if self.auth_ldap_tls_cacertdir: ldap.set_option(ldap.OPT_X_TLS_CACERTDIR, self.auth_ldap_tls_cacertdir) if self.auth_ldap_tls_cacertfile: - ldap.set_option(ldap.OPT_X_TLS_CACERTFILE, self.auth_ldap_tls_cacertfile) + ldap.set_option( + ldap.OPT_X_TLS_CACERTFILE, self.auth_ldap_tls_cacertfile + ) if self.auth_ldap_tls_certfile: ldap.set_option(ldap.OPT_X_TLS_CERTFILE, self.auth_ldap_tls_certfile) if self.auth_ldap_tls_keyfile: @@ -1489,7 +1505,9 @@ def _get_user_permission_view_menus( # Then check against database-stored roles pvms_names = [ pvm.view_menu.name - for pvm in self.find_roles_permission_view_menus(permission_name, db_role_ids) + for pvm in self.find_roles_permission_view_menus( + permission_name, db_role_ids + ) ] result.update(pvms_names) return result @@ -1636,7 +1654,9 @@ def _get_new_old_permissions(baseview) -> Dict: method_name ) # Actions do not get prefix when normally defined - if hasattr(baseview, "actions") and baseview.actions.get(old_permission_name): + if hasattr(baseview, "actions") and baseview.actions.get( + old_permission_name + ): permission_prefix = "" else: permission_prefix = PERMISSION_PREFIX @@ -1912,7 +1932,9 @@ def find_permission(self, name): """ raise NotImplementedError - def find_roles_permission_view_menus(self, permission_name: str, role_ids: List[int]): + def find_roles_permission_view_menus( + self, permission_name: str, role_ids: List[int] + ): raise NotImplementedError def exist_permission_on_roles( diff --git a/tests/security/test_auth_oauth.py b/tests/security/test_auth_oauth.py index 5040dd4812..1b85776e61 100644 --- a/tests/security/test_auth_oauth.py +++ b/tests/security/test_auth_oauth.py @@ -6,6 +6,7 @@ from flask_appbuilder import AppBuilder, SQLA from flask_appbuilder.const import AUTH_OAUTH import jinja2 +import jwt from tests.const import USERNAME_ADMIN, USERNAME_READONLY from tests.fixtures.users import create_default_users @@ -24,9 +25,29 @@ def setUp(self): ) self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False self.app.config["AUTH_TYPE"] = AUTH_OAUTH - self.app.config[ - "OAUTH_PROVIDERS" - ] = [] # can be empty, because we dont use the external providers in tests + self.app.config["OAUTH_PROVIDERS"] = [ + { + "name": "azure", + "icon": "fa-windows", + "token_key": "access_token", + "remote_app": { + "client_id": "CLIENT_ID", + "client_secret": "SECRET", + "api_base_url": "https://login.microsoftonline.com/TENANT_ID/oauth2", + "client_kwargs": { + "scope": "User.Read name email profile", + "resource": "AZURE_APPLICATION_ID", + }, + "request_token_url": None, + "access_token_url": "https://login.microsoftonline.com/" + "AZURE_APPLICATION_ID/" + "oauth2/token", + "authorize_url": "https://login.microsoftonline.com/" + "AZURE_APPLICATION_ID/" + "oauth2/authorize", + }, + } + ] # start Database self.db = SQLA(self.app) @@ -437,3 +458,39 @@ def test__registered__jmespath_role__with_role_sync(self): # validate - user was given the correct roles self.assertListEqual(user.roles, [sm.find_role("User")]) + + def test_oauth_user_info_azure(self): + + self.appbuilder = AppBuilder(self.app, self.db.session) + claims = { + "aud": "test-aud", + "iss": "https://sts.windows.net/test/", + "iat": 7282182129, + "nbf": 7282182129, + "exp": 1000000000, + "amr": ["pwd"], + "email": "test@gmail.com", + "family_name": "user", + "given_name": "test", + "idp": "live.com", + "name": "Test user", + "oid": "b1a54a40-8dfa-4a6d-a2b8-f90b84d4b1df", + "unique_name": "live.com#test@gmail.com", + "ver": "1.0", + } + + # Create an unsigned JWT + unsigned_jwt = jwt.encode(claims, key=None, algorithm="none") + user_info = self.appbuilder.sm.get_oauth_user_info( + "azure", {"access_token": "", "id_token": unsigned_jwt} + ) + self.assertEqual( + user_info, + { + "email": "test@gmail.com", + "first_name": "test", + "last_name": "user", + "role_keys": [], + "username": "b1a54a40-8dfa-4a6d-a2b8-f90b84d4b1df", + }, + ) diff --git a/tests/security/test_base_security_manager.py b/tests/security/test_base_security_manager.py index 899506575e..dfb04ea1a8 100644 --- a/tests/security/test_base_security_manager.py +++ b/tests/security/test_base_security_manager.py @@ -1,10 +1,8 @@ import datetime -import json import unittest from unittest.mock import MagicMock, patch from flask_appbuilder.security.manager import BaseSecurityManager -from flask_appbuilder.security.manager import JsonWebKey, jwt JWTClaimsMock = MagicMock() @@ -71,12 +69,3 @@ def test_subsequent_unsuccessful_auth(self, mock1, mock2): self.assertEqual(user_mock.fail_login_count, 10) self.assertEqual(user_mock.last_login, None) self.assertTrue(bsm.update_user.called_once) - - @patch.object(JsonWebKey, "import_key_set", MagicMock()) - @patch.object(jwt, "decode", MagicMock(return_value=JWTClaimsMock)) - @patch.object(json, "dumps", MagicMock(return_value="DecodedExampleAzureJWT")) - def test_azure_jwt_validated(self, mock1, mock2): - bsm = BaseSecurityManager() - - bsm._decode_and_validate_azure_jwt("ExampleAzureJWT") - JWTClaimsMock.validate.assert_called()