Skip to content

Commit ae5050b

Browse files
committed
fix(oauth): preserve existing refresh_token when refresh response omits it (#2270)
1 parent 62eb08e commit ae5050b

File tree

2 files changed

+116
-1
lines changed

2 files changed

+116
-1
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ async def _refresh_token(self) -> httpx.Request:
447447

448448
return httpx.Request("POST", token_url, data=refresh_data, headers=headers)
449449

450-
async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover
450+
async def _handle_refresh_response(self, response: httpx.Response) -> bool:
451451
"""Handle token refresh response. Returns True if successful."""
452452
if response.status_code != 200:
453453
logger.warning(f"Token refresh failed: {response.status_code}")
@@ -458,6 +458,18 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool: # p
458458
content = await response.aread()
459459
token_response = OAuthToken.model_validate_json(content)
460460

461+
# Per RFC 6749 Section 6, the authorization server MAY issue a new
462+
# refresh token. If the response omits one, preserve the existing
463+
# refresh token so subsequent refresh attempts remain possible.
464+
if (
465+
not token_response.refresh_token
466+
and self.context.current_tokens
467+
and self.context.current_tokens.refresh_token
468+
):
469+
token_response = token_response.model_copy(
470+
update={"refresh_token": self.context.current_tokens.refresh_token}
471+
)
472+
461473
self.context.current_tokens = token_response
462474
self.context.update_token_expiry(token_response)
463475
await self.context.storage.set_tokens(token_response)

tests/client/test_auth.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,109 @@ async def test_basic_auth_refresh_token(self, oauth_provider: OAuthClientProvide
711711
content = request.content.decode()
712712
assert "client_secret=" not in content
713713

714+
@pytest.mark.anyio
715+
async def test_handle_refresh_response_preserves_existing_refresh_token(
716+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
717+
):
718+
"""Test that the existing refresh_token is preserved when the server omits it.
719+
720+
Per RFC 6749 Section 6, the authorization server MAY issue a new refresh
721+
token in the refresh response. If it doesn't, the client should continue
722+
using the existing one.
723+
"""
724+
oauth_provider.context.current_tokens = valid_tokens
725+
726+
# Server response without refresh_token
727+
refresh_response = httpx.Response(
728+
200,
729+
content=b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600}',
730+
request=httpx.Request("POST", "https://auth.example.com/token"),
731+
)
732+
733+
result = await oauth_provider._handle_refresh_response(refresh_response)
734+
735+
assert result is True
736+
assert oauth_provider.context.current_tokens is not None
737+
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
738+
# Old refresh_token should be preserved
739+
assert oauth_provider.context.current_tokens.refresh_token == "test_refresh_token"
740+
741+
@pytest.mark.anyio
742+
async def test_handle_refresh_response_uses_new_refresh_token(
743+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
744+
):
745+
"""Test that a new refresh_token from the server replaces the old one."""
746+
oauth_provider.context.current_tokens = valid_tokens
747+
748+
# Server response with a new refresh_token (token rotation)
749+
refresh_response = httpx.Response(
750+
200,
751+
content=(
752+
b'{"access_token": "new_access_token", "token_type": "Bearer",'
753+
b' "expires_in": 3600, "refresh_token": "rotated_refresh_token"}'
754+
),
755+
request=httpx.Request("POST", "https://auth.example.com/token"),
756+
)
757+
758+
result = await oauth_provider._handle_refresh_response(refresh_response)
759+
760+
assert result is True
761+
assert oauth_provider.context.current_tokens is not None
762+
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
763+
assert oauth_provider.context.current_tokens.refresh_token == "rotated_refresh_token"
764+
765+
@pytest.mark.anyio
766+
async def test_handle_refresh_response_no_prior_tokens(self, oauth_provider: OAuthClientProvider):
767+
"""Test refresh response when there are no prior tokens stored."""
768+
oauth_provider.context.current_tokens = None
769+
770+
refresh_response = httpx.Response(
771+
200,
772+
content=b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600}',
773+
request=httpx.Request("POST", "https://auth.example.com/token"),
774+
)
775+
776+
result = await oauth_provider._handle_refresh_response(refresh_response)
777+
778+
assert result is True
779+
assert oauth_provider.context.current_tokens is not None
780+
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
781+
assert oauth_provider.context.current_tokens.refresh_token is None
782+
783+
@pytest.mark.anyio
784+
async def test_handle_refresh_response_failure(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken):
785+
"""Test that a non-200 refresh response clears tokens."""
786+
oauth_provider.context.current_tokens = valid_tokens
787+
788+
refresh_response = httpx.Response(
789+
401,
790+
content=b"Unauthorized",
791+
request=httpx.Request("POST", "https://auth.example.com/token"),
792+
)
793+
794+
result = await oauth_provider._handle_refresh_response(refresh_response)
795+
796+
assert result is False
797+
assert oauth_provider.context.current_tokens is None
798+
799+
@pytest.mark.anyio
800+
async def test_handle_refresh_response_invalid_json(
801+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
802+
):
803+
"""Test that an invalid response body clears tokens."""
804+
oauth_provider.context.current_tokens = valid_tokens
805+
806+
refresh_response = httpx.Response(
807+
200,
808+
content=b"not valid json",
809+
request=httpx.Request("POST", "https://auth.example.com/token"),
810+
)
811+
812+
result = await oauth_provider._handle_refresh_response(refresh_response)
813+
814+
assert result is False
815+
assert oauth_provider.context.current_tokens is None
816+
714817
@pytest.mark.anyio
715818
async def test_none_auth_method(self, oauth_provider: OAuthClientProvider):
716819
"""Test 'none' authentication method (public client)."""

0 commit comments

Comments
 (0)