Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions packages/client/src/client/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ export class SSEClientTransport implements Transport {
private _fetchWithInit: FetchLike;
private _protocolVersion?: string;

private static readonly _INTERACTIVE_AUTH_STATE_PREFIX = 'mcp:oauth:interactive:sse:';

onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage) => void;
Expand All @@ -96,6 +98,64 @@ export class SSEClientTransport implements Transport {
}
this._fetch = opts?.fetch;
this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit);

this._loadInteractiveAuthState();
}

private _interactiveAuthStateKey(): string {
return `${SSEClientTransport._INTERACTIVE_AUTH_STATE_PREFIX}${this._url.toString()}`;
}

private _saveInteractiveAuthState(): void {
if (typeof sessionStorage === 'undefined') {
return;
}

const state = {
resourceMetadataUrl: this._resourceMetadataUrl?.toString(),
scope: this._scope
};

try {
sessionStorage.setItem(this._interactiveAuthStateKey(), JSON.stringify(state));
} catch {
// Ignore storage failures (e.g. quota exceeded)
}
}

private _loadInteractiveAuthState(): void {
if (typeof sessionStorage === 'undefined') {
return;
}

try {
const raw = sessionStorage.getItem(this._interactiveAuthStateKey());
if (!raw) {
return;
}

const parsed = JSON.parse(raw) as { resourceMetadataUrl?: string; scope?: string };
if (parsed.resourceMetadataUrl) {
this._resourceMetadataUrl = new URL(parsed.resourceMetadataUrl);
}
if (parsed.scope) {
this._scope = parsed.scope;
}
} catch {
// Ignore malformed persisted state
}
}

private _clearInteractiveAuthState(): void {
if (typeof sessionStorage === 'undefined') {
return;
}

try {
sessionStorage.removeItem(this._interactiveAuthStateKey());
} catch {
// Ignore storage failures
}
}

private _last401Response?: Response;
Expand Down Expand Up @@ -137,6 +197,7 @@ export class SSEClientTransport implements Transport {
const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response);
this._resourceMetadataUrl = resourceMetadataUrl;
this._scope = scope;
this._saveInteractiveAuthState();
}
}

Expand Down Expand Up @@ -237,6 +298,8 @@ export class SSEClientTransport implements Transport {
if (result !== 'AUTHORIZED') {
throw new UnauthorizedError('Failed to authorize');
}

this._clearInteractiveAuthState();
}

async close(): Promise<void> {
Expand Down Expand Up @@ -272,6 +335,7 @@ export class SSEClientTransport implements Transport {
const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response);
this._resourceMetadataUrl = resourceMetadataUrl;
this._scope = scope;
this._saveInteractiveAuthState();
}

