diff --git a/docs/settings.rst b/docs/settings.rst index f31aff533..a7cac94a1 100644 --- a/docs/settings.rst +++ b/docs/settings.rst @@ -63,6 +63,17 @@ assigned ports. Note that you may override ``Application.get_allowed_schemes()`` to set this on a per-application basis. +ALLOWED_SCHEMES +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Default: ``["https"]`` + +A list of schemes that the ``allowed_origins`` field will be validated against. +Setting this to ``["https"]`` only in production is strongly recommended. +Adding ``"http"`` to the list is considered to be safe only for local development and testing. +Note that `OAUTHLIB_INSECURE_TRANSPORT `_ +environment variable should be also set to allow http origins. + APPLICATION_MODEL ~~~~~~~~~~~~~~~~~ diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index c37057e49..e09b41664 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -20,8 +20,7 @@ from .scopes import get_scopes_backend from .settings import oauth2_settings from .utils import jwk_from_pem -from .validators import RedirectURIValidator, URIValidator, WildcardSet - +from .validators import RedirectURIValidator, URIValidator, WildcardSet, AllowedURIValidator logger = logging.getLogger(__name__) @@ -218,7 +217,7 @@ def clean(self): allowed_origins = self.allowed_origins.strip().split() if allowed_origins: # oauthlib allows only https scheme for CORS - validator = URIValidator({"https"}) + validator = AllowedURIValidator(oauth2_settings.ALLOWED_SCHEMES, "Origin") for uri in allowed_origins: validator(uri) @@ -808,6 +807,10 @@ def is_origin_allowed(origin, allowed_origins): """ parsed_origin = urlparse(origin) + + if parsed_origin.scheme not in oauth2_settings.ALLOWED_SCHEMES: + return False + for allowed_origin in allowed_origins: parsed_allowed_origin = urlparse(allowed_origin) if ( diff --git a/oauth2_provider/settings.py b/oauth2_provider/settings.py index aa7de7351..c5af9ebae 100644 --- a/oauth2_provider/settings.py +++ b/oauth2_provider/settings.py @@ -68,6 +68,7 @@ "REFRESH_TOKEN_ADMIN_CLASS": "oauth2_provider.admin.RefreshTokenAdmin", "REQUEST_APPROVAL_PROMPT": "force", "ALLOWED_REDIRECT_URI_SCHEMES": ["http", "https"], + "ALLOWED_SCHEMES": ["https"], "OIDC_ENABLED": False, "OIDC_ISS_ENDPOINT": "", "OIDC_USERINFO_ENDPOINT": "", diff --git a/oauth2_provider/validators.py b/oauth2_provider/validators.py index 6c8fa3839..9ecced631 100644 --- a/oauth2_provider/validators.py +++ b/oauth2_provider/validators.py @@ -31,6 +31,32 @@ def __call__(self, value): raise ValidationError("Redirect URIs must not contain fragments") +class AllowedURIValidator(URIValidator): + def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fragments=False): + """ + :params schemes: List of allowed schemes. E.g.: ["https"] + :params name: Name of the validater URI required for validation message. E.g.: "Origin" + :params allow_path: If URI can contain path part + :params allow_query: If URI can contain query part + :params allow_fragments: If URI can contain fragments part + """ + super().__init__(schemes=schemes) + self.name = name + self.allow_path = allow_path + self.allow_query = allow_query + self.allow_fragments = allow_fragments + + def __call__(self, value): + super().__call__(value) + value = force_str(value) + scheme, netloc, path, query, fragment = urlsplit(value) + if path and not self.allow_path: + raise ValidationError("{} URIs must not contain path".format(self.name)) + if query and not self.allow_query: + raise ValidationError("{} URIs must not contain query".format(self.name)) + if fragment and not self.allow_fragments: + raise ValidationError("{} URIs must not contain fragments".format(self.name)) + ## # WildcardSet is a special set that contains everything. # This is required in order to move validation of the scheme from diff --git a/tests/test_validators.py b/tests/test_validators.py index 0760e0290..d77e128a3 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -2,7 +2,7 @@ from django.core.validators import ValidationError from django.test import TestCase -from oauth2_provider.validators import RedirectURIValidator +from oauth2_provider.validators import RedirectURIValidator, AllowedURIValidator @pytest.mark.usefixtures("oauth2_settings") @@ -36,6 +36,11 @@ def test_validate_custom_uri_scheme(self): # Check ValidationError not thrown validator(uri) + validator = AllowedURIValidator(["my-scheme", "https", "git+ssh"], "Origin") + for uri in good_uris: + # Check ValidationError not thrown + validator(uri) + def test_validate_bad_uris(self): validator = RedirectURIValidator(allowed_schemes=["https"]) self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https", "good"] @@ -61,3 +66,67 @@ def test_validate_bad_uris(self): for uri in bad_uris: with self.assertRaises(ValidationError): validator(uri) + + def test_validate_good_origin_uris(self): + """ + Test AllowedURIValidator validates origin URIs if they match requirements + """ + validator = AllowedURIValidator( + ["https"], + "Origin", + allow_path=False, + allow_query=False, + allow_fragments=False, + ) + good_uris = [ + "https://example.com", + "https://example.com:8080", + "https://example", + "https://localhost", + "https://1.1.1.1", + "https://127.0.0.1", + "https://255.255.255.255", + ] + for uri in good_uris: + # Check ValidationError not thrown + validator(uri) + + def test_validate_bad_origin_uris(self): + """ + Test AllowedURIValidator rejects origin URIs if they do not match requirements + """ + validator = AllowedURIValidator( + ["https"], + "Origin", + allow_path=False, + allow_query=False, + allow_fragments=False, + ) + bad_uris = [ + "http:/example.com", + "HTTP://localhost", + "HTTP://example.com", + "HTTP://example.com.", + "http://example.com/#fragment", + "123://example.com", + "http://fe80::1", + "git+ssh://example.com", + "my-scheme://example.com", + "uri-without-a-scheme", + "https://example.com/#fragment", + "good://example.com/#fragment", + " ", + "", + # Bad IPv6 URL, urlparse behaves differently for these + 'https://[">', + # Origin uri should not contain path, query of fragment parts + # https://www.rfc-editor.org/rfc/rfc6454#section-7.1 + "https:/example.com/", + "https:/example.com/test", + "https:/example.com/?q=test", + "https:/example.com/#test", + ] + + for uri in bad_uris: + with self.assertRaises(ValidationError): + validator(uri)