|
1 | 1 | """Tests for refactored OAuth client authentication implementation.""" |
2 | 2 |
|
3 | 3 | import base64 |
| 4 | +import contextlib |
4 | 5 | import time |
5 | 6 | from unittest import mock |
6 | 7 | from urllib.parse import parse_qs, quote, unquote, urlparse |
7 | 8 |
|
| 9 | +import anyio |
8 | 10 | import httpx |
9 | 11 | import pytest |
10 | 12 | from inline_snapshot import Is, snapshot |
@@ -2618,3 +2620,116 @@ async def callback_handler() -> tuple[str, str | None]: |
2618 | 2620 | await auth_flow.asend(final_response) |
2619 | 2621 | except StopAsyncIteration: |
2620 | 2622 | pass |
| 2623 | + |
| 2624 | + |
| 2625 | +class TestConcurrentRequestsDoNotDeadlock: |
| 2626 | + """Regression tests for #1326. |
| 2627 | +
|
| 2628 | + Ensures that ``OAuthClientProvider.async_auth_flow`` does not serialize |
| 2629 | + concurrent unrelated requests behind a long-running one (e.g. GET SSE |
| 2630 | + long-poll). The fix narrows ``context.lock`` to state mutation only; the |
| 2631 | + actual ``yield request`` runs outside any lock. |
| 2632 | + """ |
| 2633 | + |
| 2634 | + @pytest.mark.anyio |
| 2635 | + async def test_concurrent_request_not_blocked_by_pending_long_running_request( |
| 2636 | + self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken |
| 2637 | + ): |
| 2638 | + """A second request must reach its yield while the first is still |
| 2639 | + suspended at its yield (= simulating a server-side long-poll). |
| 2640 | +
|
| 2641 | + Before this fix, ``async_auth_flow`` held ``context.lock`` across |
| 2642 | + ``yield request``. A GET SSE long-poll would therefore hold the lock |
| 2643 | + for the entire SSE lifetime, blocking any concurrent request waiting |
| 2644 | + on the same provider's lock and producing a multi-second stall. |
| 2645 | + """ |
| 2646 | + # Set up valid tokens so neither refresh (Phase 2) nor full OAuth |
| 2647 | + # flow (Phase 4) is triggered — we want to exercise the steady-state |
| 2648 | + # Phase 3 yield path that previously held the lock. |
| 2649 | + oauth_provider.context.current_tokens = valid_tokens |
| 2650 | + oauth_provider.context.token_expiry_time = time.time() + 1800 |
| 2651 | + oauth_provider.context.client_info = OAuthClientInformationFull( |
| 2652 | + client_id="test_client_id", |
| 2653 | + client_secret="test_client_secret", |
| 2654 | + redirect_uris=[AnyUrl("http://localhost:3030/callback")], |
| 2655 | + ) |
| 2656 | + oauth_provider._initialized = True |
| 2657 | + |
| 2658 | + # Flow 1: simulate a slow request. Drive it to its yield, then |
| 2659 | + # deliberately do not send a response — it stays suspended at the |
| 2660 | + # yield, just like a GET SSE long-poll waiting for the next event. |
| 2661 | + slow_request = httpx.Request("GET", "https://api.example.com/v1/mcp") |
| 2662 | + slow_flow = oauth_provider.async_auth_flow(slow_request) |
| 2663 | + yielded_slow = await slow_flow.__anext__() |
| 2664 | + assert yielded_slow.headers.get("Authorization") == "Bearer test_access_token" |
| 2665 | + |
| 2666 | + # Flow 2: a concurrent request on the same provider. With the fix, |
| 2667 | + # context.lock is not held during Flow 1's yield, so Flow 2 reaches |
| 2668 | + # its yield almost immediately. Without the fix, this would block |
| 2669 | + # until Flow 1 receives a response — i.e., it would hit the timeout. |
| 2670 | + fast_request = httpx.Request("POST", "https://api.example.com/v1/mcp") |
| 2671 | + fast_flow = oauth_provider.async_auth_flow(fast_request) |
| 2672 | + with anyio.fail_after(1.0): |
| 2673 | + yielded_fast = await fast_flow.__anext__() |
| 2674 | + assert yielded_fast.headers.get("Authorization") == "Bearer test_access_token" |
| 2675 | + |
| 2676 | + # Clean up both generators in deterministic order. |
| 2677 | + with contextlib.suppress(StopAsyncIteration): |
| 2678 | + await fast_flow.asend(httpx.Response(200, request=yielded_fast)) |
| 2679 | + with contextlib.suppress(StopAsyncIteration): |
| 2680 | + await slow_flow.asend(httpx.Response(200, request=yielded_slow)) |
| 2681 | + |
| 2682 | + @pytest.mark.anyio |
| 2683 | + async def test_concurrent_token_refresh_is_single_flight( |
| 2684 | + self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken |
| 2685 | + ): |
| 2686 | + """When concurrent requests both observe an expired token, only one |
| 2687 | + refresh request is sent: ``refresh_lock`` provides single-flight |
| 2688 | + semantics so the second waiter re-checks state and proceeds without |
| 2689 | + re-triggering refresh. |
| 2690 | + """ |
| 2691 | + # Mark the token as expired so the next auth_flow run enters Phase 2. |
| 2692 | + oauth_provider.context.current_tokens = valid_tokens |
| 2693 | + oauth_provider.context.token_expiry_time = time.time() - 100 # expired |
| 2694 | + oauth_provider.context.client_info = OAuthClientInformationFull( |
| 2695 | + client_id="test_client_id", |
| 2696 | + client_secret="test_client_secret", |
| 2697 | + redirect_uris=[AnyUrl("http://localhost:3030/callback")], |
| 2698 | + ) |
| 2699 | + oauth_provider._initialized = True |
| 2700 | + |
| 2701 | + # Flow A: drive it to the refresh yield and suspend there. |
| 2702 | + request_a = httpx.Request("GET", "https://api.example.com/v1/mcp") |
| 2703 | + flow_a = oauth_provider.async_auth_flow(request_a) |
| 2704 | + refresh_a = await flow_a.__anext__() |
| 2705 | + assert "grant_type=refresh_token" in refresh_a.read().decode() |
| 2706 | + |
| 2707 | + # Complete Flow A's refresh with a fresh token. |
| 2708 | + refresh_response = httpx.Response( |
| 2709 | + 200, |
| 2710 | + content=( |
| 2711 | + b'{"access_token": "new_access_token", "token_type": "Bearer", ' |
| 2712 | + b'"expires_in": 3600, "refresh_token": "new_refresh_token"}' |
| 2713 | + ), |
| 2714 | + request=refresh_a, |
| 2715 | + ) |
| 2716 | + request_a_post = await flow_a.asend(refresh_response) |
| 2717 | + assert request_a_post.headers.get("Authorization") == "Bearer new_access_token" |
| 2718 | + |
| 2719 | + # Flow B starts after Flow A's refresh has completed. Because token |
| 2720 | + # state was updated under context.lock, Flow B observes the fresh |
| 2721 | + # token in Phase 1, skips Phase 2 entirely, and reaches its yield |
| 2722 | + # directly. No second refresh request is sent. |
| 2723 | + request_b = httpx.Request("POST", "https://api.example.com/v1/mcp") |
| 2724 | + flow_b = oauth_provider.async_auth_flow(request_b) |
| 2725 | + with anyio.fail_after(1.0): |
| 2726 | + request_b_yielded = await flow_b.__anext__() |
| 2727 | + assert request_b_yielded.headers.get("Authorization") == "Bearer new_access_token" |
| 2728 | + # Confirm Flow B yielded the original POST, not a refresh request. |
| 2729 | + assert request_b_yielded.method == "POST" |
| 2730 | + |
| 2731 | + # Clean up. |
| 2732 | + with contextlib.suppress(StopAsyncIteration): |
| 2733 | + await flow_b.asend(httpx.Response(200, request=request_b_yielded)) |
| 2734 | + with contextlib.suppress(StopAsyncIteration): |
| 2735 | + await flow_a.asend(httpx.Response(200, request=request_a_post)) |
0 commit comments