Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 84 additions & 31 deletions flask_oidc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2014-2015, Erica Ehrhardt
# Copyright (c) 2016, Patrick Uiterwijk <[email protected]>
# Copyright (c) 2020, Aaron Olson <[email protected]>
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -35,12 +36,13 @@

from six.moves.urllib.parse import urlencode
from flask import request, session, redirect, url_for, g, current_app, abort
from oauth2client.client import flow_from_clientsecrets, OAuth2WebServerFlow,\
from oauth2client.client import flow_from_clientsecrets, OAuth2WebServerFlow, \
AccessTokenRefreshError, OAuth2Credentials
import httplib2
from itsdangerous import JSONWebSignatureSerializer, BadSignature
from itsdangerous import JSONWebSignatureSerializer, BadSignature, \
TimedJSONWebSignatureSerializer, SignatureExpired

__all__ = ['OpenIDConnect', 'MemoryCredentials']
__all__ = ['OpenIDConnect', 'MemoryCredentials', 'MemoryTokens', 'MemoryBadTokens']

logger = logging.getLogger(__name__)

Expand All @@ -50,6 +52,7 @@ def _json_loads(content):
content = content.decode('utf-8')
return json.loads(content)


class MemoryCredentials(dict):
"""
Non-persistent local credentials store.
Expand All @@ -59,10 +62,29 @@ class MemoryCredentials(dict):
pass


class MemoryTokens(dict):
"""
Non-persistent local token cache.
Use this if you only have one app server, and don't mind having
to re-validate everyone's tokens after restart.
"""
pass


class MemoryBadTokens(dict):
"""
Non-persistent local responses cache for bad tokens.
Use this if you only have one app server, and don't mind having
to re-validate everyone's tokens after restart.
"""
pass


class DummySecretsCache(object):
"""
oauth2client secrets cache
"""

def __init__(self, client_secrets):
self.client_secrets = client_secrets

Expand All @@ -85,6 +107,7 @@ class ErrStr(str):
this ErrStr class, which are basic strings except for their bool() results:
they return False.
"""

def __nonzero__(self):
"""The py2 method for bool()."""
return False
Expand All @@ -101,19 +124,27 @@ class OpenIDConnect(object):
"""
The core OpenID Connect client object.
"""
def __init__(self, app=None, credentials_store=None, http=None, time=None,
urandom=None):
self.credentials_store = credentials_store\
if credentials_store is not None\

def __init__(self, app=None, credentials_store=None, http=None, tokens_store=None,
bad_tokens_store=None, time=None, urandom=None):
self.credentials_store = credentials_store \
if credentials_store is not None \
else MemoryCredentials()

self.tokens_store = tokens_store \
if tokens_store is not None \
else MemoryTokens()

self.bad_tokens_store = bad_tokens_store \
if bad_tokens_store is not None \
else MemoryBadTokens()

if http is not None:
warn('HTTP argument is deprecated and unused', DeprecationWarning)
if time is not None:
warn('time argument is deprecated and unused', DeprecationWarning)
if urandom is not None:
warn('urandom argument is deprecated and unused',
DeprecationWarning)
warn('urandom argument is deprecated and unused', DeprecationWarning)

# By default, we do not have a custom callback
self._custom_callback = None
Expand Down Expand Up @@ -176,6 +207,7 @@ def init_app(self, app):
cache=secrets_cache)
assert isinstance(self.flow, OAuth2WebServerFlow)

# TODO use configurable salt string instead of hardcoded value to improve security
# create signers using the Flask secret key
self.extra_data_serializer = JSONWebSignatureSerializer(
app.config['SECRET_KEY'], salt='flask-oidc-extra-data')
Expand Down Expand Up @@ -335,7 +367,6 @@ def _retrieve_userinfo(self, access_token=None):

return info


def get_cookie_id_token(self):
"""
.. deprecated:: 1.0
Expand All @@ -347,8 +378,7 @@ def get_cookie_id_token(self):

def _get_cookie_id_token(self):
try:
id_token_cookie = request.cookies.get(current_app.config[
'OIDC_ID_TOKEN_COOKIE_NAME'])
id_token_cookie = request.cookies.get(current_app.config['OIDC_ID_TOKEN_COOKIE_NAME'])
if not id_token_cookie:
# Do not error if we were unable to get the cookie.
# The user can debug this themselves.
Expand Down Expand Up @@ -486,12 +516,14 @@ def require_login(self, view_func):
.. versionadded:: 1.0
This was :func:`check` before.
"""

@wraps(view_func)
def decorated(*args, **kwargs):
if g.oidc_id_token is None:
return self.redirect_to_auth_server(request.url)
return view_func(*args, **kwargs)
return decorated

# Backwards compatibility
check = require_login
"""
Expand All @@ -502,11 +534,9 @@ def decorated(*args, **kwargs):
def require_keycloak_role(self, client, role):
"""
Function to check for a KeyCloak client role in JWT access token.

This is intended to be replaced with a more generic 'require this value
in token or claims' system, at which point backwards compatibility will
be added.

