Skip to content

Commit

Permalink
Fix conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
safeith committed Mar 11, 2024
1 parent 165d72d commit a87829c
Showing 1 changed file with 131 additions and 67 deletions.
198 changes: 131 additions & 67 deletions flask_appbuilder/security/manager.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import base64
import datetime
import json
import logging
import re
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

from flask import Flask, g, session, url_for
from flask_appbuilder.exceptions import InvalidLoginAttempt, OAuthProviderUnknown
from flask_babel import lazy_gettext as _
from flask_jwt_extended import current_user as current_user_jwt
from flask_jwt_extended import JWTManager
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from flask_login import current_user, LoginManager
import jwt
from werkzeug.security import check_password_hash, generate_password_hash

from .api import SecurityApi
Expand Down Expand Up @@ -54,6 +54,7 @@
LOGMSG_WAR_SEC_LOGIN_FAILED,
LOGMSG_WAR_SEC_NO_USER,
LOGMSG_WAR_SEC_NOLDAP_OBJ,
MICROSOFT_KEY_SET_URL,
PERMISSION_PREFIX,
)

Expand Down Expand Up @@ -257,19 +258,27 @@ def __init__(self, appbuilder):
app.config.setdefault("AUTH_LDAP_LASTNAME_FIELD", "sn")
app.config.setdefault("AUTH_LDAP_EMAIL_FIELD", "mail")

if self.auth_type == AUTH_REMOTE_USER:
app.config.setdefault("AUTH_REMOTE_USER_ENV_VAR", "REMOTE_USER")

# Rate limiting
app.config.setdefault("AUTH_RATE_LIMITED", False)
app.config.setdefault("AUTH_RATE_LIMIT", "10 per 20 second")

if self.auth_type == AUTH_OID:
from flask_openid import OpenID

log.warning(
"AUTH_OID is deprecated and will be removed in version 5. "
"Migrate to other authentication methods."
)
self.oid = OpenID(app)

if self.auth_type == AUTH_OAUTH:
from authlib.integrations.flask_client import OAuth

self.oauth = OAuth(app)
self.oauth_remotes = dict()
self.oauth_remotes = {}
for _provider in self.oauth_providers:
provider_name = _provider["name"]
log.debug("OAuth providers init %s", provider_name)
Expand Down Expand Up @@ -414,6 +423,10 @@ def auth_user_registration_role(self) -> str:
def auth_user_registration_role_jmespath(self) -> str:
return self.appbuilder.get_app.config["AUTH_USER_REGISTRATION_ROLE_JMESPATH"]

@property
def auth_remote_user_env_var(self) -> str:
return self.appbuilder.get_app.config["AUTH_REMOTE_USER_ENV_VAR"]

@property
def auth_roles_mapping(self) -> Dict[str, List[str]]:
return self.appbuilder.get_app.config["AUTH_ROLES_MAPPING"]
Expand Down Expand Up @@ -517,7 +530,10 @@ def current_user(self):
elif current_user_jwt:
return current_user_jwt

def oauth_user_info_getter(self, f):
def oauth_user_info_getter(
self,
func: Callable[["BaseSecurityManager", str, Dict[str, Any]], Dict[str, Any]],
):
"""
Decorator function to be the OAuth user info getter
for all the providers, receives provider and response
Expand All @@ -532,21 +548,11 @@ def my_oauth_user_info(sm, provider, response=None):
if provider == 'github':
me = sm.oauth_remotes[provider].get('user')
return {'username': me.data.get('login')}
else:
return {}
return {}
"""

def wraps(provider, response=None):
ret = f(self, provider, response=response)
# Checks if decorator is well behaved and returns a dict as supposed.
if not type(ret) == dict:
log.error(
"OAuth user info decorated function "
"did not returned a dict, but: %s",
type(ret),
)
return {}
return ret
def wraps(provider: str, response: Dict[str, Any] = None) -> Dict[str, Any]:
return func(self, provider, response)

self.oauth_user_info = wraps
return wraps
Expand Down Expand Up @@ -585,9 +591,11 @@ def set_oauth_session(self, provider, oauth_response):
)
session["oauth_provider"] = provider

def get_oauth_user_info(self, provider, resp):
def get_oauth_user_info(
self, provider: str, resp: Dict[str, Any]
) -> Dict[str, Any]:
"""
Since there are different OAuth API's with different ways to
Since there are different OAuth APIs with different ways to
retrieve user info
"""
# for GITHUB
Expand Down Expand Up @@ -626,23 +634,16 @@ def get_oauth_user_info(self, provider, resp):
"last_name": data.get("family_name", ""),
"email": data.get("email", ""),
}
# for Azure AD Tenant. Azure OAuth response contains
# JWT token which has user info.
# JWT token needs to be base64 decoded.
# https://docs.microsoft.com/en-us/azure/active-directory/develop/
# active-directory-protocols-oauth-code
if provider == "azure":
log.debug("Azure response received : %s", resp)
id_token = resp["id_token"]
log.debug(str(id_token))
me = self._azure_jwt_token_parse(id_token)
log.debug("Parse JWT token : %s", me)
me = self._decode_and_validate_azure_jwt(resp["id_token"])
log.debug("User info from Azure: %s", me)
# https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference#payload-claims
return {
"name": me.get("name", ""),
"email": me["upn"],
# To keep backward compatibility with previous versions
# of FAB, we use upn if available, otherwise we use email
"email": me["upn"] if "upn" in me else me["email"],
"first_name": me.get("given_name", ""),
"last_name": me.get("family_name", ""),
"id": me["oid"],
"username": me["oid"],
"role_keys": me.get("roles", []),
}
Expand All @@ -661,15 +662,26 @@ def get_oauth_user_info(self, provider, resp):
log.debug("User info from Okta: %s", data)
if "error" not in data:
return {
"username": "okta_" + data.get("sub", ""),
"username": f"{provider}_{data['sub']}",
"first_name": data.get("given_name", ""),
"last_name": data.get("family_name", ""),
"email": data.get("email", ""),
"email": data["email"],
"role_keys": data.get("groups", []),
}
else:
log.error(data.get("error_description"))

