55"""
66
77import logging
8- import time
98from json import JSONDecodeError
109
1110import httpx
1211import jwt
13- from pydantic import BaseModel , Field
12+ from pydantic import BaseModel , Field , ValidationError
1413from typing_extensions import NotRequired , Required , TypedDict
1514
1615from mcp .client .auth import OAuthClientProvider , OAuthFlowError , OAuthTokenError , TokenStorage
@@ -140,7 +139,12 @@ def id_jag(self) -> str:
140139
141140
142141class IDJAGClaims (BaseModel ):
143- """Claims structure for Identity Assertion JWT Authorization Grant."""
142+ """Claims structure for Identity Assertion JWT Authorization Grant.
143+
144+ Note: ``typ`` is sourced from the JWT *header* (not the payload) by
145+ ``decode_id_jag``. It is included here for convenience so callers
146+ can inspect the full ID-JAG structure from a single object.
147+ """
144148
145149 model_config = {"extra" : "allow" }
146150
@@ -174,17 +178,13 @@ class EnterpriseAuthOAuthClientProvider(OAuthClientProvider):
174178
175179 Concurrency & Thread Safety:
176180 - SAFE: Concurrent requests within a single asyncio event loop. Token
177- operations (including ``_id_jag`` / ``_id_jag_expiry``) are protected
178- by the parent class's ``OAuthContext.lock`` via ``async_auth_flow``.
181+ operations are protected by the parent class's ``OAuthContext.lock``
182+ via ``async_auth_flow``.
179183 - UNSAFE: Sharing a provider instance across multiple OS threads. Each
180184 thread must instantiate its own provider and event loop.
181185 - Note: Ensure any shared ``TokenStorage`` implementation is async-safe.
182186 """
183187
184- # Default ID-JAG expiry when IdP doesn't provide expires_in.
185- # 15 minutes is a conservative default for enterprise environments.
186- DEFAULT_ID_JAG_EXPIRY_SECONDS = 900
187-
188188 def __init__ (
189189 self ,
190190 server_url : str ,
@@ -195,7 +195,6 @@ def __init__(
195195 timeout : float = 300.0 ,
196196 idp_client_id : str | None = None ,
197197 idp_client_secret : str | None = None ,
198- default_id_jag_expiry : int = DEFAULT_ID_JAG_EXPIRY_SECONDS ,
199198 override_audience_with_issuer : bool = True ,
200199 ) -> None :
201200 """Initialize Enterprise Auth OAuth Client.
@@ -211,8 +210,6 @@ def __init__(
211210 idp_client_secret: Optional client secret registered with the IdP.
212211 Must be accompanied by ``idp_client_id``; providing a secret
213212 without an ID raises ``ValueError``.
214- default_id_jag_expiry: Fallback ID-JAG expiry in seconds if the IdP
215- omits ``expires_in`` (default: 900 s / 15 min).
216213 override_audience_with_issuer: If True (default), replaces the IdP
217214 audience with the discovered OAuth issuer URL. Set to False for
218215 federated identity setups where the audience must differ.
@@ -221,6 +218,13 @@ def __init__(
221218 ValueError: If ``idp_client_secret`` is provided without ``idp_client_id``.
222219 OAuthFlowError: If ``token_exchange_params`` fail validation.
223220 """
221+ # Validate pure parameters before creating any state (fail-fast)
222+ if idp_client_secret is not None and idp_client_id is None :
223+ raise ValueError (
224+ "idp_client_secret was provided without idp_client_id. Provide both together, or omit the secret."
225+ )
226+ validate_token_exchange_params (token_exchange_params )
227+
224228 super ().__init__ (
225229 server_url = server_url ,
226230 client_metadata = client_metadata ,
@@ -233,19 +237,7 @@ def __init__(
233237 self ._subject_token = token_exchange_params .subject_token
234238 self .idp_client_id = idp_client_id
235239 self .idp_client_secret = idp_client_secret
236- self .default_id_jag_expiry = default_id_jag_expiry
237240 self .override_audience_with_issuer = override_audience_with_issuer
238- self ._id_jag : str | None = None
239- self ._id_jag_expiry : float | None = None
240-
241- # Fail-fast: secret without ID is almost certainly a misconfiguration
242- if idp_client_secret is not None and idp_client_id is None :
243- raise ValueError (
244- "idp_client_secret was provided without idp_client_id. Provide both together, or omit the secret."
245- )
246-
247- # Validate token exchange params at construction time
248- validate_token_exchange_params (token_exchange_params )
249241
250242 async def exchange_token_for_id_jag (
251243 self ,
@@ -266,7 +258,9 @@ async def exchange_token_for_id_jag(
266258
267259 audience = self .token_exchange_params .audience
268260 if self .override_audience_with_issuer :
269- if self .context .oauth_metadata and self .context .oauth_metadata .issuer :
261+ # OAuthMetadata.issuer is a required AnyHttpUrl field (RFC 8414),
262+ # so it is always non-None when oauth_metadata is present.
263+ if self .context .oauth_metadata :
270264 discovered_issuer = str (self .context .oauth_metadata .issuer )
271265 if audience != discovered_issuer :
272266 logger .warning (
@@ -288,7 +282,11 @@ async def exchange_token_for_id_jag(
288282 if self .token_exchange_params .scope and self .token_exchange_params .scope .strip ():
289283 token_data ["scope" ] = self .token_exchange_params .scope
290284
291- # Add IdP client authentication if provided
285+ # Add IdP client authentication if provided.
286+ # Sent as POST body parameters (not HTTP Basic) because this is the
287+ # IdP's token-exchange endpoint — most enterprise IdPs (Okta, Azure AD,
288+ # Ping) accept body credentials for token exchange. HTTP Basic is
289+ # allowed by RFC 6749 §2.3.1 but not universally required here.
292290 if self .idp_client_id is not None :
293291 token_data ["client_id" ] = self .idp_client_id
294292 if self .idp_client_secret is not None :
@@ -324,63 +322,60 @@ async def exchange_token_for_id_jag(
324322 logger .warning (f"Expected token_type 'N_A', got '{ token_response .token_type } '" )
325323
326324 logger .debug ("Successfully obtained ID-JAG" )
327- self ._id_jag = token_response .id_jag
328-
329- if token_response .expires_in :
330- self ._id_jag_expiry = time .time () + token_response .expires_in
331- else :
332- self ._id_jag_expiry = time .time () + self .default_id_jag_expiry
333- logger .debug (f"IdP omitted expires_in; using default of { self .default_id_jag_expiry } s for ID-JAG" )
334325
335326 return token_response .id_jag
336327
337328 except httpx .HTTPError as e :
338329 raise OAuthTokenError (f"HTTP error during token exchange: { e } " ) from e
330+ except ValidationError as e :
331+ raise OAuthTokenError ("Invalid token exchange response from IdP" ) from e
339332
340333 async def exchange_id_jag_for_access_token (
341334 self ,
342335 id_jag : str ,
343336 ) -> httpx .Request :
344- """Build JWT bearer grant request to exchange ID-JAG for access token (RFC 7523).
337+ """Build a JWT bearer grant request to exchange an ID-JAG for an access token (RFC 7523).
345338
346- Builds the request without executing it. HTTP execution and error parsing
347- are deferred to the parent class's `async_auth_flow` for consistency.
339+ This method only *builds* the ``httpx.Request``; it does not execute
340+ it. HTTP execution and error parsing are deferred to the parent
341+ class's ``async_auth_flow`` via ``_handle_token_response``.
342+
343+ Follows the same pattern as ``ClientCredentialsOAuthProvider._exchange_token_client_credentials``
344+ and ``RFC7523OAuthClientProvider._exchange_token_jwt_bearer``:
345+ use ``_get_token_endpoint()`` for the URL and ``prepare_token_auth()``
346+ for client authentication — no manual ``client_id`` injection or
347+ context swapping needed.
348348
349349 Args:
350- id_jag: The ID-JAG token
350+ id_jag: The ID-JAG token obtained from ``exchange_token_for_id_jag``
351351
352352 Returns:
353- httpx.Request for the JWT bearer grant
354-
355- Raises:
356- OAuthFlowError: If OAuth metadata not discovered
353+ An ``httpx.Request`` for the JWT bearer grant
357354 """
358355 logger .debug ("Building JWT bearer grant request for ID-JAG" )
359356
360- if not self .context .oauth_metadata or not self .context .oauth_metadata .token_endpoint :
361- raise OAuthFlowError ("MCP server token endpoint not discovered" )
362-
363- token_endpoint = str (self .context .oauth_metadata .token_endpoint )
364-
365- # Build as a plain dict — avoids the double-cast through TypedDict
366357 token_data : dict [str , str ] = {
367358 "grant_type" : "urn:ietf:params:oauth:grant-type:jwt-bearer" ,
368359 "assertion" : id_jag ,
369360 }
370361
371- # Add client_id to body. prepare_token_auth handles client_secret
372- # placement (Basic header vs. POST body) based on auth method.
373- if self .context .client_info :
374- if self .context .client_info .token_endpoint_auth_method is None :
375- self .context .client_info .token_endpoint_auth_method = "client_secret_basic"
376-
377- if self .context .client_info .client_id is not None :
378- token_data ["client_id" ] = self .context .client_info .client_id
379-
380362 headers : dict [str , str ] = {"Content-Type" : "application/x-www-form-urlencoded" }
363+
364+ # Delegate client authentication (client_secret_basic, client_secret_post,
365+ # or none) to the parent's context helper — same as every other grant type.
381366 token_data , headers = self .context .prepare_token_auth (token_data , headers )
382367
383- return httpx .Request ("POST" , token_endpoint , data = token_data , headers = headers )
368+ # Include resource parameter per RFC 8707 — same guard as every sibling provider
369+ if self .context .should_include_resource_param (self .context .protocol_version ):
370+ token_data ["resource" ] = self .context .get_resource_url ()
371+
372+ # Include scope if configured (may have been updated by parent's async_auth_flow
373+ # from the server's WWW-Authenticate header before _perform_authorization is called)
374+ if self .context .client_metadata .scope :
375+ token_data ["scope" ] = self .context .client_metadata .scope
376+
377+ token_url = self ._get_token_endpoint ()
378+ return httpx .Request ("POST" , token_url , data = token_data , headers = headers )
384379
385380 async def _perform_authorization (self ) -> httpx .Request :
386381 """Perform enterprise authorization flow.
@@ -400,7 +395,6 @@ async def _perform_authorization(self) -> httpx.Request:
400395 # Step 1: Exchange IDP subject token for ID-JAG (RFC 8693)
401396 async with httpx .AsyncClient (timeout = self .context .timeout ) as client :
402397 id_jag = await self .exchange_token_for_id_jag (client )
403- self ._id_jag = id_jag
404398
405399 # Step 2: Build JWT bearer grant request (RFC 7523)
406400 jwt_bearer_request = await self .exchange_id_jag_for_access_token (id_jag )
@@ -412,23 +406,27 @@ async def refresh_with_new_id_token(self, new_id_token: str) -> None:
412406 """Refresh MCP server access tokens using a fresh ID token from the IdP.
413407
414408 Updates the subject token and clears cached state so that the next API
415- request triggers a full re-authentication.
409+ request triggers a full re-authentication. Acquires the context lock
410+ to prevent racing with an in-progress ``async_auth_flow``.
416411
417412 Note: OAuth metadata is not re-discovered. If the MCP server's OAuth
418413 configuration has changed, create a new provider instance instead.
419414
415+ Warning: This method is NOT safe to call from a different OS thread.
416+ Call it only from the same thread and event loop that owns this
417+ provider instance.
418+
420419 Args:
421420 new_id_token: Fresh ID token obtained from your enterprise IdP.
422421 """
423- logger .info ("Refreshing tokens with new ID token from IdP" )
424- # Update the mutable subject token (does NOT mutate the original params object)
425- self ._subject_token = new_id_token
422+ async with self .context .lock :
423+ logger .info ("Refreshing tokens with new ID token from IdP" )
424+ # Update the mutable subject token (does NOT mutate the original params object)
425+ self ._subject_token = new_id_token
426426
427- # Clear caches to force full re-exchange on next request
428- self ._id_jag = None
429- self ._id_jag_expiry = None
430- self .context .clear_tokens ()
431- logger .debug ("Token refresh prepared — will re-authenticate on next request" )
427+ # Clear tokens to force full re-exchange on next request
428+ self .context .clear_tokens ()
429+ logger .debug ("Token refresh prepared — will re-authenticate on next request" )
432430
433431
434432def decode_id_jag (id_jag : str ) -> IDJAGClaims :
@@ -452,7 +450,12 @@ def decode_id_jag(id_jag: str) -> IDJAGClaims:
452450def validate_token_exchange_params (
453451 params : TokenExchangeParameters ,
454452) -> None :
455- """Validate token exchange parameters.
453+ """Validate token exchange parameters beyond Pydantic field constraints.
454+
455+ Pydantic ``Field(...)`` rejects *missing* values but permits empty strings.
456+ This function adds:
457+ - Empty-string checks for ``subject_token``, ``audience``, ``resource``
458+ - Allow-list check for ``subject_token_type`` (id_token or saml2)
456459
457460 Args:
458461 params: Token exchange parameters to validate
@@ -469,8 +472,8 @@ def validate_token_exchange_params(
469472 if not params .resource :
470473 raise OAuthFlowError ("resource is required" )
471474
472- if params .subject_token_type not in [
475+ if params .subject_token_type not in {
473476 "urn:ietf:params:oauth:token-type:id_token" ,
474477 "urn:ietf:params:oauth:token-type:saml2" ,
475- ] :
478+ } :
476479 raise OAuthFlowError (f"Invalid subject_token_type: { params .subject_token_type } " )
0 commit comments