diff --git a/tests/test_session.py b/tests/test_session.py index 5447d10d..3b7ddd78 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -2,9 +2,10 @@ from unittest.mock import AsyncMock, Mock, patch import jwt from datetime import datetime, timezone +import concurrent.futures from tests.conftest import with_jwks_mock -from workos.session import AsyncSession, Session +from workos.session import AsyncSession, Session, _get_jwks_client from workos.types.user_management.authentication_response import ( RefreshTokenAuthenticationResponse, ) @@ -20,6 +21,12 @@ class SessionFixtures: + @pytest.fixture(autouse=True) + def clear_jwks_cache(self): + _get_jwks_client.cache_clear() + yield + _get_jwks_client.cache_clear() + @pytest.fixture def session_constants(self): # Generate RSA key pair for testing @@ -491,3 +498,43 @@ async def test_refresh_success_with_aud_claim( response = await session.refresh() assert isinstance(response, RefreshWithSessionCookieSuccessResponse) + + +class TestJWKSCaching: + def test_jwks_client_caching_same_url(self): + url = "https://api.workos.com/sso/jwks/test" + + client1 = _get_jwks_client(url) + client2 = _get_jwks_client(url) + + # Should be the exact same instance + assert client1 is client2 + assert id(client1) == id(client2) + + def test_jwks_client_caching_different_urls(self): + url1 = "https://api.workos.com/sso/jwks/client1" + url2 = "https://api.workos.com/sso/jwks/client2" + + client1 = _get_jwks_client(url1) + client2 = _get_jwks_client(url2) + + # Should be different instances + assert client1 is not client2 + assert id(client1) != id(client2) + + def test_jwks_cache_thread_safety(self): + url = "https://api.workos.com/sso/jwks/thread_test" + clients = [] + + def get_client(): + return _get_jwks_client(url) + + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(get_client) for _ in range(10)] + clients = [future.result() for future in futures] + + first_client = clients[0] + for client in clients[1:]: + assert ( + client is first_client + ), "All concurrent calls should return the same instance" diff --git a/workos/session.py b/workos/session.py index 58d27f65..62aaae36 100644 --- a/workos/session.py +++ b/workos/session.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, List, Protocol +from functools import lru_cache import json from typing import Any, Dict, Optional, Union, cast import jwt @@ -21,6 +22,11 @@ from workos.user_management import AsyncUserManagement, UserManagement +@lru_cache(maxsize=None) +def _get_jwks_client(jwks_url: str) -> PyJWKClient: + return PyJWKClient(jwks_url) + + class SessionModule(Protocol): user_management: "UserManagementModule" client_id: str @@ -46,7 +52,7 @@ def __init__( self.session_data = session_data self.cookie_password = cookie_password - self.jwks = PyJWKClient(self.user_management.get_jwks_url()) + self.jwks = _get_jwks_client(self.user_management.get_jwks_url()) # Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm self.jwk_algorithms = ["RS256"] @@ -164,7 +170,7 @@ def __init__( self.session_data = session_data self.cookie_password = cookie_password - self.jwks = PyJWKClient(self.user_management.get_jwks_url()) + self.jwks = _get_jwks_client(self.user_management.get_jwks_url()) # Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm self.jwk_algorithms = ["RS256"] @@ -254,7 +260,7 @@ def __init__( self.session_data = session_data self.cookie_password = cookie_password - self.jwks = PyJWKClient(self.user_management.get_jwks_url()) + self.jwks = _get_jwks_client(self.user_management.get_jwks_url()) # Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm self.jwk_algorithms = ["RS256"]