Skip to content

Commit 7c9276a

Browse files
committed
fix(auth): harden enterprise managed auth with RFC compliance fixes
Fix several correctness issues in EnterpriseAuthOAuthClientProvider found during post-review audit: - Add missing resource (RFC 8707) and scope params to JWT bearer grant, matching the pattern used by all sibling providers - Move constructor validation before super().__init__() to avoid orphaned OAuthContext on validation failure - Catch pydantic ValidationError in exchange_token_for_id_jag when IdP returns HTTP 200 with malformed JSON - Wrap refresh_with_new_id_token in context lock to prevent racing with concurrent async_auth_flow - Remove write-only dead state (_id_jag, _id_jag_expiry, default_id_jag_expiry, DEFAULT_ID_JAG_EXPIRY_SECONDS) - Use set literal instead of list for membership test in validate_token_exchange_params - Delegate client auth entirely to prepare_token_auth() instead of manually injecting client_id and mutating token_endpoint_auth_method - Use _get_token_endpoint() instead of requiring oauth_metadata - Improve docstrings for IDJAGClaims, validate_token_exchange_params, and concurrency notes
1 parent a67a98a commit 7c9276a

File tree

4 files changed

+375
-266
lines changed

4 files changed

+375
-266
lines changed

.github/actions/conformance/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,9 @@ async def run_cross_app_access_complete_flow(server_url: str) -> None:
352352
if not idp_issuer:
353353
raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_issuer'")
354354

355-
# Extract base URL by stripping trailing /mcp path (Python 3.9+)
355+
# Extract base URL by stripping trailing /mcp path (Python 3.9+).
356+
# The conformance harness always serves MCP at <base>/mcp, so stripping
357+
# the suffix gives us the auth-server base URL for fallback defaults.
356358
base_url = server_url.removesuffix("/mcp")
357359
auth_issuer = context.get("auth_issuer", base_url)
358360
resource_id = context.get("resource_id", server_url)
@@ -375,7 +377,6 @@ async def run_cross_app_access_complete_flow(server_url: str) -> None:
375377
client_id = context.get("client_id")
376378
client_secret = context.get("client_secret")
377379

378-
# Create storage and pre-configure client info if credentials are provided
379380
storage = InMemoryTokenStorage()
380381

381382
# Create enterprise auth provider

.github/workflows/conformance.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ jobs:
3333
runs-on: ubuntu-latest
3434
continue-on-error: true
3535
steps:
36-
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
37-
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # v7.2.0
36+
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1
37+
- uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1
3838
with:
3939
enable-cache: true
4040
version: 0.9.5
41-
- uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
41+
- uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v6.2.0
4242
with:
4343
node-version: 24
4444
- run: uv sync --frozen --all-extras --package mcp

src/mcp/client/auth/extensions/enterprise_managed_auth.py

Lines changed: 73 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
"""
66

77
import logging
8-
import time
98
from json import JSONDecodeError
109

1110
import httpx
1211
import jwt
13-
from pydantic import BaseModel, Field
12+
from pydantic import BaseModel, Field, ValidationError
1413
from typing_extensions import NotRequired, Required, TypedDict
1514

1615
from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage
@@ -140,7 +139,12 @@ def id_jag(self) -> str:
140139

141140

142141
class IDJAGClaims(BaseModel):
143-
"""Claims structure for Identity Assertion JWT Authorization Grant."""
142+
"""Claims structure for Identity Assertion JWT Authorization Grant.
143+
144+
Note: ``typ`` is sourced from the JWT *header* (not the payload) by
145+
``decode_id_jag``. It is included here for convenience so callers
146+
can inspect the full ID-JAG structure from a single object.
147+
"""
144148

145149
model_config = {"extra": "allow"}
146150

@@ -174,17 +178,13 @@ class EnterpriseAuthOAuthClientProvider(OAuthClientProvider):
174178
175179
Concurrency & Thread Safety:
176180
- SAFE: Concurrent requests within a single asyncio event loop. Token
177-
operations (including ``_id_jag`` / ``_id_jag_expiry``) are protected
178-
by the parent class's ``OAuthContext.lock`` via ``async_auth_flow``.
181+
operations are protected by the parent class's ``OAuthContext.lock``
182+
via ``async_auth_flow``.
179183
- UNSAFE: Sharing a provider instance across multiple OS threads. Each
180184
thread must instantiate its own provider and event loop.
181185
- Note: Ensure any shared ``TokenStorage`` implementation is async-safe.
182186
"""
183187

