diff --git a/tests/test_session.py b/tests/test_session.py index f68415c1..5447d10d 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -236,9 +236,7 @@ def test_authenticate_success(self, session_constants, mock_user_management): "entitlements": ["feature_1"], } - with patch.object( - Session, "unseal_data", return_value=mock_session - ), patch.object(session, "_is_valid_jwt", return_value=True), patch( + with patch.object(Session, "unseal_data", return_value=mock_session), patch( "jwt.decode", return_value=mock_jwt_payload ), patch.object( session.jwks, @@ -324,22 +322,21 @@ def test_refresh_success(self, session_constants, mock_user_management): cookie_password=session_constants["COOKIE_PASSWORD"], ) - with patch.object(session, "_is_valid_jwt", return_value=True) as _: - with patch( - "jwt.decode", - return_value={ - "sid": session_constants["SESSION_ID"], - "org_id": session_constants["ORGANIZATION_ID"], - "role": "admin", - "permissions": ["read"], - "entitlements": ["feature_1"], - }, - ): - response = session.refresh() + with patch( + "jwt.decode", + return_value={ + "sid": session_constants["SESSION_ID"], + "org_id": session_constants["ORGANIZATION_ID"], + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"], + }, + ): + response = session.refresh() - assert isinstance(response, RefreshWithSessionCookieSuccessResponse) - assert response.authenticated is True - assert response.user.id == session_constants["TEST_USER"]["id"] + assert isinstance(response, RefreshWithSessionCookieSuccessResponse) + assert response.authenticated is True + assert response.user.id == session_constants["TEST_USER"]["id"] # Verify the refresh token was used correctly mock_user_management.authenticate_with_refresh_token.assert_called_once_with( @@ -425,22 +422,21 @@ async def test_refresh_success(self, session_constants, mock_user_management): cookie_password=session_constants["COOKIE_PASSWORD"], ) - with patch.object(session, "_is_valid_jwt", return_value=True) as _: - with patch( - "jwt.decode", - return_value={ - "sid": session_constants["SESSION_ID"], - "org_id": session_constants["ORGANIZATION_ID"], - "role": "admin", - "permissions": ["read"], - "entitlements": ["feature_1"], - }, - ): - response = await session.refresh() + with patch( + "jwt.decode", + return_value={ + "sid": session_constants["SESSION_ID"], + "org_id": session_constants["ORGANIZATION_ID"], + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"], + }, + ): + response = await session.refresh() - assert isinstance(response, RefreshWithSessionCookieSuccessResponse) - assert response.authenticated is True - assert response.user.id == session_constants["TEST_USER"]["id"] + assert isinstance(response, RefreshWithSessionCookieSuccessResponse) + assert response.authenticated is True + assert response.user.id == session_constants["TEST_USER"]["id"] # Verify the refresh token was used correctly mock_user_management.authenticate_with_refresh_token.assert_called_once_with( diff --git a/workos/session.py b/workos/session.py index 3a105081..58d27f65 100644 --- a/workos/session.py +++ b/workos/session.py @@ -77,20 +77,20 @@ def authenticate( reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE, ) - if not self._is_valid_jwt(session["access_token"]): + try: + signing_key = self.jwks.get_signing_key_from_jwt(session["access_token"]) + decoded = jwt.decode( + session["access_token"], + signing_key.key, + algorithms=self.jwk_algorithms, + options={"verify_aud": False}, + ) + except jwt.exceptions.InvalidTokenError: return AuthenticateWithSessionCookieErrorResponse( authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT, ) - signing_key = self.jwks.get_signing_key_from_jwt(session["access_token"]) - decoded = jwt.decode( - session["access_token"], - signing_key.key, - algorithms=self.jwk_algorithms, - options={"verify_aud": False}, - ) - return AuthenticateWithSessionCookieSuccessResponse( authenticated=True, session_id=decoded["sid"], @@ -128,19 +128,6 @@ def get_logout_url(self, return_to: Optional[str] = None) -> str: ) return str(result) - def _is_valid_jwt(self, token: str) -> bool: - try: - signing_key = self.jwks.get_signing_key_from_jwt(token) - jwt.decode( - token, - signing_key.key, - algorithms=self.jwk_algorithms, - options={"verify_aud": False}, - ) - return True - except jwt.exceptions.InvalidTokenError: - return False - @staticmethod def seal_data(data: Dict[str, Any], key: str) -> str: fernet = Fernet(key)