@@ -711,6 +711,109 @@ async def test_basic_auth_refresh_token(self, oauth_provider: OAuthClientProvide
711711 content = request .content .decode ()
712712 assert "client_secret=" not in content
713713
714+ @pytest .mark .anyio
715+ async def test_handle_refresh_response_preserves_existing_refresh_token (
716+ self , oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken
717+ ):
718+ """Test that the existing refresh_token is preserved when the server omits it.
719+
720+ Per RFC 6749 Section 6, the authorization server MAY issue a new refresh
721+ token in the refresh response. If it doesn't, the client should continue
722+ using the existing one.
723+ """
724+ oauth_provider .context .current_tokens = valid_tokens
725+
726+ # Server response without refresh_token
727+ refresh_response = httpx .Response (
728+ 200 ,
729+ content = b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600}' ,
730+ request = httpx .Request ("POST" , "https://auth.example.com/token" ),
731+ )
732+
733+ result = await oauth_provider ._handle_refresh_response (refresh_response )
734+
735+ assert result is True
736+ assert oauth_provider .context .current_tokens is not None
737+ assert oauth_provider .context .current_tokens .access_token == "new_access_token"
738+ # Old refresh_token should be preserved
739+ assert oauth_provider .context .current_tokens .refresh_token == "test_refresh_token"
740+
741+ @pytest .mark .anyio
742+ async def test_handle_refresh_response_uses_new_refresh_token (
743+ self , oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken
744+ ):
745+ """Test that a new refresh_token from the server replaces the old one."""
746+ oauth_provider .context .current_tokens = valid_tokens
747+
748+ # Server response with a new refresh_token (token rotation)
749+ refresh_response = httpx .Response (
750+ 200 ,
751+ content = (
752+ b'{"access_token": "new_access_token", "token_type": "Bearer",'
753+ b' "expires_in": 3600, "refresh_token": "rotated_refresh_token"}'
754+ ),
755+ request = httpx .Request ("POST" , "https://auth.example.com/token" ),
756+ )
757+
758+ result = await oauth_provider ._handle_refresh_response (refresh_response )
759+
760+ assert result is True
761+ assert oauth_provider .context .current_tokens is not None
762+ assert oauth_provider .context .current_tokens .access_token == "new_access_token"
763+ assert oauth_provider .context .current_tokens .refresh_token == "rotated_refresh_token"
764+
765+ @pytest .mark .anyio
766+ async def test_handle_refresh_response_no_prior_tokens (self , oauth_provider : OAuthClientProvider ):
767+ """Test refresh response when there are no prior tokens stored."""
768+ oauth_provider .context .current_tokens = None
769+
770+ refresh_response = httpx .Response (
771+ 200 ,
772+ content = b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600}' ,
773+ request = httpx .Request ("POST" , "https://auth.example.com/token" ),
774+ )
775+
776+ result = await oauth_provider ._handle_refresh_response (refresh_response )
777+
778+ assert result is True
779+ assert oauth_provider .context .current_tokens is not None
780+ assert oauth_provider .context .current_tokens .access_token == "new_access_token"
781+ assert oauth_provider .context .current_tokens .refresh_token is None
782+
783+ @pytest .mark .anyio
784+ async def test_handle_refresh_response_failure (self , oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken ):
785+ """Test that a non-200 refresh response clears tokens."""
786+ oauth_provider .context .current_tokens = valid_tokens
787+
788+ refresh_response = httpx .Response (
789+ 401 ,
790+ content = b"Unauthorized" ,
791+ request = httpx .Request ("POST" , "https://auth.example.com/token" ),
792+ )
793+
794+ result = await oauth_provider ._handle_refresh_response (refresh_response )
795+
796+ assert result is False
797+ assert oauth_provider .context .current_tokens is None
798+
799+ @pytest .mark .anyio
800+ async def test_handle_refresh_response_invalid_json (
801+ self , oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken
802+ ):
803+ """Test that an invalid response body clears tokens."""
804+ oauth_provider .context .current_tokens = valid_tokens
805+
806+ refresh_response = httpx .Response (
807+ 200 ,
808+ content = b"not valid json" ,
809+ request = httpx .Request ("POST" , "https://auth.example.com/token" ),
810+ )
811+
812+ result = await oauth_provider ._handle_refresh_response (refresh_response )
813+
814+ assert result is False
815+ assert oauth_provider .context .current_tokens is None
816+
714817 @pytest .mark .anyio
715818 async def test_none_auth_method (self , oauth_provider : OAuthClientProvider ):
716819 """Test 'none' authentication method (public client)."""
0 commit comments