Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from mcp.shared.auth_utils import (
calculate_token_expiry,
calculate_token_refresh_time,
check_resource_allowed,
resource_url_from_server_url,
)
Expand Down Expand Up @@ -113,6 +114,9 @@ class OAuthContext:
# Token management
current_tokens: OAuthToken | None = None
token_expiry_time: float | None = None
# Jittered point (before hard expiry) at which to proactively refresh, so a fleet
# of connectors does not all refresh in the same window. See should_refresh_token.
token_refresh_time: float | None = None

# State
lock: anyio.Lock = field(default_factory=anyio.Lock)
Expand All @@ -123,11 +127,12 @@ def get_authorization_base_url(self, server_url: str) -> str:
return f"{parsed.scheme}://{parsed.netloc}"

def update_token_expiry(self, token: OAuthToken) -> None:
"""Update token expiry time using shared util function."""
"""Update token expiry and proactive-refresh times using shared util functions."""
self.token_expiry_time = calculate_token_expiry(token.expires_in)
self.token_refresh_time = calculate_token_refresh_time(token.expires_in)

def is_token_valid(self) -> bool:
"""Check if current token is valid."""
"""Check if current token is valid (i.e. usable, not past hard expiry)."""
return bool(
self.current_tokens
and self.current_tokens.access_token
Expand All @@ -138,10 +143,28 @@ def can_refresh_token(self) -> bool:
"""Check if token can be refreshed."""
return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info)

def should_refresh_token(self) -> bool:
"""Check if the token should be *proactively* refreshed.

Returns True when we hold refreshable tokens and have passed the jittered
proactive-refresh point (``token_refresh_time``), even if the token is still
technically valid. Refreshing slightly early -- and at a per-connector jittered
moment -- spreads a fleet's refreshes out instead of bunching them into the
same expiry window. Returns False when no refresh time is known (no expiry
info) so behavior degrades to the existing reactive path.
"""
return bool(
self.current_tokens
and self.can_refresh_token()
and self.token_refresh_time is not None
and time.time() >= self.token_refresh_time
)

def clear_tokens(self) -> None:
"""Clear current tokens."""
self.current_tokens = None
self.token_expiry_time = None
self.token_refresh_time = None

def get_resource_url(self) -> str:
"""Get resource URL for RFC 8707.
Expand Down Expand Up @@ -511,7 +534,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Capture protocol version from request headers
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)

if not self.context.is_token_valid() and self.context.can_refresh_token():
if (
not self.context.is_token_valid() or self.context.should_refresh_token()
) and self.context.can_refresh_token():
# Refresh either reactively (token already invalid) or proactively
# (past the jittered refresh point, before hard expiry).
# Try to refresh token
refresh_request = await self._refresh_token()
refresh_response = yield refresh_request
Expand Down
78 changes: 78 additions & 0 deletions src/mcp/shared/auth_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707) and PKCE (RFC 7636)."""

import random
import time
from urllib.parse import urlparse, urlsplit, urlunsplit

Expand Down Expand Up @@ -78,3 +79,80 @@ def calculate_token_expiry(expires_in: int | str | None) -> float | None:
return None # pragma: no cover
# Defensive: handle servers that return expires_in as string
return time.time() + int(expires_in)


def calculate_token_refresh_time(
expires_in: int | str | None,
*,
refresh_fraction: float = 0.8,
max_jitter_seconds: float = 30.0,
jitter: float | None = None,
) -> float | None:
"""Calculate when a token should be *proactively* refreshed.

Reactive refresh (waiting until a token has already expired) means that, for a
fleet of OAuth-backed MCP connectors provisioned around the same time, every
token tends to expire inside the same narrow window. When they do, all of those
clients try to refresh simultaneously, producing a "thundering herd" of refresh
requests against the authorization server -- contention, rate limiting, and
spurious auth failures.

To avoid that, this returns a timestamp *before* hard expiry at which the token
should be refreshed:

refresh_at = now + expires_in * refresh_fraction - jitter

The jitter is always *subtracted* so it pulls the refresh point earlier and can
never push it past the hard-expiry boundary. Spreading each client's refresh
point by a small random amount means a fleet naturally desynchronizes instead of
refreshing in lockstep.

Args:
expires_in: Seconds until token expiration (may be a string from some servers).
refresh_fraction: Fraction of the token lifetime after which to refresh.
Defaults to 0.8 (refresh once 80% of the lifetime has elapsed).
max_jitter_seconds: Upper bound (in seconds) of the random jitter subtracted
from the refresh point. Defaults to 30s.
jitter: Optional explicit jitter value (seconds). When provided it is used
directly instead of drawing a random value, which keeps the function
deterministic and testable. When None, a value in
``[0, max_jitter_seconds]`` is drawn at random.

Returns:
Unix timestamp at which the token should be proactively refreshed, or None
if ``expires_in`` is None (no expiry information -> nothing to schedule).
The result is always in ``(now, hard_expiry]`` and never in the past.
"""
if expires_in is None:
return None

expires_in_seconds = int(expires_in)
now = time.time()
hard_expiry = now + expires_in_seconds

# Base proactive point: refresh once `refresh_fraction` of the lifetime elapsed.
refresh_at = now + expires_in_seconds * refresh_fraction

# Cap the jitter so it can never reach back before `now`, which matters for very
# short TTLs (e.g. expires_in smaller than max_jitter_seconds). The window we are
# allowed to pull earlier into is (refresh_at - now); never jitter more than that.
available_window = refresh_at - now
effective_max_jitter = min(max_jitter_seconds, max(available_window, 0.0))

