Skip to content

Commit

Permalink
Added Allowed Origins application setting
Browse files Browse the repository at this point in the history
  • Loading branch information
akanstantsinau authored and dopry committed Oct 17, 2023
1 parent 70074b7 commit 4d38e4e
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 12 deletions.
4 changes: 4 additions & 0 deletions docs/tutorial/tutorial_01.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ point your browser to http://localhost:8000/o/applications/ and add an Applicati
specifies one of the verified redirection uris. For this tutorial, paste verbatim the value
`https://www.getpostman.com/oauth2/callback`

* `Allowed origins`: Web applications use Cross-Origin Resource Sharing (CORS) to request resources from origins other than their own.
You can provide list of origins of web applications that will have access to the token endpoint of :term:`Authorization Server`.
This setting controls only token endpoint and it is not related with Django CORS Headers settings.

* `Client type`: this value affects the security level at which some communications between the client application and
the authorization server are performed. For this tutorial choose *Confidential*.

Expand Down
18 changes: 18 additions & 0 deletions oauth2_provider/migrations/0010_application_allowed_origins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.1.5 on 2023-09-27 20:15

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("oauth2_provider", "0009_add_hash_client_secret"),
]

operations = [
migrations.AddField(
model_name="application",
name="allowed_origins",
field=models.TextField(blank=True, help_text="Allowed origins list to enable CORS, space separated"),
),
]
40 changes: 37 additions & 3 deletions oauth2_provider/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from .scopes import get_scopes_backend
from .settings import oauth2_settings
from .utils import jwk_from_pem
from .validators import RedirectURIValidator, WildcardSet

from .validators import RedirectURIValidator, WildcardSet, URIValidator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -132,7 +131,10 @@ class AbstractApplication(models.Model):
created = models.DateTimeField(auto_now_add=True)
updated = models.DateTimeField(auto_now=True)
algorithm = models.CharField(max_length=5, choices=ALGORITHM_TYPES, default=NO_ALGORITHM, blank=True)

allowed_origins = models.TextField(
blank=True,
help_text=_("Allowed origins list to enable CORS, space separated"),
)
class Meta:
abstract = True

