Skip to content

Commit

Permalink
fix: azure oauth email claim
Browse files Browse the repository at this point in the history
  • Loading branch information
dpgaspar committed Oct 6, 2023
1 parent e30f170 commit 10e2d4b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 38 deletions.
6 changes: 3 additions & 3 deletions examples/oauth/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,17 @@
"remote_app": {
"client_id": os.environ.get("AZURE_APPLICATION_ID"),
"client_secret": os.environ.get("AZURE_SECRET"),
"api_base_url": "https://login.microsoftonline.com/{AZURE_TENANT_ID}/oauth2",
"api_base_url": f"https://login.microsoftonline.com/{os.environ.get('AZURE_TENANT_ID')}/oauth2",
"client_kwargs": {
"scope": "User.read name preferred_username email profile upn",
"resource": os.environ.get("AZURE_APPLICATION_ID"),
},
"request_token_url": None,
"access_token_url": f"https://login.microsoftonline.com/"
f"{os.environ.get('AZURE_APPLICATION_ID')}/"
f"{os.environ.get('AZURE_TENANT_ID')}/"
"oauth2/token",
"authorize_url": f"https://login.microsoftonline.com/"
f"{os.environ.get('AZURE_APPLICATION_ID')}/"
f"{os.environ.get('AZURE_TENANT_ID')}/"
f"oauth2/authorize",
},
},
Expand Down
47 changes: 12 additions & 35 deletions flask_appbuilder/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,23 +626,17 @@ def get_oauth_user_info(self, provider, resp):
"last_name": data.get("family_name", ""),
"email": data.get("email", ""),
}
# for Azure AD Tenant. Azure OAuth response contains
# JWT token which has user info.
# JWT token needs to be base64 decoded.
# https://docs.microsoft.com/en-us/azure/active-directory/develop/
# active-directory-protocols-oauth-code
if provider == "azure":
log.debug("Azure response received : %s", resp)
id_token = resp["id_token"]
log.debug(str(id_token))
me = self._azure_jwt_token_parse(id_token)
log.debug("Parse JWT token : %s", me)
# Claims documentation
# https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference#payload-claims
return {
"name": me.get("name", ""),
"email": me["upn"],
"email": me["email"],
"first_name": me.get("given_name", ""),
"last_name": me.get("family_name", ""),
"id": me["oid"],
"username": me["oid"],
"role_keys": me.get("roles", []),
}
Expand All @@ -668,9 +662,7 @@ def get_oauth_user_info(self, provider, resp):
}
# for Keycloak
if provider in ["keycloak", "keycloak_before_17"]:
me = self.appbuilder.sm.oauth_remotes[provider].get(
"openid-connect/userinfo"
)
me = self.appbuilder.sm.oauth_remotes[provider].get("openid-connect/userinfo")
me.raise_for_status()
data = me.json()
log.debug("User info from Keycloak: %s", data)
Expand All @@ -680,8 +672,7 @@ def get_oauth_user_info(self, provider, resp):
"last_name": data.get("family_name", ""),
"email": data.get("email", ""),
}
else:
return {}
return {}

def _azure_parse_jwt(self, id_token):
jwt_token_parts = r"^([^\.\s]*)\.([^\.\s]+)\.([^\.\s]*)$"
Expand Down Expand Up @@ -817,9 +808,7 @@ def register_views(self):
label=_("Views/Menus"),
category="Security",
)
if self.appbuilder.app.config.get(
"FAB_ADD_SECURITY_PERMISSION_VIEWS_VIEW", True
):
if self.appbuilder.app.config.get("FAB_ADD_SECURITY_PERMISSION_VIEWS_VIEW", True):
self.appbuilder.add_view(
self.permissionviewmodelview,
"Permission on Views/Menus",
Expand Down Expand Up @@ -992,9 +981,7 @@ def _search_ldap(self, ldap, con, username):
except (IndexError, NameError):
return None, None

def _ldap_calculate_user_roles(
self, user_attributes: Dict[str, bytes]
) -> List[str]:
def _ldap_calculate_user_roles(self, user_attributes: Dict[str, bytes]) -> List[str]:
user_role_objects = set()

# apply AUTH_ROLES_MAPPING
Expand Down Expand Up @@ -1059,9 +1046,7 @@ def _ldap_bind(ldap, con, dn: str, password: str) -> bool:
return False

@staticmethod
def ldap_extract(
ldap_dict: Dict[str, bytes], field_name: str, fallback: str
) -> str:
def ldap_extract(ldap_dict: Dict[str, bytes], field_name: str, fallback: str) -> str:
raw_value = ldap_dict.get(field_name, [bytes()])
# decode - if empty string, default to fallback, otherwise take first element
return raw_value[0].decode("utf-8") or fallback
Expand Down Expand Up @@ -1108,9 +1093,7 @@ def auth_user_ldap(self, username, password):
if self.auth_ldap_tls_cacertdir:
ldap.set_option(ldap.OPT_X_TLS_CACERTDIR, self.auth_ldap_tls_cacertdir)
if self.auth_ldap_tls_cacertfile:
ldap.set_option(
ldap.OPT_X_TLS_CACERTFILE, self.auth_ldap_tls_cacertfile
)
ldap.set_option(ldap.OPT_X_TLS_CACERTFILE, self.auth_ldap_tls_cacertfile)
if self.auth_ldap_tls_certfile:
ldap.set_option(ldap.OPT_X_TLS_CERTFILE, self.auth_ldap_tls_certfile)
if self.auth_ldap_tls_keyfile:
Expand Down Expand Up @@ -1529,9 +1512,7 @@ def _get_user_permission_view_menus(
# Then check against database-stored roles
pvms_names = [
pvm.view_menu.name
for pvm in self.find_roles_permission_view_menus(
permission_name, db_role_ids
)
for pvm in self.find_roles_permission_view_menus(permission_name, db_role_ids)
]
result.update(pvms_names)
return result
Expand Down Expand Up @@ -1678,9 +1659,7 @@ def _get_new_old_permissions(baseview) -> Dict:
method_name
)
# Actions do not get prefix when normally defined
if hasattr(baseview, "actions") and baseview.actions.get(
old_permission_name
):
if hasattr(baseview, "actions") and baseview.actions.get(old_permission_name):
permission_prefix = ""
else:
permission_prefix = PERMISSION_PREFIX
Expand Down Expand Up @@ -1956,9 +1935,7 @@ def find_permission(self, name):
"""
raise NotImplementedError

def find_roles_permission_view_menus(
self, permission_name: str, role_ids: List[int]
):
def find_roles_permission_view_menus(self, permission_name: str, role_ids: List[int]):
raise NotImplementedError

def exist_permission_on_roles(
Expand Down

0 comments on commit 10e2d4b

Please sign in to comment.