diff --git a/docs/tutorial/tutorial_01.rst b/docs/tutorial/tutorial_01.rst index 9f79e895f..5462a32fb 100644 --- a/docs/tutorial/tutorial_01.rst +++ b/docs/tutorial/tutorial_01.rst @@ -91,6 +91,10 @@ point your browser to http://localhost:8000/o/applications/ and add an Applicati specifies one of the verified redirection uris. For this tutorial, paste verbatim the value `https://www.getpostman.com/oauth2/callback` + * `Allowed origins`: Web applications use Cross-Origin Resource Sharing (CORS) to request resources from origins other than their own. + You can provide list of origins of web applications that will have access to the token endpoint of :term:`Authorization Server`. + This setting controls only token endpoint and it is not related with Django CORS Headers settings. + * `Client type`: this value affects the security level at which some communications between the client application and the authorization server are performed. For this tutorial choose *Confidential*. diff --git a/oauth2_provider/migrations/0010_application_allowed_origins.py b/oauth2_provider/migrations/0010_application_allowed_origins.py new file mode 100644 index 000000000..39ca9af8e --- /dev/null +++ b/oauth2_provider/migrations/0010_application_allowed_origins.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.5 on 2023-09-27 20:15 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("oauth2_provider", "0009_add_hash_client_secret"), + ] + + operations = [ + migrations.AddField( + model_name="application", + name="allowed_origins", + field=models.TextField(blank=True, help_text="Allowed origins list to enable CORS, space separated"), + ), + ] diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index 649f0cd33..4d31d5e19 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -20,8 +20,7 @@ from .scopes import get_scopes_backend from .settings import oauth2_settings from .utils import jwk_from_pem -from .validators import RedirectURIValidator, WildcardSet - +from .validators import RedirectURIValidator, WildcardSet, URIValidator logger = logging.getLogger(__name__) @@ -132,7 +131,10 @@ class AbstractApplication(models.Model): created = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now=True) algorithm = models.CharField(max_length=5, choices=ALGORITHM_TYPES, default=NO_ALGORITHM, blank=True) - + allowed_origins = models.TextField( + blank=True, + help_text=_("Allowed origins list to enable CORS, space separated"), + ) class Meta: abstract = True @@ -172,6 +174,14 @@ def post_logout_redirect_uri_allowed(self, uri): """ return redirect_to_uri_allowed(uri, self.post_logout_redirect_uris.split()) + def origin_allowed(self, origin): + """ + Checks if given origin is one of the items in :attr:`allowed_origins` string + + :param origin: Origin to check + """ + return self.allowed_origins and is_origin_allowed(origin, self.allowed_origins.split()) + def clean(self): from django.core.exceptions import ValidationError @@ -202,6 +212,13 @@ def clean(self): grant_type=self.authorization_grant_type ) ) + allowed_origins = self.allowed_origins.strip().split() + if allowed_origins: + # oauthlib allows only https scheme for CORS + validator = URIValidator({"https"}) + for uri in allowed_origins: + validator(uri) + if self.algorithm == AbstractApplication.RS256_ALGORITHM: if not oauth2_settings.OIDC_RSA_PRIVATE_KEY: raise ValidationError(_("You must set OIDC_RSA_PRIVATE_KEY to use RSA algorithm")) @@ -777,3 +794,20 @@ def redirect_to_uri_allowed(uri, allowed_uris): return True return False + + +def is_origin_allowed(origin, allowed_origins): + """ + Checks if a given origin uri is allowed based on the provided allowed_origins configuration. + + :param origin: Origin URI to check + :param allowed_origins: A list of Origin URIs that are allowed + """ + + parsed_origin = urlparse(origin) + for allowed_origin in allowed_origins: + parsed_allowed_origin = urlparse(allowed_origin) + if (parsed_allowed_origin.scheme == parsed_origin.scheme + and parsed_allowed_origin.netloc == parsed_origin.netloc): + return True + return False diff --git a/oauth2_provider/oauth2_backends.py b/oauth2_provider/oauth2_backends.py index c99a8699b..401e9fc5c 100644 --- a/oauth2_provider/oauth2_backends.py +++ b/oauth2_provider/oauth2_backends.py @@ -75,6 +75,8 @@ def extract_headers(self, request): del headers["wsgi.errors"] if "HTTP_AUTHORIZATION" in headers: headers["Authorization"] = headers["HTTP_AUTHORIZATION"] + # Add Access-Control-Allow-Origin header to the token endpoint response for authentication code grant, if the origin is allowed by RequestValidator.is_origin_allowed. + # https://github.com/oauthlib/oauthlib/pull/791 if "HTTP_ORIGIN" in headers: headers["Origin"] = headers["HTTP_ORIGIN"] if request.is_secure(): diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index ae6b92813..6a4acc8e3 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -958,3 +958,12 @@ def get_userinfo_claims(self, request): def get_additional_claims(self, request): return {} + + def is_origin_allowed(self, client_id, origin, request, *args, **kwargs): + if request.client is None or not request.client.client_id: + return False + application = Application.objects.filter(client_id=request.client.client_id).first() + if application: + return application.origin_allowed(origin) + else: + return False diff --git a/oauth2_provider/views/application.py b/oauth2_provider/views/application.py index 9b5a8ffb6..b896c45e3 100644 --- a/oauth2_provider/views/application.py +++ b/oauth2_provider/views/application.py @@ -39,6 +39,7 @@ def get_form_class(self): "authorization_grant_type", "redirect_uris", "post_logout_redirect_uris", + "allowed_origins", "algorithm", ), ) @@ -99,6 +100,7 @@ def get_form_class(self): "authorization_grant_type", "redirect_uris", "post_logout_redirect_uris", + "allowed_origins", "algorithm", ), ) diff --git a/tests/conftest.py b/tests/conftest.py index d620c3f59..2cc3c3901 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -108,6 +108,7 @@ def application(): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, algorithm=Application.RS256_ALGORITHM, client_secret=CLEARTEXT_SECRET, + allowed_origins="https://example.com", ) diff --git a/tests/migrations/0005_basetestapplication_allowed_origins_and_more.py b/tests/migrations/0005_basetestapplication_allowed_origins_and_more.py new file mode 100644 index 000000000..fbc083a2b --- /dev/null +++ b/tests/migrations/0005_basetestapplication_allowed_origins_and_more.py @@ -0,0 +1,26 @@ +# Generated by Django 4.1.5 on 2023-09-27 22:25 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.OAUTH2_PROVIDER_ID_TOKEN_MODEL), + ("tests", "0004_basetestapplication_hash_client_secret_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="basetestapplication", + name="allowed_origins", + field=models.TextField(blank=True, help_text="Allowed origins list to enable CORS, space separated"), + ), + migrations.AddField( + model_name="sampleapplication", + name="allowed_origins", + field=models.TextField(blank=True, help_text="Allowed origins list to enable CORS, space separated"), + ), + ] diff --git a/tests/test_cors.py b/tests/test_cors.py index 9d7260bc9..64f2a5fec 100644 --- a/tests/test_cors.py +++ b/tests/test_cors.py @@ -1,3 +1,4 @@ +import json from urllib.parse import parse_qs, urlparse import pytest @@ -6,18 +7,11 @@ from django.urls import reverse from oauth2_provider.models import get_application_model -from oauth2_provider.oauth2_validators import OAuth2Validator from . import presets from .utils import get_basic_auth_header -class CorsOAuth2Validator(OAuth2Validator): - def is_origin_allowed(self, client_id, origin, request, *args, **kwargs): - """Enable CORS in OAuthLib""" - return True - - Application = get_application_model() UserModel = get_user_model() @@ -50,10 +44,10 @@ def setUp(self): client_type=Application.CLIENT_CONFIDENTIAL, authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, client_secret=CLEARTEXT_SECRET, + allowed_origins=CLIENT_URI, ) self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https"] - self.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CorsOAuth2Validator def tearDown(self): self.application.delete() @@ -76,10 +70,42 @@ def test_cors_header(self): auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET) auth_headers["HTTP_ORIGIN"] = CLIENT_URI response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + + content = json.loads(response.content.decode("utf-8")) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response["Access-Control-Allow-Origin"], CLIENT_URI) + + token_request_data = { + "grant_type": "refresh_token", + "refresh_token": content["refresh_token"], + "scope": content["scope"], + } + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) self.assertEqual(response["Access-Control-Allow-Origin"], CLIENT_URI) - def test_no_cors_header(self): + def test_no_cors_header_origin_not_allowed(self): + """ + Test that /token endpoint does not have Access-Control-Allow-Origin + when request origin is not in Application.allowed_origins + """ + authorization_code = self._get_authorization_code() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": CLIENT_URI, + } + + auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET) + auth_headers["HTTP_ORIGIN"] = "another_example.org" + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + self.assertFalse(response.has_header("Access-Control-Allow-Origin")) + + def test_no_cors_header_no_origin(self): """ Test that /token endpoint does not have Access-Control-Allow-Origin """