From 37d761d5a3a7c5ebe1b9f58634e98f1fa5a2dceb Mon Sep 17 00:00:00 2001 From: Dan Dorman Date: Thu, 31 Jul 2025 13:25:13 -0600 Subject: [PATCH 1/3] Cache JWKS clients per URL --- tests/test_session.py | 31 ++++++++++++++++++++++++++++++- workos/session.py | 24 +++++++++++++++++++++--- 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/tests/test_session.py b/tests/test_session.py index 5447d10d..54fbc7c5 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -4,7 +4,7 @@ from datetime import datetime, timezone from tests.conftest import with_jwks_mock -from workos.session import AsyncSession, Session +from workos.session import AsyncSession, Session, _get_jwks_client, _jwks_cache from workos.types.user_management.authentication_response import ( RefreshTokenAuthenticationResponse, ) @@ -20,6 +20,12 @@ class SessionFixtures: + @pytest.fixture(autouse=True) + def clear_jwks_cache(self): + _jwks_cache._clients.clear() + yield + _jwks_cache._clients.clear() + @pytest.fixture def session_constants(self): # Generate RSA key pair for testing @@ -491,3 +497,26 @@ async def test_refresh_success_with_aud_claim( response = await session.refresh() assert isinstance(response, RefreshWithSessionCookieSuccessResponse) + + +class TestJWKSCaching: + def test_jwks_client_caching_same_url(self): + url = "https://api.workos.com/sso/jwks/test" + + client1 = _get_jwks_client(url) + client2 = _get_jwks_client(url) + + # Should be the exact same instance + assert client1 is client2 + assert id(client1) == id(client2) + + def test_jwks_client_caching_different_urls(self): + url1 = "https://api.workos.com/sso/jwks/client1" + url2 = "https://api.workos.com/sso/jwks/client2" + + client1 = _get_jwks_client(url1) + client2 = _get_jwks_client(url2) + + # Should be different instances + assert client1 is not client2 + assert id(client1) != id(client2) diff --git a/workos/session.py b/workos/session.py index 58d27f65..d4d78f2d 100644 --- a/workos/session.py +++ b/workos/session.py @@ -21,6 +21,24 @@ from workos.user_management import AsyncUserManagement, UserManagement +class _JWKSClientCache: + def __init__(self) -> None: + self._clients: Dict[str, PyJWKClient] = {} + + def get_client(self, jwks_url: str) -> PyJWKClient: + if jwks_url not in self._clients: + self._clients[jwks_url] = PyJWKClient(jwks_url) + return self._clients[jwks_url] + + +# Module-level cache instance +_jwks_cache = _JWKSClientCache() + + +def _get_jwks_client(jwks_url: str) -> PyJWKClient: + return _jwks_cache.get_client(jwks_url) + + class SessionModule(Protocol): user_management: "UserManagementModule" client_id: str @@ -46,7 +64,7 @@ def __init__( self.session_data = session_data self.cookie_password = cookie_password - self.jwks = PyJWKClient(self.user_management.get_jwks_url()) + self.jwks = _get_jwks_client(self.user_management.get_jwks_url()) # Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm self.jwk_algorithms = ["RS256"] @@ -164,7 +182,7 @@ def __init__( self.session_data = session_data self.cookie_password = cookie_password - self.jwks = PyJWKClient(self.user_management.get_jwks_url()) + self.jwks = _get_jwks_client(self.user_management.get_jwks_url()) # Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm self.jwk_algorithms = ["RS256"] @@ -254,7 +272,7 @@ def __init__( self.session_data = session_data self.cookie_password = cookie_password - self.jwks = PyJWKClient(self.user_management.get_jwks_url()) + self.jwks = _get_jwks_client(self.user_management.get_jwks_url()) # Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm self.jwk_algorithms = ["RS256"] From e76eb9a4fbd5b4a8df587de7c87e349ebecdea48 Mon Sep 17 00:00:00 2001 From: Dan Dorman Date: Thu, 31 Jul 2025 15:55:33 -0600 Subject: [PATCH 2/3] Make JWKS client cache threadsafe --- tests/test_session.py | 22 ++++++++++++++++++++-- workos/session.py | 24 +++++++++++++++++++++--- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/tests/test_session.py b/tests/test_session.py index 54fbc7c5..9c323dd8 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -2,6 +2,7 @@ from unittest.mock import AsyncMock, Mock, patch import jwt from datetime import datetime, timezone +import concurrent.futures from tests.conftest import with_jwks_mock from workos.session import AsyncSession, Session, _get_jwks_client, _jwks_cache @@ -22,9 +23,9 @@ class SessionFixtures: @pytest.fixture(autouse=True) def clear_jwks_cache(self): - _jwks_cache._clients.clear() + _jwks_cache.clear() yield - _jwks_cache._clients.clear() + _jwks_cache.clear() @pytest.fixture def session_constants(self): @@ -520,3 +521,20 @@ def test_jwks_client_caching_different_urls(self): # Should be different instances assert client1 is not client2 assert id(client1) != id(client2) + + def test_jwks_cache_thread_safety(self): + url = "https://api.workos.com/sso/jwks/thread_test" + clients = [] + + def get_client(): + return _get_jwks_client(url) + + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(get_client) for _ in range(10)] + clients = [future.result() for future in futures] + + first_client = clients[0] + for client in clients[1:]: + assert ( + client is first_client + ), "All concurrent calls should return the same instance" diff --git a/workos/session.py b/workos/session.py index d4d78f2d..3e9f47a6 100644 --- a/workos/session.py +++ b/workos/session.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, List, Protocol import json +import threading from typing import Any, Dict, Optional, Union, cast import jwt from jwt import PyJWKClient @@ -24,11 +25,28 @@ class _JWKSClientCache: def __init__(self) -> None: self._clients: Dict[str, PyJWKClient] = {} + self._lock = threading.Lock() def get_client(self, jwks_url: str) -> PyJWKClient: - if jwks_url not in self._clients: - self._clients[jwks_url] = PyJWKClient(jwks_url) - return self._clients[jwks_url] + if jwks_url in self._clients: + return self._clients[jwks_url] + + with self._lock: + if jwks_url in self._clients: + return self._clients[jwks_url] + + client = PyJWKClient(jwks_url) + self._clients[jwks_url] = client + return client + + def clear(self) -> None: + """Intended primarily for test cleanup and manual cache invalidation. + + Warning: If called concurrently with get_client(), some newly created + clients might be lost due to lock acquisition ordering. + """ + with self._lock: + self._clients.clear() # Module-level cache instance From 7346f1587453046722f0ed9dbdcbf8fdb422acb6 Mon Sep 17 00:00:00 2001 From: Dan Dorman Date: Thu, 31 Jul 2025 16:04:30 -0600 Subject: [PATCH 3/3] Switch to lru_cache --- tests/test_session.py | 8 ++++---- workos/session.py | 36 +++--------------------------------- 2 files changed, 7 insertions(+), 37 deletions(-) diff --git a/tests/test_session.py b/tests/test_session.py index 9c323dd8..3b7ddd78 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -5,7 +5,7 @@ import concurrent.futures from tests.conftest import with_jwks_mock -from workos.session import AsyncSession, Session, _get_jwks_client, _jwks_cache +from workos.session import AsyncSession, Session, _get_jwks_client from workos.types.user_management.authentication_response import ( RefreshTokenAuthenticationResponse, ) @@ -23,9 +23,9 @@ class SessionFixtures: @pytest.fixture(autouse=True) def clear_jwks_cache(self): - _jwks_cache.clear() + _get_jwks_client.cache_clear() yield - _jwks_cache.clear() + _get_jwks_client.cache_clear() @pytest.fixture def session_constants(self): @@ -521,7 +521,7 @@ def test_jwks_client_caching_different_urls(self): # Should be different instances assert client1 is not client2 assert id(client1) != id(client2) - + def test_jwks_cache_thread_safety(self): url = "https://api.workos.com/sso/jwks/thread_test" clients = [] diff --git a/workos/session.py b/workos/session.py index 3e9f47a6..62aaae36 100644 --- a/workos/session.py +++ b/workos/session.py @@ -1,8 +1,8 @@ from __future__ import annotations from typing import TYPE_CHECKING, List, Protocol +from functools import lru_cache import json -import threading from typing import Any, Dict, Optional, Union, cast import jwt from jwt import PyJWKClient @@ -22,39 +22,9 @@ from workos.user_management import AsyncUserManagement, UserManagement -class _JWKSClientCache: - def __init__(self) -> None: - self._clients: Dict[str, PyJWKClient] = {} - self._lock = threading.Lock() - - def get_client(self, jwks_url: str) -> PyJWKClient: - if jwks_url in self._clients: - return self._clients[jwks_url] - - with self._lock: - if jwks_url in self._clients: - return self._clients[jwks_url] - - client = PyJWKClient(jwks_url) - self._clients[jwks_url] = client - return client - - def clear(self) -> None: - """Intended primarily for test cleanup and manual cache invalidation. - - Warning: If called concurrently with get_client(), some newly created - clients might be lost due to lock acquisition ordering. - """ - with self._lock: - self._clients.clear() - - -# Module-level cache instance -_jwks_cache = _JWKSClientCache() - - +@lru_cache(maxsize=None) def _get_jwks_client(jwks_url: str) -> PyJWKClient: - return _jwks_cache.get_client(jwks_url) + return PyJWKClient(jwks_url) class SessionModule(Protocol):