diff --git a/docs/security.rst b/docs/security.rst index 53570ad376..47d8c3395a 100644 --- a/docs/security.rst +++ b/docs/security.rst @@ -90,7 +90,7 @@ This method will authenticate the user's credentials against an LDAP server. WARNING: To use LDAP you need to install `python-ldap `_. -For a typical Microsoft AD setup (where all users can preform LDAP searches):: +For a typical Microsoft AD setup (where all users can perform LDAP searches):: AUTH_TYPE = AUTH_LDAP AUTH_LDAP_SERVER = "ldap://ldap.example.com" diff --git a/flask_appbuilder/const.py b/flask_appbuilder/const.py index db3030cc9d..3ccb1bf33e 100644 --- a/flask_appbuilder/const.py +++ b/flask_appbuilder/const.py @@ -191,3 +191,9 @@ API_ADD_TITLE_RIS_KEY = "add_title" API_EDIT_TITLE_RIS_KEY = "edit_title" API_SHOW_TITLE_RIS_KEY = "show_title" + +# ----------------------------------- +# OAuth Provider Constants +# ----------------------------------- + +MICROSOFT_KEY_SET_URL = "https://login.microsoftonline.com/common/discovery/keys" diff --git a/flask_appbuilder/security/manager.py b/flask_appbuilder/security/manager.py index 3c62498888..304554c16c 100644 --- a/flask_appbuilder/security/manager.py +++ b/flask_appbuilder/security/manager.py @@ -1,10 +1,10 @@ -import base64 import datetime import json import logging import re from typing import Any, Dict, List, Optional, Set, Tuple, Union +from authlib.jose import JsonWebKey, jwt from flask import Flask, g, session, url_for from flask_babel import lazy_gettext as _ from flask_jwt_extended import current_user as current_user_jwt @@ -12,6 +12,7 @@ from flask_limiter import Limiter from flask_limiter.util import get_remote_address from flask_login import current_user, LoginManager +import requests from werkzeug.security import check_password_hash, generate_password_hash from .api import SecurityApi @@ -54,6 +55,7 @@ LOGMSG_WAR_SEC_LOGIN_FAILED, LOGMSG_WAR_SEC_NO_USER, LOGMSG_WAR_SEC_NOLDAP_OBJ, + MICROSOFT_KEY_SET_URL, PERMISSION_PREFIX, ) @@ -627,11 +629,9 @@ def get_oauth_user_info(self, provider, resp): "email": data.get("email", ""), } if provider == "azure": - log.debug("Azure response received : %s", resp) - id_token = resp["id_token"] - me = self._azure_jwt_token_parse(id_token) - log.debug("Parse JWT token : %s", me) - # Claims documentation + log.debug("Azure response received:\n%s", json.dumps(resp, indent=4)) + me = self._decode_and_validate_azure_jwt(resp["id_token"]) + log.debug("Decoded JWT:\n%s", json.dumps(me, indent=4)) # https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference#payload-claims return { "email": me["email"], @@ -674,36 +674,13 @@ def get_oauth_user_info(self, provider, resp): } return {} - def _azure_parse_jwt(self, id_token): - jwt_token_parts = r"^([^\.\s]*)\.([^\.\s]+)\.([^\.\s]*)$" - matches = re.search(jwt_token_parts, id_token) - if not matches or len(matches.groups()) < 3: - log.error("Unable to parse token.") - return {} - return { - "header": matches.group(1), - "Payload": matches.group(2), - "Sig": matches.group(3), - } - - def _azure_jwt_token_parse(self, id_token): - jwt_split_token = self._azure_parse_jwt(id_token) - if not jwt_split_token: - return - - jwt_payload = jwt_split_token["Payload"] - # Prepare for base64 decoding - payload_b64_string = jwt_payload - payload_b64_string += "=" * (4 - ((len(jwt_payload) % 4))) - decoded_payload = base64.urlsafe_b64decode(payload_b64_string.encode("ascii")) - - if not decoded_payload: - log.error("Payload of id_token could not be base64 url decoded.") - return - - jwt_decoded_payload = json.loads(decoded_payload.decode("utf-8")) + def _decode_and_validate_azure_jwt(self, id_token): + keyset = JsonWebKey.import_key_set(requests.get(MICROSOFT_KEY_SET_URL).json()) + claims = jwt.decode(id_token, keyset) + claims.validate() + log.debug("Decoded JWT:\n%s", json.dumps(claims, indent=4)) - return jwt_decoded_payload + return claims def register_views(self): if not self.appbuilder.app.config.get("FAB_ADD_SECURITY_VIEWS", True): @@ -943,7 +920,7 @@ def _search_ldap(self, ldap, con, username): if len(self.auth_roles_mapping) > 0: request_fields.append(self.auth_ldap_group_field) - # preform the LDAP search + # perform the LDAP search log.debug( "LDAP search for '%s' with fields %s in scope '%s'", filter_str, @@ -1120,7 +1097,7 @@ def auth_user_ldap(self, username, password): user_attributes = {} # Flow 1 - (Indirect Search Bind): - # - in this flow, special bind credentials are used to preform the + # - in this flow, special bind credentials are used to perform the # LDAP search # - in this flow, AUTH_LDAP_SEARCH must be set if self.auth_ldap_bind_user: @@ -1156,7 +1133,7 @@ def auth_user_ldap(self, username, password): # Flow 2 - (Direct Search Bind): # - in this flow, the credentials provided by the end-user are used - # to preform the LDAP search + # to perform the LDAP search # - in this flow, we only search LDAP if AUTH_LDAP_SEARCH is set # - features like AUTH_USER_REGISTRATION & AUTH_ROLES_SYNC_AT_LOGIN # will only work if AUTH_LDAP_SEARCH is set diff --git a/requirements-extra.txt b/requirements-extra.txt index c6ddb42e4f..89ea25cb92 100644 --- a/requirements-extra.txt +++ b/requirements-extra.txt @@ -4,6 +4,6 @@ mysqlclient==2.0.1 psycopg2-binary==2.9.6 pyodbc==4.0.35 requests==2.26.0 -Authlib==0.15.4 +Authlib==1.2.1 python-ldap==3.3.1 flask-openid==1.3.0 diff --git a/tests/security/test_base_security_manager.py b/tests/security/test_base_security_manager.py index 736b0b88d6..899506575e 100644 --- a/tests/security/test_base_security_manager.py +++ b/tests/security/test_base_security_manager.py @@ -1,8 +1,12 @@ import datetime +import json import unittest from unittest.mock import MagicMock, patch from flask_appbuilder.security.manager import BaseSecurityManager +from flask_appbuilder.security.manager import JsonWebKey, jwt + +JWTClaimsMock = MagicMock() @patch.object(BaseSecurityManager, "update_user") @@ -67,3 +71,12 @@ def test_subsequent_unsuccessful_auth(self, mock1, mock2): self.assertEqual(user_mock.fail_login_count, 10) self.assertEqual(user_mock.last_login, None) self.assertTrue(bsm.update_user.called_once) + + @patch.object(JsonWebKey, "import_key_set", MagicMock()) + @patch.object(jwt, "decode", MagicMock(return_value=JWTClaimsMock)) + @patch.object(json, "dumps", MagicMock(return_value="DecodedExampleAzureJWT")) + def test_azure_jwt_validated(self, mock1, mock2): + bsm = BaseSecurityManager() + + bsm._decode_and_validate_azure_jwt("ExampleAzureJWT") + JWTClaimsMock.validate.assert_called()