Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to base and authentication #472

Merged
merged 10 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion EXAMPLES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
- [Connections](#connections)
- [Error handling](#error-handling)
- [Asynchronous environments](#asynchronous-environments)

## Authentication SDK

### ID token validation
Expand Down
40 changes: 30 additions & 10 deletions auth0/authentication/async_token_verifier.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
"""Token Verifier module"""
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from .. import TokenValidationError
from ..rest_async import AsyncRestClient
from .token_verifier import AsymmetricSignatureVerifier, JwksFetcher, TokenVerifier

if TYPE_CHECKING:
from aiohttp import ClientSession
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey


class AsyncAsymmetricSignatureVerifier(AsymmetricSignatureVerifier):
"""Async verifier for RSA signatures, which rely on public key certificates.
Expand All @@ -12,11 +20,11 @@ class AsyncAsymmetricSignatureVerifier(AsymmetricSignatureVerifier):
algorithm (str, optional): The expected signing algorithm. Defaults to "RS256".
"""

def __init__(self, jwks_url, algorithm="RS256"):
def __init__(self, jwks_url: str, algorithm: str = "RS256") -> None:
super().__init__(jwks_url, algorithm)
self._fetcher = AsyncJwksFetcher(jwks_url)

def set_session(self, session):
def set_session(self, session: ClientSession) -> None:
"""Set Client Session to improve performance by reusing session.

Args:
Expand All @@ -32,7 +40,7 @@ async def _fetch_key(self, key_id=None):
key_id (str): The key's key id."""
return await self._fetcher.get_key(key_id)

async def verify_signature(self, token):
async def verify_signature(self, token) -> dict[str, Any]:
"""Verifies the signature of the given JSON web token.

Args:
Expand All @@ -57,11 +65,11 @@ class AsyncJwksFetcher(JwksFetcher):
cache_ttl (str, optional): The lifetime of the JWK set cache in seconds. Defaults to 600 seconds.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._async_client = AsyncRestClient(None)

def set_session(self, session):
def set_session(self, session: ClientSession) -> None:
"""Set Client Session to improve performance by reusing session.

Args:
Expand All @@ -70,7 +78,7 @@ def set_session(self, session):
"""
self._async_client.set_session(session)

async def _fetch_jwks(self, force=False):
async def _fetch_jwks(self, force: bool = False) -> dict[str, RSAPublicKey]:
"""Attempts to obtain the JWK set from the cache, as long as it's still valid.
When not, it will perform a network request to the jwks_url to obtain a fresh result
and update the cache value with it.
Expand All @@ -90,7 +98,7 @@ async def _fetch_jwks(self, force=False):
self._cache_is_fresh = False
return self._cache_value

async def get_key(self, key_id):
async def get_key(self, key_id: str) -> RSAPublicKey:
"""Obtains the JWK associated with the given key id.

Args:
Expand Down Expand Up @@ -126,7 +134,13 @@ class AsyncTokenVerifier(TokenVerifier):
Defaults to 60 seconds.
"""

def __init__(self, signature_verifier, issuer, audience, leeway=0):
def __init__(
self,
signature_verifier: AsyncAsymmetricSignatureVerifier,
issuer: str,
audience: str,
leeway: int = 0,
) -> None:
if not signature_verifier or not isinstance(
signature_verifier, AsyncAsymmetricSignatureVerifier
):
Expand All @@ -140,7 +154,7 @@ def __init__(self, signature_verifier, issuer, audience, leeway=0):
self._sv = signature_verifier
self._clock = None # legacy testing requirement

def set_session(self, session):
def set_session(self, session: ClientSession) -> None:
"""Set Client Session to improve performance by reusing session.

Args:
Expand All @@ -149,7 +163,13 @@ def set_session(self, session):
"""
self._sv.set_session(session)

async def verify(self, token, nonce=None, max_age=None, organization=None):
async def verify(
self,
token: str,
nonce: str | None = None,
max_age: int | None = None,
organization: str | None = None,
) -> dict[str, Any]:
"""Attempts to verify the given ID token, following the steps defined in the OpenID Connect spec.

Args:
Expand Down
46 changes: 33 additions & 13 deletions auth0/authentication/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from __future__ import annotations

from typing import Any

from auth0.rest import RestClient, RestClientOptions
from auth0.types import RequestData, TimeoutType

from .client_authentication import add_client_authentication

Expand All @@ -21,15 +26,15 @@ class AuthenticationBase:

def __init__(
self,
domain,
client_id,
client_secret=None,
client_assertion_signing_key=None,
client_assertion_signing_alg=None,
telemetry=True,
timeout=5.0,
protocol="https",
):
domain: str,
client_id: str,
client_secret: str | None = None,
client_assertion_signing_key: str | None = None,
client_assertion_signing_alg: str | None = None,
telemetry: bool = True,
timeout: TimeoutType = 5.0,
protocol: str = "https",
) -> None:
self.domain = domain
self.client_id = client_id
self.client_secret = client_secret
Expand All @@ -41,7 +46,7 @@ def __init__(
options=RestClientOptions(telemetry=telemetry, timeout=timeout, retries=0),
)

def _add_client_authentication(self, payload):
def _add_client_authentication(self, payload: dict[str, Any]) -> dict[str, Any]:
return add_client_authentication(
payload,
self.domain,
Expand All @@ -51,13 +56,28 @@ def _add_client_authentication(self, payload):
self.client_assertion_signing_alg,
)

def post(self, url, data=None, headers=None):
def post(
self,
url: str,
data: RequestData | None = None,
headers: dict[str, str] | None = None,
) -> Any:
return self.client.post(url, data=data, headers=headers)

def authenticated_post(self, url, data=None, headers=None):
def authenticated_post(
self,
url: str,
data: dict[str, Any],
headers: dict[str, str] | None = None,
) -> Any:
return self.client.post(
url, data=self._add_client_authentication(data), headers=headers
)

def get(self, url, params=None, headers=None):
def get(
self,
url: str,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
return self.client.get(url, params, headers)
28 changes: 17 additions & 11 deletions auth0/authentication/client_authentication.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
from __future__ import annotations

import datetime
import uuid
from typing import Any

import jwt


def create_client_assertion_jwt(
domain, client_id, client_assertion_signing_key, client_assertion_signing_alg
):
domain: str,
client_id: str,
client_assertion_signing_key: str,
client_assertion_signing_alg: str | None,
) -> str:
"""Creates a JWT for the client_assertion field.

Args:
domain (str): The domain of your Auth0 tenant
client_id (str): Your application's client ID
client_assertion_signing_key (str, optional): Private key used to sign the client assertion JWT
client_assertion_signing_key (str): Private key used to sign the client assertion JWT
client_assertion_signing_alg (str, optional): Algorithm used to sign the client assertion JWT (defaults to 'RS256')

Returns:
Expand All @@ -35,20 +41,20 @@ def create_client_assertion_jwt(


def add_client_authentication(
payload,
domain,
client_id,
client_secret,
client_assertion_signing_key,
client_assertion_signing_alg,
):
payload: dict[str, Any],
domain: str,
client_id: str,
client_secret: str | None,
client_assertion_signing_key: str | None,
client_assertion_signing_alg: str | None,
) -> dict[str, Any]:
"""Adds the client_assertion or client_secret fields to authenticate a payload.

Args:
payload (dict): The POST payload that needs additional fields to be authenticated.
domain (str): The domain of your Auth0 tenant
client_id (str): Your application's client ID
client_secret (str): Your application's client secret
client_secret (str, optional): Your application's client secret
client_assertion_signing_key (str, optional): Private key used to sign the client assertion JWT
client_assertion_signing_alg (str, optional): Algorithm used to sign the client assertion JWT (defaults to 'RS256')

Expand Down
38 changes: 22 additions & 16 deletions auth0/authentication/database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import warnings
from __future__ import annotations

from typing import Any

from .base import AuthenticationBase

Expand All @@ -12,17 +14,17 @@ class Database(AuthenticationBase):

def signup(
self,
email,
password,
connection,
username=None,
user_metadata=None,
given_name=None,
family_name=None,
name=None,
nickname=None,
picture=None,
):
email: str,
password: str,
connection: str,
username: str | None = None,
user_metadata: dict[str, Any] | None = None,
given_name: str | None = None,
family_name: str | None = None,
name: str | None = None,
nickname: str | None = None,
picture: str | None = None,
) -> dict[str, Any]:
"""Signup using email and password.

Args:
Expand Down Expand Up @@ -50,7 +52,7 @@ def signup(

See: https://auth0.com/docs/api/authentication#signup
"""
body = {
body: dict[str, Any] = {
"client_id": self.client_id,
"email": email,
"password": password,
Expand All @@ -71,11 +73,14 @@ def signup(
if picture:
body.update({"picture": picture})

return self.post(
data: dict[str, Any] = self.post(
f"{self.protocol}://{self.domain}/dbconnections/signup", data=body
)
return data

def change_password(self, email, connection, password=None):
def change_password(
self, email: str, connection: str, password: str | None = None
) -> str:
"""Asks to change a password for a given user.

email (str): The user's email address.
Expand All @@ -88,7 +93,8 @@ def change_password(self, email, connection, password=None):
"connection": connection,
}

return self.post(
data: str = self.post(
f"{self.protocol}://{self.domain}/dbconnections/change_password",
data=body,
)
return data
18 changes: 11 additions & 7 deletions auth0/authentication/delegated.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

from typing import Any

from .base import AuthenticationBase


Expand All @@ -10,13 +14,13 @@ class Delegated(AuthenticationBase):

def get_token(
self,
target,
api_type,
grant_type,
id_token=None,
refresh_token=None,
scope="openid",
):
target: str,
api_type: str,
grant_type: str,
id_token: str | None = None,
refresh_token: str | None = None,
scope: str = "openid",
) -> Any:
"""Obtain a delegation token."""

if id_token and refresh_token:
Expand Down
6 changes: 4 additions & 2 deletions auth0/authentication/enterprise.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from .base import AuthenticationBase


Expand All @@ -9,7 +11,7 @@ class Enterprise(AuthenticationBase):
domain (str): Your auth0 domain (e.g: my-domain.us.auth0.com)
"""

def saml_metadata(self):
def saml_metadata(self) -> Any:
"""Get SAML2.0 Metadata."""

return self.get(
Expand All @@ -18,7 +20,7 @@ def saml_metadata(self):
)
)

def wsfed_metadata(self):
def wsfed_metadata(self) -> Any:
"""Returns the WS-Federation Metadata."""

url = "{}://{}/wsfed/FederationMetadata/2007-06/FederationMetadata.xml"
Expand Down
Loading