Skip to content

Commit 425161d

Browse files
committed
fix(oauth): narrow async_auth_flow lock scope to avoid blocking long-poll requests
Closes #2847. OAuthContext.lock is an anyio.Lock, which records task identity at acquire() and enforces same-task release(). async_auth_flow held this lock across yield points; when httpx drives the generator from a different task during concurrent OAuth connections, release() raises 'RuntimeError: The current task is not holding this lock'. Narrows the lock scope so no HTTP yield (long-poll GET SSE, token-refresh round trips) runs while holding context.lock, plus a single-flight refresh_lock with a re-check under the lock. Keeps trio portability (no asyncio.Lock swap). Salvage of #2660 by @peisuke, rebased onto current main.
1 parent 4472428 commit 425161d

2 files changed

Lines changed: 348 additions & 18 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 81 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,18 @@ class OAuthContext:
120120
token_expiry_time: float | None = None
121121

122122
# State
123+
#
124+
# `lock` guards short-lived reads/writes of provider state (initialization
125+
# flag, token cache mutation, protocol_version assignment). It is held only
126+
# while mutating state and is released before any HTTP request is yielded
127+
# so a long-running request (e.g. GET SSE long-poll) does not block
128+
# unrelated concurrent requests.
129+
#
130+
# `refresh_lock` provides single-flight semantics for token refresh: only
131+
# one concurrent refresh fires; other waiters block on this lock, then
132+
# re-check the token cache and proceed without re-refreshing.
123133
lock: anyio.Lock = field(default_factory=anyio.Lock)
134+
refresh_lock: anyio.Lock = field(default_factory=anyio.Lock)
124135

125136
def get_authorization_base_url(self, server_url: str) -> str:
126137
"""Extract base URL by removing path component."""
@@ -475,7 +486,7 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool:
475486
await self.context.storage.set_tokens(token_response)
476487

477488
return True
478-
except ValidationError: # pragma: no cover
489+
except ValidationError:
479490
logger.exception("Invalid refresh response")
480491
self.context.clear_tokens()
481492
return False
@@ -511,29 +522,80 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None
511522
raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}")
512523

513524
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
514-
"""HTTPX auth flow integration."""
525+
"""HTTPX auth flow integration.
526+
527+
Lock scope:
528+
``self.context.lock`` is held only while reading/mutating provider
529+
state. The actual HTTP request yield (which may be a long-poll GET
530+
SSE stream) runs outside any lock so concurrent unrelated requests
531+
are not blocked. ``self.context.refresh_lock`` provides
532+
single-flight semantics for token refresh.
533+
"""
534+
# === Phase 1: state read + refresh decision (brief context.lock) ===
535+
needs_refresh = False
515536
async with self.context.lock:
516537
if not self._initialized:
517538
await self._initialize()
518539

519540
# Capture protocol version from request headers
520541
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
521542

522-
if not self.context.is_token_valid() and self.context.can_refresh_token():
523-
# Try to refresh token
524-
refresh_request = await self._refresh_token()
525-
refresh_response = yield refresh_request
526-
527-
if not await self._handle_refresh_response(refresh_response):
528-
# Refresh failed, need full re-authentication
529-
self._initialized = False
530-
531-
if self.context.is_token_valid():
532-
self._add_auth_header(request)
533-
534-
response = yield request
535-
536-
if response.status_code == 401:
543+
# pragma: no branch — coverage.py on Python 3.10/3.11 (sys.settrace
544+
# backend) cannot reliably track both arms of compound boolean
545+
# predicates inside an ``async with`` block in an async generator.
546+
# Python 3.12+ (sys.monitoring) handles this correctly; the pragmas
547+
# below are workarounds for the legacy backend only.
548+
if not self.context.is_token_valid() and self.context.can_refresh_token(): # pragma: no branch
549+
needs_refresh = True
550+
551+
# === Phase 2: single-flight token refresh (yield outside context.lock) ===
552+
if needs_refresh:
553+
async with self.context.refresh_lock:
554+
# Re-check under context.lock: another coroutine may already have
555+
# refreshed while we were waiting on refresh_lock.
556+
refresh_request: httpx.Request | None = None
557+
async with self.context.lock:
558+
if not self.context.is_token_valid() and self.context.can_refresh_token(): # pragma: no branch
559+
refresh_request = await self._refresh_token()
560+
if refresh_request is not None: # pragma: no branch
561+
# yield runs outside any lock so a long network round trip
562+
# does not block unrelated concurrent requests.
563+
refresh_response = yield refresh_request
564+
async with self.context.lock:
565+
if not await self._handle_refresh_response(refresh_response): # pragma: no branch
566+
# Refresh failed; fall through to 401 handling below.
567+
self._initialized = False
568+
569+
# === Phase 3: send request (no lock; safe for long-poll GET SSE) ===
570+
if self.context.is_token_valid():
571+
self._add_auth_header(request)
572+
573+
# Capture the access token actually used to send this request so the
574+
# 401 handler below can detect a token change made by a concurrent
575+
# request while this one was in flight.
576+
sent_access_token = self.context.current_tokens.access_token if self.context.current_tokens else None
577+
578+
response = yield request
579+
580+
# === Phase 4: 401 / 403 full OAuth flow ===
581+
# NOTE: Phase 4 yields multiple sub-requests (discovery, registration,
582+
# token exchange) under context.lock. This is the existing behavior and
583+
# is acceptable because the 401 path is exceptional and not concurrent
584+
# with steady-state traffic. A future refactor could narrow the lock
585+
# here in the same pattern as Phase 1-2.
586+
if response.status_code == 401:
587+
async with self.context.lock:
588+
# Concurrency guard: while this request was in flight, another
589+
# request holding ``context.lock`` may have already completed a
590+
# token refresh or a full re-authorization. If the stored access
591+
# token changed since we sent this request, the 401 is stale -
592+
# retry once with the new token instead of running a second,
593+
# duplicate ``authorization_code`` exchange.
594+
current_access_token = self.context.current_tokens.access_token if self.context.current_tokens else None
595+
if current_access_token is not None and current_access_token != sent_access_token:
596+
self._add_auth_header(request)
597+
yield request
598+
return
537599
# Perform full OAuth flow
538600
try:
539601
# OAuth flow must be inline due to generator constraints
@@ -652,7 +714,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
652714
# Retry with new tokens
653715
self._add_auth_header(request)
654716
yield request
655-
elif response.status_code == 403:
717+
elif response.status_code == 403:
718+
async with self.context.lock:
656719
# Step 1: Extract error field from WWW-Authenticate header
657720
error = extract_field_from_www_auth(response, "error")
658721

0 commit comments

Comments
 (0)