Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from unittest.mock import AsyncMock, Mock, patch
import jwt
from datetime import datetime, timezone
import concurrent.futures

from tests.conftest import with_jwks_mock
from workos.session import AsyncSession, Session
from workos.session import AsyncSession, Session, _get_jwks_client
from workos.types.user_management.authentication_response import (
RefreshTokenAuthenticationResponse,
)
Expand All @@ -20,6 +21,12 @@


class SessionFixtures:
@pytest.fixture(autouse=True)
def clear_jwks_cache(self):
_get_jwks_client.cache_clear()
yield
_get_jwks_client.cache_clear()

@pytest.fixture
def session_constants(self):
# Generate RSA key pair for testing
Expand Down Expand Up @@ -491,3 +498,43 @@ async def test_refresh_success_with_aud_claim(
response = await session.refresh()

assert isinstance(response, RefreshWithSessionCookieSuccessResponse)


class TestJWKSCaching:
def test_jwks_client_caching_same_url(self):
url = "https://api.workos.com/sso/jwks/test"

client1 = _get_jwks_client(url)
client2 = _get_jwks_client(url)

# Should be the exact same instance
assert client1 is client2
assert id(client1) == id(client2)

def test_jwks_client_caching_different_urls(self):
url1 = "https://api.workos.com/sso/jwks/client1"
url2 = "https://api.workos.com/sso/jwks/client2"

client1 = _get_jwks_client(url1)
client2 = _get_jwks_client(url2)

# Should be different instances
assert client1 is not client2
assert id(client1) != id(client2)

def test_jwks_cache_thread_safety(self):
url = "https://api.workos.com/sso/jwks/thread_test"
clients = []

def get_client():
return _get_jwks_client(url)

with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(get_client) for _ in range(10)]
clients = [future.result() for future in futures]

first_client = clients[0]
for client in clients[1:]:
assert (
client is first_client
), "All concurrent calls should return the same instance"
12 changes: 9 additions & 3 deletions workos/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List, Protocol

from functools import lru_cache
import json
from typing import Any, Dict, Optional, Union, cast
import jwt
Expand All @@ -21,6 +22,11 @@
from workos.user_management import AsyncUserManagement, UserManagement


@lru_cache(maxsize=None)
def _get_jwks_client(jwks_url: str) -> PyJWKClient:
return PyJWKClient(jwks_url)


class SessionModule(Protocol):
user_management: "UserManagementModule"
client_id: str
Expand All @@ -46,7 +52,7 @@ def __init__(
self.session_data = session_data
self.cookie_password = cookie_password

self.jwks = PyJWKClient(self.user_management.get_jwks_url())
self.jwks = _get_jwks_client(self.user_management.get_jwks_url())

# Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm
self.jwk_algorithms = ["RS256"]
Expand Down Expand Up @@ -164,7 +170,7 @@ def __init__(
self.session_data = session_data
self.cookie_password = cookie_password

self.jwks = PyJWKClient(self.user_management.get_jwks_url())
self.jwks = _get_jwks_client(self.user_management.get_jwks_url())

# Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm
self.jwk_algorithms = ["RS256"]
Expand Down Expand Up @@ -254,7 +260,7 @@ def __init__(
self.session_data = session_data
self.cookie_password = cookie_password

self.jwks = PyJWKClient(self.user_management.get_jwks_url())
self.jwks = _get_jwks_client(self.user_management.get_jwks_url())

# Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm
self.jwk_algorithms = ["RS256"]
Expand Down