.. versionadded:: 1.5.0
"""
def wrapper(view_func):
Expand Down Expand Up @@ -652,8 +682,7 @@ def _is_id_token_valid(self, id_token):

# additional steps specific to our usage
if current_app.config['OIDC_GOOGLE_APPS_DOMAIN'] and \
id_token.get('hd') != current_app.config[
'OIDC_GOOGLE_APPS_DOMAIN']:
id_token.get('hd') != current_app.config['OIDC_GOOGLE_APPS_DOMAIN']:
logger.error('Invalid google apps domain')
return False

Expand All @@ -672,6 +701,7 @@ def custom_callback(self, view_func):
The custom OIDC callback will get the custom state field passed in with
redirect_to_auth_server.
"""

@wraps(view_func)
def decorated(*args, **kwargs):
plainreturn, data = self._process_callback('custom')
Expand Down Expand Up @@ -719,11 +749,10 @@ def _process_callback(self, statefield):
id_token = credentials.id_token
if not self._is_id_token_valid(id_token):
logger.debug("Invalid ID token")
if id_token.get('hd') != current_app.config[
'OIDC_GOOGLE_APPS_DOMAIN']:
if id_token.get('hd') != current_app.config['OIDC_GOOGLE_APPS_DOMAIN']:
return True, self._oidc_error(
"You must log in with an account from the {0} domain."
.format(current_app.config['OIDC_GOOGLE_APPS_DOMAIN']),
.format(current_app.config['OIDC_GOOGLE_APPS_DOMAIN']),
self.WRONG_GOOGLE_APPS_DOMAIN)
return True, self._oidc_error()

Expand Down Expand Up @@ -835,18 +864,21 @@ def _validate_token(self, token, scopes_required=None):
logger.debug('Token missed required scopes')

if (valid_token and has_required_scopes):
self.tokens_store[token] = token_info
g.oidc_token_info = token_info
return True

if not valid_token:
self.bad_tokens_store[token] = token
return 'Token required but invalid'
elif not has_required_scopes:
self.bad_tokens_store[token] = token
return 'Token does not have required scopes'
else:
self.bad_tokens_store[token] = token
return 'Something went wrong checking your token'

def accept_token(self, require_token=False, scopes_required=None,
render_errors=True):
def accept_token(self, require_token=False, scopes_required=None, render_errors=True):
"""
Use this to decorate view functions that should accept OAuth2 tokens,
this will most likely apply to API functions.
Expand Down Expand Up @@ -878,7 +910,7 @@ def wrapper(view_func):
def decorated(*args, **kwargs):
token = None
if 'Authorization' in request.headers and request.headers['Authorization'].startswith('Bearer '):
token = request.headers['Authorization'].split(None,1)[1].strip()
token = request.headers['Authorization'].split(None, 1)[1].strip()
if 'access_token' in request.form:
token = request.form['access_token']
elif 'access_token' in request.args:
Expand All @@ -897,6 +929,21 @@ def decorated(*args, **kwargs):
return decorated
return wrapper

def clear_tokens_store(self):
self.tokens_store = {}

def clear_bad_tokens_store(self):
self.bad_tokens_store = {}

def is_expired(self, token):
current_time = time.time()
cached_token = self.tokens_store[token]
if cached_token.get('exp'):
if current_time >= cached_token['exp']:
self.tokens_store.pop(token, False)
return True
return False

def _get_token_info(self, token):
# We hardcode to use client_secret_post, because that's what the Google
# oauth2client library defaults to
Expand All @@ -907,7 +954,7 @@ def _get_token_info(self, token):
if hint != 'none':
request['token_type_hint'] = hint

auth_method = current_app.config['OIDC_INTROSPECTION_AUTH_METHOD']
auth_method = current_app.config['OIDC_INTROSPECTION_AUTH_METHOD']
if (auth_method == 'client_secret_basic'):
basic_auth_string = '%s:%s' % (self.client_secrets['client_id'], self.client_secrets['client_secret'])
basic_auth_bytes = bytearray(basic_auth_string, 'utf-8')
Expand All @@ -916,11 +963,17 @@ def _get_token_info(self, token):
headers['Authorization'] = 'Bearer %s' % token
elif (auth_method == 'client_secret_post'):
request['client_id'] = self.client_secrets['client_id']
if self.client_secrets['client_secret'] is not None:
request['client_secret'] = self.client_secrets['client_secret']

resp, content = httplib2.Http().request(
self.client_secrets['token_introspection_uri'], 'POST',
urlencode(request), headers=headers)
# TODO: Cache this reply
return _json_loads(content)
request['client_secret'] = self.client_secrets['client_secret']

if self.bad_tokens_store and self.bad_tokens_store.get(token):
raise ValueError('Attempting to authenticate using token that failed validation recently. '
'Generate a new token and try again.')
if not self.tokens_store or not self.tokens_store.get(token) or self.is_expired(token):
resp, content_string = httplib2.Http().request(
self.client_secrets['token_introspection_uri'], 'POST',
urlencode(request), headers=headers)
content = _json_loads(content_string)
else:
# using cached token
content = self.tokens_store[token]
return content