Expand Down Expand Up @@ -172,6 +174,14 @@ def post_logout_redirect_uri_allowed(self, uri):
"""
return redirect_to_uri_allowed(uri, self.post_logout_redirect_uris.split())

def origin_allowed(self, origin):
"""
Checks if given origin is one of the items in :attr:`allowed_origins` string
:param origin: Origin to check
"""
return self.allowed_origins and is_origin_allowed(origin, self.allowed_origins.split())

def clean(self):
from django.core.exceptions import ValidationError

Expand Down Expand Up @@ -202,6 +212,13 @@ def clean(self):
grant_type=self.authorization_grant_type
)
)
allowed_origins = self.allowed_origins.strip().split()
if allowed_origins:
# oauthlib allows only https scheme for CORS
validator = URIValidator({"https"})
for uri in allowed_origins:
validator(uri)

if self.algorithm == AbstractApplication.RS256_ALGORITHM:
if not oauth2_settings.OIDC_RSA_PRIVATE_KEY:
raise ValidationError(_("You must set OIDC_RSA_PRIVATE_KEY to use RSA algorithm"))
Expand Down Expand Up @@ -777,3 +794,20 @@ def redirect_to_uri_allowed(uri, allowed_uris):
return True

return False


def is_origin_allowed(origin, allowed_origins):
"""
Checks if a given origin uri is allowed based on the provided allowed_origins configuration.
:param origin: Origin URI to check
:param allowed_origins: A list of Origin URIs that are allowed
"""

parsed_origin = urlparse(origin)
for allowed_origin in allowed_origins:
parsed_allowed_origin = urlparse(allowed_origin)
if (parsed_allowed_origin.scheme == parsed_origin.scheme
and parsed_allowed_origin.netloc == parsed_origin.netloc):
return True
return False
2 changes: 2 additions & 0 deletions oauth2_provider/oauth2_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def extract_headers(self, request):
del headers["wsgi.errors"]
if "HTTP_AUTHORIZATION" in headers:
headers["Authorization"] = headers["HTTP_AUTHORIZATION"]
# Add Access-Control-Allow-Origin header to the token endpoint response for authentication code grant, if the origin is allowed by RequestValidator.is_origin_allowed.
# https://github.com/oauthlib/oauthlib/pull/791
if "HTTP_ORIGIN" in headers:
headers["Origin"] = headers["HTTP_ORIGIN"]
if request.is_secure():
Expand Down
9 changes: 9 additions & 0 deletions oauth2_provider/oauth2_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,3 +958,12 @@ def get_userinfo_claims(self, request):

def get_additional_claims(self, request):
return {}

def is_origin_allowed(self, client_id, origin, request, *args, **kwargs):
if request.client is None or not request.client.client_id:
return False
application = Application.objects.filter(client_id=request.client.client_id).first()
if application:
return application.origin_allowed(origin)
else:
return False
2 changes: 2 additions & 0 deletions oauth2_provider/views/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def get_form_class(self):
"authorization_grant_type",
"redirect_uris",
"post_logout_redirect_uris",
"allowed_origins",
"algorithm",
),
)
Expand Down Expand Up @@ -99,6 +100,7 @@ def get_form_class(self):
"authorization_grant_type",
"redirect_uris",
"post_logout_redirect_uris",
"allowed_origins",
"algorithm",
),
)
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def application():
authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE,
algorithm=Application.RS256_ALGORITHM,
client_secret=CLEARTEXT_SECRET,
allowed_origins="https://example.com",
)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Generated by Django 4.1.5 on 2023-09-27 22:25

from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):

dependencies = [
migrations.swappable_dependency(settings.OAUTH2_PROVIDER_ID_TOKEN_MODEL),
("tests", "0004_basetestapplication_hash_client_secret_and_more"),
]

operations = [
migrations.AddField(
model_name="basetestapplication",
name="allowed_origins",
field=models.TextField(blank=True, help_text="Allowed origins list to enable CORS, space separated"),
),
migrations.AddField(
model_name="sampleapplication",
name="allowed_origins",
field=models.TextField(blank=True, help_text="Allowed origins list to enable CORS, space separated"),
),
]
44 changes: 35 additions & 9 deletions tests/test_cors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from urllib.parse import parse_qs, urlparse

import pytest
Expand All @@ -6,18 +7,11 @@
from django.urls import reverse

from oauth2_provider.models import get_application_model
from oauth2_provider.oauth2_validators import OAuth2Validator

from . import presets
from .utils import get_basic_auth_header


class CorsOAuth2Validator(OAuth2Validator):
def is_origin_allowed(self, client_id, origin, request, *args, **kwargs):
"""Enable CORS in OAuthLib"""
return True


Application = get_application_model()
UserModel = get_user_model()

Expand Down Expand Up @@ -50,10 +44,10 @@ def setUp(self):
client_type=Application.CLIENT_CONFIDENTIAL,
authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE,
client_secret=CLEARTEXT_SECRET,
allowed_origins=CLIENT_URI,
)

self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https"]
self.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CorsOAuth2Validator

def tearDown(self):
self.application.delete()
Expand All @@ -76,10 +70,42 @@ def test_cors_header(self):
auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET)
auth_headers["HTTP_ORIGIN"] = CLIENT_URI
response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers)

content = json.loads(response.content.decode("utf-8"))

self.assertEqual(response.status_code, 200)
self.assertEqual(response["Access-Control-Allow-Origin"], CLIENT_URI)

token_request_data = {
"grant_type": "refresh_token",
"refresh_token": content["refresh_token"],
"scope": content["scope"],
}
response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers)
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Access-Control-Allow-Origin"], CLIENT_URI)

def test_no_cors_header(self):
def test_no_cors_header_origin_not_allowed(self):
"""
Test that /token endpoint does not have Access-Control-Allow-Origin
when request origin is not in Application.allowed_origins
"""
authorization_code = self._get_authorization_code()

# exchange authorization code for a valid access token
token_request_data = {
"grant_type": "authorization_code",
"code": authorization_code,
"redirect_uri": CLIENT_URI,
}

auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET)
auth_headers["HTTP_ORIGIN"] = "another_example.org"
response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers)
self.assertEqual(response.status_code, 200)
self.assertFalse(response.has_header("Access-Control-Allow-Origin"))

def test_no_cors_header_no_origin(self):
"""
Test that /token endpoint does not have Access-Control-Allow-Origin
"""
Expand Down

0 comments on commit 4d38e4e

Please sign in to comment.