Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import httpx
import pytest
import asyncio
from functools import wraps

from tests.utils.client_configuration import ClientConfiguration
from tests.utils.list_resource import list_data_to_dicts, list_response_of
Expand All @@ -26,7 +28,6 @@

from jwt import PyJWKClient
from unittest.mock import Mock, patch
from functools import wraps


def _get_test_client_setup(
Expand Down Expand Up @@ -310,7 +311,19 @@ def inner(

def with_jwks_mock(func):
@wraps(func)
def wrapper(*args, **kwargs):
async def async_wrapper(*args, **kwargs):
# Create mock JWKS client
mock_jwks = Mock(spec=PyJWKClient)
mock_signing_key = Mock()
mock_signing_key.key = kwargs["session_constants"]["PUBLIC_KEY"]
mock_jwks.get_signing_key_from_jwt.return_value = mock_signing_key

# Apply the mock
with patch("workos.session.PyJWKClient", return_value=mock_jwks):
return await func(*args, **kwargs)

@wraps(func)
def sync_wrapper(*args, **kwargs):
# Create mock JWKS client
mock_jwks = Mock(spec=PyJWKClient)
mock_signing_key = Mock()
Expand All @@ -321,4 +334,7 @@ def wrapper(*args, **kwargs):
with patch("workos.session.PyJWKClient", return_value=mock_jwks):
return func(*args, **kwargs)

return wrapper
# Return appropriate wrapper based on whether the function is async or not
if asyncio.iscoroutinefunction(func):
return async_wrapper
return sync_wrapper
12 changes: 7 additions & 5 deletions tests/test_session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from unittest.mock import Mock, patch
from unittest.mock import AsyncMock, Mock, patch
import jwt
from datetime import datetime, timezone

Expand Down Expand Up @@ -396,6 +396,7 @@ def test_refresh_success_with_aud_claim(


class TestAsyncSession(SessionFixtures):
@pytest.mark.asyncio
@with_jwks_mock
async def test_refresh_success(self, session_constants, mock_user_management):
session_data = AsyncSession.seal_data(
Expand All @@ -413,8 +414,8 @@ async def test_refresh_success(self, session_constants, mock_user_management):
"user": session_constants["TEST_USER"],
}

mock_user_management.authenticate_with_refresh_token.return_value = (
RefreshTokenAuthenticationResponse(**mock_response)
mock_user_management.authenticate_with_refresh_token = AsyncMock(
return_value=(RefreshTokenAuthenticationResponse(**mock_response))
)

session = AsyncSession(
Expand Down Expand Up @@ -451,6 +452,7 @@ async def test_refresh_success(self, session_constants, mock_user_management):
},
)

@pytest.mark.asyncio
@with_jwks_mock
async def test_refresh_success_with_aud_claim(
self, session_constants, mock_user_management
Expand Down Expand Up @@ -479,8 +481,8 @@ async def test_refresh_success_with_aud_claim(
"user": session_constants["TEST_USER"],
}

mock_user_management.authenticate_with_refresh_token.return_value = (
RefreshTokenAuthenticationResponse(**mock_response)
mock_user_management.authenticate_with_refresh_token = AsyncMock(
return_value=(RefreshTokenAuthenticationResponse(**mock_response))
)

session = AsyncSession(
Expand Down