From a059ef11e1d9f54073b78d105054193b1804d5d4 Mon Sep 17 00:00:00 2001 From: Raashish Aggarwal <94279692+raashish1601@users.noreply.github.com> Date: Sat, 28 Mar 2026 03:35:47 +0530 Subject: [PATCH] fix(client): dedupe concurrent oauth refreshes --- .changeset/fresh-plums-cheer.md | 5 ++ packages/client/src/client/auth.ts | 65 ++++++++++++++--- packages/client/test/client/auth.test.ts | 93 ++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 11 deletions(-) create mode 100644 .changeset/fresh-plums-cheer.md diff --git a/.changeset/fresh-plums-cheer.md b/.changeset/fresh-plums-cheer.md new file mode 100644 index 000000000..3a765a726 --- /dev/null +++ b/.changeset/fresh-plums-cheer.md @@ -0,0 +1,5 @@ +--- +'@modelcontextprotocol/client': patch +--- + +Deduplicate concurrent OAuth refreshes for the same provider, authorization server, resource, and refresh token so parallel `auth()` callers reuse the in-flight refresh instead of replaying a rotating refresh token. diff --git a/packages/client/src/client/auth.ts b/packages/client/src/client/auth.ts index 1a021be18..55504dcd5 100644 --- a/packages/client/src/client/auth.ts +++ b/packages/client/src/client/auth.ts @@ -384,6 +384,46 @@ function isClientAuthMethod(method: string): method is ClientAuthMethod { const AUTHORIZATION_CODE_RESPONSE_TYPE = 'code'; const AUTHORIZATION_CODE_CHALLENGE_METHOD = 'S256'; +const inFlightRefreshAuthorizations = new WeakMap>>(); + +function buildRefreshAuthorizationKey(authorizationServerUrl: string | URL, refreshToken: string, resource?: URL): string { + return JSON.stringify({ + authorizationServerUrl: String(authorizationServerUrl), + refreshToken, + resource: resource?.toString() + }); +} + +async function runWithRefreshAuthorizationLock( + provider: OAuthClientProvider, + key: string, + runRefreshAuthorization: () => Promise +): Promise { + const existingRefreshAuthorization = inFlightRefreshAuthorizations.get(provider)?.get(key); + if (existingRefreshAuthorization) { + return await existingRefreshAuthorization; + } + + let refreshesByKey = inFlightRefreshAuthorizations.get(provider); + if (!refreshesByKey) { + refreshesByKey = new Map(); + inFlightRefreshAuthorizations.set(provider, refreshesByKey); + } + + const refreshAuthorizationPromise = runRefreshAuthorization(); + refreshesByKey.set(key, refreshAuthorizationPromise); + try { + return await refreshAuthorizationPromise; + } finally { + const currentRefreshesByKey = inFlightRefreshAuthorizations.get(provider); + if (currentRefreshesByKey?.get(key) === refreshAuthorizationPromise) { + currentRefreshesByKey.delete(key); + if (currentRefreshesByKey.size === 0) { + inFlightRefreshAuthorizations.delete(provider); + } + } + } +} /** * Determines the best client authentication method to use based on server support and client configuration. @@ -731,18 +771,21 @@ async function authInternal( // Handle token refresh or new authorization if (tokens?.refresh_token) { try { - // Attempt to refresh the token - const newTokens = await refreshAuthorization(authorizationServerUrl, { - metadata, - clientInformation, - refreshToken: tokens.refresh_token, - resource, - addClientAuthentication: provider.addClientAuthentication, - fetchFn - }); + const refreshToken = tokens.refresh_token; + const refreshAuthorizationKey = buildRefreshAuthorizationKey(authorizationServerUrl, refreshToken, resource); + return await runWithRefreshAuthorizationLock(provider, refreshAuthorizationKey, async () => { + const newTokens = await refreshAuthorization(authorizationServerUrl, { + metadata, + clientInformation, + refreshToken, + resource, + addClientAuthentication: provider.addClientAuthentication, + fetchFn + }); - await provider.saveTokens(newTokens); - return 'AUTHORIZED'; + await provider.saveTokens(newTokens); + return 'AUTHORIZED'; + }); } catch (error) { // If this is a ServerError, or an unknown type, log it out and try to continue. Otherwise, escalate so we can fix things and retry. if (!(error instanceof OAuthError) || error.code === OAuthErrorCode.ServerError) { diff --git a/packages/client/test/client/auth.test.ts b/packages/client/test/client/auth.test.ts index 8178df906..0a0f2c8e4 100644 --- a/packages/client/test/client/auth.test.ts +++ b/packages/client/test/client/auth.test.ts @@ -2492,6 +2492,99 @@ describe('OAuth Authorization', () => { expect(body.get('refresh_token')).toBe('refresh123'); }); + it('deduplicates concurrent refreshes for the same provider and resource', async () => { + let releaseTokensReaders: (() => void) | undefined; + const tokensReady = new Promise(resolve => { + releaseTokensReaders = resolve; + }); + let tokensReaderCount = 0; + let resolveRefreshResponse: ((value: { ok: true; status: 200; json: () => Promise }) => void) | undefined; + const refreshResponse = new Promise<{ ok: true; status: 200; json: () => Promise }>(resolve => { + resolveRefreshResponse = resolve; + }); + let refreshRequestCount = 0; + + mockFetch.mockImplementation(url => { + const urlString = url.toString(); + + if (urlString.includes('/.well-known/oauth-protected-resource')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + resource: 'https://api.example.com/mcp-server', + authorization_servers: ['https://auth.example.com'] + }) + }); + } + + if (urlString.includes('/.well-known/oauth-authorization-server')) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + }); + } + + if (urlString.includes('/token')) { + refreshRequestCount++; + if (refreshRequestCount > 1) { + throw new Error('duplicate refresh request'); + } + return refreshResponse; + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + (mockProvider.clientInformation as Mock).mockResolvedValue({ + client_id: 'test-client', + client_secret: 'test-secret' + }); + (mockProvider.tokens as Mock).mockImplementation(async () => { + tokensReaderCount++; + if (tokensReaderCount === 2) { + releaseTokensReaders?.(); + } + await tokensReady; + return { + access_token: 'old-access', + refresh_token: 'refresh123' + }; + }); + (mockProvider.saveTokens as Mock).mockResolvedValue(undefined); + + const authResults = Promise.all([ + auth(mockProvider, { serverUrl: 'https://api.example.com/mcp-server' }), + auth(mockProvider, { serverUrl: 'https://api.example.com/mcp-server' }) + ]); + + await vi.waitFor(() => { + expect(refreshRequestCount).toBe(1); + }); + + resolveRefreshResponse?.({ + ok: true, + status: 200, + json: async () => ({ + access_token: 'new-access123', + refresh_token: 'new-refresh456', + token_type: 'Bearer', + expires_in: 3600 + }) + }); + + await expect(authResults).resolves.toEqual(['AUTHORIZED', 'AUTHORIZED']); + expect(refreshRequestCount).toBe(1); + expect(mockProvider.saveTokens).toHaveBeenCalledTimes(1); + }); + it('skips default PRM resource validation when custom validateResourceURL is provided', async () => { const mockValidateResourceURL = vi.fn().mockResolvedValue(undefined); const providerWithCustomValidation = {