diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c68732..b0f23e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## Unreleased +### New Features + +- **`ground_location_tool` task-based streaming** (experimental): Convert `ground_location_tool` to use the MCP tasks extension (`server.experimental.tasks.registerToolTask`). The tool now returns a task handle immediately on `tools/call` instead of blocking until all API calls complete. Reverse geocoding and sampling classification run in parallel in the background; POI search and isochrone follow once the strategy is known. Task-capable clients get streaming updates; clients without task support get the same synchronous result as before via the SDK automatic polling path (`taskSupport: 'optional'`). The server is configured with `InMemoryTaskStore` to support task lifecycle management. See issue #197. + ### Security - **static_map_image_tool**: Stop embedding the Mapbox access token in tool results. Previously the tool returned a `createUIResource({ iframeUrl })` whose URL carried the caller's `?access_token=` query param, leaking the secret token via the MCP-UI resource item. The credentialed URL is now only used server-side to fetch the image, which is returned inline as base64. The tool's `meta.ui.resourceUri` declaration is removed (the iframe path required the credentialed URL to function and cannot be reinstated without leaking). A regression test asserts the access token does not appear in any content item. diff --git a/src/index.ts b/src/index.ts index d0f3dc6..fd1d0ee 100644 --- a/src/index.ts +++ b/src/index.ts @@ -10,6 +10,7 @@ import { existsSync } from 'node:fs'; import { SpanStatusCode } from '@opentelemetry/api'; import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import { InMemoryTaskStore } from '@modelcontextprotocol/sdk/experimental/tasks'; import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; import { registerAppResource, @@ -107,7 +108,8 @@ const server = new McpServer( resources: {}, prompts: {}, logging: {} - } + }, + taskStore: new InMemoryTaskStore() } ); diff --git a/src/tools/MapboxApiBasedTool.ts b/src/tools/MapboxApiBasedTool.ts index 4febfcc..4f0353a 100644 --- a/src/tools/MapboxApiBasedTool.ts +++ b/src/tools/MapboxApiBasedTool.ts @@ -46,7 +46,7 @@ export abstract class MapboxApiBasedTool< * @param token The token string to validate * @returns boolean indicating if the token has valid JWT format */ - private isValidJwtFormat(token: string): boolean { + protected isValidJwtFormat(token: string): boolean { // JWT consists of three parts separated by dots: header.payload.signature const parts = token.split('.'); if (parts.length !== 3) return false; diff --git a/src/tools/ground-location-tool/GroundLocationTool.ts b/src/tools/ground-location-tool/GroundLocationTool.ts index 86f5bfa..eeeb0cb 100644 --- a/src/tools/ground-location-tool/GroundLocationTool.ts +++ b/src/tools/ground-location-tool/GroundLocationTool.ts @@ -4,6 +4,11 @@ import type { z } from 'zod'; import { MapboxApiBasedTool } from '../MapboxApiBasedTool.js'; import type { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; +import type { + McpServer, + RegisteredTool +} from '@modelcontextprotocol/sdk/server/mcp.js'; +import type { RequestTaskStore } from '@modelcontextprotocol/sdk/shared/protocol.js'; import type { HttpRequest } from '../../utils/types.js'; import { GroundLocationInputSchema } from './GroundLocationTool.input.schema.js'; import { @@ -88,6 +93,168 @@ export class GroundLocationTool extends MapboxApiBasedTool< }); } + override installTo(server: McpServer): RegisteredTool { + this.server = server; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const inputShape = (this.inputSchema as unknown as { shape: any }).shape; + return server.experimental.tasks.registerToolTask( + this.name, + { + title: this.annotations.title, + description: this.description, + inputSchema: inputShape, + outputSchema: this.outputSchema, + annotations: this.annotations, + execution: { taskSupport: 'optional' } + }, + { + createTask: async ( + args: z.infer, + extra + ) => { + const accessToken = + extra.authInfo?.token || MapboxApiBasedTool.mapboxAccessToken; + if (!accessToken || !this.isValidJwtFormat(accessToken)) { + throw new Error( + 'No valid access token. Provide via Bearer auth or MAPBOX_ACCESS_TOKEN env var.' + ); + } + // pollInterval is set low so the SDK automatic polling path (used for + // clients that do not support tasks) has minimal extra latency. + const task = await extra.taskStore.createTask({ + ttl: 60_000, + pollInterval: 50 + }); + void this.runTaskBackground( + args, + accessToken, + task.taskId, + extra.taskStore + ); + return { task }; + }, + getTask: async (_args, extra) => { + return extra.taskStore.getTask(extra.taskId); + }, + getTaskResult: async (_args, extra) => { + return extra.taskStore.getTaskResult( + extra.taskId + ) as Promise; + } + } + ); + } + + private async runTaskBackground( + rawArgs: z.infer, + accessToken: string, + taskId: string, + taskStore: RequestTaskStore + ): Promise { + try { + const { + longitude, + latitude, + query, + profile, + contours_minutes, + limit, + language + } = GroundLocationInputSchema.parse(rawArgs); + const citations: string[] = ['Mapbox Geocoding API']; + + // Kick off sampling + a fast initial geocode in parallel so the place name + // is available as soon as possible regardless of sampling latency. + const [strategy, initialGeocode] = await Promise.all([ + this.classifyGroundingStrategy(query, longitude, latitude), + this.reverseGeocode( + longitude, + latitude, + accessToken, + 'neighborhood,locality,place', + language + ) + ]); + + // Refine geocode types now that we know the strategy. + const geocodeTypes = + strategy === 'routing' + ? 'address,poi' + : strategy === 'region' + ? 'region,district,place' + : 'neighborhood,locality,place'; + + const geocodeResult = + geocodeTypes !== 'neighborhood,locality,place' + ? await this.reverseGeocode( + longitude, + latitude, + accessToken, + geocodeTypes, + language + ) + : initialGeocode; + + // Fan out POIs + isochrone now that strategy is known. + const [poisResult, isochroneResult] = await Promise.all([ + query || strategy === 'poi' + ? this.categorySearch( + query ?? 'place', + longitude, + latitude, + strategy === 'poi' ? Math.max(limit, 15) : limit, + accessToken, + language + ).then((pois) => { + if (pois?.length) citations.push('Mapbox Search API'); + return pois; + }) + : Promise.resolve(undefined), + strategy === 'region' || strategy === 'neighborhood' + ? this.isochrone( + longitude, + latitude, + profile, + contours_minutes, + accessToken + ).then((iso) => { + if (iso) citations.push('Mapbox Isochrone API'); + return iso; + }) + : Promise.resolve(undefined) + ]); + + const result: GroundLocationOutput = { + place: geocodeResult.place, + full_address: geocodeResult.full_address, + longitude, + latitude, + nearby_pois: poisResult ?? undefined, + isochrone: isochroneResult ?? undefined, + citations + }; + + const validated = GroundLocationOutputSchema.safeParse(result); + const output = validated.success ? validated.data : result; + + await taskStore.storeTaskResult(taskId, 'completed', { + content: [{ type: 'text', text: this.formatOutput(output, strategy) }], + structuredContent: output as unknown as Record, + isError: false + }); + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + try { + await taskStore.storeTaskResult(taskId, 'failed', { + content: [{ type: 'text', text: message }], + isError: true + }); + } catch { + // Task may have been cancelled before we could store the failure. + } + } + } + /** * Use sampling to classify what kind of grounding the query needs. * Falls back to 'neighborhood' if sampling is unavailable or classification fails. diff --git a/test/tools/ground-location-tool/GroundLocationTool.test.ts b/test/tools/ground-location-tool/GroundLocationTool.test.ts index 9ccf192..11a81c4 100644 --- a/test/tools/ground-location-tool/GroundLocationTool.test.ts +++ b/test/tools/ground-location-tool/GroundLocationTool.test.ts @@ -7,6 +7,7 @@ process.env.MAPBOX_ACCESS_TOKEN = import { describe, it, expect, afterEach, vi } from 'vitest'; import { setupHttpRequest } from '../../utils/httpPipelineUtils.js'; import { GroundLocationTool } from '../../../src/tools/ground-location-tool/GroundLocationTool.js'; +import { InMemoryTaskStore } from '@modelcontextprotocol/sdk/experimental/tasks'; const geocodeResponse = { features: [ @@ -268,3 +269,140 @@ describe('GroundLocationTool', () => { expect(categoryCall?.[0]).toContain('limit=15'); }); }); + +describe('GroundLocationTool — task-based flow', () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + function buildTaskStore() { + const store = new InMemoryTaskStore(); + const requestId = 1; + const request = { + method: 'tools/call', + params: { name: 'ground_location_tool', arguments: {} } + }; + // Wrap in a RequestTaskStore-compatible shim bound to a fixed session. + const taskStore = { + createTask: (params: { ttl?: number }) => + store.createTask(params, requestId, request), + getTask: (taskId: string) => + store.getTask(taskId).then((t) => { + if (!t) throw new Error(`task not found: ${taskId}`); + return t; + }), + storeTaskResult: ( + taskId: string, + status: 'completed' | 'failed', + // eslint-disable-next-line @typescript-eslint/no-explicit-any + result: any + ) => store.storeTaskResult(taskId, status, result), + getTaskResult: (taskId: string) => store.getTaskResult(taskId), + updateTaskStatus: ( + taskId: string, + status: Parameters[1] + ) => store.updateTaskStatus(taskId, status), + listTasks: (cursor?: string) => store.listTasks(cursor) + }; + return taskStore; + } + + it('creates task immediately and resolves with place name', async () => { + const { httpRequest } = setupHttpRequest(); + const mockFetch = vi.fn().mockImplementation((url: string) => { + if (url.includes('geocode/v6/reverse')) + return Promise.resolve({ + ok: true, + json: async () => geocodeResponse + }); + if (url.includes('isochrone/v1')) + return Promise.resolve({ + ok: true, + json: async () => isochroneResponse + }); + return Promise.resolve({ ok: false, json: async () => ({}) }); + }); + const tool = new GroundLocationTool({ + httpRequest: mockFetch as unknown as typeof httpRequest + }); + + const taskStore = buildTaskStore(); + const task = await taskStore.createTask({ ttl: 60_000 }); + + // Simulate what createTask handler does + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await (tool as any).runTaskBackground( + { longitude: -122.419, latitude: 37.759 }, + process.env.MAPBOX_ACCESS_TOKEN, + task.taskId, + taskStore + ); + + const completedTask = await taskStore.getTask(task.taskId); + expect(completedTask.status).toBe('completed'); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const result = (await taskStore.getTaskResult(task.taskId)) as any; + expect(result.isError).toBe(false); + const text = result.content[0].text as string; + expect(text).toContain('Mission District'); + }); + + it('stores failed result when API errors out', async () => { + const { httpRequest } = setupHttpRequest(); + const mockFetch = vi.fn().mockRejectedValue(new Error('network error')); + const tool = new GroundLocationTool({ + httpRequest: mockFetch as unknown as typeof httpRequest + }); + + const taskStore = buildTaskStore(); + const task = await taskStore.createTask({ ttl: 60_000 }); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await (tool as any).runTaskBackground( + { longitude: -122.419, latitude: 37.759 }, + process.env.MAPBOX_ACCESS_TOKEN, + task.taskId, + taskStore + ); + + const completedTask = await taskStore.getTask(task.taskId); + expect(completedTask.status).toBe('failed'); + }); + + it('non-task clients still get a result via runTaskBackground', async () => { + const { httpRequest } = setupHttpRequest(); + const mockFetch = vi.fn().mockImplementation((url: string) => { + if (url.includes('geocode/v6/reverse')) + return Promise.resolve({ + ok: true, + json: async () => geocodeResponse + }); + if (url.includes('isochrone/v1')) + return Promise.resolve({ + ok: true, + json: async () => isochroneResponse + }); + return Promise.resolve({ ok: false, json: async () => ({}) }); + }); + const tool = new GroundLocationTool({ + httpRequest: mockFetch as unknown as typeof httpRequest + }); + + const taskStore = buildTaskStore(); + const task = await taskStore.createTask({ ttl: 60_000, pollInterval: 50 }); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await (tool as any).runTaskBackground( + { longitude: -122.419, latitude: 37.759 }, + process.env.MAPBOX_ACCESS_TOKEN, + task.taskId, + taskStore + ); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const result = (await taskStore.getTaskResult(task.taskId)) as any; + expect(result.isError).toBe(false); + expect(result.content[0].text).toContain('Mission District'); + }); +});