diff --git a/client/src/lib/__tests__/auth.test.ts b/client/src/lib/__tests__/auth.test.ts index 03c503d81..cf498571e 100644 --- a/client/src/lib/__tests__/auth.test.ts +++ b/client/src/lib/__tests__/auth.test.ts @@ -1,5 +1,6 @@ -import { discoverScopes } from "../auth"; +import { discoverScopes, revokeTokens } from "../auth"; import { discoverAuthorizationServerMetadata } from "@modelcontextprotocol/sdk/client/auth.js"; +import { SESSION_KEYS, getServerSpecificKey } from "../constants"; jest.mock("@modelcontextprotocol/sdk/client/auth.js", () => ({ discoverAuthorizationServerMetadata: jest.fn(), @@ -156,3 +157,183 @@ describe("discoverScopes", () => { }, ); }); + +describe("revokeTokens", () => { + const serverUrl = "https://example.com"; + const revocationEndpoint = "https://test.com/revoke"; + const metadataWithRevocation = { + ...baseMetadata, + revocation_endpoint: revocationEndpoint, + }; + + const seedTokens = (tokens: { + access_token: string; + token_type?: string; + refresh_token?: string; + }) => { + sessionStorage.setItem( + getServerSpecificKey(SESSION_KEYS.TOKENS, serverUrl), + JSON.stringify({ token_type: "Bearer", ...tokens }), + ); + }; + + const seedClientInfo = ( + client_id: string, + { isPreregistered = false } = {}, + ) => { + const key = getServerSpecificKey( + isPreregistered + ? SESSION_KEYS.PREREGISTERED_CLIENT_INFORMATION + : SESSION_KEYS.CLIENT_INFORMATION, + serverUrl, + ); + sessionStorage.setItem(key, JSON.stringify({ client_id })); + }; + + const parseRevokeBody = (call: [unknown, RequestInit | undefined]) => { + const init = call[1]; + return new URLSearchParams(init?.body as string); + }; + + let warnSpy: jest.SpyInstance; + let debugSpy: jest.SpyInstance; + + beforeEach(() => { + jest.clearAllMocks(); + sessionStorage.clear(); + warnSpy = jest.spyOn(console, "warn").mockImplementation(() => {}); + debugSpy = jest.spyOn(console, "debug").mockImplementation(() => {}); + }); + + afterEach(() => { + warnSpy.mockRestore(); + debugSpy.mockRestore(); + }); + + it("posts refresh_token to revocation_endpoint when available, includes client_id", async () => { + mockDiscoverAuth.mockResolvedValue(metadataWithRevocation); + seedTokens({ access_token: "at-123", refresh_token: "rt-456" }); + seedClientInfo("client-xyz"); + const fetchFn = jest + .fn, [RequestInfo | URL, RequestInit?]>() + .mockResolvedValue(new Response(null, { status: 200 })); + + await revokeTokens({ serverUrl, fetchFn }); + + expect(fetchFn).toHaveBeenCalledTimes(1); + const [url, init] = fetchFn.mock.calls[0]; + expect(url).toBe(revocationEndpoint); + expect(init?.method).toBe("POST"); + expect((init?.headers as Record)["Content-Type"]).toBe( + "application/x-www-form-urlencoded", + ); + const body = parseRevokeBody(fetchFn.mock.calls[0]); + expect(body.get("token")).toBe("rt-456"); + expect(body.get("token_type_hint")).toBe("refresh_token"); + expect(body.get("client_id")).toBe("client-xyz"); + }); + + it("prefers preregistered client_id over dynamic", async () => { + mockDiscoverAuth.mockResolvedValue(metadataWithRevocation); + seedTokens({ access_token: "at-123", refresh_token: "rt-456" }); + seedClientInfo("dynamic-client"); + seedClientInfo("preregistered-client", { isPreregistered: true }); + const fetchFn = jest + .fn, [RequestInfo | URL, RequestInit?]>() + .mockResolvedValue(new Response(null, { status: 200 })); + + await revokeTokens({ serverUrl, fetchFn }); + + const body = parseRevokeBody(fetchFn.mock.calls[0]); + expect(body.get("client_id")).toBe("preregistered-client"); + }); + + it("falls back to access_token when no refresh_token is present", async () => { + mockDiscoverAuth.mockResolvedValue(metadataWithRevocation); + seedTokens({ access_token: "at-only" }); + const fetchFn = jest + .fn, [RequestInfo | URL, RequestInit?]>() + .mockResolvedValue(new Response(null, { status: 200 })); + + await revokeTokens({ serverUrl, fetchFn }); + + expect(fetchFn).toHaveBeenCalledTimes(1); + const body = parseRevokeBody(fetchFn.mock.calls[0]); + expect(body.get("token")).toBe("at-only"); + expect(body.get("token_type_hint")).toBe("access_token"); + expect(body.get("client_id")).toBeNull(); + }); + + it("no-ops when no tokens are stored", async () => { + const fetchFn = jest.fn< + Promise, + [RequestInfo | URL, RequestInit?] + >(); + + await revokeTokens({ serverUrl, fetchFn }); + + expect(fetchFn).not.toHaveBeenCalled(); + expect(mockDiscoverAuth).not.toHaveBeenCalled(); + }); + + it("no-ops when AS metadata has no revocation_endpoint", async () => { + mockDiscoverAuth.mockResolvedValue(baseMetadata); + seedTokens({ access_token: "at-123", refresh_token: "rt-456" }); + const fetchFn = jest.fn< + Promise, + [RequestInfo | URL, RequestInit?] + >(); + + await revokeTokens({ serverUrl, fetchFn }); + + expect(fetchFn).not.toHaveBeenCalled(); + }); + + it("swallows fetch rejection and logs a warning", async () => { + mockDiscoverAuth.mockResolvedValue(metadataWithRevocation); + seedTokens({ access_token: "at-123", refresh_token: "rt-456" }); + const fetchFn = jest + .fn, [RequestInfo | URL, RequestInit?]>() + .mockRejectedValue(new Error("network down")); + + await expect(revokeTokens({ serverUrl, fetchFn })).resolves.toBeUndefined(); + expect(warnSpy).toHaveBeenCalledWith( + "Token revocation failed (best-effort):", + expect.any(Error), + ); + }); + + it("treats non-2xx response as a soft failure without throwing", async () => { + mockDiscoverAuth.mockResolvedValue(metadataWithRevocation); + seedTokens({ access_token: "at-123", refresh_token: "rt-456" }); + const fetchFn = jest + .fn, [RequestInfo | URL, RequestInit?]>() + .mockResolvedValue( + new Response("nope", { status: 400, statusText: "Bad Request" }), + ); + + await expect(revokeTokens({ serverUrl, fetchFn })).resolves.toBeUndefined(); + expect(warnSpy).toHaveBeenCalledWith( + expect.stringContaining("Token revocation responded 400"), + ); + }); + + it("uses the provided fetchFn, not the global fetch", async () => { + mockDiscoverAuth.mockResolvedValue(metadataWithRevocation); + seedTokens({ access_token: "at-123", refresh_token: "rt-456" }); + const globalFetchSpy = jest + .spyOn(globalThis, "fetch") + .mockResolvedValue(new Response(null, { status: 200 })); + const fetchFn = jest + .fn, [RequestInfo | URL, RequestInit?]>() + .mockResolvedValue(new Response(null, { status: 200 })); + + try { + await revokeTokens({ serverUrl, fetchFn }); + expect(fetchFn).toHaveBeenCalledTimes(1); + expect(globalFetchSpy).not.toHaveBeenCalled(); + } finally { + globalFetchSpy.mockRestore(); + } + }); +}); diff --git a/client/src/lib/auth.ts b/client/src/lib/auth.ts index f0fc2fc4b..68a578cef 100644 --- a/client/src/lib/auth.ts +++ b/client/src/lib/auth.ts @@ -129,6 +129,92 @@ export const clearScopeFromSessionStorage = (serverUrl: string) => { sessionStorage.removeItem(key); }; +/** + * Best-effort RFC 7009 token revocation. Called on user-initiated disconnect so + * the authorization server can invalidate the access/refresh token rather than + * waiting for natural expiry. Per RFC 7009 ยง2.1, revoking a refresh token also + * invalidates associated access tokens, so we prefer the refresh token when + * present and fall back to the access token otherwise. + * + * Never throws: if there are no saved tokens, no advertised revocation_endpoint, + * or the POST fails for any reason, this resolves quietly. A 3s timeout keeps a + * slow AS from blocking the UI. + */ +export const revokeTokens = async ({ + serverUrl, + fetchFn, +}: { + serverUrl: string; + fetchFn?: typeof fetch; +}): Promise => { + try { + const tokensRaw = sessionStorage.getItem( + getServerSpecificKey(SESSION_KEYS.TOKENS, serverUrl), + ); + if (!tokensRaw) { + return; + } + const tokens = await OAuthTokensSchema.parseAsync(JSON.parse(tokensRaw)); + + const metadata = await discoverAuthorizationServerMetadata( + new URL("/", serverUrl), + { fetchFn }, + ); + // `revocation_endpoint` is declared on OAuthMetadata but not on the OIDC + // branch of the union, even though OIDC providers may advertise it at + // runtime per RFC 8414. Narrow with `in` so we read it from either shape. + const revocationEndpoint = + metadata && "revocation_endpoint" in metadata + ? (metadata as { revocation_endpoint?: string }).revocation_endpoint + : undefined; + if (!revocationEndpoint) { + return; + } + + const token = tokens.refresh_token ?? tokens.access_token; + const tokenTypeHint = tokens.refresh_token + ? "refresh_token" + : "access_token"; + + const body = new URLSearchParams(); + body.set("token", token); + body.set("token_type_hint", tokenTypeHint); + + // Public-client convention: include client_id in the form body. Try + // preregistered first, then dynamically registered (same priority as + // InspectorOAuthClientProvider.clientInformation()). + const clientInfo = + (await getClientInformationFromSessionStorage({ + serverUrl, + isPreregistered: true, + })) ?? + (await getClientInformationFromSessionStorage({ + serverUrl, + isPreregistered: false, + })); + if (clientInfo?.client_id) { + body.set("client_id", clientInfo.client_id); + } + + const response = await (fetchFn ?? fetch)(revocationEndpoint, { + method: "POST", + headers: { "Content-Type": "application/x-www-form-urlencoded" }, + body: body.toString(), + signal: AbortSignal.timeout(3000), + }); + + if (!response.ok) { + console.warn( + `Token revocation responded ${response.status} ${response.statusText} (best-effort, continuing)`, + ); + return; + } + console.debug("Token revocation succeeded"); + } catch (error) { + console.warn("Token revocation failed (best-effort):", error); + } +}; + export class InspectorOAuthClientProvider implements OAuthClientProvider { constructor(protected serverUrl: string) { // Save the server URL to session storage diff --git a/client/src/lib/configurationTypes.ts b/client/src/lib/configurationTypes.ts index 60a993564..a1c0705c8 100644 --- a/client/src/lib/configurationTypes.ts +++ b/client/src/lib/configurationTypes.ts @@ -45,4 +45,12 @@ export type InspectorConfig = { * Default Time-to-Live (TTL) in milliseconds for newly created tasks. */ MCP_TASK_TTL: ConfigItem; + + /** + * Whether to send an RFC 7009 token revocation request to the authorization server + * on Disconnect (when the server advertises a `revocation_endpoint`). Default `true` + * (spec-compliant). Disable to test server behavior when a client disconnects + * without revoking, or to suppress the network call during offline testing. + */ + MCP_OAUTH_REVOKE_ON_DISCONNECT: ConfigItem; }; diff --git a/client/src/lib/constants.ts b/client/src/lib/constants.ts index d986d3802..283598f37 100644 --- a/client/src/lib/constants.ts +++ b/client/src/lib/constants.ts @@ -92,4 +92,11 @@ export const DEFAULT_INSPECTOR_CONFIG: InspectorConfig = { value: 60000, is_session_item: false, }, + MCP_OAUTH_REVOKE_ON_DISCONNECT: { + label: "Revoke OAuth Tokens on Disconnect", + description: + "When disconnecting, send an RFC 7009 token revocation request to the authorization server before clearing local state. Disable to test how a server behaves when a client disconnects without revoking.", + value: true, + is_session_item: false, + }, } as const; diff --git a/client/src/lib/hooks/__tests__/useConnection.test.tsx b/client/src/lib/hooks/__tests__/useConnection.test.tsx index 1d4f4bd0f..27a0c591d 100644 --- a/client/src/lib/hooks/__tests__/useConnection.test.tsx +++ b/client/src/lib/hooks/__tests__/useConnection.test.tsx @@ -25,7 +25,7 @@ import { ElicitRequest, } from "@modelcontextprotocol/sdk/types.js"; import { auth } from "@modelcontextprotocol/sdk/client/auth.js"; -import { discoverScopes } from "../../auth"; +import { discoverScopes, revokeTokens } from "../../auth"; import { CustomHeaders } from "../../types/customHeaders"; // Mock fetch @@ -145,18 +145,23 @@ jest.mock("../../auth", () => ({ InspectorOAuthClientProvider: jest.fn().mockImplementation(() => ({ tokens: jest.fn().mockResolvedValue({ access_token: "mock-token" }), redirectUrl: "http://localhost:3000/oauth/callback", + clear: jest.fn(), })), clearClientInformationFromSessionStorage: jest.fn(), saveClientInformationToSessionStorage: jest.fn(), saveScopeToSessionStorage: jest.fn(), clearScopeFromSessionStorage: jest.fn(), discoverScopes: jest.fn(), + revokeTokens: jest.fn().mockResolvedValue(undefined), })); const mockAuth = auth as jest.MockedFunction; const mockDiscoverScopes = discoverScopes as jest.MockedFunction< typeof discoverScopes >; +const mockRevokeTokens = revokeTokens as jest.MockedFunction< + typeof revokeTokens +>; describe("useConnection", () => { const defaultProps: Parameters[0] = { @@ -1791,4 +1796,82 @@ describe("useConnection", () => { ).toBeNull(); }); }); + + describe("disconnect", () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + test("revokes tokens (RFC 7009) before clearing local OAuth state", async () => { + const { InspectorOAuthClientProvider } = jest.requireMock("../../auth"); + const providerCtor = InspectorOAuthClientProvider as jest.Mock; + providerCtor.mockClear(); + + const { result } = renderHook(() => useConnection(defaultProps)); + await act(async () => { + await result.current.connect(); + }); + + await act(async () => { + await result.current.disconnect(); + }); + + expect(mockRevokeTokens).toHaveBeenCalledTimes(1); + expect(mockRevokeTokens).toHaveBeenCalledWith({ + serverUrl: defaultProps.sseUrl, + fetchFn: expect.any(Function), + }); + + // disconnect constructs an InspectorOAuthClientProvider just for the + // local-clear step; that instance's clear() must have been invoked. + const disconnectInstance = providerCtor.mock.results[ + providerCtor.mock.results.length - 1 + ].value as { clear: jest.Mock }; + expect(disconnectInstance.clear).toHaveBeenCalledTimes(1); + + // Revocation must precede the local clear() so we have tokens to send. + expect(mockRevokeTokens.mock.invocationCallOrder[0]).toBeLessThan( + disconnectInstance.clear.mock.invocationCallOrder[0], + ); + + expect(result.current.connectionStatus).toBe("disconnected"); + }); + + test("skips revokeTokens when MCP_OAUTH_REVOKE_ON_DISCONNECT is false", async () => { + const { InspectorOAuthClientProvider } = jest.requireMock("../../auth"); + const providerCtor = InspectorOAuthClientProvider as jest.Mock; + providerCtor.mockClear(); + + const propsWithRevokeDisabled: Parameters[0] = { + ...defaultProps, + config: { + ...DEFAULT_INSPECTOR_CONFIG, + MCP_OAUTH_REVOKE_ON_DISCONNECT: { + ...DEFAULT_INSPECTOR_CONFIG.MCP_OAUTH_REVOKE_ON_DISCONNECT, + value: false, + }, + }, + }; + + const { result } = renderHook(() => + useConnection(propsWithRevokeDisabled), + ); + await act(async () => { + await result.current.connect(); + }); + + await act(async () => { + await result.current.disconnect(); + }); + + expect(mockRevokeTokens).not.toHaveBeenCalled(); + // Local clear still runs so the user gets a fresh slate even when + // remote revocation is opted out. + const disconnectInstance = providerCtor.mock.results[ + providerCtor.mock.results.length - 1 + ].value as { clear: jest.Mock }; + expect(disconnectInstance.clear).toHaveBeenCalledTimes(1); + expect(result.current.connectionStatus).toBe("disconnected"); + }); + }); }); diff --git a/client/src/lib/hooks/useConnection.ts b/client/src/lib/hooks/useConnection.ts index 016f8aa4f..0aa42a548 100644 --- a/client/src/lib/hooks/useConnection.ts +++ b/client/src/lib/hooks/useConnection.ts @@ -62,6 +62,7 @@ import { saveScopeToSessionStorage, clearScopeFromSessionStorage, discoverScopes, + revokeTokens, } from "../auth"; import { createProxyFetch } from "../proxyFetch"; import { @@ -70,6 +71,7 @@ import { getMCPServerRequestMaxTotalTimeout, resetRequestTimeoutOnProgress, getMCPProxyAuthToken, + revokeOAuthTokensOnDisconnect, } from "@/utils/configUtils"; import { getMCPServerRequestTimeout } from "@/utils/configUtils"; import { InspectorConfig } from "../configurationTypes"; @@ -1181,6 +1183,14 @@ export function useConnection({ clientTransport as StreamableHTTPClientTransport ).terminateSession(); await mcpClient?.close(); + // RFC 7009: revoke tokens at the AS before wiping local state, so the + // server doesn't keep a still-valid token around as a tombstone. Users + // testing the inverse scenario can opt out via the config toggle. + if (revokeOAuthTokensOnDisconnect(config)) { + const fetchFn = + connectionType === "proxy" ? createProxyFetch(config) : undefined; + await revokeTokens({ serverUrl: sseUrl, fetchFn }); + } const authProvider = new InspectorOAuthClientProvider(sseUrl); authProvider.clear(); setMcpClient(null); diff --git a/client/src/utils/configUtils.ts b/client/src/utils/configUtils.ts index bc081b8f8..0457e55e5 100644 --- a/client/src/utils/configUtils.ts +++ b/client/src/utils/configUtils.ts @@ -58,6 +58,12 @@ export const getMCPTaskTtl = (config: InspectorConfig): number => { return config.MCP_TASK_TTL.value as number; }; +export const revokeOAuthTokensOnDisconnect = ( + config: InspectorConfig, +): boolean => { + return config.MCP_OAUTH_REVOKE_ON_DISCONNECT.value as boolean; +}; + export const getInitialTransportType = (): | "stdio" | "sse"