From d469f8bfa67cdc4514a229ebd64b7372f8fe11b4 Mon Sep 17 00:00:00 2001 From: Sanchay Harneja Date: Thu, 1 May 2025 13:37:05 -0700 Subject: [PATCH 1/3] Avoid decoding jwt twice Currently the Session::authenticate() function (which runs on every request and consumes CPU cycles) is decoding the jwt twice unnecessarily. This small refactor fixes that --- workos/session.py | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/workos/session.py b/workos/session.py index 3a105081..d9214279 100644 --- a/workos/session.py +++ b/workos/session.py @@ -77,20 +77,20 @@ def authenticate( reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE, ) - if not self._is_valid_jwt(session["access_token"]): + try: + signing_key = self.jwks.get_signing_key_from_jwt(session["access_token"]) + decoded = jwt.decode( + token, + signing_key.key, + algorithms=self.jwk_algorithms, + options={"verify_aud": False}, + ) + except jwt.exceptions.InvalidTokenError: return AuthenticateWithSessionCookieErrorResponse( authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT, ) - signing_key = self.jwks.get_signing_key_from_jwt(session["access_token"]) - decoded = jwt.decode( - session["access_token"], - signing_key.key, - algorithms=self.jwk_algorithms, - options={"verify_aud": False}, - ) - return AuthenticateWithSessionCookieSuccessResponse( authenticated=True, session_id=decoded["sid"], @@ -128,19 +128,6 @@ def get_logout_url(self, return_to: Optional[str] = None) -> str: ) return str(result) - def _is_valid_jwt(self, token: str) -> bool: - try: - signing_key = self.jwks.get_signing_key_from_jwt(token) - jwt.decode( - token, - signing_key.key, - algorithms=self.jwk_algorithms, - options={"verify_aud": False}, - ) - return True - except jwt.exceptions.InvalidTokenError: - return False - @staticmethod def seal_data(data: Dict[str, Any], key: str) -> str: fernet = Fernet(key) From 8bba538b5dbcbc3ecb98040b54cbde78674c3480 Mon Sep 17 00:00:00 2001 From: Giovanni Carvelli Date: Fri, 2 May 2025 09:43:26 -0400 Subject: [PATCH 2/3] Fix typo --- workos/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/workos/session.py b/workos/session.py index d9214279..58d27f65 100644 --- a/workos/session.py +++ b/workos/session.py @@ -80,7 +80,7 @@ def authenticate( try: signing_key = self.jwks.get_signing_key_from_jwt(session["access_token"]) decoded = jwt.decode( - token, + session["access_token"], signing_key.key, algorithms=self.jwk_algorithms, options={"verify_aud": False}, From 75e1befb29ea90d6e8116dcd13f3495b38d2a422 Mon Sep 17 00:00:00 2001 From: Giovanni Carvelli Date: Fri, 2 May 2025 10:05:32 -0400 Subject: [PATCH 3/3] remove unnecessary mock --- tests/test_session.py | 62 ++++++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 33 deletions(-) diff --git a/tests/test_session.py b/tests/test_session.py index f68415c1..5447d10d 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -236,9 +236,7 @@ def test_authenticate_success(self, session_constants, mock_user_management): "entitlements": ["feature_1"], } - with patch.object( - Session, "unseal_data", return_value=mock_session - ), patch.object(session, "_is_valid_jwt", return_value=True), patch( + with patch.object(Session, "unseal_data", return_value=mock_session), patch( "jwt.decode", return_value=mock_jwt_payload ), patch.object( session.jwks, @@ -324,22 +322,21 @@ def test_refresh_success(self, session_constants, mock_user_management): cookie_password=session_constants["COOKIE_PASSWORD"], ) - with patch.object(session, "_is_valid_jwt", return_value=True) as _: - with patch( - "jwt.decode", - return_value={ - "sid": session_constants["SESSION_ID"], - "org_id": session_constants["ORGANIZATION_ID"], - "role": "admin", - "permissions": ["read"], - "entitlements": ["feature_1"], - }, - ): - response = session.refresh() + with patch( + "jwt.decode", + return_value={ + "sid": session_constants["SESSION_ID"], + "org_id": session_constants["ORGANIZATION_ID"], + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"], + }, + ): + response = session.refresh() - assert isinstance(response, RefreshWithSessionCookieSuccessResponse) - assert response.authenticated is True - assert response.user.id == session_constants["TEST_USER"]["id"] + assert isinstance(response, RefreshWithSessionCookieSuccessResponse) + assert response.authenticated is True + assert response.user.id == session_constants["TEST_USER"]["id"] # Verify the refresh token was used correctly mock_user_management.authenticate_with_refresh_token.assert_called_once_with( @@ -425,22 +422,21 @@ async def test_refresh_success(self, session_constants, mock_user_management): cookie_password=session_constants["COOKIE_PASSWORD"], ) - with patch.object(session, "_is_valid_jwt", return_value=True) as _: - with patch( - "jwt.decode", - return_value={ - "sid": session_constants["SESSION_ID"], - "org_id": session_constants["ORGANIZATION_ID"], - "role": "admin", - "permissions": ["read"], - "entitlements": ["feature_1"], - }, - ): - response = await session.refresh() + with patch( + "jwt.decode", + return_value={ + "sid": session_constants["SESSION_ID"], + "org_id": session_constants["ORGANIZATION_ID"], + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"], + }, + ): + response = await session.refresh() - assert isinstance(response, RefreshWithSessionCookieSuccessResponse) - assert response.authenticated is True - assert response.user.id == session_constants["TEST_USER"]["id"] + assert isinstance(response, RefreshWithSessionCookieSuccessResponse) + assert response.authenticated is True + assert response.user.id == session_constants["TEST_USER"]["id"] # Verify the refresh token was used correctly mock_user_management.authenticate_with_refresh_token.assert_called_once_with(