Skip to content

Commit b9bf42a

Browse files
committed
chore(auth): type-safe auth context push
1 parent de67bd7 commit b9bf42a

2 files changed

Lines changed: 11 additions & 10 deletions

File tree

src/mcp/server/auth/middleware/auth_context.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import contextvars
2-
32
from contextvars import Token
3+
from typing import Any
44

55
from starlette.requests import Request
66
from starlette.types import ASGIApp, Receive, Scope, Send
@@ -23,7 +23,7 @@ def get_access_token() -> AccessToken | None:
2323
return auth_user.access_token if auth_user else None
2424

2525

26-
def _push_auth_context_from_request(request: Request | None) -> Token[AuthenticatedUser | None] | None:
26+
def push_auth_context_from_request(request: Request | None) -> Token[AuthenticatedUser | None] | None:
2727
"""Set auth context for the current task from an incoming request.
2828
2929
This is primarily used by server transports where request handlers may run
@@ -32,10 +32,7 @@ def _push_auth_context_from_request(request: Request | None) -> Token[Authentica
3232
if request is None:
3333
return None
3434
# Avoid Request.user, which asserts AuthenticationMiddleware is installed.
35-
user = None
36-
scope = getattr(request, "scope", None)
37-
if isinstance(scope, dict):
38-
user = scope.get("user")
35+
user: Any | None = request.scope.get("user")
3936
if user is None:
4037
try:
4138
user = getattr(request, "user", None)
@@ -46,7 +43,7 @@ def _push_auth_context_from_request(request: Request | None) -> Token[Authentica
4643
return None
4744

4845

49-
def _pop_auth_context(token: Token[AuthenticatedUser | None] | None) -> None:
46+
def pop_auth_context(token: Token[AuthenticatedUser | None] | None) -> None:
5047
if token is None:
5148
return
5249
auth_context_var.reset(token)

src/mcp/server/lowlevel/server.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@ async def main():
5353
from typing_extensions import TypeVar
5454

5555
from mcp import types
56-
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, _pop_auth_context, _push_auth_context_from_request
56+
from mcp.server.auth.middleware.auth_context import (
57+
AuthContextMiddleware,
58+
pop_auth_context,
59+
push_auth_context_from_request,
60+
)
5761
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware
5862
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier
5963
from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes
@@ -497,11 +501,11 @@ async def _handle_request(
497501
close_sse_stream=close_sse_stream_cb,
498502
close_standalone_sse_stream=close_standalone_sse_stream_cb,
499503
)
500-
auth_token = _push_auth_context_from_request(request_data)
504+
auth_token = push_auth_context_from_request(request_data)
501505
try:
502506
response = await handler(ctx, req.params)
503507
finally:
504-
_pop_auth_context(auth_token)
508+
pop_auth_context(auth_token)
505509
except MCPError as err:
506510
response = err.error
507511
except anyio.get_cancelled_exc_class():

0 commit comments

Comments
 (0)