From 45c5a0cd15da0f997ea57826230a7bfb5a80e7a4 Mon Sep 17 00:00:00 2001 From: Kripa Dev Date: Sun, 29 Mar 2026 20:11:39 +0530 Subject: [PATCH] client: persist interactive OAuth metadata across redirects --- packages/client/src/client/sse.ts | 64 +++++++++++++ packages/client/src/client/streamableHttp.ts | 64 +++++++++++++ packages/client/test/client/sse.test.ts | 59 ++++++++++++ .../client/test/client/streamableHttp.test.ts | 91 ++++++++++++++++++- 4 files changed, 277 insertions(+), 1 deletion(-) diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index f441e9cdb..4203f11fe 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -78,6 +78,8 @@ export class SSEClientTransport implements Transport { private _fetchWithInit: FetchLike; private _protocolVersion?: string; + private static readonly _INTERACTIVE_AUTH_STATE_PREFIX = 'mcp:oauth:interactive:sse:'; + onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage) => void; @@ -96,6 +98,64 @@ export class SSEClientTransport implements Transport { } this._fetch = opts?.fetch; this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); + + this._loadInteractiveAuthState(); + } + + private _interactiveAuthStateKey(): string { + return `${SSEClientTransport._INTERACTIVE_AUTH_STATE_PREFIX}${this._url.toString()}`; + } + + private _saveInteractiveAuthState(): void { + if (typeof sessionStorage === 'undefined') { + return; + } + + const state = { + resourceMetadataUrl: this._resourceMetadataUrl?.toString(), + scope: this._scope + }; + + try { + sessionStorage.setItem(this._interactiveAuthStateKey(), JSON.stringify(state)); + } catch { + // Ignore storage failures (e.g. quota exceeded) + } + } + + private _loadInteractiveAuthState(): void { + if (typeof sessionStorage === 'undefined') { + return; + } + + try { + const raw = sessionStorage.getItem(this._interactiveAuthStateKey()); + if (!raw) { + return; + } + + const parsed = JSON.parse(raw) as { resourceMetadataUrl?: string; scope?: string }; + if (parsed.resourceMetadataUrl) { + this._resourceMetadataUrl = new URL(parsed.resourceMetadataUrl); + } + if (parsed.scope) { + this._scope = parsed.scope; + } + } catch { + // Ignore malformed persisted state + } + } + + private _clearInteractiveAuthState(): void { + if (typeof sessionStorage === 'undefined') { + return; + } + + try { + sessionStorage.removeItem(this._interactiveAuthStateKey()); + } catch { + // Ignore storage failures + } } private _last401Response?: Response; @@ -137,6 +197,7 @@ export class SSEClientTransport implements Transport { const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); this._resourceMetadataUrl = resourceMetadataUrl; this._scope = scope; + this._saveInteractiveAuthState(); } } @@ -237,6 +298,8 @@ export class SSEClientTransport implements Transport { if (result !== 'AUTHORIZED') { throw new UnauthorizedError('Failed to authorize'); } + + this._clearInteractiveAuthState(); } async close(): Promise { @@ -272,6 +335,7 @@ export class SSEClientTransport implements Transport { const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); this._resourceMetadataUrl = resourceMetadataUrl; this._scope = scope; + this._saveInteractiveAuthState(); } if (this._authProvider.onUnauthorized && !isAuthRetry) { diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index 3d45b60e9..dcbbf324b 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -152,6 +152,8 @@ export class StreamableHTTPClientTransport implements Transport { private _serverRetryMs?: number; // Server-provided retry delay from SSE retry field private _reconnectionTimeout?: ReturnType; + private static readonly _INTERACTIVE_AUTH_STATE_PREFIX = 'mcp:oauth:interactive:streamable-http:'; + onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage) => void; @@ -172,6 +174,64 @@ export class StreamableHTTPClientTransport implements Transport { this._sessionId = opts?.sessionId; this._protocolVersion = opts?.protocolVersion; this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS; + + this._loadInteractiveAuthState(); + } + + private _interactiveAuthStateKey(): string { + return `${StreamableHTTPClientTransport._INTERACTIVE_AUTH_STATE_PREFIX}${this._url.toString()}`; + } + + private _saveInteractiveAuthState(): void { + if (typeof sessionStorage === 'undefined') { + return; + } + + const state = { + resourceMetadataUrl: this._resourceMetadataUrl?.toString(), + scope: this._scope + }; + + try { + sessionStorage.setItem(this._interactiveAuthStateKey(), JSON.stringify(state)); + } catch { + // Ignore storage failures (e.g. quota exceeded) + } + } + + private _loadInteractiveAuthState(): void { + if (typeof sessionStorage === 'undefined') { + return; + } + + try { + const raw = sessionStorage.getItem(this._interactiveAuthStateKey()); + if (!raw) { + return; + } + + const parsed = JSON.parse(raw) as { resourceMetadataUrl?: string; scope?: string }; + if (parsed.resourceMetadataUrl) { + this._resourceMetadataUrl = new URL(parsed.resourceMetadataUrl); + } + if (parsed.scope) { + this._scope = parsed.scope; + } + } catch { + // Ignore malformed persisted state + } + } + + private _clearInteractiveAuthState(): void { + if (typeof sessionStorage === 'undefined') { + return; + } + + try { + sessionStorage.removeItem(this._interactiveAuthStateKey()); + } catch { + // Ignore storage failures + } } private async _commonHeaders(): Promise { @@ -223,6 +283,7 @@ export class StreamableHTTPClientTransport implements Transport { const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); this._resourceMetadataUrl = resourceMetadataUrl; this._scope = scope; + this._saveInteractiveAuthState(); } if (this._authProvider.onUnauthorized && !isAuthRetry) { @@ -455,6 +516,8 @@ export class StreamableHTTPClientTransport implements Transport { if (result !== 'AUTHORIZED') { throw new UnauthorizedError('Failed to authorize'); } + + this._clearInteractiveAuthState(); } async close(): Promise { @@ -516,6 +579,7 @@ export class StreamableHTTPClientTransport implements Transport { const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); this._resourceMetadataUrl = resourceMetadataUrl; this._scope = scope; + this._saveInteractiveAuthState(); } if (this._authProvider.onUnauthorized && !isAuthRetry) { diff --git a/packages/client/test/client/sse.test.ts b/packages/client/test/client/sse.test.ts index b0b9588f0..df0a3ee8e 100644 --- a/packages/client/test/client/sse.test.ts +++ b/packages/client/test/client/sse.test.ts @@ -36,6 +36,20 @@ describe('SSEClientTransport', () => { let sendServerMessage: ((message: string) => void) | null = null; beforeEach(async () => { + const sessionStorageStore = new Map(); + Object.defineProperty(globalThis, 'sessionStorage', { + configurable: true, + value: { + getItem: (key: string) => sessionStorageStore.get(key) ?? null, + setItem: (key: string, value: string) => { + sessionStorageStore.set(key, value); + }, + removeItem: (key: string) => { + sessionStorageStore.delete(key); + } + } + }); + // Reset state lastServerRequest = null as unknown as IncomingMessage; sendServerMessage = null; @@ -111,6 +125,7 @@ describe('SSEClientTransport', () => { await authServer.close(); vi.clearAllMocks(); + delete (globalThis as { sessionStorage?: Storage }).sessionStorage; }); describe('connection handling', () => { @@ -1527,6 +1542,50 @@ describe('SSEClientTransport', () => { // Global fetch should never have been called expect(globalFetchSpy).not.toHaveBeenCalled(); }); + + it('persists interactive auth metadata across transport recreation before finishAuth', async () => { + const authProviderWithCode = createMockAuthProvider({ + clientRegistered: true, + authorizationCode: 'test-auth-code' + }); + + const unauthorizedResponse = new Response(null, { + status: 401, + headers: { + 'WWW-Authenticate': `Bearer realm="mcp", resource_metadata="${resourceBaseUrl.href}.well-known/oauth-protected-resource", scope="calendar.read"` + } + }); + + const firstAuthProvider = { + token: vi.fn(async () => undefined) + }; + + const firstTransport = new SSEClientTransport(resourceBaseUrl, { + authProvider: firstAuthProvider, + fetch: vi.fn().mockResolvedValue(unauthorizedResponse) + }); + + // Skip EventSource startup; directly exercise POST 401 path where metadata is captured. + (firstTransport as unknown as { _endpoint: URL })._endpoint = new URL(resourceBaseUrl.href); + await expect(firstTransport.send({ jsonrpc: '2.0', method: 'ping', params: {}, id: '1' })).rejects.toThrow(UnauthorizedError); + await firstTransport.close(); + + const secondTransport = new SSEClientTransport(resourceBaseUrl, { + authProvider: authProviderWithCode, + fetch: customFetch + }); + + await secondTransport.finishAuth('test-auth-code'); + + expect(authProviderWithCode.saveTokens).toHaveBeenCalledWith({ + access_token: 'new-access-token', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'new-refresh-token' + }); + + await secondTransport.close(); + }); }); describe('minimal AuthProvider (non-OAuth)', () => { diff --git a/packages/client/test/client/streamableHttp.test.ts b/packages/client/test/client/streamableHttp.test.ts index 55bf79a50..720d593ff 100644 --- a/packages/client/test/client/streamableHttp.test.ts +++ b/packages/client/test/client/streamableHttp.test.ts @@ -2,7 +2,7 @@ import type { JSONRPCMessage, JSONRPCRequest } from '@modelcontextprotocol/core' import { OAuthError, OAuthErrorCode, SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; import type { Mock, Mocked } from 'vitest'; -import type { OAuthClientProvider } from '../../src/client/auth.js'; +import type { AuthProvider, OAuthClientProvider } from '../../src/client/auth.js'; import { UnauthorizedError } from '../../src/client/auth.js'; import type { StartSSEOptions, StreamableHTTPReconnectionOptions } from '../../src/client/streamableHttp.js'; import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; @@ -12,6 +12,20 @@ describe('StreamableHTTPClientTransport', () => { let mockAuthProvider: Mocked; beforeEach(() => { + const sessionStorageStore = new Map(); + Object.defineProperty(globalThis, 'sessionStorage', { + configurable: true, + value: { + getItem: (key: string) => sessionStorageStore.get(key) ?? null, + setItem: (key: string, value: string) => { + sessionStorageStore.set(key, value); + }, + removeItem: (key: string) => { + sessionStorageStore.delete(key); + } + } + }); + mockAuthProvider = { get redirectUrl() { return 'http://localhost/callback'; @@ -34,6 +48,7 @@ describe('StreamableHTTPClientTransport', () => { afterEach(async () => { await transport.close().catch(() => {}); vi.clearAllMocks(); + delete (globalThis as { sessionStorage?: Storage }).sessionStorage; }); it('should send JSON-RPC messages via POST', async () => { @@ -1441,6 +1456,80 @@ describe('StreamableHTTPClientTransport', () => { // Global fetch should never have been called expect(globalThis.fetch).not.toHaveBeenCalled(); }); + + it('persists interactive auth metadata across transport recreation before finishAuth', async () => { + const customFetch = vi + .fn() + // First transport send -> 401 with auth metadata + .mockResolvedValueOnce( + new Response(null, { + status: 401, + headers: { + 'WWW-Authenticate': + 'Bearer resource_metadata="http://localhost:1234/.well-known/oauth-protected-resource", scope="calendar.read"' + } + }) + ) + // Second transport finishAuth -> resource metadata discovery + .mockResolvedValueOnce( + Response.json({ + authorization_servers: ['http://localhost:1234'], + resource: 'http://localhost:1234/mcp' + }) + ) + // auth server metadata discovery + .mockResolvedValueOnce( + Response.json({ + issuer: 'http://localhost:1234', + authorization_endpoint: 'http://localhost:1234/authorize', + token_endpoint: 'http://localhost:1234/token', + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'] + }) + ) + // authorization code exchange + .mockResolvedValueOnce( + Response.json({ + access_token: 'new-access-token', + refresh_token: 'new-refresh-token', + token_type: 'Bearer', + expires_in: 3600 + }) + ); + + const firstAuthProvider: AuthProvider = { + token: vi.fn(async () => undefined) + }; + + const firstTransport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + authProvider: firstAuthProvider, + fetch: customFetch + }); + + await firstTransport.start(); + + await expect(firstTransport.send({ jsonrpc: '2.0', method: 'ping', params: {}, id: '1' } as JSONRPCMessage)).rejects.toThrow( + UnauthorizedError + ); + + await firstTransport.close(); + + const secondTransport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + authProvider: mockAuthProvider, + fetch: customFetch + }); + + await secondTransport.finishAuth('auth-code-after-redirect'); + + expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ + access_token: 'new-access-token', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'new-refresh-token' + }); + + await secondTransport.close(); + }); }); describe('SSE retry field handling', () => {