diff --git a/LICENSE b/LICENSE index 656ceca8..67f8e417 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2021 WorkOS +Copyright (c) 2024 WorkOS Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..478712f7 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,9 @@ +flake8 +pytest==8.3.2 +pytest-asyncio==0.23.8 +pytest-cov==5.0.0 +six==1.16.0 +black==24.4.2 +twine==5.1.1 +mypy==1.12.0 +httpx>=0.27.0 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..beaf1927 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +httpx>=0.27.0 +pydantic==2.9.2 +PyJWT==2.9.0 +cryptography==43.0.3 diff --git a/setup.py b/setup.py index e86e5a92..103be545 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,12 @@ with open(os.path.join(base_dir, "workos", "__about__.py")) as f: exec(f.read(), about) + +def read_requirements(filename): + with open(filename) as f: + return [line.strip() for line in f if line.strip() and not line.startswith("#")] + + setup( name=about["__package_name__"], version=about["__version__"], @@ -27,19 +33,9 @@ ), zip_safe=False, license=about["__license__"], - install_requires=["httpx>=0.27.0", "pydantic==2.9.2"], + install_requires=read_requirements("requirements.txt"), extras_require={ - "dev": [ - "flake8", - "pytest==8.3.2", - "pytest-asyncio==0.23.8", - "pytest-cov==5.0.0", - "six==1.16.0", - "black==24.4.2", - "twine==5.1.1", - "mypy==1.12.0", - "httpx>=0.27.0", - ], + "dev": read_requirements("requirements-dev.txt"), ":python_version<'3.4'": ["enum34"], }, classifiers=[ diff --git a/tests/conftest.py b/tests/conftest.py index 81ef0ca8..7f9f058a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,10 @@ from workos.utils.http_client import AsyncHTTPClient, HTTPClient, SyncHTTPClient from workos.utils.request_helper import DEFAULT_LIST_RESPONSE_LIMIT +from jwt import PyJWKClient +from unittest.mock import Mock, patch +from functools import wraps + def _get_test_client_setup( http_client_class_name: str, @@ -302,3 +306,19 @@ def inner( assert request_kwargs["params"][param] == params[param] return inner + + +def with_jwks_mock(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Create mock JWKS client + mock_jwks = Mock(spec=PyJWKClient) + mock_signing_key = Mock() + mock_signing_key.key = kwargs["TEST_CONSTANTS"]["PUBLIC_KEY"] + mock_jwks.get_signing_key_from_jwt.return_value = mock_signing_key + + # Apply the mock + with patch("workos.session.PyJWKClient", return_value=mock_jwks): + return func(*args, **kwargs) + + return wrapper diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 00000000..fbb82717 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,308 @@ +import pytest +from unittest.mock import Mock, patch +import jwt +from jwt import PyJWKClient +from datetime import datetime, timezone + +from tests.conftest import with_jwks_mock +from workos.session import Session +from workos.types.user_management.authentication_response import ( + RefreshTokenAuthenticationResponse, +) +from workos.types.user_management.session import ( + AuthenticateWithSessionCookieFailureReason, + AuthenticateWithSessionCookieSuccessResponse, + RefreshWithSessionCookieErrorResponse, + RefreshWithSessionCookieSuccessResponse, +) +from workos.types.user_management.user import User + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa + + +@pytest.fixture(scope="session") +def TEST_CONSTANTS(): + # Generate RSA key pair for testing + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + public_key = private_key.public_key() + + # Get the private key in PEM format + private_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + return { + "COOKIE_PASSWORD": "pfSqwTFXUTGEBBD1RQh2kt/oNJYxBgaoZan4Z8sMrKU=", + "SESSION_DATA": "session_data", + "CLIENT_ID": "client_123", + "USER_ID": "user_123", + "SESSION_ID": "session_123", + "ORGANIZATION_ID": "organization_123", + "CURRENT_TIMESTAMP": str(datetime.now(timezone.utc)), + "PRIVATE_KEY": private_pem, + "PUBLIC_KEY": public_key, + "TEST_TOKEN": jwt.encode( + { + "sid": "session_123", + "org_id": "organization_123", + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"], + "exp": int(datetime.now(timezone.utc).timestamp()) + 3600, + "iat": int(datetime.now(timezone.utc).timestamp()), + }, + private_pem, + algorithm="RS256", + ), + } + + +@pytest.fixture +def mock_user_management(): + mock = Mock() + mock.get_jwks_url.return_value = ( + "https://api.workos.com/user_management/sso/jwks/client_123" + ) + + return mock + + +@with_jwks_mock +def test_initialize_session_module(TEST_CONSTANTS, mock_user_management): + session = Session( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data=TEST_CONSTANTS["SESSION_DATA"], + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], + ) + + assert session.client_id == TEST_CONSTANTS["CLIENT_ID"] + assert session.cookie_password is not None + + +@with_jwks_mock +def test_initialize_without_cookie_password(TEST_CONSTANTS, mock_user_management): + with pytest.raises(ValueError, match="cookie_password is required"): + Session( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data=TEST_CONSTANTS["SESSION_DATA"], + cookie_password="", + ) + + +@with_jwks_mock +def test_authenticate_no_session_cookie_provided(TEST_CONSTANTS, mock_user_management): + session = Session( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data=None, + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], + ) + + response = session.authenticate() + + assert ( + response.reason + == AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED + ) + + +@with_jwks_mock +def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_management): + session = Session( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data="invalid_session_data", + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], + ) + + response = session.authenticate() + + assert ( + response.reason + == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + ) + + +@with_jwks_mock +def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management): + invalid_session_data = Session.seal_data( + {"access_token": "invalid_session_data"}, TEST_CONSTANTS["COOKIE_PASSWORD"] + ) + session = Session( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data=invalid_session_data, + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], + ) + + response = session.authenticate() + + assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT + + +@with_jwks_mock +def test_authenticate_success(TEST_CONSTANTS, mock_user_management): + session = Session( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data=TEST_CONSTANTS["SESSION_DATA"], + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], + ) + + # Mock the session data that would be unsealed + mock_session = { + "access_token": jwt.encode( + { + "sid": TEST_CONSTANTS["SESSION_ID"], + "org_id": TEST_CONSTANTS["ORGANIZATION_ID"], + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"], + "exp": int(datetime.now(timezone.utc).timestamp()) + 3600, + "iat": int(datetime.now(timezone.utc).timestamp()), + }, + TEST_CONSTANTS["PRIVATE_KEY"], + algorithm="RS256", + ), + "user": { + "object": "user", + "id": TEST_CONSTANTS["USER_ID"], + "email": "user@example.com", + "email_verified": True, + "created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], + "updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], + }, + "impersonator": None, + } + + # Mock the JWT payload that would be decoded + mock_jwt_payload = { + "sid": TEST_CONSTANTS["SESSION_ID"], + "org_id": TEST_CONSTANTS["ORGANIZATION_ID"], + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"], + } + + with patch.object(Session, "unseal_data", return_value=mock_session), patch.object( + session, "_is_valid_jwt", return_value=True + ), patch("jwt.decode", return_value=mock_jwt_payload), patch.object( + session.jwks, + "get_signing_key_from_jwt", + return_value=Mock(key=TEST_CONSTANTS["PUBLIC_KEY"]), + ): + response = session.authenticate() + + assert isinstance(response, AuthenticateWithSessionCookieSuccessResponse) + assert response.authenticated is True + assert response.session_id == TEST_CONSTANTS["SESSION_ID"] + assert response.organization_id == TEST_CONSTANTS["ORGANIZATION_ID"] + assert response.role == "admin" + assert response.permissions == ["read"] + assert response.entitlements == ["feature_1"] + assert response.user.id == TEST_CONSTANTS["USER_ID"] + assert response.impersonator is None + + +@with_jwks_mock +def test_refresh_invalid_session_cookie(TEST_CONSTANTS, mock_user_management): + session = Session( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data="invalid_session_data", + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], + ) + + response = session.refresh() + + assert isinstance(response, RefreshWithSessionCookieErrorResponse) + assert ( + response.reason + == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + ) + + +@with_jwks_mock +def test_refresh_success(TEST_CONSTANTS, mock_user_management): + test_user = { + "object": "user", + "id": TEST_CONSTANTS["USER_ID"], + "email": "user@example.com", + "first_name": "Test", + "last_name": "User", + "email_verified": True, + "created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], + "updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], + } + + session_data = Session.seal_data( + {"refresh_token": "refresh_token_12345", "user": test_user}, + TEST_CONSTANTS["COOKIE_PASSWORD"], + ) + + mock_response = { + "access_token": TEST_CONSTANTS["TEST_TOKEN"], + "refresh_token": "refresh_token_123", + "sealed_session": session_data, + "user": test_user, + } + + mock_user_management.authenticate_with_refresh_token.return_value = ( + RefreshTokenAuthenticationResponse(**mock_response) + ) + + session = Session( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data=session_data, + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], + ) + + with patch.object(session, "_is_valid_jwt", return_value=True) as _: + with patch( + "jwt.decode", + return_value={ + "sid": TEST_CONSTANTS["SESSION_ID"], + "org_id": TEST_CONSTANTS["ORGANIZATION_ID"], + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"], + }, + ): + response = session.refresh() + + assert isinstance(response, RefreshWithSessionCookieSuccessResponse) + assert response.authenticated is True + assert response.user.id == test_user["id"] + + # Verify the refresh token was used correctly + mock_user_management.authenticate_with_refresh_token.assert_called_once_with( + refresh_token="refresh_token_12345", + organization_id=None, + session={ + "seal_session": True, + "cookie_password": TEST_CONSTANTS["COOKIE_PASSWORD"], + }, + ) + + +def test_seal_data(TEST_CONSTANTS): + test_data = {"test": "data"} + sealed = Session.seal_data(test_data, TEST_CONSTANTS["COOKIE_PASSWORD"]) + assert isinstance(sealed, str) + + # Test unsealing + unsealed = Session.unseal_data(sealed, TEST_CONSTANTS["COOKIE_PASSWORD"]) + + assert unsealed == test_data + + +def test_unseal_invalid_data(TEST_CONSTANTS): + with pytest.raises(Exception): # Adjust exception type based on your implementation + Session.unseal_data("invalid_sealed_data", TEST_CONSTANTS["COOKIE_PASSWORD"]) diff --git a/tests/test_user_management.py b/tests/test_user_management.py index bdab7e34..ba49302f 100644 --- a/tests/test_user_management.py +++ b/tests/test_user_management.py @@ -60,9 +60,12 @@ def base_authentication_params(self): @pytest.fixture def mock_auth_refresh_token_response(self): + user = MockUser("user_01H7ZGXFP5C6BBQY6Z7277ZCT0").dict() + return { "access_token": "access_token_12345", "refresh_token": "refresh_token_12345", + "user": user, } @pytest.fixture diff --git a/workos/session.py b/workos/session.py new file mode 100644 index 00000000..fea062ca --- /dev/null +++ b/workos/session.py @@ -0,0 +1,196 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +import json +from typing import Any, Dict, Optional, Union, cast +import jwt +from jwt import PyJWKClient +from cryptography.fernet import Fernet + +from workos.types.user_management.authentication_response import ( + RefreshTokenAuthenticationResponse, +) +from workos.types.user_management.session import ( + AuthenticateWithSessionCookieFailureReason, + AuthenticateWithSessionCookieSuccessResponse, + AuthenticateWithSessionCookieErrorResponse, + RefreshWithSessionCookieErrorResponse, + RefreshWithSessionCookieSuccessResponse, +) + +if TYPE_CHECKING: + from workos.user_management import UserManagementModule + + +class Session: + def __init__( + self, + *, + user_management: "UserManagementModule", + client_id: str, + session_data: str, + cookie_password: str, + ) -> None: + # If the cookie password is not provided, throw an error + if cookie_password is None or cookie_password == "": + raise ValueError("cookie_password is required") + + self.user_management = user_management + self.client_id = client_id + self.session_data = session_data + self.cookie_password = cookie_password + + self.jwks = PyJWKClient(self.user_management.get_jwks_url()) + + # Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm + self.jwk_algorithms = ["RS256"] + + def authenticate( + self, + ) -> Union[ + AuthenticateWithSessionCookieSuccessResponse, + AuthenticateWithSessionCookieErrorResponse, + ]: + if self.session_data is None or self.session_data == "": + return AuthenticateWithSessionCookieErrorResponse( + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED, + ) + + try: + session = self.unseal_data(self.session_data, self.cookie_password) + except Exception: + return AuthenticateWithSessionCookieErrorResponse( + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE, + ) + + if not session.get("access_token", None): + return AuthenticateWithSessionCookieErrorResponse( + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE, + ) + + if not self._is_valid_jwt(session["access_token"]): + return AuthenticateWithSessionCookieErrorResponse( + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT, + ) + + signing_key = self.jwks.get_signing_key_from_jwt(session["access_token"]) + decoded = jwt.decode( + session["access_token"], signing_key.key, algorithms=self.jwk_algorithms + ) + + return AuthenticateWithSessionCookieSuccessResponse( + authenticated=True, + session_id=decoded["sid"], + organization_id=decoded.get("org_id", None), + role=decoded.get("role", None), + permissions=decoded.get("permissions", None), + entitlements=decoded.get("entitlements", None), + user=session["user"], + impersonator=session.get("impersonator", None), + ) + + def refresh( + self, + *, + organization_id: Optional[str] = None, + cookie_password: Optional[str] = None, + ) -> Union[ + RefreshWithSessionCookieSuccessResponse, + RefreshWithSessionCookieErrorResponse, + ]: + cookie_password = ( + self.cookie_password if cookie_password is None else cookie_password + ) + + try: + session = self.unseal_data(self.session_data, cookie_password) + except Exception: + return RefreshWithSessionCookieErrorResponse( + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE, + ) + + if not session.get("refresh_token", None) or not session.get("user", None): + return RefreshWithSessionCookieErrorResponse( + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE, + ) + + try: + auth_response = cast( + RefreshTokenAuthenticationResponse, + self.user_management.authenticate_with_refresh_token( + refresh_token=session["refresh_token"], + organization_id=organization_id, + session={"seal_session": True, "cookie_password": cookie_password}, + ), + ) + + self.session_data = str(auth_response.sealed_session) + self.cookie_password = ( + cookie_password if cookie_password is not None else self.cookie_password + ) + + signing_key = self.jwks.get_signing_key_from_jwt(auth_response.access_token) + + decoded = jwt.decode( + auth_response.access_token, + signing_key.key, + algorithms=self.jwk_algorithms, + ) + + return RefreshWithSessionCookieSuccessResponse( + authenticated=True, + sealed_session=str(auth_response.sealed_session), + session_id=decoded["sid"], + organization_id=decoded.get("org_id", None), + role=decoded.get("role", None), + permissions=decoded.get("permissions", None), + entitlements=decoded.get("entitlements", None), + user=auth_response.user, + impersonator=auth_response.impersonator, + ) + except Exception as e: + return RefreshWithSessionCookieErrorResponse( + authenticated=False, reason=str(e) + ) + + def get_logout_url(self) -> str: + auth_response = self.authenticate() + + if isinstance(auth_response, AuthenticateWithSessionCookieErrorResponse): + raise ValueError( + f"Failed to extract session ID for logout URL: {auth_response.reason}" + ) + + result = self.user_management.get_logout_url( + session_id=auth_response.session_id + ) + return str(result) + + def _is_valid_jwt(self, token: str) -> bool: + try: + signing_key = self.jwks.get_signing_key_from_jwt(token) + jwt.decode(token, signing_key.key, algorithms=self.jwk_algorithms) + return True + except jwt.exceptions.InvalidTokenError: + return False + + @staticmethod + def seal_data(data: Dict[str, Any], key: str) -> str: + fernet = Fernet(key) + # Encrypt and convert bytes to string + encrypted_bytes = fernet.encrypt(json.dumps(data).encode()) + return encrypted_bytes.decode("utf-8") + + @staticmethod + def unseal_data(sealed_data: str, key: str) -> Dict[str, Any]: + fernet = Fernet(key) + # Convert string back to bytes before decryption + encrypted_bytes = sealed_data.encode("utf-8") + decrypted_str = fernet.decrypt(encrypted_bytes).decode() + return cast(Dict[str, Any], json.loads(decrypted_str)) diff --git a/workos/types/user_management/authenticate_with_common.py b/workos/types/user_management/authenticate_with_common.py index af423e18..8adc9ee6 100644 --- a/workos/types/user_management/authenticate_with_common.py +++ b/workos/types/user_management/authenticate_with_common.py @@ -1,5 +1,6 @@ from typing import Literal, Union from typing_extensions import TypedDict +from workos.types.user_management.session import SessionConfig class AuthenticateWithBaseParameters(TypedDict): @@ -17,6 +18,7 @@ class AuthenticateWithCodeParameters(AuthenticateWithBaseParameters): code: str code_verifier: Union[str, None] grant_type: Literal["authorization_code"] + session: Union[SessionConfig, None] class AuthenticateWithMagicAuthParameters(AuthenticateWithBaseParameters): @@ -49,6 +51,7 @@ class AuthenticateWithRefreshTokenParameters(AuthenticateWithBaseParameters): refresh_token: str organization_id: Union[str, None] grant_type: Literal["refresh_token"] + session: Union[SessionConfig, None] AuthenticateWithParameters = Union[ diff --git a/workos/types/user_management/authentication_response.py b/workos/types/user_management/authentication_response.py index 6caa57d0..2fdab1b0 100644 --- a/workos/types/user_management/authentication_response.py +++ b/workos/types/user_management/authentication_response.py @@ -30,6 +30,7 @@ class AuthenticationResponse(_AuthenticationResponseBase): impersonator: Optional[Impersonator] = None organization_id: Optional[str] = None user: User + sealed_session: Optional[str] = None class AuthKitAuthenticationResponse(AuthenticationResponse): @@ -39,7 +40,7 @@ class AuthKitAuthenticationResponse(AuthenticationResponse): oauth_tokens: Optional[OAuthTokens] = None -class RefreshTokenAuthenticationResponse(_AuthenticationResponseBase): +class RefreshTokenAuthenticationResponse(AuthenticationResponse): """Representation of a WorkOS refresh token authentication response.""" pass diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py new file mode 100644 index 00000000..76739f9d --- /dev/null +++ b/workos/types/user_management/session.py @@ -0,0 +1,44 @@ +from typing import Optional, Sequence, TypedDict, Union +from enum import Enum +from typing_extensions import Literal +from workos.types.user_management.impersonator import Impersonator +from workos.types.user_management.user import User +from workos.types.workos_model import WorkOSModel + + +class AuthenticateWithSessionCookieFailureReason(Enum): + INVALID_JWT = "invalid_jwt" + INVALID_SESSION_COOKIE = "invalid_session_cookie" + NO_SESSION_COOKIE_PROVIDED = "no_session_cookie_provided" + + +class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel): + authenticated: Literal[True] + session_id: str + organization_id: Optional[str] = None + role: Optional[str] = None + permissions: Optional[Sequence[str]] = None + user: User + impersonator: Optional[Impersonator] = None + entitlements: Optional[Sequence[str]] = None + + +class AuthenticateWithSessionCookieErrorResponse(WorkOSModel): + authenticated: Literal[False] + reason: Union[AuthenticateWithSessionCookieFailureReason, str] + + +class RefreshWithSessionCookieSuccessResponse( + AuthenticateWithSessionCookieSuccessResponse +): + sealed_session: str + + +class RefreshWithSessionCookieErrorResponse(WorkOSModel): + authenticated: Literal[False] + reason: Union[AuthenticateWithSessionCookieFailureReason, str] + + +class SessionConfig(TypedDict, total=False): + seal_session: bool + cookie_password: str diff --git a/workos/user_management.py b/workos/user_management.py index c2fcc582..c0e94acf 100644 --- a/workos/user_management.py +++ b/workos/user_management.py @@ -1,5 +1,6 @@ -from typing import Optional, Protocol, Sequence, Set, Type +from typing import Optional, Protocol, Sequence, Set, Type, cast from workos._client_configuration import ClientConfiguration +from workos.session import Session from workos.types.list_resource import ( ListArgs, ListMetadata, @@ -43,6 +44,7 @@ UsersListFilters, ) from workos.types.user_management.password_hash_type import PasswordHashType +from workos.types.user_management.session import SessionConfig from workos.types.user_management.user_management_provider_type import ( UserManagementProviderType, ) @@ -109,6 +111,20 @@ class UserManagementModule(Protocol): _client_configuration: ClientConfiguration + def load_sealed_session( + self, *, sealed_session: str, cookie_password: str + ) -> SyncOrAsync[Session]: + """Load a sealed session and return the session data. + + Args: + sealed_session (str): The sealed session data to load. + cookie_password (str): The cookie password to use to decrypt the session data. + + Returns: + Session: The session module. + """ + ... + def get_user(self, user_id: str) -> SyncOrAsync[User]: """Get the details of an existing user. @@ -423,6 +439,7 @@ def authenticate_with_code( self, *, code: str, + session: Optional[SessionConfig] = None, code_verifier: Optional[str] = None, ip_address: Optional[str] = None, user_agent: Optional[str] = None, @@ -431,6 +448,7 @@ def authenticate_with_code( Kwargs: code (str): The authorization value which was passed back as a query parameter in the callback to the Redirect URI. + session (SessionConfig): Configuration for the session. (Optional) code_verifier (str): The randomly generated string used to derive the code challenge that was passed to the authorization url as part of the PKCE flow. This parameter is required when the client secret is not present. (Optional) ip_address (str): The IP address of the request from the user who is attempting to authenticate. (Optional) @@ -534,6 +552,7 @@ def authenticate_with_refresh_token( self, *, refresh_token: str, + session: Optional[SessionConfig] = None, organization_id: Optional[str] = None, ip_address: Optional[str] = None, user_agent: Optional[str] = None, @@ -542,6 +561,7 @@ def authenticate_with_refresh_token( Kwargs: refresh_token (str): The token associated to the user. + session (SessionConfig): Configuration for the session. (Optional) organization_id (str): The organization to issue the new access token for. (Optional) ip_address (str): The IP address of the request from the user who is attempting to authenticate. (Optional) user_agent (str): The user agent of the request from the user who is attempting to authenticate. (Optional) @@ -810,6 +830,16 @@ def __init__( self._client_configuration = client_configuration self._http_client = http_client + def load_sealed_session( + self, *, sealed_session: str, cookie_password: str + ) -> Session: + return Session( + user_management=self, + client_id=self._http_client.client_id, + session_data=sealed_session, + cookie_password=cookie_password, + ) + def get_user(self, user_id: str) -> User: response = self._http_client.request( USER_DETAIL_PATH.format(user_id), method=REQUEST_METHOD_GET @@ -1019,7 +1049,16 @@ def _authenticate_with( json=json, ) - return response_model.model_validate(response) + response_data = dict(response) + + session = cast(Optional[SessionConfig], payload.get("session", None)) + + if session is not None and session.get("seal_session") is True: + response_data["sealed_session"] = Session.seal_data( + response_data, str(session.get("cookie_password")) + ) + + return response_model.model_validate(response_data) def authenticate_with_password( self, @@ -1043,16 +1082,25 @@ def authenticate_with_code( self, *, code: str, + session: Optional[SessionConfig] = None, code_verifier: Optional[str] = None, ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> AuthKitAuthenticationResponse: + if ( + session is not None + and session.get("seal_session") + and not session.get("cookie_password") + ): + raise ValueError("cookie_password is required when sealing session") + payload: AuthenticateWithCodeParameters = { "code": code, "grant_type": "authorization_code", "ip_address": ip_address, "user_agent": user_agent, "code_verifier": code_verifier, + "session": session, } return self._authenticate_with( @@ -1139,16 +1187,25 @@ def authenticate_with_refresh_token( self, *, refresh_token: str, + session: Optional[SessionConfig] = None, organization_id: Optional[str] = None, ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> RefreshTokenAuthenticationResponse: + if ( + session is not None + and session.get("seal_session") + and not session.get("cookie_password") + ): + raise ValueError("cookie_password is required when sealing session") + payload: AuthenticateWithRefreshTokenParameters = { "refresh_token": refresh_token, "organization_id": organization_id, "grant_type": "refresh_token", "ip_address": ip_address, "user_agent": user_agent, + "session": session, } return self._authenticate_with( @@ -1223,10 +1280,7 @@ def get_magic_auth(self, magic_auth_id: str) -> MagicAuth: return MagicAuth.model_validate(response) def create_magic_auth( - self, - *, - email: str, - invitation_token: Optional[str] = None, + self, *, email: str, invitation_token: Optional[str] = None ) -> MagicAuth: json = { "email": email, @@ -1385,6 +1439,11 @@ def __init__( self._client_configuration = client_configuration self._http_client = http_client + async def load_sealed_session( + self, *, sealed_session: str, cookie_password: str + ) -> Session: + raise NotImplementedError("Async load_sealed_session not implemented") + async def get_user(self, user_id: str) -> User: response = await self._http_client.request( USER_DETAIL_PATH.format(user_id), method=REQUEST_METHOD_GET @@ -1595,7 +1654,16 @@ async def _authenticate_with( json=json, ) - return response_model.model_validate(response) + response_data = dict(response) + + session = cast(Optional[SessionConfig], payload.get("session", None)) + + if session is not None and session.get("seal_session") is True: + response_data["sealed_session"] = Session.seal_data( + response_data, str(session.get("cookie_password")) + ) + + return response_model.model_validate(response_data) async def authenticate_with_password( self, @@ -1621,16 +1689,25 @@ async def authenticate_with_code( self, *, code: str, + session: Optional[SessionConfig] = None, code_verifier: Optional[str] = None, ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> AuthKitAuthenticationResponse: + if ( + session is not None + and session.get("seal_session") + and not session.get("cookie_password") + ): + raise ValueError("cookie_password is required when sealing session") + payload: AuthenticateWithCodeParameters = { "code": code, "grant_type": "authorization_code", "ip_address": ip_address, "user_agent": user_agent, "code_verifier": code_verifier, + "session": session, } return await self._authenticate_with( @@ -1725,16 +1802,25 @@ async def authenticate_with_refresh_token( self, *, refresh_token: str, + session: Optional[SessionConfig] = None, organization_id: Optional[str] = None, ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> RefreshTokenAuthenticationResponse: + if ( + session is not None + and session.get("seal_session") + and not session.get("cookie_password") + ): + raise ValueError("cookie_password is required when sealing session") + payload: AuthenticateWithRefreshTokenParameters = { "refresh_token": refresh_token, "organization_id": organization_id, "grant_type": "refresh_token", "ip_address": ip_address, "user_agent": user_agent, + "session": session, } return await self._authenticate_with(