From a87829ca65d851155ddaf2808905d8d190c743e7 Mon Sep 17 00:00:00 2001 From: Hojjat Ali Mohammadi Date: Mon, 11 Mar 2024 20:22:27 +0100 Subject: [PATCH] Fix conflicts --- flask_appbuilder/security/manager.py | 198 ++++++++++++++++++--------- 1 file changed, 131 insertions(+), 67 deletions(-) diff --git a/flask_appbuilder/security/manager.py b/flask_appbuilder/security/manager.py index 9fbad38628..a25b63ae72 100644 --- a/flask_appbuilder/security/manager.py +++ b/flask_appbuilder/security/manager.py @@ -1,17 +1,17 @@ -import base64 import datetime -import json import logging import re -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from flask import Flask, g, session, url_for +from flask_appbuilder.exceptions import InvalidLoginAttempt, OAuthProviderUnknown from flask_babel import lazy_gettext as _ from flask_jwt_extended import current_user as current_user_jwt from flask_jwt_extended import JWTManager from flask_limiter import Limiter from flask_limiter.util import get_remote_address from flask_login import current_user, LoginManager +import jwt from werkzeug.security import check_password_hash, generate_password_hash from .api import SecurityApi @@ -54,6 +54,7 @@ LOGMSG_WAR_SEC_LOGIN_FAILED, LOGMSG_WAR_SEC_NO_USER, LOGMSG_WAR_SEC_NOLDAP_OBJ, + MICROSOFT_KEY_SET_URL, PERMISSION_PREFIX, ) @@ -257,6 +258,9 @@ def __init__(self, appbuilder): app.config.setdefault("AUTH_LDAP_LASTNAME_FIELD", "sn") app.config.setdefault("AUTH_LDAP_EMAIL_FIELD", "mail") + if self.auth_type == AUTH_REMOTE_USER: + app.config.setdefault("AUTH_REMOTE_USER_ENV_VAR", "REMOTE_USER") + # Rate limiting app.config.setdefault("AUTH_RATE_LIMITED", False) app.config.setdefault("AUTH_RATE_LIMIT", "10 per 20 second") @@ -264,12 +268,17 @@ def __init__(self, appbuilder): if self.auth_type == AUTH_OID: from flask_openid import OpenID + log.warning( + "AUTH_OID is deprecated and will be removed in version 5. " + "Migrate to other authentication methods." + ) self.oid = OpenID(app) + if self.auth_type == AUTH_OAUTH: from authlib.integrations.flask_client import OAuth self.oauth = OAuth(app) - self.oauth_remotes = dict() + self.oauth_remotes = {} for _provider in self.oauth_providers: provider_name = _provider["name"] log.debug("OAuth providers init %s", provider_name) @@ -414,6 +423,10 @@ def auth_user_registration_role(self) -> str: def auth_user_registration_role_jmespath(self) -> str: return self.appbuilder.get_app.config["AUTH_USER_REGISTRATION_ROLE_JMESPATH"] + @property + def auth_remote_user_env_var(self) -> str: + return self.appbuilder.get_app.config["AUTH_REMOTE_USER_ENV_VAR"] + @property def auth_roles_mapping(self) -> Dict[str, List[str]]: return self.appbuilder.get_app.config["AUTH_ROLES_MAPPING"] @@ -517,7 +530,10 @@ def current_user(self): elif current_user_jwt: return current_user_jwt - def oauth_user_info_getter(self, f): + def oauth_user_info_getter( + self, + func: Callable[["BaseSecurityManager", str, Dict[str, Any]], Dict[str, Any]], + ): """ Decorator function to be the OAuth user info getter for all the providers, receives provider and response @@ -532,21 +548,11 @@ def my_oauth_user_info(sm, provider, response=None): if provider == 'github': me = sm.oauth_remotes[provider].get('user') return {'username': me.data.get('login')} - else: - return {} + return {} """ - def wraps(provider, response=None): - ret = f(self, provider, response=response) - # Checks if decorator is well behaved and returns a dict as supposed. - if not type(ret) == dict: - log.error( - "OAuth user info decorated function " - "did not returned a dict, but: %s", - type(ret), - ) - return {} - return ret + def wraps(provider: str, response: Dict[str, Any] = None) -> Dict[str, Any]: + return func(self, provider, response) self.oauth_user_info = wraps return wraps @@ -585,9 +591,11 @@ def set_oauth_session(self, provider, oauth_response): ) session["oauth_provider"] = provider - def get_oauth_user_info(self, provider, resp): + def get_oauth_user_info( + self, provider: str, resp: Dict[str, Any] + ) -> Dict[str, Any]: """ - Since there are different OAuth API's with different ways to + Since there are different OAuth APIs with different ways to retrieve user info """ # for GITHUB @@ -626,23 +634,16 @@ def get_oauth_user_info(self, provider, resp): "last_name": data.get("family_name", ""), "email": data.get("email", ""), } - # for Azure AD Tenant. Azure OAuth response contains - # JWT token which has user info. - # JWT token needs to be base64 decoded. - # https://docs.microsoft.com/en-us/azure/active-directory/develop/ - # active-directory-protocols-oauth-code if provider == "azure": - log.debug("Azure response received : %s", resp) - id_token = resp["id_token"] - log.debug(str(id_token)) - me = self._azure_jwt_token_parse(id_token) - log.debug("Parse JWT token : %s", me) + me = self._decode_and_validate_azure_jwt(resp["id_token"]) + log.debug("User info from Azure: %s", me) + # https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference#payload-claims return { - "name": me.get("name", ""), - "email": me["upn"], + # To keep backward compatibility with previous versions + # of FAB, we use upn if available, otherwise we use email + "email": me["upn"] if "upn" in me else me["email"], "first_name": me.get("given_name", ""), "last_name": me.get("family_name", ""), - "id": me["oid"], "username": me["oid"], "role_keys": me.get("roles", []), } @@ -661,15 +662,26 @@ def get_oauth_user_info(self, provider, resp): log.debug("User info from Okta: %s", data) if "error" not in data: return { - "username": "okta_" + data.get("sub", ""), + "username": f"{provider}_{data['sub']}", "first_name": data.get("given_name", ""), "last_name": data.get("family_name", ""), - "email": data.get("email", ""), + "email": data["email"], "role_keys": data.get("groups", []), } else: log.error(data.get("error_description")) - + return {} + # for Auth0 + if provider == "auth0": + data = self.appbuilder.sm.oauth_remotes[provider].userinfo() + log.debug("User info from Auth0: %s", data) + return { + "username": f"{provider}_{data['sub']}", + "first_name": data.get("given_name", ""), + "last_name": data.get("family_name", ""), + "email": data["email"], + "role_keys": data.get("groups", []), + } # for Keycloak if provider in ["keycloak", "keycloak_before_17"]: me = self.appbuilder.sm.oauth_remotes[provider].get( @@ -684,39 +696,80 @@ def get_oauth_user_info(self, provider, resp): "last_name": data.get("family_name", ""), "email": data.get("email", ""), } + # for Authentik + if provider == "authentik": + id_token = resp["id_token"] + me = self._get_authentik_token_info(id_token) + log.debug("User info from authentik: %s", me) + return { + "email": me["preferred_username"], + "first_name": me.get("given_name", ""), + "username": me["nickname"], + "role_keys": me.get("groups", []), + } - return {} - - def _azure_parse_jwt(self, id_token): - jwt_token_parts = r"^([^\.\s]*)\.([^\.\s]+)\.([^\.\s]*)$" - matches = re.search(jwt_token_parts, id_token) - if not matches or len(matches.groups()) < 3: - log.error("Unable to parse token.") - return {} - return { - "header": matches.group(1), - "Payload": matches.group(2), - "Sig": matches.group(3), - } + raise OAuthProviderUnknown() - def _azure_jwt_token_parse(self, id_token): - jwt_split_token = self._azure_parse_jwt(id_token) - if not jwt_split_token: - return + def _get_microsoft_jwks(self) -> List[Dict[str, Any]]: + import requests - jwt_payload = jwt_split_token["Payload"] - # Prepare for base64 decoding - payload_b64_string = jwt_payload - payload_b64_string += "=" * (4 - ((len(jwt_payload) % 4))) - decoded_payload = base64.urlsafe_b64decode(payload_b64_string.encode("ascii")) + return requests.get(MICROSOFT_KEY_SET_URL).json() - if not decoded_payload: - log.error("Payload of id_token could not be base64 url decoded.") - return + def _decode_and_validate_azure_jwt(self, id_token: str) -> Dict[str, str]: + verify_signature = self.oauth_remotes["azure"].client_kwargs.get( + "verify_signature", False + ) + if verify_signature: + from authlib.jose import JsonWebKey, jwt as authlib_jwt + + keyset = JsonWebKey.import_key_set(self._get_microsoft_jwks()) + claims = authlib_jwt.decode(id_token, keyset) + claims.validate() + return claims + + return jwt.decode(id_token, options={"verify_signature": False}) + + def _get_authentik_jwks(self, jwks_url) -> dict: + import requests + + resp = requests.get(jwks_url) + if resp.status_code == 200: + return resp.json() + return False + + def _validate_jwt(self, id_token, jwks): + from authlib.jose import JsonWebKey, jwt as authlib_jwt + + keyset = JsonWebKey.import_key_set(jwks) + claims = authlib_jwt.decode(id_token, keyset) + claims.validate() + log.info("JWT token is validated") + return claims + + def _get_authentik_token_info(self, id_token): + me = jwt.decode(id_token, options={"verify_signature": False}) - jwt_decoded_payload = json.loads(decoded_payload.decode("utf-8")) + verify_signature = self.oauth_remotes["authentik"].client_kwargs.get( + "verify_signature", True + ) + if verify_signature: + # Validate the token using authentik certificate + jwks_uri = self.oauth_remotes["authentik"].server_metadata.get("jwks_uri") + if jwks_uri: + jwks = self._get_authentik_jwks(jwks_uri) + if jwks: + return self._validate_jwt(id_token, jwks) + else: + log.error( + "jwks_uri not specified in OAuth Providers, " + "could not verify token signature" + ) + else: + # Return the token info without validating + log.warning("JWT token is not validated!") + return me - return jwt_decoded_payload + raise InvalidLoginAttempt("OAuth signature verify failed") def register_views(self): if not self.appbuilder.app.config.get("FAB_ADD_SECURITY_VIEWS", True): @@ -1464,6 +1517,14 @@ def _has_view_access( # If it's not a builtin role check against database store roles return self.exist_permission_on_roles(view_name, permission_name, db_role_ids) + def get_oid_identity_url(self, provider_name: str) -> Optional[str]: + """ + Returns the OIDC identity provider URL + """ + for provider in self.openid_providers: + if provider.get("name") == provider_name: + return provider.get("url") + def get_user_roles(self, user) -> List[object]: """ Get current user roles, if user is not authenticated returns the public role @@ -2097,14 +2158,17 @@ def import_roles(self, path: str) -> None: raise NotImplementedError def load_user(self, pk): - return self.get_user_by_id(int(pk)) + user = self.get_user_by_id(int(pk)) + if user.is_active: + return user def load_user_jwt(self, _jwt_header, jwt_data): identity = jwt_data["sub"] user = self.load_user(identity) - # Set flask g.user to JWT user, we can't do it on before request - g.user = user - return user + if user.is_active: + # Set flask g.user to JWT user, we can't do it on before request + g.user = user + return user @staticmethod def before_request():