Skip to content

Commit

Permalink
Add type hints to base and authentication (#472)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjmcgrath committed May 9, 2023
2 parents ea52e49 + b8a4a2c commit 38f65c2
Show file tree
Hide file tree
Showing 20 changed files with 384 additions and 186 deletions.
2 changes: 1 addition & 1 deletion EXAMPLES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
- [Connections](#connections)
- [Error handling](#error-handling)
- [Asynchronous environments](#asynchronous-environments)

## Authentication SDK

### ID token validation
Expand Down
40 changes: 30 additions & 10 deletions auth0/authentication/async_token_verifier.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
"""Token Verifier module"""
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from .. import TokenValidationError
from ..rest_async import AsyncRestClient
from .token_verifier import AsymmetricSignatureVerifier, JwksFetcher, TokenVerifier

if TYPE_CHECKING:
from aiohttp import ClientSession
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey


class AsyncAsymmetricSignatureVerifier(AsymmetricSignatureVerifier):
"""Async verifier for RSA signatures, which rely on public key certificates.
Expand All @@ -12,11 +20,11 @@ class AsyncAsymmetricSignatureVerifier(AsymmetricSignatureVerifier):
algorithm (str, optional): The expected signing algorithm. Defaults to "RS256".
"""

def __init__(self, jwks_url, algorithm="RS256"):
def __init__(self, jwks_url: str, algorithm: str = "RS256") -> None:
super().__init__(jwks_url, algorithm)
self._fetcher = AsyncJwksFetcher(jwks_url)

def set_session(self, session):
def set_session(self, session: ClientSession) -> None:
"""Set Client Session to improve performance by reusing session.
Args:
Expand All @@ -32,7 +40,7 @@ async def _fetch_key(self, key_id=None):
key_id (str): The key's key id."""
return await self._fetcher.get_key(key_id)

async def verify_signature(self, token):
async def verify_signature(self, token) -> dict[str, Any]:
"""Verifies the signature of the given JSON web token.
Args:
Expand All @@ -57,11 +65,11 @@ class AsyncJwksFetcher(JwksFetcher):
cache_ttl (str, optional): The lifetime of the JWK set cache in seconds. Defaults to 600 seconds.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._async_client = AsyncRestClient(None)

def set_session(self, session):
def set_session(self, session: ClientSession) -> None:
"""Set Client Session to improve performance by reusing session.
Args:
Expand All @@ -70,7 +78,7 @@ def set_session(self, session):
"""
self._async_client.set_session(session)

async def _fetch_jwks(self, force=False):
async def _fetch_jwks(self, force: bool = False) -> dict[str, RSAPublicKey]:
"""Attempts to obtain the JWK set from the cache, as long as it's still valid.
When not, it will perform a network request to the jwks_url to obtain a fresh result
and update the cache value with it.
Expand All @@ -90,7 +98,7 @@ async def _fetch_jwks(self, force=False):
self._cache_is_fresh = False
return self._cache_value

async def get_key(self, key_id):
async def get_key(self, key_id: str) -> RSAPublicKey:
"""Obtains the JWK associated with the given key id.
Args:
Expand Down Expand Up @@ -126,7 +134,13 @@ class AsyncTokenVerifier(TokenVerifier):
Defaults to 60 seconds.
"""

def __init__(self, signature_verifier, issuer, audience, leeway=0):
def __init__(
self,
signature_verifier: AsyncAsymmetricSignatureVerifier,
issuer: str,
audience: str,
leeway: int = 0,
) -> None:
if not signature_verifier or not isinstance(
signature_verifier, AsyncAsymmetricSignatureVerifier
):
Expand All @@ -140,7 +154,7 @@ def __init__(self, signature_verifier, issuer, audience, leeway=0):
self._sv = signature_verifier
self._clock = None # legacy testing requirement

def set_session(self, session):
def set_session(self, session: ClientSession) -> None:
"""Set Client Session to improve performance by reusing session.
Args:
Expand All @@ -149,7 +163,13 @@ def set_session(self, session):
"""
self._sv.set_session(session)

async def verify(self, token, nonce=None, max_age=None, organization=None):
async def verify(
self,
token: str,
nonce: str | None = None,
max_age: int | None = None,
organization: str | None = None,
) -> dict[str, Any]:
"""Attempts to verify the given ID token, following the steps defined in the OpenID Connect spec.
Args:
Expand Down
46 changes: 33 additions & 13 deletions auth0/authentication/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from __future__ import annotations

from typing import Any

from auth0.rest import RestClient, RestClientOptions
from auth0.types import RequestData, TimeoutType

from .client_authentication import add_client_authentication

Expand All @@ -21,15 +26,15 @@ class AuthenticationBase:

def __init__(
self,
domain,
client_id,
client_secret=None,
client_assertion_signing_key=None,
client_assertion_signing_alg=None,
telemetry=True,
timeout=5.0,
protocol="https",
):
domain: str,
client_id: str,
client_secret: str | None = None,
client_assertion_signing_key: str | None = None,
client_assertion_signing_alg: str | None = None,
telemetry: bool = True,
timeout: TimeoutType = 5.0,
protocol: str = "https",
) -> None:
self.domain = domain
self.client_id = client_id
self.client_secret = client_secret
Expand All @@ -41,7 +46,7 @@ def __init__(
options=RestClientOptions(telemetry=telemetry, timeout=timeout, retries=0),
)

def _add_client_authentication(self, payload):
def _add_client_authentication(self, payload: dict[str, Any]) -> dict[str, Any]:
return add_client_authentication(
payload,
self.domain,
Expand All @@ -51,13 +56,28 @@ def _add_client_authentication(self, payload):
self.client_assertion_signing_alg,
)

def post(self, url, data=None, headers=None):
def post(
self,
url: str,
data: RequestData | None = None,
headers: dict[str, str] | None = None,
) -> Any:
return self.client.post(url, data=data, headers=headers)

def authenticated_post(self, url, data=None, headers=None):
def authenticated_post(
self,
url: str,
data: dict[str, Any],
headers: dict[str, str] | None = None,
) -> Any:
return self.client.post(
url, data=self._add_client_authentication(data), headers=headers
)

def get(self, url, params=None, headers=None):
def get(
self,
url: str,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
return self.client.get(url, params, headers)
28 changes: 17 additions & 11 deletions auth0/authentication/client_authentication.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
from __future__ import annotations

import datetime
import uuid
from typing import Any

import jwt


def create_client_assertion_jwt(
domain, client_id, client_assertion_signing_key, client_assertion_signing_alg
):
domain: str,
client_id: str,
client_assertion_signing_key: str,
client_assertion_signing_alg: str | None,
) -> str:
"""Creates a JWT for the client_assertion field.
Args:
domain (str): The domain of your Auth0 tenant
client_id (str): Your application's client ID
client_assertion_signing_key (str, optional): Private key used to sign the client assertion JWT
client_assertion_signing_key (str): Private key used to sign the client assertion JWT
client_assertion_signing_alg (str, optional): Algorithm used to sign the client assertion JWT (defaults to 'RS256')
Returns:
Expand All @@ -35,20 +41,20 @@ def create_client_assertion_jwt(


def add_client_authentication(
payload,
domain,
client_id,
client_secret,
client_assertion_signing_key,
client_assertion_signing_alg,
):
payload: dict[str, Any],
domain: str,
client_id: str,
client_secret: str | None,
client_assertion_signing_key: str | None,
client_assertion_signing_alg: str | None,
) -> dict[str, Any]:
"""Adds the client_assertion or client_secret fields to authenticate a payload.
Args:
payload (dict): The POST payload that needs additional fields to be authenticated.
domain (str): The domain of your Auth0 tenant
client_id (str): Your application's client ID
client_secret (str): Your application's client secret
client_secret (str, optional): Your application's client secret
client_assertion_signing_key (str, optional): Private key used to sign the client assertion JWT
client_assertion_signing_alg (str, optional): Algorithm used to sign the client assertion JWT (defaults to 'RS256')
Expand Down
38 changes: 22 additions & 16 deletions auth0/authentication/database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import warnings
from __future__ import annotations

from typing import Any

from .base import AuthenticationBase

Expand All @@ -12,17 +14,17 @@ class Database(AuthenticationBase):

def signup(
self,
email,
password,
connection,
username=None,
user_metadata=None,
given_name=None,
family_name=None,
name=None,
nickname=None,
picture=None,
):
email: str,
password: str,
connection: str,
username: str | None = None,
user_metadata: dict[str, Any] | None = None,
given_name: str | None = None,
family_name: str | None = None,
name: str | None = None,
nickname: str | None = None,
picture: str | None = None,
) -> dict[str, Any]:
"""Signup using email and password.
Args:
Expand Down Expand Up @@ -50,7 +52,7 @@ def signup(
See: https://auth0.com/docs/api/authentication#signup
"""
body = {
body: dict[str, Any] = {
"client_id": self.client_id,
"email": email,
"password": password,
Expand All @@ -71,11 +73,14 @@ def signup(
if picture:
body.update({"picture": picture})

return self.post(
data: dict[str, Any] = self.post(
f"{self.protocol}://{self.domain}/dbconnections/signup", data=body
)
return data

def change_password(self, email, connection, password=None):
def change_password(
self, email: str, connection: str, password: str | None = None
) -> str:
"""Asks to change a password for a given user.
email (str): The user's email address.
Expand All @@ -88,7 +93,8 @@ def change_password(self, email, connection, password=None):
"connection": connection,
}

return self.post(
data: str = self.post(
f"{self.protocol}://{self.domain}/dbconnections/change_password",
data=body,
)
return data
18 changes: 11 additions & 7 deletions auth0/authentication/delegated.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

from typing import Any

from .base import AuthenticationBase


Expand All @@ -10,13 +14,13 @@ class Delegated(AuthenticationBase):

def get_token(
self,
target,
api_type,
grant_type,
id_token=None,
refresh_token=None,
scope="openid",
):
target: str,
api_type: str,
grant_type: str,
id_token: str | None = None,
refresh_token: str | None = None,
scope: str = "openid",
) -> Any:
"""Obtain a delegation token."""

if id_token and refresh_token:
Expand Down
6 changes: 4 additions & 2 deletions auth0/authentication/enterprise.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from .base import AuthenticationBase


Expand All @@ -9,7 +11,7 @@ class Enterprise(AuthenticationBase):
domain (str): Your auth0 domain (e.g: my-domain.us.auth0.com)
"""

def saml_metadata(self):
def saml_metadata(self) -> Any:
"""Get SAML2.0 Metadata."""

return self.get(
Expand All @@ -18,7 +20,7 @@ def saml_metadata(self):
)
)

def wsfed_metadata(self):
def wsfed_metadata(self) -> Any:
"""Returns the WS-Federation Metadata."""

url = "{}://{}/wsfed/FederationMetadata/2007-06/FederationMetadata.xml"
Expand Down
Loading

0 comments on commit 38f65c2

Please sign in to comment.