diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..dd2aa46c --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,35 @@ +# Read the Docs configuration file for Sphinx projects +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the OS, Python version and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.12" + # You can also specify other tool versions: + # nodejs: "20" + # rust: "1.70" + # golang: "1.20" + +# Build documentation in the "docs/" directory with Sphinx +sphinx: + configuration: docs/conf.py + # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs + # builder: "dirhtml" + # Fail on all warnings to avoid broken references + # fail_on_warning: true + +# Optionally build your docs in additional formats such as PDF and ePub +# formats: +# - pdf +# - epub + +# Optional but recommended, declare the Python requirements required +# to build your documentation +# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +# python: +# install: +# - requirements: docs/requirements.txt diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..5519a401 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,20 @@ +ARG PYVERSION=3.9.19-bullseye + +FROM python:${PYVERSION} AS dev + +WORKDIR /app + +COPY requirements.txt /app/ + +RUN apt-get update \ + && apt-get install -q -y \ + jq \ + && apt-get clean + +RUN pip install -r requirements.txt + +FROM dev as prod + +COPY ./ /app/ + + diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100644 index 00000000..81655a8a --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,47 @@ +pipeline { + agent { + label "python" + } + stages { + stage('Virtualenv'){ + steps { + sh '/usr/bin/virtualenv toxtest -p /usr/bin/python3' + sh 'toxtest/bin/pip install tox==3.28.0 pathlib2' + } + } + stage('Test'){ + parallel { + stage('Unit Test Django 3.1'){ + steps { + sh 'toxtest/bin/tox -e py3.8-django{3.1}' + } + } + stage('Unit Test Django 3.2'){ + steps { + sh 'toxtest/bin/tox -e py3.8-django{3.2}' + } + } + stage('Unit Test Django 4.0'){ + steps { + sh 'toxtest/bin/tox -e py3.8-django{4.0}' + } + } + stage('Unit Test Django 4.1'){ + steps { + sh 'toxtest/bin/tox -e py3.8-django{4.1}' + } + } + stage('Unit Test Django 4.2'){ + steps { + sh 'toxtest/bin/tox -e py3.8-django{4.2}' + } + } + } + } + } + post { + cleanup { + cleanWs() + } + } +} diff --git a/README.rst b/README.rst index 1eb97cad..75568089 100644 --- a/README.rst +++ b/README.rst @@ -12,3 +12,10 @@ License ======= *django-oauth2* is a fork of *django-oauth2-provider* which is released under the MIT License. Please see the LICENSE file for details. + + +Packaging +========= + + $ python -m build + diff --git a/aws_identity_example.py b/aws_identity_example.py new file mode 100644 index 00000000..6bc65199 --- /dev/null +++ b/aws_identity_example.py @@ -0,0 +1,68 @@ +import os +import sys +import json + +from datetime import datetime +from urllib import request, error +import requests + +import boto3 +# aws-v4-signature==2.0 +from awsv4sign import generate_http11_header + +service = 'sts' +region = 'us-west-2' + +session = boto3.Session() +creds = session.get_credentials() +access_key = creds.access_key +secret_key = creds.secret_key +session_token = creds.token + +print(f"access_key: {access_key[:10]}") +print(f"secret_key: {secret_key[:10]}") +print(f"session_token: {session_token[:20]}") +print(f"profile: {os.environ.get('AWS_PROFILE')}") + +url = 'https://sts.{region}.amazonaws.com/'.format(region=region) +httpMethod = 'post' +canonicalHeaders = { + 'host': f'sts.{region}.amazonaws.com', + 'x-amz-date': datetime.utcnow().strftime('%Y%m%dT%H%M%SZ'), + 'content-type': 'application/x-www-form-urlencoded; charset=utf-8', +} +if session_token: + canonicalHeaders['x-amz-security-token'] = session_token + +payload_str = "Action=GetCallerIdentity&Version=2011-06-15" + +headers = generate_http11_header( + service, region, access_key, secret_key, + url, 'post', canonicalHeaders, {}, + '', payload_str +) + +token_request_args = { + "grant_type": "aws_identity", + "region": region, + "post_body": payload_str, + "headers_json": json.dumps(headers), +} +print(payload_str) +print(json.dumps(headers, indent=4)) + +req = request.Request("https://sts.us-west-2.amazonaws.com/", data=payload_str.encode('utf-8'), headers=headers, method='POST') +try: + response = request.urlopen(req) + print(f"Local request test result: {response.read()}") +except error.HTTPError as e: + print(f"HTTPError: {e}: {e.fp.read()}") + sys.exit(1) + +print("Attempting access_token grant request with same signed request:\n") + +token_response = requests.post("http://localhost:8000/oauth2/access_token", + data=token_request_args) +token_info = token_response.json() + +print(token_info) diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..739aba32 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,24 @@ + +services: + test: + build: + context: . + target: dev + user: ${UID} + volumes: + - ${WORKSPACE:-.}:/app + environment: + - DJANGO_SETTINGS_MODULE=tests.settings + + web: + build: + context: . + target: dev + user: ${UID} + volumes: + - ${WORKSPACE:-.}:/app + ports: + - "8000:8000" + environment: + - DJANGO_SETTINGS_MODULE=tests.settings +# entrypoint: [ "python3", "manage.py", "runserver" ] diff --git a/docs/changes.rst b/docs/changes.rst index bc2f4904..b6d78eb9 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -1,3 +1,12 @@ +v 4.1 +----- +* Add aws_identity grant_type +* Update for Django 3.1-4.2 + +v 4.0 +----- +* Update for Django 3.0-4.1 + v 2.4 ----- * Add HTTP Authorization Bearer token support to Oauth2UserMiddleware diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 4308e6fc..23f7b956 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -35,7 +35,7 @@ Add :attr:`provider.oauth2.urls` to your root ``urls.py`` file. :: - url(r'^oauth2/', include('provider.oauth2.urls', namespace = 'oauth2')), + path('oauth2/', include(('provider.oauth2.urls', 'oauth2'))), .. note:: The namespace argument is required. @@ -92,6 +92,27 @@ in :rfc:`4`. .. note:: Remember that you should always use HTTPS for all your OAuth 2 requests otherwise you won't be secured. +Request an Access Token using AWS credentials +--------------------------------------------- + +The new aws_identity grant_type uses the parameters for a signed GetCallerIdentity +request to prove the caller's identity. + +Your client needs to submit a :attr:`POST` request to +:attr:`/oauth2/access_token` including the following parameters: + +* ``region`` - AWS Region +* ``post_body`` - The post body used for signing the request. Usually ``Action=GetCallerIdentity&Version=2011-06-15`` +* ``headers_json`` - The headers produced by the AWSv4 signing process + +The region value is used to produce the standard https://sts.(region).amazonaws.com/ url used to +make the GetCallerIdentity request. The URL is generated server side to reduce the risk of an +attack based on sending an improperly crafted full URL. + +The aws-v4-signature library implements awsv4sign.generate_http11_header(). An example is +presented in the root of the repository in aws_identity_examply.py. + + Integrate with Django Authentication #################################### diff --git a/provider/__init__.py b/provider/__init__.py index 080e846a..e29313b0 100644 --- a/provider/__init__.py +++ b/provider/__init__.py @@ -1 +1,2 @@ -__version__ = "3.2" +__version__ = "4.2" +# The major version is expected to follow the current django major version:q diff --git a/provider/oauth2/admin.py b/provider/oauth2/admin.py index d4711999..00b1c58a 100644 --- a/provider/oauth2/admin.py +++ b/provider/oauth2/admin.py @@ -23,9 +23,15 @@ class AuthorizedClientAdmin(admin.ModelAdmin): raw_id_fields = ('user',) +class AwsAccountAdmin(admin.ModelAdmin): + list_display = ('arn', 'client', 'max_token_lifetime') + raw_id_fields = ('acting_user',) + + admin.site.register(models.AccessToken, AccessTokenAdmin) admin.site.register(models.Grant, GrantAdmin) admin.site.register(models.Client, ClientAdmin) admin.site.register(models.AuthorizedClient, AuthorizedClientAdmin) +admin.site.register(models.AwsAccount, AwsAccountAdmin) admin.site.register(models.RefreshToken) admin.site.register(models.Scope) diff --git a/provider/oauth2/apps.py b/provider/oauth2/apps.py index c9c50344..73b1ae77 100644 --- a/provider/oauth2/apps.py +++ b/provider/oauth2/apps.py @@ -4,3 +4,6 @@ class Oauth2(AppConfig): name = 'provider.oauth2' label = 'oauth2' verbose_name = "Provider Oauth2" + + def ready(self): + import provider.oauth2.signals diff --git a/provider/oauth2/fixtures/test_oauth2.json b/provider/oauth2/fixtures/test_oauth2.json index c8905acc..bc6cf75f 100644 --- a/provider/oauth2/fixtures/test_oauth2.json +++ b/provider/oauth2/fixtures/test_oauth2.json @@ -73,6 +73,24 @@ "model": "auth.user", "pk": 2 }, + { + "fields": { + "date_joined": "2012-01-23 05:53:31", + "email": "", + "first_name": "", + "groups": [], + "is_active": true, + "is_staff": false, + "is_superuser": false, + "last_login": "2012-01-23 05:53:31", + "last_name": "", + "password": "sha1$0cf1b$d66589690edd96b410170fcae5cc2bdfb68821e7", + "user_permissions": [], + "username": "test-user-aws" + }, + "model": "auth.user", + "pk": 3 + }, { "fields": { "name": "basic", @@ -88,5 +106,19 @@ }, "model": "oauth2.scope", "pk": 2 + }, + { + "fields": { + "arn": "arn:aws:iam::123456789012:role/testrole", + "account_id": "123456789012", + "name": "testrole", + "general_type": "role", + "client": 2, + "autoprovision_user": false, + "acting_user": 3, + "scope": ["basic", "advanced"] + }, + "model": "oauth2.awsaccount", + "pk": 1 } ] diff --git a/provider/oauth2/forms.py b/provider/oauth2/forms.py index f51a3c9f..da438175 100644 --- a/provider/oauth2/forms.py +++ b/provider/oauth2/forms.py @@ -1,14 +1,20 @@ -from six import string_types +import logging +from io import StringIO +from urllib import request +from urllib.error import HTTPError +from xml.etree import ElementTree + from django import forms from django.contrib.auth import authenticate from django.conf import settings -from django.utils.translation import ugettext as _ +from django.utils.translation import gettext as _ from django.utils import timezone from provider.constants import RESPONSE_TYPE_CHOICES, SCOPES, PUBLIC from provider.forms import OAuthForm, OAuthValidationError -from provider.utils import now +from provider.utils import now, ArnHelper from provider.oauth2.models import Client, Grant, RefreshToken, Scope +log = logging.getLogger('provider.oauth2') DEFAULT_SCOPE = getattr(settings, 'OAUTH2_DEFAULT_SCOPE', 'read') @@ -53,7 +59,7 @@ class ScopeModelChoiceField(forms.ModelMultipleChoiceField): # widget = forms.TextInput def to_python(self, value): - if isinstance(value, string_types): + if isinstance(value, str): return [s for s in value.split(' ') if s != ''] elif isinstance(value, list): value_list = list() @@ -311,6 +317,46 @@ def clean(self): return data +class AwsGrantForm(OAuthForm): + grant_type = forms.CharField(required=True) + region = forms.CharField(required=True) + post_body = forms.CharField(required=True) + headers_json = forms.JSONField(required=True) + + def clean_grant_type(self): + grant_type = self.cleaned_data.get('grant_type') + + if grant_type != 'aws_identity': + raise OAuthValidationError({'error': 'invalid_grant'}) + + return grant_type + + def clean(self): + region = self.cleaned_data['region'] + + sts_url = f"https://sts.{region}.amazonaws.com/" + + post_body = self.cleaned_data['post_body'] + headers_json = self.cleaned_data['headers_json'] + + req = request.Request(sts_url, data=post_body.encode('utf-8'), headers=headers_json, method='POST') + try: + response = request.urlopen(req) + except HTTPError as e: + log.info("Error calling GetCallerIdentity for aws_identity grant: %s", e) + raise OAuthValidationError({'error': 'invalid_grant'}) + + xmldata = response.read() + + et = ElementTree.parse(StringIO(xmldata.decode('utf-8'))) + root = et.getroot() + result = root.find('{https://sts.amazonaws.com/doc/2011-06-15/}GetCallerIdentityResult') + caller_arn = result.find('{https://sts.amazonaws.com/doc/2011-06-15/}Arn').text + self.cleaned_data['arn_string'] = caller_arn + self.cleaned_data['arn'] = ArnHelper(caller_arn) + return self.cleaned_data + + class PublicClientForm(OAuthForm): client_id = forms.CharField(required=True) grant_type = forms.CharField(required=True) diff --git a/provider/oauth2/migrations/0004_awsaccount.py b/provider/oauth2/migrations/0004_awsaccount.py new file mode 100644 index 00000000..c1d50ccc --- /dev/null +++ b/provider/oauth2/migrations/0004_awsaccount.py @@ -0,0 +1,35 @@ +# Generated by Django 4.2 on 2024-08-07 19:03 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('oauth2', '0003_public_client_options'), + ] + + operations = [ + migrations.CreateModel( + name='AwsAccount', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('arn', models.CharField(help_text='AWS User or Role ARN', max_length=255, unique=True)), + ('general_type', models.CharField(blank=True, max_length=15, null=True)), + ('account_id', models.CharField(blank=True, max_length=12, null=True)), + ('name', models.CharField(blank=True, max_length=255, null=True)), + ('autoprovision_user', models.BooleanField(default=True, help_text='Automatically create acting user on first use')), + ('acting_user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.DO_NOTHING, to=settings.AUTH_USER_MODEL)), + ('max_token_lifetime', models.IntegerField(default=3600, blank=True, help_text="Maximum access token lifetime in seconds")), + ('client', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='oauth2.client')), + ('scope', models.ManyToManyField(help_text='Scopes to be applied to tokens', to='oauth2.scope')), + ], + options={ + 'db_table': 'oauth2_awsaccount', + 'unique_together': {('general_type', 'account_id', 'name')}, + }, + ), + ] diff --git a/provider/oauth2/models.py b/provider/oauth2/models.py index f782ff4d..778cb6e3 100644 --- a/provider/oauth2/models.py +++ b/provider/oauth2/models.py @@ -6,6 +6,7 @@ from django.db import models from django.conf import settings +from django.contrib.auth import get_user_model from provider import constants from provider.constants import CLIENT_TYPES from provider.utils import now, short_token, long_token, get_code_expiry @@ -264,3 +265,34 @@ def __unicode__(self): class Meta: app_label = 'oauth2' db_table = 'oauth2_refreshtoken' + + +class AwsAccount(models.Model): + arn = models.CharField(max_length=255, unique=True, help_text="AWS User or Role ARN") + general_type = models.CharField(max_length=15, blank=True, null=True) + account_id = models.CharField(max_length=12, blank=True, null=True) + name = models.CharField(max_length=255, blank=True, null=True) + + client = models.ForeignKey('Client', models.DO_NOTHING) + autoprovision_user = models.BooleanField(default=True, help_text="Automatically create acting user on first use") + acting_user = models.ForeignKey(settings.AUTH_USER_MODEL, models.DO_NOTHING, blank=True, null=True) + max_token_lifetime = models.IntegerField(default=3600, blank=True, help_text="Maximum access token lifetime in seconds") + scope = models.ManyToManyField("Scope", help_text="Scopes to be applied to tokens") + + class Meta: + app_label = 'oauth2' + db_table = 'oauth2_awsaccount' + unique_together = ( + ('general_type', 'account_id', 'name'), + ) + + def get_or_create_user(self): + if self.acting_user is not None: + return self.acting_user + + if self.autoprovision_user: + username = f"{self.name}_{self.general_type}_{self.account_id}" + User = get_user_model() + self.acting_user, _ = User.objects.get_or_create(username=username) + self.save() + return self.acting_user diff --git a/provider/oauth2/signals.py b/provider/oauth2/signals.py new file mode 100644 index 00000000..fe975f17 --- /dev/null +++ b/provider/oauth2/signals.py @@ -0,0 +1,19 @@ +from django.contrib.auth.models import User +from django.db.models.signals import pre_save +from django.dispatch import receiver + +from provider.utils import ArnHelper +from provider.oauth2.models import AwsAccount + + +@receiver(pre_save, sender=AwsAccount) +def awsaccount_pre_save(sender, instance, **kwargs): + arn = ArnHelper(instance.arn) + if instance.general_type != arn.general_type: + instance.general_type = arn.general_type + + if instance.name != arn.name: + instance.name = arn.name + + if instance.account_id != arn.account_id: + instance.account_id = arn.account_id diff --git a/provider/oauth2/tests/test_middleware.py b/provider/oauth2/tests/test_middleware.py index e3509e84..eace624c 100644 --- a/provider/oauth2/tests/test_middleware.py +++ b/provider/oauth2/tests/test_middleware.py @@ -1,5 +1,5 @@ import json -from six.moves.urllib_parse import urlparse +from urllib.parse import urlparse from django.shortcuts import reverse from django.http import QueryDict diff --git a/provider/oauth2/tests/test_models.py b/provider/oauth2/tests/test_models.py new file mode 100644 index 00000000..7d571672 --- /dev/null +++ b/provider/oauth2/tests/test_models.py @@ -0,0 +1,26 @@ + +from django.test import TestCase + +from provider.oauth2.models import Client, AwsAccount + + +class ModelTests(TestCase): + fixtures = ['test_oauth2'] + + def test_aws_account(self): + client = Client.objects.get(id=2) + + account = AwsAccount.objects.create( + client=client, + arn="arn:aws:iam::123456789012:user/imauser" + ) + + self.assertEqual(account.account_id, "123456789012") + self.assertEqual(account.name, "imauser") + self.assertEqual(account.general_type, "user") + + new_account = AwsAccount.objects.get(pk=account.pk) + + self.assertEqual(new_account.account_id, "123456789012") + self.assertEqual(new_account.name, "imauser") + self.assertEqual(new_account.general_type, "user") diff --git a/provider/oauth2/tests/test_views.py b/provider/oauth2/tests/test_views.py index 78da3b26..2461f01b 100644 --- a/provider/oauth2/tests/test_views.py +++ b/provider/oauth2/tests/test_views.py @@ -1,9 +1,9 @@ import base64 import json import datetime -from six.moves.urllib_parse import urlparse, parse_qs, quote +from urllib.parse import urlparse, parse_qs, quote -from unittest import SkipTest +from unittest.mock import patch from django.http import QueryDict from django.conf import settings from django.shortcuts import reverse @@ -15,7 +15,7 @@ from provider.templatetags.scope import scopes from provider.utils import now as date_now from provider.oauth2.forms import ClientForm -from provider.oauth2.models import Client, Grant, AccessToken, RefreshToken, AuthorizedClient +from provider.oauth2.models import Client, Grant, AccessToken, RefreshToken, AuthorizedClient, AwsAccount from provider.oauth2.backends import BasicClientBackend, RequestParamsClientBackend from provider.oauth2.backends import AccessTokenBackend @@ -52,6 +52,9 @@ def get_user(self): def get_password(self): return 'test' + def get_aws_role(self): + return AwsAccount.objects.get(id=1) + def _login_and_authorize(self, url_func=None): if url_func is None: def url_func(): @@ -508,6 +511,75 @@ def test_access_token_response_valid_token_type(self): token = self._login_authorize_get_token() self.assertEqual(token['token_type'], constants.TOKEN_TYPE, token) + @patch('urllib.request.urlopen') + def test_aws_grant_invalid_caller_identity(self, urlopen): + headers = { + "header1": "a", + "header2": "b", + } + post_body = "mypostbody" + + caller_identity_result = """ + + + arn:aws:iam::123456789012:user/myuser + AIDA27 + 123456789012 + + + 00000000-3558-43b5-8157-07d0769322b5 + + """.strip("\n ").encode('utf-8') + + urlopen.return_value.read.return_value = caller_identity_result + + urlopen.return_value.code = 200 + + response = self.client.post(self.access_token_url(), { + 'grant_type': 'aws_identity', + 'region': "us-west-2", + 'post_body': post_body, + 'headers_json': json.dumps(headers), + }) + + self.assertEqual(400, response.status_code) + self.assertEqual('not_authorized', json.loads(response.content), + response.content) + + @patch('urllib.request.urlopen') + def test_aws_grant_valid_caller_identity(self, urlopen): + headers = { + "header1": "a", + "header2": "b", + } + post_body = "mypostbody" + + caller_identity_result = """ + + + arn:aws:iam::123456789012:assumed-role/testrole/testsession + AIDA27 + 123456789012 + + + 00000000-3558-43b5-8157-07d0769322b5 + + """.strip("\n ").encode('utf-8') + + urlopen.return_value.read.return_value = caller_identity_result + + urlopen.return_value.code = 200 + + response = self.client.post(self.access_token_url(), { + 'grant_type': 'aws_identity', + 'region': "us-west-2", + 'post_body': post_body, + 'headers_json': json.dumps(headers), + }) + + self.assertEqual(200, response.status_code) + self.assertNotIn('refresh_token', json.loads(response.content)) + class AuthBackendTest(BaseOAuth2TestCase): fixtures = ['test_oauth2'] diff --git a/provider/oauth2/tests/urls.py b/provider/oauth2/tests/urls.py index 0eefd116..445ab379 100644 --- a/provider/oauth2/tests/urls.py +++ b/provider/oauth2/tests/urls.py @@ -1,4 +1,4 @@ -from django.conf.urls import url +from django.urls import path from django.http.response import JsonResponse from django.views.generic import View from django.contrib.auth.mixins import LoginRequiredMixin @@ -37,6 +37,6 @@ def get(self, request, *args, **kwargs): urlpatterns = [ - url('^badscope$', BadScopeView.as_view(), name='badscope'), - url('^user/(?P\d+)$', UserView.as_view(), name='user'), + path('badscope', BadScopeView.as_view(), name='badscope'), + path('user/', UserView.as_view(), name='user'), ] diff --git a/provider/oauth2/urls.py b/provider/oauth2/urls.py index 43abcc63..2a759219 100644 --- a/provider/oauth2/urls.py +++ b/provider/oauth2/urls.py @@ -35,22 +35,22 @@ from django.contrib.auth.decorators import login_required from django.views.decorators.csrf import csrf_exempt -from django.conf.urls import url, include +from django.urls import path from provider.oauth2 import views app_name = 'oauth2' urlpatterns = [ - url('^authorize/?$', + path('authorize', login_required(views.CaptureView.as_view()), name='capture'), - url('^authorize/confirm/?$', + path('authorize/confirm', login_required(views.AuthorizeView.as_view()), name='authorize'), - url('^redirect/?$', + path('redirect', login_required(views.RedirectView.as_view()), name='redirect'), - url('^access_token/?$', + path('access_token', csrf_exempt(views.AccessTokenView.as_view()), name='access_token'), ] diff --git a/provider/oauth2/views.py b/provider/oauth2/views.py index ccc9db86..aef9b600 100644 --- a/provider/oauth2/views.py +++ b/provider/oauth2/views.py @@ -1,13 +1,18 @@ from datetime import timedelta +import logging + from django.shortcuts import reverse from provider import constants from provider.views import CaptureViewBase, AuthorizeViewBase, RedirectViewBase from provider.views import AccessTokenViewBase, OAuthError -from provider.utils import now +from provider.utils import now, ArnHelper from provider.oauth2 import forms from provider.oauth2 import models from provider.oauth2 import backends +log = logging.getLogger('provider.oauth2') + + class CaptureView(CaptureViewBase): """ Implementation of :class:`provider.views.Capture`. @@ -115,6 +120,25 @@ def get_password_grant(self, request, data, client): raise OAuthError(form.errors) return form.cleaned_data + def get_aws_grant(self, request, data, _client): + form = forms.AwsGrantForm(data) + if not form.is_valid(): + raise OAuthError(form.errors) + data = form.cleaned_data + arn = data.get('arn') + try: + account = models.AwsAccount.objects.get( + account_id=arn.account_id, + general_type=arn.general_type, + name=arn.name, + ) + except models.AwsAccount.DoesNotExist: + log.info("No AwsAccount found for arn '%s'", arn.arn) + raise OAuthError("not_authorized") + + data['awsaccount'] = account + return data + def get_access_token(self, request, user, scope, client): try: # Attempt to fetch an existing access token. diff --git a/provider/tests/test_utils.py b/provider/tests/test_utils.py index da1ae28b..8de4527b 100644 --- a/provider/tests/test_utils.py +++ b/provider/tests/test_utils.py @@ -4,6 +4,66 @@ from django.test import TestCase +from provider.utils import ArnHelper, BadArn + class UtilsTestCase(TestCase): - pass + def test_arn_user_helper(self): + user_arn = "arn:aws:iam::123456789012:user/imauser" + + arn = ArnHelper(user_arn) + self.assertEqual(arn.account_id, "123456789012") + self.assertEqual(arn.type, "user") + self.assertEqual(arn.name, "imauser") + + def test_arn_user_equality(self): + user_arn = "arn:aws:iam::123456789012:user/imauser" + + arn = ArnHelper(user_arn) + + caller_identity_arn = ArnHelper("arn:aws:iam::123456789012:user/imauser") + + self.assertEqual(arn, caller_identity_arn) + + def test_arn_role_helper(self): + role_arn = "arn:aws:iam::123456789012:role/my-ec2-role" + arn = ArnHelper(role_arn) + + self.assertEqual(arn.account_id, "123456789012") + self.assertEqual(arn.type, "role") + self.assertEqual(arn.name, "my-ec2-role") + + def test_arn_role_caller_identity_helper(self): + role_arn = "arn:aws:sts::123456789012:assumed-role/my-ec2-role/sessionidentifier" + arn = ArnHelper(role_arn) + self.assertEqual(arn.account_id, "123456789012") + self.assertEqual(arn.type, "assumed-role") + self.assertEqual(arn.general_type, "role") + self.assertEqual(arn.name, "my-ec2-role") + self.assertEqual(arn.session, "sessionidentifier") + + def test_arn_role_equality(self): + role_arn = "arn:aws:iam::123456789012:role/my-ec2-role" + arn = ArnHelper(role_arn) + + caller_identity_arn = ArnHelper( + "arn:aws:sts::123456789012:assumed-role/my-ec2-role/sessionidentifier" + ) + + self.assertEqual(arn, caller_identity_arn) + + def test_invalid_arn_too_long(self): + with self.assertRaises(BadArn): + ArnHelper("arn:aws:iam::123456789012:role/my-ec2-role:invalidextra") + + def test_invalid_arn_too_short(self): + with self.assertRaises(BadArn): + ArnHelper("arn:aws:iam::123456789012") + + def test_invalid_arn_bad_prefix(self): + with self.assertRaises(BadArn): + ArnHelper("notarn:aws:iam::123456789012:role/my-ec2-role") + + def test_invalid_arn_bad_service(self): + with self.assertRaises(BadArn): + ArnHelper("arn:aws:s3::123456789012:role/my-ec2-role") diff --git a/provider/utils.py b/provider/utils.py index 25894901..3e5011b0 100644 --- a/provider/utils.py +++ b/provider/utils.py @@ -48,3 +48,38 @@ def get_code_expiry(): :attr:`datetime.timedelta` object. """ return now() + EXPIRE_CODE_DELTA + + +class BadArn(Exception): + pass + + +class ArnHelper: + def __init__(self, arn): + self.arn = arn + parts = arn.split(':') + if len(parts) != 6: + raise BadArn("Arn must have 6 parts") + if parts[:2] != ['arn', 'aws']: + raise BadArn("Arn must start with 'arn:aws:...'") + + if parts[2] not in ['iam', 'sts']: + raise BadArn("Arn must come from 'iam' or 'sts' service") + + self.service = parts[2] + self.account_id = parts[4] + self.entity_ref = parts[5] + entity_parts = self.entity_ref.split('/') + self.type = entity_parts[0] + self.general_type = self.type if self.type != "assumed-role" else "role" + self.name = entity_parts[1] + self.session = entity_parts[2] if len(entity_parts) > 2 else None + + def __eq__(self, other): + if not isinstance(other, ArnHelper): + return False + + if self.account_id == other.account_id and self.general_type == other.general_type and self.name == other.name: + return True + + return False diff --git a/provider/views.py b/provider/views.py index 94e82d2d..af8fba1e 100644 --- a/provider/views.py +++ b/provider/views.py @@ -2,14 +2,16 @@ import json -from six.moves.urllib_parse import urlparse, ParseResult +from urllib.parse import urlparse, ParseResult +from datetime import timedelta from django.http import HttpResponse from django.http import HttpResponseRedirect, QueryDict -from django.utils.translation import ugettext as _ +from django.utils.translation import gettext as _ from django.views.generic.base import TemplateView, View from django.core.exceptions import ObjectDoesNotExist -from provider.oauth2.models import Client, Scope +from provider.oauth2.models import Client +from provider.utils import now from provider import constants @@ -396,7 +398,7 @@ class AccessTokenViewBase(AuthUtilMixin, TemplateView): Authentication backends used to authenticate a particular client. """ - grant_types = ['authorization_code', 'refresh_token', 'password'] + grant_types = ['authorization_code', 'refresh_token', 'password', 'aws_identity'] """ The default grant types supported by this view. """ @@ -425,6 +427,14 @@ def get_password_grant(self, request, data, client): """ raise NotImplementedError + def get_aws_grant(self, request, data, client): + """ + Return a user associated with this request or an error dict. + + :return: ``tuple`` - ``(True or False, user or error_dict)`` + """ + raise NotImplementedError + def get_access_token(self, request, user, scope, client): """ Override to handle fetching of an existing access token. @@ -569,6 +579,16 @@ def password(self, request, data, client): return self.access_token_response(at) + def aws_identity(self, request, data, client): + data = self.get_aws_grant(request, data, client) + account = data.get('awsaccount') + scope = list(account.scope.all()) + + at = self.create_access_token(request, account.get_or_create_user(), scope, account.client) + at.expires = now() + timedelta(seconds=account.max_token_lifetime) + at.save() + return self.access_token_response(at) + def get_handler(self, grant_type): """ Return a function or method that is capable handling the ``grant_type`` @@ -581,6 +601,8 @@ def get_handler(self, grant_type): return self.refresh_token elif grant_type == 'password': return self.password + elif grant_type == 'aws_identity': + return self.aws_identity return None def get(self, request): @@ -614,7 +636,7 @@ def post(self, request): client = self.authenticate(request) - if client is None: + if client is None and grant_type != 'aws_identity': return self.error_response({'error': 'invalid_client'}) handler = self.get_handler(grant_type) diff --git a/requirements.txt b/requirements.txt index 7cf48965..b0235ead 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -Django==3.2 +Django==4.2 shortuuid==1.0.11 -six>=0.16.0 sqlparse>=0.4.3 diff --git a/setup.py b/setup.py index 1eb1e63b..7d605d7d 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,6 @@ ], install_requires=[ "shortuuid>=1.0.11", - "six>=0.16.0", "sqlparse>=0.4.3", ], include_package_data=True, diff --git a/tests/urls.py b/tests/urls.py index ea504bc0..bb1bd5c9 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,10 +1,10 @@ -from django.conf.urls import url, include +from django.urls import path, include from django.contrib import admin admin.autodiscover() urlpatterns = [ - url(r'^admin/', admin.site.urls), - url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), - url(r'^tests/', include('provider.oauth2.tests.urls', namespace='tests')), + path('admin/', admin.site.urls), + path('oauth2/', include('provider.oauth2.urls', namespace='oauth2')), + path('tests/', include('provider.oauth2.tests.urls', namespace='tests')), ] diff --git a/tox.ini b/tox.ini index 923b96f5..3c713f66 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] toxworkdir={env:TOX_WORK_DIR:.tox} downloadcache = {toxworkdir}/cache/ -envlist = py{3.8,3.9,3.10}-django{3.0,3.1,3.2,4.0,4.1} +envlist = py{3.8,3.9,3.10}-django{3.1,3.2,4.0,4.1,4.2} [testenv] setenv = @@ -15,11 +15,6 @@ python = 3.8: py3.8-django{3.0,3.1,3.2,4.0,4.1} -[testenv:py3.8-django3.0] -basepython = python3.8 -deps = Django>=3.0,<3.1 - {[testenv]deps} - [testenv:py3.8-django3.1] basepython = python3.8 deps = Django>=3.1,<3.2 @@ -39,3 +34,8 @@ deps = Django>=4.0,<4.1 basepython = python3.8 deps = Django>=4.1,<4.2 {[testenv]deps} + +[testenv:py3.8-django4.2] +basepython = python3.8 +deps = Django>=4.2,<5.0 + {[testenv]deps}