Skip to content

Commit

Permalink
Merge pull request #5 from skruger/improved_scopes
Browse files Browse the repository at this point in the history
Improve Oauth middleware
  • Loading branch information
skruger authored Jul 30, 2019
2 parents d6da6e3 + 071294b commit fb3bffa
Show file tree
Hide file tree
Showing 17 changed files with 234 additions and 23 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
dist: xenial
sudo: false
language: python
python:
Expand Down
2 changes: 1 addition & 1 deletion Vagrantfile
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Vagrant.configure(2) do |config|
# Display the VirtualBox GUI when booting the machine
# vb.gui = true
# Customize the amount of memory on the VM:
vb.memory = "1024"
vb.memory = "2048"
end
#
# View the documentation for the provider you are using for more
Expand Down
9 changes: 8 additions & 1 deletion docs/changes.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
v 2.2
-----
* Improve Oauth2UserMiddleware
* Prevent SessionMiddleware from creating new sessions when using oauth tokens.
* Add OAuthRequiredMixin to allow scope enforcement

v 2.1
-----
* Fixed documentation links. Removed 2.0 package.

v 2.0
Expand All @@ -7,7 +14,7 @@ v 2.0

v 1.2
-----
Updated to make skopes configurable in the database and update for Django 1.7
Updated to make scopes configurable in the database and update for Django 1.7

v 1.0
-----
Expand Down
2 changes: 1 addition & 1 deletion provider/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.1"
__version__ = "2.2"
4 changes: 2 additions & 2 deletions provider/oauth2/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from django.contrib.auth import authenticate
from django.conf import settings
from django.utils.translation import ugettext as _
from provider.constants import RESPONSE_TYPE_CHOICES, SCOPES
from provider.constants import RESPONSE_TYPE_CHOICES, SCOPES, PUBLIC
from provider.forms import OAuthForm, OAuthValidationError
from provider.utils import now
from provider.oauth2.models import Client, Grant, RefreshToken, Scope
Expand Down Expand Up @@ -298,7 +298,7 @@ def clean(self):
except Client.DoesNotExist:
raise OAuthValidationError({'error': 'invalid_client'})

if client.client_type != 1: # public
if client.client_type != PUBLIC: # public
raise OAuthValidationError({'error': 'invalid_client'})

data['client'] = client
Expand Down
19 changes: 18 additions & 1 deletion provider/oauth2/middleware.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@

from django.conf import settings
from django.contrib import auth
from django.core.exceptions import ImproperlyConfigured
from django.utils.deprecation import MiddlewareMixin

from provider.oauth2.models import AccessToken

import logging
log = logging.getLogger(__name__)

class Oauth2UserMiddleware(object):

