Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/fresh-plums-cheer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@modelcontextprotocol/client': patch
---

Deduplicate concurrent OAuth refreshes for the same provider, authorization server, resource, and refresh token so parallel `auth()` callers reuse the in-flight refresh instead of replaying a rotating refresh token.
65 changes: 54 additions & 11 deletions packages/client/src/client/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,46 @@ function isClientAuthMethod(method: string): method is ClientAuthMethod {

const AUTHORIZATION_CODE_RESPONSE_TYPE = 'code';
const AUTHORIZATION_CODE_CHALLENGE_METHOD = 'S256';
const inFlightRefreshAuthorizations = new WeakMap<OAuthClientProvider, Map<string, Promise<AuthResult>>>();

function buildRefreshAuthorizationKey(authorizationServerUrl: string | URL, refreshToken: string, resource?: URL): string {
return JSON.stringify({
authorizationServerUrl: String(authorizationServerUrl),
refreshToken,
resource: resource?.toString()
});
}

async function runWithRefreshAuthorizationLock(
provider: OAuthClientProvider,
key: string,
runRefreshAuthorization: () => Promise<AuthResult>
): Promise<AuthResult> {
const existingRefreshAuthorization = inFlightRefreshAuthorizations.get(provider)?.get(key);
if (existingRefreshAuthorization) {
return await existingRefreshAuthorization;
}

let refreshesByKey = inFlightRefreshAuthorizations.get(provider);
if (!refreshesByKey) {
refreshesByKey = new Map();
inFlightRefreshAuthorizations.set(provider, refreshesByKey);
}

const refreshAuthorizationPromise = runRefreshAuthorization();
refreshesByKey.set(key, refreshAuthorizationPromise);
try {
return await refreshAuthorizationPromise;
} finally {
const currentRefreshesByKey = inFlightRefreshAuthorizations.get(provider);
if (currentRefreshesByKey?.get(key) === refreshAuthorizationPromise) {
currentRefreshesByKey.delete(key);
if (currentRefreshesByKey.size === 0) {
inFlightRefreshAuthorizations.delete(provider);
}
}
}
}

/**
* Determines the best client authentication method to use based on server support and client configuration.
Expand Down Expand Up @@ -731,18 +771,21 @@ async function authInternal(
// Handle token refresh or new authorization
if (tokens?.refresh_token) {
try {
// Attempt to refresh the token
const newTokens = await refreshAuthorization(authorizationServerUrl, {
metadata,
clientInformation,
refreshToken: tokens.refresh_token,
resource,
addClientAuthentication: provider.addClientAuthentication,
fetchFn
});
const refreshToken = tokens.refresh_token;
const refreshAuthorizationKey = buildRefreshAuthorizationKey(authorizationServerUrl, refreshToken, resource);
return await runWithRefreshAuthorizationLock(provider, refreshAuthorizationKey, async () => {
const newTokens = await refreshAuthorization(authorizationServerUrl, {
metadata,
clientInformation,
refreshToken,
resource,
addClientAuthentication: provider.addClientAuthentication,
fetchFn
});

await provider.saveTokens(newTokens);
return 'AUTHORIZED';
await provider.saveTokens(newTokens);
return 'AUTHORIZED';
});
} catch (error) {
// If this is a ServerError, or an unknown type, log it out and try to continue. Otherwise, escalate so we can fix things and retry.
if (!(error instanceof OAuthError) || error.code === OAuthErrorCode.ServerError) {
Expand Down
93 changes: 93 additions & 0 deletions packages/client/test/client/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2492,6 +2492,99 @@ describe('OAuth Authorization', () => {
expect(body.get('refresh_token')).toBe('refresh123');
});

it('deduplicates concurrent refreshes for the same provider and resource', async () => {
let releaseTokensReaders: (() => void) | undefined;
const tokensReady = new Promise<void>(resolve => {
releaseTokensReaders = resolve;
});
let tokensReaderCount = 0;
let resolveRefreshResponse: ((value: { ok: true; status: 200; json: () => Promise<OAuthTokens> }) => void) | undefined;
const refreshResponse = new Promise<{ ok: true; status: 200; json: () => Promise<OAuthTokens> }>(resolve => {
resolveRefreshResponse = resolve;
});
let refreshRequestCount = 0;

mockFetch.mockImplementation(url => {
const urlString = url.toString();

if (urlString.includes('/.well-known/oauth-protected-resource')) {
return Promise.resolve({
ok: true,
status: 200,
json: async () => ({
resource: 'https://api.example.com/mcp-server',
authorization_servers: ['https://auth.example.com']
})
});
}

if (urlString.includes('/.well-known/oauth-authorization-server')) {
return Promise.resolve({
ok: true,
status: 200,
json: async () => ({
issuer: 'https://auth.example.com',
authorization_endpoint: 'https://auth.example.com/authorize',
token_endpoint: 'https://auth.example.com/token',
response_types_supported: ['code'],
code_challenge_methods_supported: ['S256']
})
});
}

if (urlString.includes('/token')) {
refreshRequestCount++;
if (refreshRequestCount > 1) {
throw new Error('duplicate refresh request');
}
return refreshResponse;
}

return Promise.resolve({ ok: false, status: 404 });
});

(mockProvider.clientInformation as Mock).mockResolvedValue({
client_id: 'test-client',
client_secret: 'test-secret'
});
(mockProvider.tokens as Mock).mockImplementation(async () => {
tokensReaderCount++;
if (tokensReaderCount === 2) {
releaseTokensReaders?.();
}
await tokensReady;
return {
access_token: 'old-access',
refresh_token: 'refresh123'
};
});
(mockProvider.saveTokens as Mock).mockResolvedValue(undefined);

const authResults = Promise.all([
auth(mockProvider, { serverUrl: 'https://api.example.com/mcp-server' }),
auth(mockProvider, { serverUrl: 'https://api.example.com/mcp-server' })
]);

await vi.waitFor(() => {
expect(refreshRequestCount).toBe(1);
});

resolveRefreshResponse?.({
ok: true,
status: 200,
json: async () => ({
access_token: 'new-access123',
refresh_token: 'new-refresh456',
token_type: 'Bearer',
expires_in: 3600
})
});

await expect(authResults).resolves.toEqual(['AUTHORIZED', 'AUTHORIZED']);
expect(refreshRequestCount).toBe(1);
expect(mockProvider.saveTokens).toHaveBeenCalledTimes(1);
});

it('skips default PRM resource validation when custom validateResourceURL is provided', async () => {
const mockValidateResourceURL = vi.fn().mockResolvedValue(undefined);
const providerWithCustomValidation = {
Expand Down
Loading