if (this._authProvider.onUnauthorized && !isAuthRetry) {
Expand Down
64 changes: 64 additions & 0 deletions packages/client/src/client/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ export class StreamableHTTPClientTransport implements Transport {
private _serverRetryMs?: number; // Server-provided retry delay from SSE retry field
private _reconnectionTimeout?: ReturnType<typeof setTimeout>;

private static readonly _INTERACTIVE_AUTH_STATE_PREFIX = 'mcp:oauth:interactive:streamable-http:';

onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage) => void;
Expand All @@ -172,6 +174,64 @@ export class StreamableHTTPClientTransport implements Transport {
this._sessionId = opts?.sessionId;
this._protocolVersion = opts?.protocolVersion;
this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS;

this._loadInteractiveAuthState();
}

private _interactiveAuthStateKey(): string {
return `${StreamableHTTPClientTransport._INTERACTIVE_AUTH_STATE_PREFIX}${this._url.toString()}`;
}

private _saveInteractiveAuthState(): void {
if (typeof sessionStorage === 'undefined') {
return;
}

const state = {
resourceMetadataUrl: this._resourceMetadataUrl?.toString(),
scope: this._scope
};

try {
sessionStorage.setItem(this._interactiveAuthStateKey(), JSON.stringify(state));
} catch {
// Ignore storage failures (e.g. quota exceeded)
}
}

private _loadInteractiveAuthState(): void {
if (typeof sessionStorage === 'undefined') {
return;
}

try {
const raw = sessionStorage.getItem(this._interactiveAuthStateKey());
if (!raw) {
return;
}

const parsed = JSON.parse(raw) as { resourceMetadataUrl?: string; scope?: string };
if (parsed.resourceMetadataUrl) {
this._resourceMetadataUrl = new URL(parsed.resourceMetadataUrl);
}
if (parsed.scope) {
this._scope = parsed.scope;
}
} catch {
// Ignore malformed persisted state
}
}

private _clearInteractiveAuthState(): void {
if (typeof sessionStorage === 'undefined') {
return;
}

try {
sessionStorage.removeItem(this._interactiveAuthStateKey());
} catch {
// Ignore storage failures
}
}

private async _commonHeaders(): Promise<Headers> {
Expand Down Expand Up @@ -223,6 +283,7 @@ export class StreamableHTTPClientTransport implements Transport {
const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response);
this._resourceMetadataUrl = resourceMetadataUrl;
this._scope = scope;
this._saveInteractiveAuthState();
}

if (this._authProvider.onUnauthorized && !isAuthRetry) {
Expand Down Expand Up @@ -455,6 +516,8 @@ export class StreamableHTTPClientTransport implements Transport {
if (result !== 'AUTHORIZED') {
throw new UnauthorizedError('Failed to authorize');
}

this._clearInteractiveAuthState();
}

async close(): Promise<void> {
Expand Down Expand Up @@ -516,6 +579,7 @@ export class StreamableHTTPClientTransport implements Transport {
const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response);
this._resourceMetadataUrl = resourceMetadataUrl;
this._scope = scope;
this._saveInteractiveAuthState();
}

if (this._authProvider.onUnauthorized && !isAuthRetry) {
Expand Down
59 changes: 59 additions & 0 deletions packages/client/test/client/sse.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ describe('SSEClientTransport', () => {
let sendServerMessage: ((message: string) => void) | null = null;

beforeEach(async () => {
const sessionStorageStore = new Map<string, string>();
Object.defineProperty(globalThis, 'sessionStorage', {
configurable: true,
value: {
getItem: (key: string) => sessionStorageStore.get(key) ?? null,
setItem: (key: string, value: string) => {
sessionStorageStore.set(key, value);
},
removeItem: (key: string) => {
sessionStorageStore.delete(key);
}
}
});

// Reset state
lastServerRequest = null as unknown as IncomingMessage;
sendServerMessage = null;
Expand Down Expand Up @@ -111,6 +125,7 @@ describe('SSEClientTransport', () => {
await authServer.close();

vi.clearAllMocks();
delete (globalThis as { sessionStorage?: Storage }).sessionStorage;
});

describe('connection handling', () => {
Expand Down Expand Up @@ -1527,6 +1542,50 @@ describe('SSEClientTransport', () => {
// Global fetch should never have been called
expect(globalFetchSpy).not.toHaveBeenCalled();
});

it('persists interactive auth metadata across transport recreation before finishAuth', async () => {
const authProviderWithCode = createMockAuthProvider({
clientRegistered: true,
authorizationCode: 'test-auth-code'
});

const unauthorizedResponse = new Response(null, {
status: 401,
headers: {
'WWW-Authenticate': `Bearer realm="mcp", resource_metadata="${resourceBaseUrl.href}.well-known/oauth-protected-resource", scope="calendar.read"`
}
});

const firstAuthProvider = {
token: vi.fn(async () => undefined)
};

const firstTransport = new SSEClientTransport(resourceBaseUrl, {
authProvider: firstAuthProvider,
fetch: vi.fn().mockResolvedValue(unauthorizedResponse)
});

// Skip EventSource startup; directly exercise POST 401 path where metadata is captured.
(firstTransport as unknown as { _endpoint: URL })._endpoint = new URL(resourceBaseUrl.href);
await expect(firstTransport.send({ jsonrpc: '2.0', method: 'ping', params: {}, id: '1' })).rejects.toThrow(UnauthorizedError);
await firstTransport.close();

const secondTransport = new SSEClientTransport(resourceBaseUrl, {
authProvider: authProviderWithCode,
fetch: customFetch
});

await secondTransport.finishAuth('test-auth-code');

expect(authProviderWithCode.saveTokens).toHaveBeenCalledWith({
access_token: 'new-access-token',
token_type: 'Bearer',
expires_in: 3600,
refresh_token: 'new-refresh-token'
});

await secondTransport.close();
});
});

describe('minimal AuthProvider (non-OAuth)', () => {
Expand Down
91 changes: 90 additions & 1 deletion packages/client/test/client/streamableHttp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import type { JSONRPCMessage, JSONRPCRequest } from '@modelcontextprotocol/core'
import { OAuthError, OAuthErrorCode, SdkError, SdkErrorCode } from '@modelcontextprotocol/core';
import type { Mock, Mocked } from 'vitest';

import type { OAuthClientProvider } from '../../src/client/auth.js';
import type { AuthProvider, OAuthClientProvider } from '../../src/client/auth.js';
import { UnauthorizedError } from '../../src/client/auth.js';
import type { StartSSEOptions, StreamableHTTPReconnectionOptions } from '../../src/client/streamableHttp.js';
import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js';
Expand All @@ -12,6 +12,20 @@ describe('StreamableHTTPClientTransport', () => {
let mockAuthProvider: Mocked<OAuthClientProvider>;

beforeEach(() => {
const sessionStorageStore = new Map<string, string>();
Object.defineProperty(globalThis, 'sessionStorage', {
configurable: true,
value: {
getItem: (key: string) => sessionStorageStore.get(key) ?? null,
setItem: (key: string, value: string) => {
sessionStorageStore.set(key, value);
},
removeItem: (key: string) => {
sessionStorageStore.delete(key);
}
}
});

mockAuthProvider = {
get redirectUrl() {
return 'http://localhost/callback';
Expand All @@ -34,6 +48,7 @@ describe('StreamableHTTPClientTransport', () => {
afterEach(async () => {
await transport.close().catch(() => {});
vi.clearAllMocks();
delete (globalThis as { sessionStorage?: Storage }).sessionStorage;
});

it('should send JSON-RPC messages via POST', async () => {
Expand Down Expand Up @@ -1441,6 +1456,80 @@ describe('StreamableHTTPClientTransport', () => {
// Global fetch should never have been called
expect(globalThis.fetch).not.toHaveBeenCalled();
});

it('persists interactive auth metadata across transport recreation before finishAuth', async () => {
const customFetch = vi
.fn()
// First transport send -> 401 with auth metadata
.mockResolvedValueOnce(
new Response(null, {
status: 401,
headers: {
'WWW-Authenticate':
'Bearer resource_metadata="http://localhost:1234/.well-known/oauth-protected-resource", scope="calendar.read"'
}
})
)
// Second transport finishAuth -> resource metadata discovery
.mockResolvedValueOnce(
Response.json({
authorization_servers: ['http://localhost:1234'],
resource: 'http://localhost:1234/mcp'
})
)
// auth server metadata discovery
.mockResolvedValueOnce(
Response.json({
issuer: 'http://localhost:1234',
authorization_endpoint: 'http://localhost:1234/authorize',
token_endpoint: 'http://localhost:1234/token',
response_types_supported: ['code'],
code_challenge_methods_supported: ['S256']
})
)
// authorization code exchange
.mockResolvedValueOnce(
Response.json({
access_token: 'new-access-token',
refresh_token: 'new-refresh-token',
token_type: 'Bearer',
expires_in: 3600
})
);

const firstAuthProvider: AuthProvider = {
token: vi.fn(async () => undefined)
};

const firstTransport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), {
authProvider: firstAuthProvider,
fetch: customFetch
});

await firstTransport.start();

await expect(firstTransport.send({ jsonrpc: '2.0', method: 'ping', params: {}, id: '1' } as JSONRPCMessage)).rejects.toThrow(
UnauthorizedError
);

await firstTransport.close();

const secondTransport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), {
authProvider: mockAuthProvider,
fetch: customFetch
});

await secondTransport.finishAuth('auth-code-after-redirect');

expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({
access_token: 'new-access-token',
token_type: 'Bearer',
expires_in: 3600,
refresh_token: 'new-refresh-token'
});

await secondTransport.close();
});
});

describe('SSE retry field handling', () => {
Expand Down
Loading