class Oauth2UserMiddleware(MiddlewareMixin):
"""
Middleware for using OAuth credentials to authenticate requests
Expand All @@ -32,6 +35,13 @@ def process_request(self, request):
" Insert 'django.contrib.auth.middleware.AuthenticationMiddleware'"
" before this Oauth2UserMiddleware class."
)
if 'django.contrib.auth.backends.RemoteUserBackend' not in settings.AUTHENTICATION_BACKENDS:
raise ImproperlyConfigured(
"Remote user authentication backend is required for this module to work."
" Insert 'django.contrib.auth.backends.RemoteUserBackend' into the"
" AUTHENTICATION_BACKENDS list in your settings."

)
try:
access_token_http = self._http_access_token(request)
access_token_get = request.GET.get('access_token', access_token_http)
Expand All @@ -49,6 +59,13 @@ def process_request(self, request):
user = auth.authenticate(remote_user=token.user.username)
auth.login(request, user)
request.oauth2_client = token.client
request.oauth2_token = token
except Exception as e:
log.error("Oauth2UserMiddleware encountered an exception! "
"{}: {}".format(e.__class__.__name__, e))

def process_response(self, request, response):
if hasattr(request, 'oauth2_token'):
# Set modified=False to prevent the session from being stored and the cookie from being sent
request.session.modified = False
return response
34 changes: 34 additions & 0 deletions provider/oauth2/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from django.utils.decorators import classonlymethod
from django.http.response import JsonResponse


class OAuthRegisteredScopes(object):
scopes = set()


class OAuthRequiredMixin(object):
accepted_oauth_scopes = []

@classonlymethod
def as_view(cls, *args, **kwargs):
for scope in cls.accepted_oauth_scopes:
OAuthRegisteredScopes.scopes.add(scope)

return super(OAuthRequiredMixin, cls).as_view()

def dispatch(self, request, *args, **kwargs):
scopes = list()
if hasattr(request, 'oauth2_token'):
scopes = set(request.oauth2_token.scope.all().values_list('name', flat=True))

if request.user.is_authenticated and scopes.intersection(self.accepted_oauth_scopes):
return super(OAuthRequiredMixin, self).dispatch(request, *args, **kwargs)

return JsonResponse(
{
'error': 'bad_access_token',
'accepted_scopes': sorted(self.accepted_oauth_scopes),
'token_scopes': sorted(scopes)
},
status=401
)
2 changes: 1 addition & 1 deletion provider/oauth2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __unicode__(self):
return self.redirect_uri

def get_default_token_expiry(self):
public = (self.client_type == 1)
public = (self.client_type == constants.PUBLIC)
return get_token_expiry(public)

class Meta:
Expand Down
Empty file.
97 changes: 97 additions & 0 deletions provider/oauth2/tests/test_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import json
from six.moves.urllib_parse import urlparse

from django.shortcuts import reverse
from django.http import QueryDict

from provider.oauth2.models import Scope
from provider.oauth2.mixins import OAuthRegisteredScopes
from provider.oauth2.tests.test_views import BaseOAuth2TestCase


class MiddlewareTestCase(BaseOAuth2TestCase):
fixtures = ['test_oauth2.json']

def setUp(self):
if not Scope.objects.filter(name='read').exists():
Scope.objects.create(name='read')

def _login_authorize_get_token(self):
required_props = ['access_token', 'token_type']

self.login()
self._login_and_authorize()

response = self.client.get(self.redirect_url())
query = QueryDict(urlparse(response['Location']).query)
code = query['code']

response = self.client.post(self.access_token_url(), {
'grant_type': 'authorization_code',
'client_id': self.get_client().client_id,
'client_secret': self.get_client().client_secret,
'code': code})

self.assertEqual(200, response.status_code, response.content)

token = json.loads(response.content)

for prop in required_props:
self.assertIn(prop, token, "Access token response missing "
"required property: %s" % prop)

return token

def test_mixin_scopes(self):
self.assertIn('read', OAuthRegisteredScopes.scopes)

def test_no_token(self):
# user_url = self.live_server_url + reverse('tests:user', args=[self.get_user().pk])
# result = requests.get(user_url)

user_url = reverse('tests:user', args=[self.get_user().pk])
result = self.client.get(user_url)

self.assertEqual(result.status_code, 401)

def test_token_access(self):
self.login()
token_info = self._login_authorize_get_token()
token = token_info['access_token']

# Create a new client to ensure a clean session
oauth_client = self.client_class()

user_url = reverse('tests:user', args=[self.get_user().pk])
result = oauth_client.get(user_url, {'access_token': token})

self.assertEqual(result.status_code, 200)
result_json = result.json()
self.assertEqual(result_json.get('id'), self.get_user().pk)

def test_unauthorized_scope(self):
self.login()
token_info = self._login_authorize_get_token()
token = token_info['access_token']

badscope_url = reverse('tests:badscope')

oauth_client = self.client_class()

result = oauth_client.get(badscope_url, {'access_token': token})

self.assertEqual(result.status_code, 401)
result_json = result.json()
# self.assertEqual(result_json.get('id'), self.get_user().pk)

def test_no_stored_session(self):
self.login()
token_info = self._login_authorize_get_token()
token = token_info['access_token']

oauth_client = self.client_class()

user_url = reverse('tests:user', args=[self.get_user().pk])
result = oauth_client.get(user_url, {'access_token': token})

self.assertNotIn('sessionid', result.cookies)
19 changes: 10 additions & 9 deletions provider/oauth2/tests.py → provider/oauth2/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ def get_password(self):

def _login_and_authorize(self, url_func=None):
if url_func is None:
url_func = lambda: self.auth_url() + '?client_id={}&response_type=code&state=abc'.format(
self.get_client().client_id
)
def url_func():
return self.auth_url() + '?client_id={}&response_type=code&state=abc'.format(
self.get_client().client_id
)

response = self.client.get(url_func())
response = self.client.get(self.auth_url2())
Expand Down Expand Up @@ -344,7 +345,7 @@ def test_refreshing_an_access_token(self):

def test_password_grant_public(self):
c = self.get_client()
c.client_type = 1 # public
c.client_type = constants.PUBLIC
c.save()

response = self.client.post(self.access_token_url(), {
Expand All @@ -363,7 +364,7 @@ def test_password_grant_public(self):

def test_password_grant_confidential(self):
c = self.get_client()
c.client_type = 0 # confidential
c.client_type = constants.CONFIDENTIAL
c.save()

response = self.client.post(self.access_token_url(), {
Expand All @@ -379,7 +380,7 @@ def test_password_grant_confidential(self):

def test_password_grant_confidential_no_secret(self):
c = self.get_client()
c.client_type = 0 # confidential
c.client_type = constants.CONFIDENTIAL
c.save()

response = self.client.post(self.access_token_url(), {
Expand All @@ -393,7 +394,7 @@ def test_password_grant_confidential_no_secret(self):

def test_password_grant_invalid_password_public(self):
c = self.get_client()
c.client_type = 1 # public
c.client_type = constants.PUBLIC
c.save()

response = self.client.post(self.access_token_url(), {
Expand All @@ -408,7 +409,7 @@ def test_password_grant_invalid_password_public(self):

def test_password_grant_invalid_password_confidential(self):
c = self.get_client()
c.client_type = 0 # confidential
c.client_type = constants.CONFIDENTIAL
c.save()

response = self.client.post(self.access_token_url(), {
Expand Down Expand Up @@ -497,7 +498,7 @@ def test_client_form(self):
'name': 'TestName',
'url': 'http://127.0.0.1:8000',
'redirect_uri': 'http://localhost:8000/',
'client_type': constants.CLIENT_TYPES[0][0]})
'client_type': constants.CONFIDENTIAL})
self.assertTrue(form.is_valid())
form.save()

Expand Down
42 changes: 42 additions & 0 deletions provider/oauth2/tests/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from django.conf.urls import url
from django.http.response import JsonResponse
from django.views.generic import View
from django.contrib.auth.mixins import LoginRequiredMixin
from django.contrib.auth.models import User
from django.shortcuts import get_object_or_404

from provider.oauth2.mixins import OAuthRequiredMixin

app_name = 'tests'


class UserView(OAuthRequiredMixin, LoginRequiredMixin, View):
accepted_oauth_scopes = ['read']

def get(self, request, *args, **kwargs):
user = get_object_or_404(User, pk=self.kwargs['pk'])
return JsonResponse(
{
'username': user.username,
'id': user.pk,
}
)


class BadScopeView(OAuthRequiredMixin, LoginRequiredMixin, View):
accepted_oauth_scopes = ['badscope']

def get(self, request, *args, **kwargs):
user = self.request.user
return JsonResponse(
{
'username': user.username,
'id': user.pk,
}
)


urlpatterns = [
url('^badscope$', BadScopeView.as_view(), name='badscope'),
url('^user/(?P<pk>\d+)$', UserView.as_view(), name='user'),
]
2 changes: 1 addition & 1 deletion provider/oauth2/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_access_token(self, request, user, scope, client):
except models.AccessToken.DoesNotExist:
# None found... make a new one!
at = self.create_access_token(request, user, scope, client)
if client.client_type != 1:
if client.client_type != constants.PUBLIC:
self.create_refresh_token(request, user, scope, at, client)
return at

Expand Down
2 changes: 1 addition & 1 deletion provider/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def password(self, request, data, client):

at = self.create_access_token(request, user, scope, client)
# Public clients don't get refresh tokens
if client.client_type != 1:
if client.client_type != constants.PUBLIC:
rt = self.create_refresh_token(request, user, scope, at, client)

return self.access_token_response(at)
Expand Down
8 changes: 7 additions & 1 deletion tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
# 'django.template.context_processors.debug',
# 'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
# 'django.contrib.messages.context_processors.messages',
'django.contrib.messages.context_processors.messages',
],
},
},
Expand All @@ -81,10 +81,16 @@
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'provider.oauth2.middleware.Oauth2UserMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
)

AUTHENTICATION_BACKENDS = [
'django.contrib.auth.backends.RemoteUserBackend',
'django.contrib.auth.backends.ModelBackend',
]

PASSWORD_HASHERS = [
'django.contrib.auth.hashers.PBKDF2PasswordHasher',
'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher',
Expand Down
Loading

0 comments on commit fb3bffa

Please sign in to comment.