if jitter is None:
applied_jitter = random.uniform(0, effective_max_jitter)
else:
# Clamp an injected jitter into the valid range to preserve invariants.
applied_jitter = min(max(jitter, 0.0), effective_max_jitter)

refresh_at -= applied_jitter

# Final guard: keep the result strictly within (now, hard_expiry]. For tiny or
# zero TTLs this collapses gracefully toward `now` rather than going negative or
# past the hard-expiry boundary.
if refresh_at < now:
refresh_at = now
if refresh_at > hard_expiry:
refresh_at = hard_expiry

return refresh_at
133 changes: 133 additions & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,50 @@ def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: O
context = oauth_provider.context
context.current_tokens = valid_tokens
context.token_expiry_time = time.time() + 1800
context.token_refresh_time = time.time() + 1440

# Clear tokens
context.clear_tokens()

# Verify cleared
assert context.current_tokens is None
assert context.token_expiry_time is None
assert context.token_refresh_time is None

@pytest.mark.anyio
async def test_should_refresh_token(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken):
"""Test should_refresh_token() proactive-refresh logic."""
context = oauth_provider.context

# No tokens at all -> never proactively refresh.
assert not context.should_refresh_token()

context.current_tokens = valid_tokens
context.client_info = OAuthClientInformationFull(
client_id="test_client_id",
client_secret="test_client_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)

# Token still hard-valid AND before the jittered refresh point -> no refresh.
context.token_expiry_time = time.time() + 1800
context.token_refresh_time = time.time() + 600
assert context.is_token_valid()
assert not context.should_refresh_token()

# Token still hard-valid but we have passed the proactive refresh point -> refresh.
context.token_refresh_time = time.time() - 1
assert context.is_token_valid()
assert context.should_refresh_token()

# No refresh time known (e.g. server gave no expiry) -> fall back to reactive only.
context.token_refresh_time = None
assert not context.should_refresh_token()

# Past the refresh point but no refresh token -> cannot proactively refresh.
context.token_refresh_time = time.time() - 1
context.current_tokens.refresh_token = None
assert not context.should_refresh_token()


class TestOAuthFlow:
Expand Down Expand Up @@ -506,6 +543,102 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl
except StopAsyncIteration:
pass # Expected - generator should complete

@pytest.mark.anyio
async def test_async_auth_flow_proactively_refreshes_when_past_jitter_window(
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
):
"""async_auth_flow refreshes proactively past the jittered window.

The token is still hard-valid (is_token_valid() is True), but we are past the
proactive refresh point, so the flow should yield a refresh request *before*
sending the original request -- spreading fleet refreshes out instead of
waiting for hard expiry.
"""
context = oauth_provider.context
context.current_tokens = valid_tokens
context.client_info = OAuthClientInformationFull(
client_id="test_client_id",
client_secret="test_client_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)
oauth_provider._initialized = True

# Token is still valid for a while, but we are past the proactive refresh point.
context.token_expiry_time = time.time() + 1800
context.token_refresh_time = time.time() - 1
assert context.is_token_valid()
assert context.should_refresh_token()

test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
auth_flow = oauth_provider.async_auth_flow(test_request)

# First yielded request must be a proactive refresh, not the original request.
refresh_request = await auth_flow.__anext__()
assert refresh_request.method == "POST"
assert str(refresh_request.url) == "https://api.example.com/token"
refresh_content = refresh_request.content.decode()
assert "grant_type=refresh_token" in refresh_content
assert "refresh_token=test_refresh_token" in refresh_content

# Provide a successful refresh response with fresh tokens.
refresh_response = httpx.Response(
200,
content=(
b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, '
b'"refresh_token": "new_refresh_token"}'
),
request=refresh_request,
)

# After a successful refresh, the original request is sent with the new token.
actual_request = await auth_flow.asend(refresh_response)
assert actual_request.headers["Authorization"] == "Bearer new_access_token"
assert str(actual_request.url) == "https://api.example.com/v1/mcp"

# New proactive-refresh point should have been scheduled in the future.
assert context.token_refresh_time is not None
assert context.token_refresh_time > time.time()

# Close out the generator with a final success response.
final_response = httpx.Response(200, request=actual_request)
try:
await auth_flow.asend(final_response)
except StopAsyncIteration:
pass # Expected - generator completes

@pytest.mark.anyio
async def test_async_auth_flow_skips_refresh_before_jitter_window(
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
):
"""A fresh token (before the proactive window) is used directly, no refresh."""
context = oauth_provider.context
context.current_tokens = valid_tokens
context.client_info = OAuthClientInformationFull(
client_id="test_client_id",
client_secret="test_client_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)
oauth_provider._initialized = True

# Token valid and well before the proactive refresh point.
context.token_expiry_time = time.time() + 1800
context.token_refresh_time = time.time() + 600
assert not context.should_refresh_token()

test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
auth_flow = oauth_provider.async_auth_flow(test_request)

# First (and only auth-related) yielded request is the original request itself.
actual_request = await auth_flow.__anext__()
assert actual_request.headers["Authorization"] == "Bearer test_access_token"
assert str(actual_request.url) == "https://api.example.com/v1/mcp"

final_response = httpx.Response(200, request=actual_request)
try:
await auth_flow.asend(final_response)
except StopAsyncIteration:
pass # Expected - generator completes

@pytest.mark.anyio
async def test_handle_metadata_response_success(self, oauth_provider: OAuthClientProvider):
"""Test successful metadata response handling."""
Expand Down
Loading
Loading