return {}
# for Auth0
if provider == "auth0":
data = self.appbuilder.sm.oauth_remotes[provider].userinfo()
log.debug("User info from Auth0: %s", data)
return {
"username": f"{provider}_{data['sub']}",
"first_name": data.get("given_name", ""),
"last_name": data.get("family_name", ""),
"email": data["email"],
"role_keys": data.get("groups", []),
}
# for Keycloak
if provider in ["keycloak", "keycloak_before_17"]:
me = self.appbuilder.sm.oauth_remotes[provider].get(
Expand All @@ -684,39 +696,80 @@ def get_oauth_user_info(self, provider, resp):
"last_name": data.get("family_name", ""),
"email": data.get("email", ""),
}
# for Authentik
if provider == "authentik":
id_token = resp["id_token"]
me = self._get_authentik_token_info(id_token)
log.debug("User info from authentik: %s", me)
return {
"email": me["preferred_username"],
"first_name": me.get("given_name", ""),
"username": me["nickname"],
"role_keys": me.get("groups", []),
}

return {}

def _azure_parse_jwt(self, id_token):
jwt_token_parts = r"^([^\.\s]*)\.([^\.\s]+)\.([^\.\s]*)$"
matches = re.search(jwt_token_parts, id_token)
if not matches or len(matches.groups()) < 3:
log.error("Unable to parse token.")
return {}
return {
"header": matches.group(1),
"Payload": matches.group(2),
"Sig": matches.group(3),
}
raise OAuthProviderUnknown()

def _azure_jwt_token_parse(self, id_token):
jwt_split_token = self._azure_parse_jwt(id_token)
if not jwt_split_token:
return
def _get_microsoft_jwks(self) -> List[Dict[str, Any]]:
import requests

jwt_payload = jwt_split_token["Payload"]
# Prepare for base64 decoding
payload_b64_string = jwt_payload
payload_b64_string += "=" * (4 - ((len(jwt_payload) % 4)))
decoded_payload = base64.urlsafe_b64decode(payload_b64_string.encode("ascii"))
return requests.get(MICROSOFT_KEY_SET_URL).json()

if not decoded_payload:
log.error("Payload of id_token could not be base64 url decoded.")
return
def _decode_and_validate_azure_jwt(self, id_token: str) -> Dict[str, str]:
verify_signature = self.oauth_remotes["azure"].client_kwargs.get(
"verify_signature", False
)
if verify_signature:
from authlib.jose import JsonWebKey, jwt as authlib_jwt

keyset = JsonWebKey.import_key_set(self._get_microsoft_jwks())
claims = authlib_jwt.decode(id_token, keyset)
claims.validate()
return claims

return jwt.decode(id_token, options={"verify_signature": False})

def _get_authentik_jwks(self, jwks_url) -> dict:
import requests

resp = requests.get(jwks_url)
if resp.status_code == 200:
return resp.json()
return False

def _validate_jwt(self, id_token, jwks):
from authlib.jose import JsonWebKey, jwt as authlib_jwt

keyset = JsonWebKey.import_key_set(jwks)
claims = authlib_jwt.decode(id_token, keyset)
claims.validate()
log.info("JWT token is validated")
return claims

def _get_authentik_token_info(self, id_token):
me = jwt.decode(id_token, options={"verify_signature": False})

jwt_decoded_payload = json.loads(decoded_payload.decode("utf-8"))
verify_signature = self.oauth_remotes["authentik"].client_kwargs.get(
"verify_signature", True
)
if verify_signature:
# Validate the token using authentik certificate
jwks_uri = self.oauth_remotes["authentik"].server_metadata.get("jwks_uri")
if jwks_uri:
jwks = self._get_authentik_jwks(jwks_uri)
if jwks:
return self._validate_jwt(id_token, jwks)
else:
log.error(
"jwks_uri not specified in OAuth Providers, "
"could not verify token signature"
)
else:
# Return the token info without validating
log.warning("JWT token is not validated!")
return me

return jwt_decoded_payload
raise InvalidLoginAttempt("OAuth signature verify failed")

def register_views(self):
if not self.appbuilder.app.config.get("FAB_ADD_SECURITY_VIEWS", True):
Expand Down Expand Up @@ -1464,6 +1517,14 @@ def _has_view_access(
# If it's not a builtin role check against database store roles
return self.exist_permission_on_roles(view_name, permission_name, db_role_ids)

def get_oid_identity_url(self, provider_name: str) -> Optional[str]:
"""
Returns the OIDC identity provider URL
"""
for provider in self.openid_providers:
if provider.get("name") == provider_name:
return provider.get("url")

def get_user_roles(self, user) -> List[object]:
"""
Get current user roles, if user is not authenticated returns the public role
Expand Down Expand Up @@ -2097,14 +2158,17 @@ def import_roles(self, path: str) -> None:
raise NotImplementedError

def load_user(self, pk):
return self.get_user_by_id(int(pk))
user = self.get_user_by_id(int(pk))
if user.is_active:
return user

def load_user_jwt(self, _jwt_header, jwt_data):
identity = jwt_data["sub"]
user = self.load_user(identity)
# Set flask g.user to JWT user, we can't do it on before request
g.user = user
return user
if user.is_active:
# Set flask g.user to JWT user, we can't do it on before request
g.user = user
return user

@staticmethod
def before_request():
Expand Down

0 comments on commit a87829c

Please sign in to comment.