Skip to content

Commit 32a5e6d

Browse files
committed
merge fix
1 parent fe0e62e commit 32a5e6d

File tree

4 files changed

+11
-11
lines changed

4 files changed

+11
-11
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import httpx
2020
from pydantic import BaseModel, Field, ValidationError
2121

22-
from mcp.client.auth import OAuthFlowError, OAuthTokenError
22+
from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError
2323
from mcp.client.auth.utils import (
2424
build_oauth_authorization_server_metadata_discovery_urls,
2525
build_protected_resource_metadata_discovery_urls,
@@ -193,7 +193,7 @@ def prepare_token_auth(
193193

194194
auth_method = self.client_info.token_endpoint_auth_method
195195

196-
if auth_method == "client_secret_basic" and self.client_info.client_secret:
196+
if auth_method == "client_secret_basic" and self.client_info.client_id and self.client_info.client_secret:
197197
# URL-encode client ID and secret per RFC 6749 Section 2.3.1
198198
encoded_id = quote(self.client_info.client_id, safe="")
199199
encoded_secret = quote(self.client_info.client_secret, safe="")
@@ -426,7 +426,7 @@ async def _refresh_token(self) -> httpx.Request:
426426
if not self.context.current_tokens or not self.context.current_tokens.refresh_token:
427427
raise OAuthTokenError("No refresh token available") # pragma: no cover
428428

429-
if not self.context.client_info:
429+
if not self.context.client_info or not self.context.client_info.client_id:
430430
raise OAuthTokenError("No client info available") # pragma: no cover
431431

432432
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
@@ -435,7 +435,7 @@ async def _refresh_token(self) -> httpx.Request:
435435
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
436436
token_url = urljoin(auth_base_url, "/token")
437437

438-
refresh_data = {
438+
refresh_data: dict[str, str] = {
439439
"grant_type": "refresh_token",
440440
"refresh_token": self.context.current_tokens.refresh_token,
441441
"client_id": self.context.client_info.client_id,

src/mcp/server/auth/middleware/client_auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation
106106
# arguments to bytes.
107107
if not hmac.compare_digest(
108108
client.client_secret.encode(), request_client_secret.encode()
109-
): # pragma: no cover
110-
raise AuthenticationError("Invalid client_secret")
109+
):
110+
raise AuthenticationError("Invalid client_secret") # pragma: no cover
111111

112112
if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()):
113113
raise AuthenticationError("Client secret has expired") # pragma: no cover

src/mcp/shared/auth.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ class OAuthClientMetadata(BaseModel):
4343

4444
redirect_uris: list[AnyUrl] | None = Field(..., min_length=1)
4545
# supported auth methods for the token endpoint
46-
token_endpoint_auth_method: Literal[
47-
"none", "client_secret_post", "client_secret_basic", "private_key_jwt"
48-
] | None = None
46+
token_endpoint_auth_method: (
47+
Literal["none", "client_secret_post", "client_secret_basic", "private_key_jwt"] | None
48+
) = None
4949
# supported grant_types of this implementation
5050
grant_types: list[
5151
Literal["authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:jwt-bearer"] | str

tests/client/test_auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvid
707707
token_endpoint_auth_method="client_secret_basic",
708708
)
709709

710-
request = await oauth_provider._exchange_token("test_auth_code", "test_verifier")
710+
request = await oauth_provider._exchange_token_authorization_code("test_auth_code", "test_verifier")
711711

712712
# Should use basic auth (registered method)
713713
assert "Authorization" in request.headers
@@ -784,7 +784,7 @@ async def test_none_auth_method(self, oauth_provider: OAuthClientProvider):
784784
token_endpoint_auth_method="none",
785785
)
786786

787-
request = await oauth_provider._exchange_token("test_auth_code", "test_verifier")
787+
request = await oauth_provider._exchange_token_authorization_code("test_auth_code", "test_verifier")
788788

789789
# Should NOT have Authorization header
790790
assert "Authorization" not in request.headers

0 commit comments

Comments
 (0)