184-
# Default ID-JAG expiry when IdP doesn't provide expires_in.
185-
# 15 minutes is a conservative default for enterprise environments.
186-
DEFAULT_ID_JAG_EXPIRY_SECONDS = 900
187-
188188
def __init__(
189189
self,
190190
server_url: str,
@@ -195,7 +195,6 @@ def __init__(
195195
timeout: float = 300.0,
196196
idp_client_id: str | None = None,
197197
idp_client_secret: str | None = None,
198-
default_id_jag_expiry: int = DEFAULT_ID_JAG_EXPIRY_SECONDS,
199198
override_audience_with_issuer: bool = True,
200199
) -> None:
201200
"""Initialize Enterprise Auth OAuth Client.
@@ -211,8 +210,6 @@ def __init__(
211210
idp_client_secret: Optional client secret registered with the IdP.
212211
Must be accompanied by ``idp_client_id``; providing a secret
213212
without an ID raises ``ValueError``.
214-
default_id_jag_expiry: Fallback ID-JAG expiry in seconds if the IdP
215-
omits ``expires_in`` (default: 900 s / 15 min).
216213
override_audience_with_issuer: If True (default), replaces the IdP
217214
audience with the discovered OAuth issuer URL. Set to False for
218215
federated identity setups where the audience must differ.
@@ -221,6 +218,13 @@ def __init__(
221218
ValueError: If ``idp_client_secret`` is provided without ``idp_client_id``.
222219
OAuthFlowError: If ``token_exchange_params`` fail validation.
223220
"""
221+
# Validate pure parameters before creating any state (fail-fast)
222+
if idp_client_secret is not None and idp_client_id is None:
223+
raise ValueError(
224+
"idp_client_secret was provided without idp_client_id. Provide both together, or omit the secret."
225+
)
226+
validate_token_exchange_params(token_exchange_params)
227+
224228
super().__init__(
225229
server_url=server_url,
226230
client_metadata=client_metadata,
@@ -233,19 +237,7 @@ def __init__(
233237
self._subject_token = token_exchange_params.subject_token
234238
self.idp_client_id = idp_client_id
235239
self.idp_client_secret = idp_client_secret
236-
self.default_id_jag_expiry = default_id_jag_expiry
237240
self.override_audience_with_issuer = override_audience_with_issuer
238-
self._id_jag: str | None = None
239-
self._id_jag_expiry: float | None = None
240-
241-
# Fail-fast: secret without ID is almost certainly a misconfiguration
242-
if idp_client_secret is not None and idp_client_id is None:
243-
raise ValueError(
244-
"idp_client_secret was provided without idp_client_id. Provide both together, or omit the secret."
245-
)
246-
247-
# Validate token exchange params at construction time
248-
validate_token_exchange_params(token_exchange_params)
249241

250242
async def exchange_token_for_id_jag(
251243
self,
@@ -266,7 +258,9 @@ async def exchange_token_for_id_jag(
266258

267259
audience = self.token_exchange_params.audience
268260
if self.override_audience_with_issuer:
269-
if self.context.oauth_metadata and self.context.oauth_metadata.issuer:
261+
# OAuthMetadata.issuer is a required AnyHttpUrl field (RFC 8414),
262+
# so it is always non-None when oauth_metadata is present.
263+
if self.context.oauth_metadata:
270264
discovered_issuer = str(self.context.oauth_metadata.issuer)
271265
if audience != discovered_issuer:
272266
logger.warning(
@@ -288,7 +282,11 @@ async def exchange_token_for_id_jag(
288282
if self.token_exchange_params.scope and self.token_exchange_params.scope.strip():
289283
token_data["scope"] = self.token_exchange_params.scope
290284

291-
# Add IdP client authentication if provided
285+
# Add IdP client authentication if provided.
286+
# Sent as POST body parameters (not HTTP Basic) because this is the
287+
# IdP's token-exchange endpoint — most enterprise IdPs (Okta, Azure AD,
288+
# Ping) accept body credentials for token exchange. HTTP Basic is
289+
# allowed by RFC 6749 §2.3.1 but not universally required here.
292290
if self.idp_client_id is not None:
293291
token_data["client_id"] = self.idp_client_id
294292
if self.idp_client_secret is not None:
@@ -324,63 +322,60 @@ async def exchange_token_for_id_jag(
324322
logger.warning(f"Expected token_type 'N_A', got '{token_response.token_type}'")
325323

326324
logger.debug("Successfully obtained ID-JAG")
327-
self._id_jag = token_response.id_jag
328-
329-
if token_response.expires_in:
330-
self._id_jag_expiry = time.time() + token_response.expires_in
331-
else:
332-
self._id_jag_expiry = time.time() + self.default_id_jag_expiry
333-
logger.debug(f"IdP omitted expires_in; using default of {self.default_id_jag_expiry}s for ID-JAG")
334325

335326
return token_response.id_jag
336327

337328
except httpx.HTTPError as e:
338329
raise OAuthTokenError(f"HTTP error during token exchange: {e}") from e
330+
except ValidationError as e:
331+
raise OAuthTokenError("Invalid token exchange response from IdP") from e
339332

340333
async def exchange_id_jag_for_access_token(
341334
self,
342335
id_jag: str,
343336
) -> httpx.Request:
344-
"""Build JWT bearer grant request to exchange ID-JAG for access token (RFC 7523).
337+
"""Build a JWT bearer grant request to exchange an ID-JAG for an access token (RFC 7523).
345338
346-
Builds the request without executing it. HTTP execution and error parsing
347-
are deferred to the parent class's `async_auth_flow` for consistency.
339+
This method only *builds* the ``httpx.Request``; it does not execute
340+
it. HTTP execution and error parsing are deferred to the parent
341+
class's ``async_auth_flow`` via ``_handle_token_response``.
342+
343+
Follows the same pattern as ``ClientCredentialsOAuthProvider._exchange_token_client_credentials``
344+
and ``RFC7523OAuthClientProvider._exchange_token_jwt_bearer``:
345+
use ``_get_token_endpoint()`` for the URL and ``prepare_token_auth()``
346+
for client authentication — no manual ``client_id`` injection or
347+
context swapping needed.
348348
349349
Args:
350-
id_jag: The ID-JAG token
350+
id_jag: The ID-JAG token obtained from ``exchange_token_for_id_jag``
351351
352352
Returns:
353-
httpx.Request for the JWT bearer grant
354-
355-
Raises:
356-
OAuthFlowError: If OAuth metadata not discovered
353+
An ``httpx.Request`` for the JWT bearer grant
357354
"""
358355
logger.debug("Building JWT bearer grant request for ID-JAG")
359356

360-
if not self.context.oauth_metadata or not self.context.oauth_metadata.token_endpoint:
361-
raise OAuthFlowError("MCP server token endpoint not discovered")
362-
363-
token_endpoint = str(self.context.oauth_metadata.token_endpoint)
364-
365-
# Build as a plain dict — avoids the double-cast through TypedDict
366357
token_data: dict[str, str] = {
367358
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
368359
"assertion": id_jag,
369360
}
370361

371-
# Add client_id to body. prepare_token_auth handles client_secret
372-
# placement (Basic header vs. POST body) based on auth method.
373-
if self.context.client_info:
374-
if self.context.client_info.token_endpoint_auth_method is None:
375-
self.context.client_info.token_endpoint_auth_method = "client_secret_basic"
376-
377-
if self.context.client_info.client_id is not None:
378-
token_data["client_id"] = self.context.client_info.client_id
379-
380362
headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
363+
364+
# Delegate client authentication (client_secret_basic, client_secret_post,
365+
# or none) to the parent's context helper — same as every other grant type.
381366
token_data, headers = self.context.prepare_token_auth(token_data, headers)
382367

383-
return httpx.Request("POST", token_endpoint, data=token_data, headers=headers)
368+
# Include resource parameter per RFC 8707 — same guard as every sibling provider
369+
if self.context.should_include_resource_param(self.context.protocol_version):
370+
token_data["resource"] = self.context.get_resource_url()
371+
372+
# Include scope if configured (may have been updated by parent's async_auth_flow
373+
# from the server's WWW-Authenticate header before _perform_authorization is called)
374+
if self.context.client_metadata.scope:
375+
token_data["scope"] = self.context.client_metadata.scope
376+
377+
token_url = self._get_token_endpoint()
378+
return httpx.Request("POST", token_url, data=token_data, headers=headers)
384379

385380
async def _perform_authorization(self) -> httpx.Request:
386381
"""Perform enterprise authorization flow.
@@ -400,7 +395,6 @@ async def _perform_authorization(self) -> httpx.Request:
400395
# Step 1: Exchange IDP subject token for ID-JAG (RFC 8693)
401396
async with httpx.AsyncClient(timeout=self.context.timeout) as client:
402397
id_jag = await self.exchange_token_for_id_jag(client)
403-
self._id_jag = id_jag
404398

405399
# Step 2: Build JWT bearer grant request (RFC 7523)
406400
jwt_bearer_request = await self.exchange_id_jag_for_access_token(id_jag)
@@ -412,23 +406,27 @@ async def refresh_with_new_id_token(self, new_id_token: str) -> None:
412406
"""Refresh MCP server access tokens using a fresh ID token from the IdP.
413407
414408
Updates the subject token and clears cached state so that the next API
415-
request triggers a full re-authentication.
409+
request triggers a full re-authentication. Acquires the context lock
410+
to prevent racing with an in-progress ``async_auth_flow``.
416411
417412
Note: OAuth metadata is not re-discovered. If the MCP server's OAuth
418413
configuration has changed, create a new provider instance instead.
419414
415+
Warning: This method is NOT safe to call from a different OS thread.
416+
Call it only from the same thread and event loop that owns this
417+
provider instance.
418+
420419
Args:
421420
new_id_token: Fresh ID token obtained from your enterprise IdP.
422421
"""
423-
logger.info("Refreshing tokens with new ID token from IdP")
424-
# Update the mutable subject token (does NOT mutate the original params object)
425-
self._subject_token = new_id_token
422+
async with self.context.lock:
423+
logger.info("Refreshing tokens with new ID token from IdP")
424+
# Update the mutable subject token (does NOT mutate the original params object)
425+
self._subject_token = new_id_token
426426

427-
# Clear caches to force full re-exchange on next request
428-
self._id_jag = None
429-
self._id_jag_expiry = None
430-
self.context.clear_tokens()
431-
logger.debug("Token refresh prepared — will re-authenticate on next request")
427+
# Clear tokens to force full re-exchange on next request
428+
self.context.clear_tokens()
429+
logger.debug("Token refresh prepared — will re-authenticate on next request")
432430

433431

434432
def decode_id_jag(id_jag: str) -> IDJAGClaims:
@@ -452,7 +450,12 @@ def decode_id_jag(id_jag: str) -> IDJAGClaims:
452450
def validate_token_exchange_params(
453451
params: TokenExchangeParameters,
454452
) -> None:
455-
"""Validate token exchange parameters.
453+
"""Validate token exchange parameters beyond Pydantic field constraints.
454+
455+
Pydantic ``Field(...)`` rejects *missing* values but permits empty strings.
456+
This function adds:
457+
- Empty-string checks for ``subject_token``, ``audience``, ``resource``
458+
- Allow-list check for ``subject_token_type`` (id_token or saml2)
456459
457460
Args:
458461
params: Token exchange parameters to validate
@@ -469,8 +472,8 @@ def validate_token_exchange_params(
469472
if not params.resource:
470473
raise OAuthFlowError("resource is required")
471474

472-
if params.subject_token_type not in [
475+
if params.subject_token_type not in {
473476
"urn:ietf:params:oauth:token-type:id_token",
474477
"urn:ietf:params:oauth:token-type:saml2",
475-
]:
478+
}:
476479
raise OAuthFlowError(f"Invalid subject_token_type: {params.subject_token_type}")

0 commit comments

Comments
 (0)