From 45ea0acc4550ed06ba611dc3a744e302e9fe9965 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Sun, 19 Feb 2023 12:48:24 +0100 Subject: [PATCH 01/10] Add `types` to rest clients and `utils.py` --- auth0/exceptions.py | 9 ++++--- auth0/rest.py | 59 ++++++++++++++++++++++++--------------------- auth0/rest_async.py | 31 +++++++++++++----------- auth0/types.py | 6 +++++ auth0/utils.py | 2 +- 5 files changed, 62 insertions(+), 45 deletions(-) create mode 100644 auth0/types.py diff --git a/auth0/exceptions.py b/auth0/exceptions.py index 7f9aa325..1329f5fd 100644 --- a/auth0/exceptions.py +++ b/auth0/exceptions.py @@ -1,16 +1,19 @@ +from __future__ import annotations +from typing import Any + class Auth0Error(Exception): - def __init__(self, status_code, error_code, message, content=None): + def __init__(self, status_code: int, error_code: str, message: str, content: Any | None = None) -> None: self.status_code = status_code self.error_code = error_code self.message = message self.content = content - def __str__(self): + def __str__(self) -> str: return f"{self.status_code}: {self.message}" class RateLimitError(Auth0Error): - def __init__(self, error_code, message, reset_at): + def __init__(self, error_code: str, message: str, reset_at: int) -> None: super().__init__(status_code=429, error_code=error_code, message=message) self.reset_at = reset_at diff --git a/auth0/rest.py b/auth0/rest.py index c84e5d7f..e7573082 100644 --- a/auth0/rest.py +++ b/auth0/rest.py @@ -1,13 +1,17 @@ +from __future__ import annotations import base64 import json import platform import sys from random import randint from time import sleep +from typing import Any, Mapping import requests from auth0.exceptions import Auth0Error, RateLimitError +from auth0.rest_async import RequestsResponse +from auth0.types import RequestData, TimeoutType UNKNOWN_ERROR = "a0.sdk.internal.unknown" @@ -32,7 +36,7 @@ class RestClientOptions: (defaults to 3) """ - def __init__(self, telemetry=None, timeout=None, retries=None): + def __init__(self, telemetry: bool | None = None, timeout: TimeoutType | None = None, retries: int | None = None) -> None: self.telemetry = True self.timeout = 5.0 self.retries = 3 @@ -51,6 +55,7 @@ class RestClient: """Provides simple methods for handling all RESTful api endpoints. Args: + jwt (str): The JWT to be used with the RestClient. telemetry (bool, optional): Enable or disable Telemetry (defaults to True) timeout (float or tuple, optional): Change the requests @@ -64,7 +69,7 @@ class RestClient: (defaults to 3) """ - def __init__(self, jwt, telemetry=True, timeout=5.0, options=None): + def __init__(self, jwt: str, telemetry: bool = True, timeout: TimeoutType = 5.0, options: RestClientOptions | None = None) -> None: if options is None: options = RestClientOptions(telemetry=telemetry, timeout=timeout) @@ -111,22 +116,22 @@ def __init__(self, jwt, telemetry=True, timeout=5.0, options=None): self.timeout = options.timeout # Returns a hard cap for the maximum number of retries allowed (10) - def MAX_REQUEST_RETRIES(self): + def MAX_REQUEST_RETRIES(self) -> int: return 10 # Returns the maximum amount of jitter to introduce in milliseconds (100ms) - def MAX_REQUEST_RETRY_JITTER(self): + def MAX_REQUEST_RETRY_JITTER(self) -> int: return 100 # Returns the maximum delay window allowed (1000ms) - def MAX_REQUEST_RETRY_DELAY(self): + def MAX_REQUEST_RETRY_DELAY(self) -> int: return 1000 # Returns the minimum delay window allowed (100ms) - def MIN_REQUEST_RETRY_DELAY(self): + def MIN_REQUEST_RETRY_DELAY(self) -> int: return 100 - def get(self, url, params=None, headers=None): + def get(self, url: str, params: dict[str, Any] | None = None, headers: dict[str, str] | None = None) -> Any: request_headers = self.base_headers.copy() request_headers.update(headers or {}) @@ -162,7 +167,7 @@ def get(self, url, params=None, headers=None): # Return the final Response return self._process_response(response) - def post(self, url, data=None, headers=None): + def post(self, url: str, data: RequestData | None = None, headers: dict[str, str] | None = None) -> Any: request_headers = self.base_headers.copy() request_headers.update(headers or {}) @@ -171,7 +176,7 @@ def post(self, url, data=None, headers=None): ) return self._process_response(response) - def file_post(self, url, data=None, files=None): + def file_post(self, url: str, data: RequestData | None = None, files: dict[str, Any] | None = None) -> Any: headers = self.base_headers.copy() headers.pop("Content-Type", None) @@ -180,7 +185,7 @@ def file_post(self, url, data=None, files=None): ) return self._process_response(response) - def patch(self, url, data=None): + def patch(self, url: str, data: RequestData | None = None) -> Any: headers = self.base_headers.copy() response = requests.patch( @@ -188,7 +193,7 @@ def patch(self, url, data=None): ) return self._process_response(response) - def put(self, url, data=None): + def put(self, url: str, data: RequestData | None = None) -> Any: headers = self.base_headers.copy() response = requests.put( @@ -196,7 +201,7 @@ def put(self, url, data=None): ) return self._process_response(response) - def delete(self, url, params=None, data=None): + def delete(self, url: str, params: dict[str, Any] | None = None, data: RequestData | None = None) -> Any: headers = self.base_headers.copy() response = requests.delete( @@ -208,7 +213,7 @@ def delete(self, url, params=None, data=None): ) return self._process_response(response) - def _calculate_wait(self, attempt): + def _calculate_wait(self, attempt: int) -> int: # Retry the request. Apply a exponential backoff for subsequent attempts, using this formula: # max(MIN_REQUEST_RETRY_DELAY, min(MAX_REQUEST_RETRY_DELAY, (100ms * (2 ** attempt - 1)) + random_between(1, MAX_REQUEST_RETRY_JITTER))) @@ -229,10 +234,10 @@ def _calculate_wait(self, attempt): return wait - def _process_response(self, response): + def _process_response(self, response: requests.Response) -> Any: return self._parse(response).content() - def _parse(self, response): + def _parse(self, response: requests.Response) -> Response: if not response.text: return EmptyResponse(response.status_code) try: @@ -242,12 +247,12 @@ def _parse(self, response): class Response: - def __init__(self, status_code, content, headers): + def __init__(self, status_code: int, content: Any, headers: Mapping[str, str]) -> None: self._status_code = status_code self._content = content self._headers = headers - def content(self): + def content(self) -> Any: if self._is_error(): if self._status_code == 429: reset_at = int(self._headers.get("x-ratelimit-reset", "-1")) @@ -272,7 +277,7 @@ def content(self): else: return self._content - def _is_error(self): + def _is_error(self) -> bool: return self._status_code is None or self._status_code >= 400 # Adding these methods to force implementation in subclasses because they are references in this parent class @@ -284,11 +289,11 @@ def _error_message(self): class JsonResponse(Response): - def __init__(self, response): + def __init__(self, response: requests.Response | RequestsResponse) -> None: content = json.loads(response.text) super().__init__(response.status_code, content, response.headers) - def _error_code(self): + def _error_code(self) -> str: if "errorCode" in self._content: return self._content.get("errorCode") elif "error" in self._content: @@ -298,7 +303,7 @@ def _error_code(self): else: return UNKNOWN_ERROR - def _error_message(self): + def _error_message(self) -> str: if "error_description" in self._content: return self._content.get("error_description") message = self._content.get("message", "") @@ -308,22 +313,22 @@ def _error_message(self): class PlainResponse(Response): - def __init__(self, response): + def __init__(self, response: requests.Response | RequestsResponse) -> None: super().__init__(response.status_code, response.text, response.headers) - def _error_code(self): + def _error_code(self) -> str: return UNKNOWN_ERROR - def _error_message(self): + def _error_message(self) -> str: return self._content class EmptyResponse(Response): - def __init__(self, status_code): + def __init__(self, status_code: int) -> None: super().__init__(status_code, "", {}) - def _error_code(self): + def _error_code(self) -> str: return UNKNOWN_ERROR - def _error_message(self): + def _error_message(self) -> str: return "" diff --git a/auth0/rest_async.py b/auth0/rest_async.py index c0fe02a3..9532c72d 100644 --- a/auth0/rest_async.py +++ b/auth0/rest_async.py @@ -1,13 +1,16 @@ +from __future__ import annotations import asyncio +from typing import Any import aiohttp from auth0.exceptions import RateLimitError +from auth0.types import RequestData -from .rest import EmptyResponse, JsonResponse, PlainResponse, RestClient +from .rest import Response, EmptyResponse, JsonResponse, PlainResponse, RestClient -def _clean_params(params): +def _clean_params(params: dict[Any, Any] | None) -> dict[Any, Any] | None: if params is None: return params return {k: v for k, v in params.items() if v is not None} @@ -30,7 +33,7 @@ class AsyncRestClient(RestClient): (defaults to 3) """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._session = None sock_connect, sock_read = ( @@ -42,13 +45,13 @@ def __init__(self, *args, **kwargs): sock_connect=sock_connect, sock_read=sock_read ) - def set_session(self, session): + def set_session(self, session: aiohttp.ClientSession) -> None: """Set Client Session to improve performance by reusing session. Session should be closed manually or within context manager. """ self._session = session - async def _request(self, *args, **kwargs): + async def _request(self, *args: Any, **kwargs: Any) -> Any: kwargs["headers"] = kwargs.get("headers", self.base_headers) kwargs["timeout"] = self.timeout if self._session is not None: @@ -61,7 +64,7 @@ async def _request(self, *args, **kwargs): async with session.request(*args, **kwargs) as response: return await self._process_response(response) - async def get(self, url, params=None, headers=None): + async def get(self, url: str, params: dict[str, Any] | None = None, headers: dict[str, str] | None = None) -> Any: request_headers = self.base_headers.copy() request_headers.update(headers or {}) # Track the API request attempt number @@ -92,32 +95,32 @@ async def get(self, url, params=None, headers=None): # sleep() functions in seconds, so convert the milliseconds formula above accordingly await asyncio.sleep(wait / 1000) - async def post(self, url, data=None, headers=None): + async def post(self, url: str, data: RequestData | None = None, headers: dict[str, str] | None = None) -> Any: request_headers = self.base_headers.copy() request_headers.update(headers or {}) return await self._request("post", url, json=data, headers=request_headers) - async def file_post(self, url, data=None, files=None): + async def file_post(self, url: str, data: dict[str, Any] | None = None, files: dict[str, Any] | None = None) -> Any: headers = self.base_headers.copy() headers.pop("Content-Type", None) return await self._request("post", url, data={**data, **files}, headers=headers) - async def patch(self, url, data=None): + async def patch(self, url: str, data: RequestData | None = None) -> Any: return await self._request("patch", url, json=data) - async def put(self, url, data=None): + async def put(self, url: str, data: RequestData | None = None) -> Any: return await self._request("put", url, json=data) - async def delete(self, url, params=None, data=None): + async def delete(self, url: str, params: dict[str, Any] | None = None, data: RequestData | None = None) -> Any: return await self._request( "delete", url, json=data, params=_clean_params(params) or {} ) - async def _process_response(self, response): + async def _process_response(self, response: aiohttp.ClientResponse) -> Any: parsed_response = await self._parse(response) return parsed_response.content() - async def _parse(self, response): + async def _parse(self, response: aiohttp.ClientResponse) -> Response: text = await response.text() requests_response = RequestsResponse(response, text) if not text: @@ -129,7 +132,7 @@ async def _parse(self, response): class RequestsResponse: - def __init__(self, response, text): + def __init__(self, response: aiohttp.ClientResponse, text: str) -> None: self.status_code = response.status self.headers = response.headers self.text = text diff --git a/auth0/types.py b/auth0/types.py new file mode 100644 index 00000000..ffed01c0 --- /dev/null +++ b/auth0/types.py @@ -0,0 +1,6 @@ +from __future__ import annotations +from typing import Any + +TimeoutType = float | tuple[float, float] + +RequestData = dict[str, Any] | list[Any] diff --git a/auth0/utils.py b/auth0/utils.py index a6909f04..807e9016 100644 --- a/auth0/utils.py +++ b/auth0/utils.py @@ -1,4 +1,4 @@ -def is_async_available(): +def is_async_available() -> bool: try: import asyncio From e48ab8ce076931cac2474cdeb4b9d938375e4bec Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Mon, 13 Mar 2023 20:12:30 +0100 Subject: [PATCH 02/10] Add types to token verifiers --- EXAMPLES.md | 2 +- auth0/authentication/async_token_verifier.py | 38 +++++++--- auth0/authentication/token_verifier.py | 77 +++++++++++++------- auth0/exceptions.py | 10 ++- auth0/rest.py | 48 ++++++++++-- auth0/rest_async.py | 31 ++++++-- auth0/types.py | 1 + 7 files changed, 158 insertions(+), 49 deletions(-) diff --git a/EXAMPLES.md b/EXAMPLES.md index a1695a78..959e0314 100644 --- a/EXAMPLES.md +++ b/EXAMPLES.md @@ -7,7 +7,7 @@ - [Connections](#connections) - [Error handling](#error-handling) - [Asynchronous environments](#asynchronous-environments) - + ## Authentication SDK ### ID token validation diff --git a/auth0/authentication/async_token_verifier.py b/auth0/authentication/async_token_verifier.py index 64b97e5e..b6d5fcee 100644 --- a/auth0/authentication/async_token_verifier.py +++ b/auth0/authentication/async_token_verifier.py @@ -1,8 +1,16 @@ """Token Verifier module""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + from .. import TokenValidationError from ..rest_async import AsyncRestClient from .token_verifier import AsymmetricSignatureVerifier, JwksFetcher, TokenVerifier +if TYPE_CHECKING: + from aiohttp import ClientSession + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey + class AsyncAsymmetricSignatureVerifier(AsymmetricSignatureVerifier): """Async verifier for RSA signatures, which rely on public key certificates. @@ -12,11 +20,11 @@ class AsyncAsymmetricSignatureVerifier(AsymmetricSignatureVerifier): algorithm (str, optional): The expected signing algorithm. Defaults to "RS256". """ - def __init__(self, jwks_url, algorithm="RS256"): + def __init__(self, jwks_url: str, algorithm: str = "RS256") -> None: super().__init__(jwks_url, algorithm) self._fetcher = AsyncJwksFetcher(jwks_url) - def set_session(self, session): + def set_session(self, session: ClientSession) -> None: """Set Client Session to improve performance by reusing session. Args: @@ -57,11 +65,11 @@ class AsyncJwksFetcher(JwksFetcher): cache_ttl (str, optional): The lifetime of the JWK set cache in seconds. Defaults to 600 seconds. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._async_client = AsyncRestClient(None) - def set_session(self, session): + def set_session(self, session: ClientSession) -> None: """Set Client Session to improve performance by reusing session. Args: @@ -70,7 +78,7 @@ def set_session(self, session): """ self._async_client.set_session(session) - async def _fetch_jwks(self, force=False): + async def _fetch_jwks(self, force: bool = False) -> dict[str, RSAPublicKey]: """Attempts to obtain the JWK set from the cache, as long as it's still valid. When not, it will perform a network request to the jwks_url to obtain a fresh result and update the cache value with it. @@ -90,7 +98,7 @@ async def _fetch_jwks(self, force=False): self._cache_is_fresh = False return self._cache_value - async def get_key(self, key_id): + async def get_key(self, key_id: str) -> RSAPublicKey: """Obtains the JWK associated with the given key id. Args: @@ -126,7 +134,13 @@ class AsyncTokenVerifier(TokenVerifier): Defaults to 60 seconds. """ - def __init__(self, signature_verifier, issuer, audience, leeway=0): + def __init__( + self, + signature_verifier: AsyncAsymmetricSignatureVerifier, + issuer: str, + audience: str, + leeway: int = 0, + ) -> None: if not signature_verifier or not isinstance( signature_verifier, AsyncAsymmetricSignatureVerifier ): @@ -140,7 +154,7 @@ def __init__(self, signature_verifier, issuer, audience, leeway=0): self._sv = signature_verifier self._clock = None # legacy testing requirement - def set_session(self, session): + def set_session(self, session: ClientSession) -> None: """Set Client Session to improve performance by reusing session. Args: @@ -149,7 +163,13 @@ def set_session(self, session): """ self._sv.set_session(session) - async def verify(self, token, nonce=None, max_age=None, organization=None): + async def verify( + self, + token: str, + nonce: str | None = None, + max_age: int | None = None, + organization: str | None = None, + ) -> dict[str, Any]: """Attempts to verify the given ID token, following the steps defined in the OpenID Connect spec. Args: diff --git a/auth0/authentication/token_verifier.py b/auth0/authentication/token_verifier.py index 08331efc..8cec0e61 100644 --- a/auth0/authentication/token_verifier.py +++ b/auth0/authentication/token_verifier.py @@ -1,12 +1,18 @@ """Token Verifier module""" +from __future__ import annotations + import json import time +from typing import TYPE_CHECKING, Any, ClassVar import jwt import requests from auth0.exceptions import TokenValidationError +if TYPE_CHECKING: + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey + class SignatureVerifier: """Abstract class that will verify a given JSON web token's signature @@ -16,7 +22,7 @@ class SignatureVerifier: algorithm (str): The expected signing algorithm (e.g. RS256). """ - DISABLE_JWT_CHECKS = { + DISABLE_JWT_CHECKS: ClassVar[dict[str, bool]] = { "verify_signature": True, "verify_exp": False, "verify_nbf": False, @@ -28,12 +34,12 @@ class SignatureVerifier: "require_nbf": False, } - def __init__(self, algorithm): + def __init__(self, algorithm: str) -> None: if not algorithm or type(algorithm) != str: raise ValueError("algorithm must be specified.") self._algorithm = algorithm - def _fetch_key(self, key_id=None): + def _fetch_key(self, key_id: str | None = None) -> str | RSAPublicKey: """Obtains the key associated to the given key id. Must be implemented by subclasses. @@ -45,7 +51,7 @@ def _fetch_key(self, key_id=None): """ raise NotImplementedError - def _get_kid(self, token): + def _get_kid(self, token: str) -> str | None: """Gets the key id from the kid claim of the header of the token Args: @@ -72,7 +78,7 @@ def _get_kid(self, token): return header.get("kid", None) - def _decode_jwt(self, token, secret_or_certificate): + def _decode_jwt(self, token: str, secret_or_certificate: str) -> dict[str, Any]: """Verifies and decodes the given JSON web token with the given public key or shared secret. Args: @@ -94,7 +100,7 @@ def _decode_jwt(self, token, secret_or_certificate): raise TokenValidationError("Invalid token signature.") return decoded - def verify_signature(self, token): + def verify_signature(self, token: str) -> dict[str, Any]: """Verifies the signature of the given JSON web token. Args: @@ -118,11 +124,11 @@ class SymmetricSignatureVerifier(SignatureVerifier): algorithm (str, optional): The expected signing algorithm. Defaults to "HS256". """ - def __init__(self, shared_secret, algorithm="HS256"): + def __init__(self, shared_secret: str, algorithm: str = "HS256") -> None: super().__init__(algorithm) self._shared_secret = shared_secret - def _fetch_key(self, key_id=None): + def _fetch_key(self, key_id: str | None = None) -> str: return self._shared_secret @@ -135,20 +141,19 @@ class JwksFetcher: cache_ttl (str, optional): The lifetime of the JWK set cache in seconds. Defaults to 600 seconds. """ - CACHE_TTL = 600 # 10 min cache lifetime + CACHE_TTL: ClassVar[int] = 600 # 10 min cache lifetime - def __init__(self, jwks_url, cache_ttl=CACHE_TTL): + def __init__(self, jwks_url: str, cache_ttl: int = CACHE_TTL) -> None: self._jwks_url = jwks_url self._init_cache(cache_ttl) - return - def _init_cache(self, cache_ttl): - self._cache_value = {} + def _init_cache(self, cache_ttl: int) -> None: + self._cache_value: dict[str, RSAPublicKey] = {} self._cache_date = 0 self._cache_ttl = cache_ttl self._cache_is_fresh = False - def _cache_expired(self): + def _cache_expired(self) -> bool: """Checks if the cache is expired Returns: @@ -156,7 +161,7 @@ def _cache_expired(self): """ return self._cache_date + self._cache_ttl < time.time() - def _cache_jwks(self, jwks): + def _cache_jwks(self, jwks: dict[str, Any]) -> None: """Cache the response of the JWKS request Args: @@ -166,7 +171,7 @@ def _cache_jwks(self, jwks): self._cache_is_fresh = True self._cache_date = time.time() - def _fetch_jwks(self, force=False): + def _fetch_jwks(self, force: bool = False) -> dict[str, RSAPublicKey]: """Attempts to obtain the JWK set from the cache, as long as it's still valid. When not, it will perform a network request to the jwks_url to obtain a fresh result and update the cache value with it. @@ -178,7 +183,7 @@ def _fetch_jwks(self, force=False): self._cache_value = {} response = requests.get(self._jwks_url) if response.ok: - jwks = response.json() + jwks: dict[str, Any] = response.json() self._cache_jwks(jwks) return self._cache_value @@ -186,20 +191,22 @@ def _fetch_jwks(self, force=False): return self._cache_value @staticmethod - def _parse_jwks(jwks): + def _parse_jwks(jwks: dict[str, Any]) -> dict[str, RSAPublicKey]: """ Converts a JWK string representation into a binary certificate in PEM format. """ - keys = {} + keys: dict[str, RSAPublicKey] = {} for key in jwks["keys"]: # noinspection PyUnresolvedReferences # requirement already includes cryptography -> pyjwt[crypto] - rsa_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key)) + rsa_key: RSAPublicKey = jwt.algorithms.RSAAlgorithm.from_jwk( + json.dumps(key) + ) keys[key["kid"]] = rsa_key return keys - def get_key(self, key_id): + def get_key(self, key_id: str) -> RSAPublicKey: """Obtains the JWK associated with the given key id. Args: @@ -232,11 +239,11 @@ class AsymmetricSignatureVerifier(SignatureVerifier): cache_ttl (int, optional): The lifetime of the JWK set cache in seconds. Defaults to 600 seconds. """ - def __init__(self, jwks_url, algorithm="RS256", cache_ttl=JwksFetcher.CACHE_TTL): + def __init__(self, jwks_url: str, algorithm: str = "RS256", cache_ttl: int = JwksFetcher.CACHE_TTL) -> None: super().__init__(algorithm) self._fetcher = JwksFetcher(jwks_url, cache_ttl) - def _fetch_key(self, key_id=None): + def _fetch_key(self, key_id: str | None = None) -> RSAPublicKey: return self._fetcher.get_key(key_id) @@ -252,7 +259,13 @@ class TokenVerifier: Defaults to 60 seconds. """ - def __init__(self, signature_verifier, issuer, audience, leeway=0): + def __init__( + self, + signature_verifier: SignatureVerifier, + issuer: str, + audience: str, + leeway: int = 0, + ) -> None: if not signature_verifier or not isinstance( signature_verifier, SignatureVerifier ): @@ -266,7 +279,13 @@ def __init__(self, signature_verifier, issuer, audience, leeway=0): self._sv = signature_verifier self._clock = None # visible for testing - def verify(self, token, nonce=None, max_age=None, organization=None): + def verify( + self, + token: str, + nonce: str | None = None, + max_age: int | None = None, + organization: str | None = None, + ) -> dict[str, Any]: """Attempts to verify the given ID token, following the steps defined in the OpenID Connect spec. Args: @@ -296,7 +315,13 @@ def verify(self, token, nonce=None, max_age=None, organization=None): return payload - def _verify_payload(self, payload, nonce=None, max_age=None, organization=None): + def _verify_payload( + self, + payload: dict[str, Any], + nonce: str | None = None, + max_age: int | None = None, + organization: str | None = None, + ) -> None: # Issuer if "iss" not in payload or not isinstance(payload["iss"], str): raise TokenValidationError( diff --git a/auth0/exceptions.py b/auth0/exceptions.py index 1329f5fd..8515be04 100644 --- a/auth0/exceptions.py +++ b/auth0/exceptions.py @@ -1,8 +1,16 @@ from __future__ import annotations + from typing import Any + class Auth0Error(Exception): - def __init__(self, status_code: int, error_code: str, message: str, content: Any | None = None) -> None: + def __init__( + self, + status_code: int, + error_code: str, + message: str, + content: Any | None = None, + ) -> None: self.status_code = status_code self.error_code = error_code self.message = message diff --git a/auth0/rest.py b/auth0/rest.py index e7573082..87bf70d5 100644 --- a/auth0/rest.py +++ b/auth0/rest.py @@ -1,4 +1,5 @@ from __future__ import annotations + import base64 import json import platform @@ -36,7 +37,12 @@ class RestClientOptions: (defaults to 3) """ - def __init__(self, telemetry: bool | None = None, timeout: TimeoutType | None = None, retries: int | None = None) -> None: + def __init__( + self, + telemetry: bool | None = None, + timeout: TimeoutType | None = None, + retries: int | None = None, + ) -> None: self.telemetry = True self.timeout = 5.0 self.retries = 3 @@ -69,7 +75,13 @@ class RestClient: (defaults to 3) """ - def __init__(self, jwt: str, telemetry: bool = True, timeout: TimeoutType = 5.0, options: RestClientOptions | None = None) -> None: + def __init__( + self, + jwt: str, + telemetry: bool = True, + timeout: TimeoutType = 5.0, + options: RestClientOptions | None = None, + ) -> None: if options is None: options = RestClientOptions(telemetry=telemetry, timeout=timeout) @@ -131,7 +143,12 @@ def MAX_REQUEST_RETRY_DELAY(self) -> int: def MIN_REQUEST_RETRY_DELAY(self) -> int: return 100 - def get(self, url: str, params: dict[str, Any] | None = None, headers: dict[str, str] | None = None) -> Any: + def get( + self, + url: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> Any: request_headers = self.base_headers.copy() request_headers.update(headers or {}) @@ -167,7 +184,12 @@ def get(self, url: str, params: dict[str, Any] | None = None, headers: dict[str, # Return the final Response return self._process_response(response) - def post(self, url: str, data: RequestData | None = None, headers: dict[str, str] | None = None) -> Any: + def post( + self, + url: str, + data: RequestData | None = None, + headers: dict[str, str] | None = None, + ) -> Any: request_headers = self.base_headers.copy() request_headers.update(headers or {}) @@ -176,7 +198,12 @@ def post(self, url: str, data: RequestData | None = None, headers: dict[str, str ) return self._process_response(response) - def file_post(self, url: str, data: RequestData | None = None, files: dict[str, Any] | None = None) -> Any: + def file_post( + self, + url: str, + data: RequestData | None = None, + files: dict[str, Any] | None = None, + ) -> Any: headers = self.base_headers.copy() headers.pop("Content-Type", None) @@ -201,7 +228,12 @@ def put(self, url: str, data: RequestData | None = None) -> Any: ) return self._process_response(response) - def delete(self, url: str, params: dict[str, Any] | None = None, data: RequestData | None = None) -> Any: + def delete( + self, + url: str, + params: dict[str, Any] | None = None, + data: RequestData | None = None, + ) -> Any: headers = self.base_headers.copy() response = requests.delete( @@ -247,7 +279,9 @@ def _parse(self, response: requests.Response) -> Response: class Response: - def __init__(self, status_code: int, content: Any, headers: Mapping[str, str]) -> None: + def __init__( + self, status_code: int, content: Any, headers: Mapping[str, str] + ) -> None: self._status_code = status_code self._content = content self._headers = headers diff --git a/auth0/rest_async.py b/auth0/rest_async.py index 9532c72d..328de545 100644 --- a/auth0/rest_async.py +++ b/auth0/rest_async.py @@ -1,4 +1,5 @@ from __future__ import annotations + import asyncio from typing import Any @@ -7,7 +8,7 @@ from auth0.exceptions import RateLimitError from auth0.types import RequestData -from .rest import Response, EmptyResponse, JsonResponse, PlainResponse, RestClient +from .rest import EmptyResponse, JsonResponse, PlainResponse, Response, RestClient def _clean_params(params: dict[Any, Any] | None) -> dict[Any, Any] | None: @@ -64,7 +65,12 @@ async def _request(self, *args: Any, **kwargs: Any) -> Any: async with session.request(*args, **kwargs) as response: return await self._process_response(response) - async def get(self, url: str, params: dict[str, Any] | None = None, headers: dict[str, str] | None = None) -> Any: + async def get( + self, + url: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> Any: request_headers = self.base_headers.copy() request_headers.update(headers or {}) # Track the API request attempt number @@ -95,12 +101,22 @@ async def get(self, url: str, params: dict[str, Any] | None = None, headers: dic # sleep() functions in seconds, so convert the milliseconds formula above accordingly await asyncio.sleep(wait / 1000) - async def post(self, url: str, data: RequestData | None = None, headers: dict[str, str] | None = None) -> Any: + async def post( + self, + url: str, + data: RequestData | None = None, + headers: dict[str, str] | None = None, + ) -> Any: request_headers = self.base_headers.copy() request_headers.update(headers or {}) return await self._request("post", url, json=data, headers=request_headers) - async def file_post(self, url: str, data: dict[str, Any] | None = None, files: dict[str, Any] | None = None) -> Any: + async def file_post( + self, + url: str, + data: dict[str, Any] | None = None, + files: dict[str, Any] | None = None, + ) -> Any: headers = self.base_headers.copy() headers.pop("Content-Type", None) return await self._request("post", url, data={**data, **files}, headers=headers) @@ -111,7 +127,12 @@ async def patch(self, url: str, data: RequestData | None = None) -> Any: async def put(self, url: str, data: RequestData | None = None) -> Any: return await self._request("put", url, json=data) - async def delete(self, url: str, params: dict[str, Any] | None = None, data: RequestData | None = None) -> Any: + async def delete( + self, + url: str, + params: dict[str, Any] | None = None, + data: RequestData | None = None, + ) -> Any: return await self._request( "delete", url, json=data, params=_clean_params(params) or {} ) diff --git a/auth0/types.py b/auth0/types.py index ffed01c0..0a83dba7 100644 --- a/auth0/types.py +++ b/auth0/types.py @@ -1,4 +1,5 @@ from __future__ import annotations + from typing import Any TimeoutType = float | tuple[float, float] From 777af97d1dd501dbee1033392be8b5d0e04d9489 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Mon, 13 Mar 2023 20:17:39 +0100 Subject: [PATCH 03/10] Fix circular import --- auth0/rest.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/auth0/rest.py b/auth0/rest.py index 87bf70d5..6d963775 100644 --- a/auth0/rest.py +++ b/auth0/rest.py @@ -6,14 +6,16 @@ import sys from random import randint from time import sleep -from typing import Any, Mapping +from typing import TYPE_CHECKING, Any, Mapping import requests from auth0.exceptions import Auth0Error, RateLimitError -from auth0.rest_async import RequestsResponse from auth0.types import RequestData, TimeoutType +if TYPE_CHECKING: + from auth0.rest_async import RequestsResponse + UNKNOWN_ERROR = "a0.sdk.internal.unknown" From 8729d321a2385ae1645c68ad98d51e7c35169e12 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Mon, 13 Mar 2023 22:43:51 +0100 Subject: [PATCH 04/10] Fix `types.py` compat --- auth0/types.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/auth0/types.py b/auth0/types.py index 0a83dba7..c1929cf2 100644 --- a/auth0/types.py +++ b/auth0/types.py @@ -1,7 +1,5 @@ -from __future__ import annotations +from typing import Any, Dict, List, Tuple, Union -from typing import Any +TimeoutType = Union[float, Tuple[float, float]] -TimeoutType = float | tuple[float, float] - -RequestData = dict[str, Any] | list[Any] +RequestData = Union[Dict[str, Any], List[Any]] From da56c337e556cf367d38bbed96a90b531edf5085 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Tue, 14 Mar 2023 14:32:58 +0100 Subject: [PATCH 05/10] Add types to all `authentification` module --- auth0/authentication/base.py | 46 +++++++++++----- auth0/authentication/client_authentication.py | 24 +++++--- auth0/authentication/database.py | 30 +++++----- auth0/authentication/delegated.py | 18 +++--- auth0/authentication/enterprise.py | 6 +- auth0/authentication/get_token.py | 55 ++++++++++--------- auth0/authentication/passwordless.py | 10 +++- auth0/authentication/revoke_token.py | 4 +- auth0/authentication/social.py | 4 +- 9 files changed, 123 insertions(+), 74 deletions(-) diff --git a/auth0/authentication/base.py b/auth0/authentication/base.py index 4e6417b0..7948d06b 100644 --- a/auth0/authentication/base.py +++ b/auth0/authentication/base.py @@ -1,4 +1,9 @@ +from __future__ import annotations + +from typing import Any + from auth0.rest import RestClient, RestClientOptions +from auth0.types import RequestData, TimeoutType from .client_authentication import add_client_authentication @@ -21,15 +26,15 @@ class AuthenticationBase: def __init__( self, - domain, - client_id, - client_secret=None, - client_assertion_signing_key=None, - client_assertion_signing_alg=None, - telemetry=True, - timeout=5.0, - protocol="https", - ): + domain: str, + client_id: str, + client_secret: str | None = None, + client_assertion_signing_key: str | None = None, + client_assertion_signing_alg: str | None = None, + telemetry: bool = True, + timeout: TimeoutType = 5.0, + protocol: str = "https", + ) -> None: self.domain = domain self.client_id = client_id self.client_secret = client_secret @@ -41,7 +46,7 @@ def __init__( options=RestClientOptions(telemetry=telemetry, timeout=timeout, retries=0), ) - def _add_client_authentication(self, payload): + def _add_client_authentication(self, payload: dict[str, Any]) -> dict[str, Any]: return add_client_authentication( payload, self.domain, @@ -51,13 +56,28 @@ def _add_client_authentication(self, payload): self.client_assertion_signing_alg, ) - def post(self, url, data=None, headers=None): + def post( + self, + url: str, + data: RequestData | None = None, + headers: dict[str, str] | None = None, + ) -> Any: return self.client.post(url, data=data, headers=headers) - def authenticated_post(self, url, data=None, headers=None): + def authenticated_post( + self, + url: str, + data: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> Any: return self.client.post( url, data=self._add_client_authentication(data), headers=headers ) - def get(self, url, params=None, headers=None): + def get( + self, + url: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> Any: return self.client.get(url, params, headers) diff --git a/auth0/authentication/client_authentication.py b/auth0/authentication/client_authentication.py index 7ab742f9..d946e715 100644 --- a/auth0/authentication/client_authentication.py +++ b/auth0/authentication/client_authentication.py @@ -1,12 +1,18 @@ +from __future__ import annotations + import datetime import uuid +from typing import Any import jwt def create_client_assertion_jwt( - domain, client_id, client_assertion_signing_key, client_assertion_signing_alg -): + domain: str, + client_id: str, + client_assertion_signing_key: str | None, + client_assertion_signing_alg: str | None, +) -> str: """Creates a JWT for the client_assertion field. Args: @@ -35,13 +41,13 @@ def create_client_assertion_jwt( def add_client_authentication( - payload, - domain, - client_id, - client_secret, - client_assertion_signing_key, - client_assertion_signing_alg, -): + payload: dict[str, Any], + domain: str, + client_id: str, + client_secret: str, + client_assertion_signing_key: str | None, + client_assertion_signing_alg: str | None, +) -> dict[str, Any]: """Adds the client_assertion or client_secret fields to authenticate a payload. Args: diff --git a/auth0/authentication/database.py b/auth0/authentication/database.py index c4691b27..83e3b6db 100644 --- a/auth0/authentication/database.py +++ b/auth0/authentication/database.py @@ -1,4 +1,6 @@ -import warnings +from __future__ import annotations + +from typing import Any from .base import AuthenticationBase @@ -12,17 +14,17 @@ class Database(AuthenticationBase): def signup( self, - email, - password, - connection, - username=None, - user_metadata=None, - given_name=None, - family_name=None, - name=None, - nickname=None, - picture=None, - ): + email: str, + password: str, + connection: str, + username: str | None = None, + user_metadata: dict[str, Any] | None = None, + given_name: str | None = None, + family_name: str | None = None, + name: str | None = None, + nickname: str | None = None, + picture: str | None = None, + ) -> Any: """Signup using email and password. Args: @@ -75,7 +77,9 @@ def signup( f"{self.protocol}://{self.domain}/dbconnections/signup", data=body ) - def change_password(self, email, connection, password=None): + def change_password( + self, email: str, connection: str, password: str | None = None + ) -> Any: """Asks to change a password for a given user. email (str): The user's email address. diff --git a/auth0/authentication/delegated.py b/auth0/authentication/delegated.py index 58ae4cb8..1266db11 100644 --- a/auth0/authentication/delegated.py +++ b/auth0/authentication/delegated.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Any + from .base import AuthenticationBase @@ -10,13 +14,13 @@ class Delegated(AuthenticationBase): def get_token( self, - target, - api_type, - grant_type, - id_token=None, - refresh_token=None, - scope="openid", - ): + target: str, + api_type: str, + grant_type: str, + id_token: str | None = None, + refresh_token: str | None = None, + scope: str = "openid", + ) -> Any: """Obtain a delegation token.""" if id_token and refresh_token: diff --git a/auth0/authentication/enterprise.py b/auth0/authentication/enterprise.py index 0b4b3c5f..518d1001 100644 --- a/auth0/authentication/enterprise.py +++ b/auth0/authentication/enterprise.py @@ -1,3 +1,5 @@ +from typing import Any + from .base import AuthenticationBase @@ -9,7 +11,7 @@ class Enterprise(AuthenticationBase): domain (str): Your auth0 domain (e.g: my-domain.us.auth0.com) """ - def saml_metadata(self): + def saml_metadata(self) -> Any: """Get SAML2.0 Metadata.""" return self.get( @@ -18,7 +20,7 @@ def saml_metadata(self): ) ) - def wsfed_metadata(self): + def wsfed_metadata(self) -> Any: """Returns the WS-Federation Metadata.""" url = "{}://{}/wsfed/FederationMetadata/2007-06/FederationMetadata.xml" diff --git a/auth0/authentication/get_token.py b/auth0/authentication/get_token.py index 4986e55b..9de89291 100644 --- a/auth0/authentication/get_token.py +++ b/auth0/authentication/get_token.py @@ -1,5 +1,8 @@ +from __future__ import annotations + +from typing import Any + from .base import AuthenticationBase -from .client_authentication import add_client_authentication class GetToken(AuthenticationBase): @@ -12,10 +15,10 @@ class GetToken(AuthenticationBase): def authorization_code( self, - code, - redirect_uri, - grant_type="authorization_code", - ): + code: str, + redirect_uri: str | None, + grant_type: str = "authorization_code", + ) -> Any: """Authorization code grant This is the OAuth 2.0 grant that regular web apps utilize in order @@ -47,11 +50,11 @@ def authorization_code( def authorization_code_pkce( self, - code_verifier, - code, - redirect_uri, - grant_type="authorization_code", - ): + code_verifier: str, + code: str, + redirect_uri: str | None, + grant_type: str = "authorization_code", + ) -> Any: """Authorization code pkce grant This is the OAuth 2.0 grant that mobile apps utilize in order to access an API. @@ -86,9 +89,9 @@ def authorization_code_pkce( def client_credentials( self, - audience, - grant_type="client_credentials", - ): + audience: str, + grant_type: str = "client_credentials", + ) -> Any: """Client credentials grant This is the OAuth 2.0 grant that server processes utilize in @@ -116,13 +119,13 @@ def client_credentials( def login( self, - username, - password, - scope=None, - realm=None, - audience=None, - grant_type="http://auth0.com/oauth/grant-type/password-realm", - ): + username: str, + password: str, + scope: str | None = None, + realm: str | None = None, + audience: str | None = None, + grant_type: str = "http://auth0.com/oauth/grant-type/password-realm", + ) -> Any: """Calls /oauth/token endpoint with password-realm grant type @@ -168,10 +171,10 @@ def login( def refresh_token( self, - refresh_token, - scope="", - grant_type="refresh_token", - ): + refresh_token: str, + scope: str = "", + grant_type: str = "refresh_token", + ) -> Any: """Calls /oauth/token endpoint with refresh token grant type Use this endpoint to refresh an access token, using the refresh token you got during authorization. @@ -199,7 +202,9 @@ def refresh_token( }, ) - def passwordless_login(self, username, otp, realm, scope, audience): + def passwordless_login( + self, username: str, otp: str, realm: str, scope: str, audience: str + ) -> Any: """Calls /oauth/token endpoint with http://auth0.com/oauth/grant-type/passwordless/otp grant type Once the verification code was received, login the user using this endpoint with their diff --git a/auth0/authentication/passwordless.py b/auth0/authentication/passwordless.py index 63d26b4d..9039c802 100644 --- a/auth0/authentication/passwordless.py +++ b/auth0/authentication/passwordless.py @@ -1,4 +1,6 @@ -import warnings +from __future__ import annotations + +from typing import Any from .base import AuthenticationBase @@ -11,7 +13,9 @@ class Passwordless(AuthenticationBase): domain (str): Your auth0 domain (e.g: my-domain.us.auth0.com) """ - def email(self, email, send="link", auth_params=None): + def email( + self, email: str, send: str = "link", auth_params: dict[str, str] | None = None + ) -> Any: """Start flow sending an email. Given the user email address, it will send an email with: @@ -48,7 +52,7 @@ def email(self, email, send="link", auth_params=None): f"{self.protocol}://{self.domain}/passwordless/start", data=data ) - def sms(self, phone_number): + def sms(self, phone_number: str) -> Any: """Start flow sending an SMS message. Given the user phone number, it will send an SMS with diff --git a/auth0/authentication/revoke_token.py b/auth0/authentication/revoke_token.py index ded6397b..29223d45 100644 --- a/auth0/authentication/revoke_token.py +++ b/auth0/authentication/revoke_token.py @@ -1,3 +1,5 @@ +from typing import Any + from .base import AuthenticationBase @@ -8,7 +10,7 @@ class RevokeToken(AuthenticationBase): domain (str): Your auth0 domain (e.g: my-domain.us.auth0.com) """ - def revoke_refresh_token(self, token): + def revoke_refresh_token(self, token: str) -> Any: """Revokes a Refresh Token if it has been compromised Each revocation request invalidates not only the specific token, but all other tokens diff --git a/auth0/authentication/social.py b/auth0/authentication/social.py index c2517038..dc9b6a3a 100644 --- a/auth0/authentication/social.py +++ b/auth0/authentication/social.py @@ -1,3 +1,5 @@ +from typing import Any + from .base import AuthenticationBase @@ -9,7 +11,7 @@ class Social(AuthenticationBase): domain (str): Your auth0 domain (e.g: my-domain.us.auth0.com) """ - def login(self, access_token, connection, scope="openid"): + def login(self, access_token: str, connection: str, scope: str = "openid") -> Any: """Login using a social provider's access token Given the social provider's access_token and the connection specified, From f08c5763090385de3cdfb4cca41b3e922795c6d9 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Wed, 15 Mar 2023 15:05:33 +0100 Subject: [PATCH 06/10] Add types to missing parts and apply first mypy pass --- auth0/authentication/base.py | 2 +- auth0/authentication/client_authentication.py | 4 +-- auth0/authentication/database.py | 12 +++++---- auth0/authentication/passwordless.py | 2 +- auth0/authentication/token_verifier.py | 4 +-- auth0/authentication/users.py | 20 +++++++++------ auth0/rest.py | 25 ++++++------------- auth0/rest_async.py | 3 ++- mypy.ini | 8 ++++++ 9 files changed, 44 insertions(+), 36 deletions(-) create mode 100644 mypy.ini diff --git a/auth0/authentication/base.py b/auth0/authentication/base.py index 7948d06b..01c79d2e 100644 --- a/auth0/authentication/base.py +++ b/auth0/authentication/base.py @@ -67,7 +67,7 @@ def post( def authenticated_post( self, url: str, - data: dict[str, Any] | None = None, + data: dict[str, Any], headers: dict[str, str] | None = None, ) -> Any: return self.client.post( diff --git a/auth0/authentication/client_authentication.py b/auth0/authentication/client_authentication.py index d946e715..2b6b345f 100644 --- a/auth0/authentication/client_authentication.py +++ b/auth0/authentication/client_authentication.py @@ -44,7 +44,7 @@ def add_client_authentication( payload: dict[str, Any], domain: str, client_id: str, - client_secret: str, + client_secret: str | None, client_assertion_signing_key: str | None, client_assertion_signing_alg: str | None, ) -> dict[str, Any]: @@ -54,7 +54,7 @@ def add_client_authentication( payload (dict): The POST payload that needs additional fields to be authenticated. domain (str): The domain of your Auth0 tenant client_id (str): Your application's client ID - client_secret (str): Your application's client secret + client_secret (str, optional): Your application's client secret client_assertion_signing_key (str, optional): Private key used to sign the client assertion JWT client_assertion_signing_alg (str, optional): Algorithm used to sign the client assertion JWT (defaults to 'RS256') diff --git a/auth0/authentication/database.py b/auth0/authentication/database.py index 83e3b6db..9bfd6144 100644 --- a/auth0/authentication/database.py +++ b/auth0/authentication/database.py @@ -24,7 +24,7 @@ def signup( name: str | None = None, nickname: str | None = None, picture: str | None = None, - ) -> Any: + ) -> dict[str, Any]: """Signup using email and password. Args: @@ -52,7 +52,7 @@ def signup( See: https://auth0.com/docs/api/authentication#signup """ - body = { + body: dict[str, Any] = { "client_id": self.client_id, "email": email, "password": password, @@ -73,13 +73,14 @@ def signup( if picture: body.update({"picture": picture}) - return self.post( + data: dict[str, Any] = self.post( f"{self.protocol}://{self.domain}/dbconnections/signup", data=body ) + return data def change_password( self, email: str, connection: str, password: str | None = None - ) -> Any: + ) -> str: """Asks to change a password for a given user. email (str): The user's email address. @@ -92,7 +93,8 @@ def change_password( "connection": connection, } - return self.post( + data: str = self.post( f"{self.protocol}://{self.domain}/dbconnections/change_password", data=body, ) + return data diff --git a/auth0/authentication/passwordless.py b/auth0/authentication/passwordless.py index 9039c802..dc4ac1af 100644 --- a/auth0/authentication/passwordless.py +++ b/auth0/authentication/passwordless.py @@ -39,7 +39,7 @@ def email( auth_params (dict, optional): Parameters to append or override. """ - data = { + data: dict[str, Any] = { "client_id": self.client_id, "connection": "email", "email": email, diff --git a/auth0/authentication/token_verifier.py b/auth0/authentication/token_verifier.py index 8cec0e61..f6395085 100644 --- a/auth0/authentication/token_verifier.py +++ b/auth0/authentication/token_verifier.py @@ -113,7 +113,7 @@ def verify_signature(self, token: str) -> dict[str, Any]: kid = self._get_kid(token) secret_or_certificate = self._fetch_key(key_id=kid) - return self._decode_jwt(token, secret_or_certificate) + return self._decode_jwt(token, secret_or_certificate) # type: ignore[arg-type] class SymmetricSignatureVerifier(SignatureVerifier): @@ -149,7 +149,7 @@ def __init__(self, jwks_url: str, cache_ttl: int = CACHE_TTL) -> None: def _init_cache(self, cache_ttl: int) -> None: self._cache_value: dict[str, RSAPublicKey] = {} - self._cache_date = 0 + self._cache_date = 0.0 self._cache_ttl = cache_ttl self._cache_is_fresh = False diff --git a/auth0/authentication/users.py b/auth0/authentication/users.py index 255c90f6..9535edab 100644 --- a/auth0/authentication/users.py +++ b/auth0/authentication/users.py @@ -1,4 +1,9 @@ +from __future__ import annotations + +from typing import Any + from auth0.rest import RestClient, RestClientOptions +from auth0.types import TimeoutType class Users: @@ -13,11 +18,11 @@ class Users: def __init__( self, - domain, - telemetry=True, - timeout=5.0, - protocol="https", - ): + domain: str, + telemetry: bool = True, + timeout: TimeoutType = 5.0, + protocol: str = "https", + ) -> None: self.domain = domain self.protocol = protocol self.client = RestClient( @@ -31,7 +36,7 @@ def __init__( domain (str): Your auth0 domain (e.g: username.auth0.com) """ - def userinfo(self, access_token): + def userinfo(self, access_token: str) -> dict[str, Any]: """Returns the user information based on the Auth0 access token. This endpoint will work only if openid was granted as a scope for the access_token. @@ -42,7 +47,8 @@ def userinfo(self, access_token): The user profile. """ - return self.client.get( + data: dict[str, Any] = self.client.get( url=f"{self.protocol}://{self.domain}/userinfo", headers={"Authorization": f"Bearer {access_token}"}, ) + return data diff --git a/auth0/rest.py b/auth0/rest.py index 6d963775..e1837669 100644 --- a/auth0/rest.py +++ b/auth0/rest.py @@ -41,29 +41,20 @@ class RestClientOptions: def __init__( self, - telemetry: bool | None = None, - timeout: TimeoutType | None = None, - retries: int | None = None, + telemetry: bool = True, + timeout: TimeoutType = 5.0, + retries: int = 3, ) -> None: - self.telemetry = True - self.timeout = 5.0 - self.retries = 3 - - if telemetry is not None: - self.telemetry = telemetry - - if timeout is not None: - self.timeout = timeout - - if retries is not None: - self.retries = retries + self.telemetry = telemetry + self.timeout = timeout + self.retries = retries class RestClient: """Provides simple methods for handling all RESTful api endpoints. Args: - jwt (str): The JWT to be used with the RestClient. + jwt (str, optional): The JWT to be used with the RestClient. telemetry (bool, optional): Enable or disable Telemetry (defaults to True) timeout (float or tuple, optional): Change the requests @@ -79,7 +70,7 @@ class RestClient: def __init__( self, - jwt: str, + jwt: str | None, telemetry: bool = True, timeout: TimeoutType = 5.0, options: RestClientOptions | None = None, diff --git a/auth0/rest_async.py b/auth0/rest_async.py index 328de545..53b92840 100644 --- a/auth0/rest_async.py +++ b/auth0/rest_async.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code=override from __future__ import annotations import asyncio @@ -36,7 +37,7 @@ class AsyncRestClient(RestClient): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self._session = None + self._session: aiohttp.ClientSession | None = None sock_connect, sock_read = ( self.timeout if isinstance(self.timeout, tuple) diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..f8bc1715 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,8 @@ +[mypy] +python_version = 3.7 + +[mypy-auth0.test.*,auth0.test_async.*] +ignore_errors = True + +[mypy-auth0.management.*] +ignore_errors = True From e6515b2db31805fad907336a686c991ae97426ae Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Sun, 19 Mar 2023 10:55:15 +0100 Subject: [PATCH 07/10] mypy fixes --- auth0/authentication/async_token_verifier.py | 2 +- auth0/authentication/client_authentication.py | 4 ++-- auth0/authentication/token_verifier.py | 6 ++++-- auth0/rest.py | 2 +- auth0/rest_async.py | 5 ++--- mypy.ini | 6 ++++++ 6 files changed, 16 insertions(+), 9 deletions(-) diff --git a/auth0/authentication/async_token_verifier.py b/auth0/authentication/async_token_verifier.py index b6d5fcee..058e493f 100644 --- a/auth0/authentication/async_token_verifier.py +++ b/auth0/authentication/async_token_verifier.py @@ -40,7 +40,7 @@ async def _fetch_key(self, key_id=None): key_id (str): The key's key id.""" return await self._fetcher.get_key(key_id) - async def verify_signature(self, token): + async def verify_signature(self, token) -> dict[str, Any]: """Verifies the signature of the given JSON web token. Args: diff --git a/auth0/authentication/client_authentication.py b/auth0/authentication/client_authentication.py index 2b6b345f..849058f4 100644 --- a/auth0/authentication/client_authentication.py +++ b/auth0/authentication/client_authentication.py @@ -10,7 +10,7 @@ def create_client_assertion_jwt( domain: str, client_id: str, - client_assertion_signing_key: str | None, + client_assertion_signing_key: str, client_assertion_signing_alg: str | None, ) -> str: """Creates a JWT for the client_assertion field. @@ -18,7 +18,7 @@ def create_client_assertion_jwt( Args: domain (str): The domain of your Auth0 tenant client_id (str): Your application's client ID - client_assertion_signing_key (str, optional): Private key used to sign the client assertion JWT + client_assertion_signing_key (str): Private key used to sign the client assertion JWT client_assertion_signing_alg (str, optional): Algorithm used to sign the client assertion JWT (defaults to 'RS256') Returns: diff --git a/auth0/authentication/token_verifier.py b/auth0/authentication/token_verifier.py index f6395085..46254646 100644 --- a/auth0/authentication/token_verifier.py +++ b/auth0/authentication/token_verifier.py @@ -39,7 +39,7 @@ def __init__(self, algorithm: str) -> None: raise ValueError("algorithm must be specified.") self._algorithm = algorithm - def _fetch_key(self, key_id: str | None = None) -> str | RSAPublicKey: + def _fetch_key(self, key_id: str) -> str | RSAPublicKey: """Obtains the key associated to the given key id. Must be implemented by subclasses. @@ -111,6 +111,8 @@ def verify_signature(self, token: str) -> dict[str, Any]: or the token's signature doesn't match the calculated one. """ kid = self._get_kid(token) + if kid is None: + kid = "" secret_or_certificate = self._fetch_key(key_id=kid) return self._decode_jwt(token, secret_or_certificate) # type: ignore[arg-type] @@ -128,7 +130,7 @@ def __init__(self, shared_secret: str, algorithm: str = "HS256") -> None: super().__init__(algorithm) self._shared_secret = shared_secret - def _fetch_key(self, key_id: str | None = None) -> str: + def _fetch_key(self, key_id: str = "") -> str: return self._shared_secret diff --git a/auth0/rest.py b/auth0/rest.py index e1837669..41282b74 100644 --- a/auth0/rest.py +++ b/auth0/rest.py @@ -255,7 +255,7 @@ def _calculate_wait(self, attempt: int) -> int: wait = max(self.MIN_REQUEST_RETRY_DELAY(), wait) self._metrics["retries"] = attempt - self._metrics["backoff"].append(wait) + self._metrics["backoff"].append(wait) # type: ignore[attr-defined] return wait diff --git a/auth0/rest_async.py b/auth0/rest_async.py index 53b92840..183cfbb9 100644 --- a/auth0/rest_async.py +++ b/auth0/rest_async.py @@ -1,4 +1,3 @@ -# mypy: disable-error-code=override from __future__ import annotations import asyncio @@ -115,8 +114,8 @@ async def post( async def file_post( self, url: str, - data: dict[str, Any] | None = None, - files: dict[str, Any] | None = None, + data: dict[str, Any], + files: dict[str, Any], ) -> Any: headers = self.base_headers.copy() headers.pop("Content-Type", None) diff --git a/mypy.ini b/mypy.ini index f8bc1715..af08759b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -6,3 +6,9 @@ ignore_errors = True [mypy-auth0.management.*] ignore_errors = True + +[mypy-auth0.rest_async] +disable_error_code=override + +[mypy-auth0.authentication.async_token_verifier] +disable_error_code=override, misc, attr-defined From 38d55b1b70dcb136f2d063adcca46c08fadb7951 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Wed, 5 Apr 2023 21:34:55 +0200 Subject: [PATCH 08/10] Ignore Sphinx warning --- docs/source/conf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index b3cdbc2e..d364fdb1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -93,3 +93,6 @@ def find_version(*file_paths): # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = [] + +# Sphinx somehow can't find this one +nitpick_ignore = [("py:class", "RSAPublicKey")] From 1976d68126bb7d5e6b84089905650bb93259be6c Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Fri, 5 May 2023 22:48:25 +0200 Subject: [PATCH 09/10] black --- auth0/authentication/token_verifier.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/auth0/authentication/token_verifier.py b/auth0/authentication/token_verifier.py index 46254646..d203aec7 100644 --- a/auth0/authentication/token_verifier.py +++ b/auth0/authentication/token_verifier.py @@ -241,7 +241,12 @@ class AsymmetricSignatureVerifier(SignatureVerifier): cache_ttl (int, optional): The lifetime of the JWK set cache in seconds. Defaults to 600 seconds. """ - def __init__(self, jwks_url: str, algorithm: str = "RS256", cache_ttl: int = JwksFetcher.CACHE_TTL) -> None: + def __init__( + self, + jwks_url: str, + algorithm: str = "RS256", + cache_ttl: int = JwksFetcher.CACHE_TTL, + ) -> None: super().__init__(algorithm) self._fetcher = JwksFetcher(jwks_url, cache_ttl) From b8a4a2c7d96ce510bf142a9f4323c9986f7f76b0 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Fri, 5 May 2023 22:49:27 +0200 Subject: [PATCH 10/10] Fix mypy error for `AsymmetricSignatureVerifier` --- auth0/authentication/token_verifier.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auth0/authentication/token_verifier.py b/auth0/authentication/token_verifier.py index d203aec7..030eda27 100644 --- a/auth0/authentication/token_verifier.py +++ b/auth0/authentication/token_verifier.py @@ -44,7 +44,7 @@ def _fetch_key(self, key_id: str) -> str | RSAPublicKey: Must be implemented by subclasses. Args: - key_id (str, optional): The id of the key to fetch. + key_id (str): The id of the key to fetch. Returns: the key to use for verifying a cryptographic signature @@ -250,7 +250,7 @@ def __init__( super().__init__(algorithm) self._fetcher = JwksFetcher(jwks_url, cache_ttl) - def _fetch_key(self, key_id: str | None = None) -> RSAPublicKey: + def _fetch_key(self, key_id: str) -> RSAPublicKey: return self._fetcher.get_key(key_id)