diff --git a/tests/conftest.py b/tests/conftest.py index 6faedbde..c89fc788 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,8 @@ import httpx import pytest +import asyncio +from functools import wraps from tests.utils.client_configuration import ClientConfiguration from tests.utils.list_resource import list_data_to_dicts, list_response_of @@ -26,7 +28,6 @@ from jwt import PyJWKClient from unittest.mock import Mock, patch -from functools import wraps def _get_test_client_setup( @@ -310,7 +311,19 @@ def inner( def with_jwks_mock(func): @wraps(func) - def wrapper(*args, **kwargs): + async def async_wrapper(*args, **kwargs): + # Create mock JWKS client + mock_jwks = Mock(spec=PyJWKClient) + mock_signing_key = Mock() + mock_signing_key.key = kwargs["session_constants"]["PUBLIC_KEY"] + mock_jwks.get_signing_key_from_jwt.return_value = mock_signing_key + + # Apply the mock + with patch("workos.session.PyJWKClient", return_value=mock_jwks): + return await func(*args, **kwargs) + + @wraps(func) + def sync_wrapper(*args, **kwargs): # Create mock JWKS client mock_jwks = Mock(spec=PyJWKClient) mock_signing_key = Mock() @@ -321,4 +334,7 @@ def wrapper(*args, **kwargs): with patch("workos.session.PyJWKClient", return_value=mock_jwks): return func(*args, **kwargs) - return wrapper + # Return appropriate wrapper based on whether the function is async or not + if asyncio.iscoroutinefunction(func): + return async_wrapper + return sync_wrapper diff --git a/tests/test_session.py b/tests/test_session.py index b2fb654e..f68415c1 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,5 +1,5 @@ import pytest -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import jwt from datetime import datetime, timezone @@ -396,6 +396,7 @@ def test_refresh_success_with_aud_claim( class TestAsyncSession(SessionFixtures): + @pytest.mark.asyncio @with_jwks_mock async def test_refresh_success(self, session_constants, mock_user_management): session_data = AsyncSession.seal_data( @@ -413,8 +414,8 @@ async def test_refresh_success(self, session_constants, mock_user_management): "user": session_constants["TEST_USER"], } - mock_user_management.authenticate_with_refresh_token.return_value = ( - RefreshTokenAuthenticationResponse(**mock_response) + mock_user_management.authenticate_with_refresh_token = AsyncMock( + return_value=(RefreshTokenAuthenticationResponse(**mock_response)) ) session = AsyncSession( @@ -451,6 +452,7 @@ async def test_refresh_success(self, session_constants, mock_user_management): }, ) + @pytest.mark.asyncio @with_jwks_mock async def test_refresh_success_with_aud_claim( self, session_constants, mock_user_management @@ -479,8 +481,8 @@ async def test_refresh_success_with_aud_claim( "user": session_constants["TEST_USER"], } - mock_user_management.authenticate_with_refresh_token.return_value = ( - RefreshTokenAuthenticationResponse(**mock_response) + mock_user_management.authenticate_with_refresh_token = AsyncMock( + return_value=(RefreshTokenAuthenticationResponse(**mock_response)) ) session = AsyncSession(