diff --git a/package.json b/package.json index 0da0abb8..90eed241 100644 --- a/package.json +++ b/package.json @@ -44,7 +44,7 @@ "dist" ], "scripts": { - "start": "node dist/index.js --transport http --loggers stderr mcp --previewFeatures vectorSearch", + "start": "node dist/index.js --transport http --loggers stderr mcp --previewFeatures search", "start:stdio": "node dist/index.js --transport stdio --loggers stderr mcp", "prepare": "husky && pnpm run build", "build:clean": "rm -rf dist", diff --git a/src/tools/mongodb/create/insertMany.ts b/src/tools/mongodb/create/insertMany.ts index e68e97a1..600edf23 100644 --- a/src/tools/mongodb/create/insertMany.ts +++ b/src/tools/mongodb/create/insertMany.ts @@ -1,6 +1,6 @@ import { z } from "zod"; -import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; +import type { ToolResult } from "../../tool.js"; import { type ToolArgs, type OperationType, formatUntrustedData } from "../../tool.js"; import { zEJSON } from "../../args.js"; import { type Document } from "bson"; @@ -37,6 +37,13 @@ export class InsertManyTool extends MongoDBToolBase { ), } : commonArgs; + + protected outputShape = { + success: z.boolean(), + insertedCount: z.number(), + insertedIds: z.array(z.any()), + }; + static operationType: OperationType = "create"; protected async execute({ @@ -44,7 +51,7 @@ export class InsertManyTool extends MongoDBToolBase { collection, documents, embeddingParameters: providedEmbeddingParameters, - }: ToolArgs): Promise { + }: ToolArgs): Promise> { const provider = await this.ensureConnected(); const embeddingParameters = this.isFeatureEnabled("search") @@ -70,8 +77,14 @@ export class InsertManyTool extends MongoDBToolBase { `Inserted \`${result.insertedCount}\` document(s) into ${database}.${collection}.`, `Inserted IDs: ${Object.values(result.insertedIds).join(", ")}` ); + return { content, + structuredContent: { + success: true, + insertedCount: result.insertedCount, + insertedIds: Object.values(result.insertedIds), + }, }; } diff --git a/src/tools/mongodb/metadata/listDatabases.ts b/src/tools/mongodb/metadata/listDatabases.ts index 8cdb4aab..0f7347ad 100644 --- a/src/tools/mongodb/metadata/listDatabases.ts +++ b/src/tools/mongodb/metadata/listDatabases.ts @@ -1,24 +1,31 @@ -import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { MongoDBToolBase } from "../mongodbTool.js"; import type * as bson from "bson"; -import type { OperationType } from "../../tool.js"; +import type { OperationType, ToolResult } from "../../tool.js"; import { formatUntrustedData } from "../../tool.js"; +import z, { type ZodNever } from "zod"; + +export const ListDatabasesToolOutputShape = { + dbs: z.array(z.object({ name: z.string(), sizeOnDisk: z.string(), sizeUnit: z.literal("bytes") })), +}; + +export type ListDatabasesToolOutput = z.objectOutputType; export class ListDatabasesTool extends MongoDBToolBase { public name = "list-databases"; protected description = "List all databases for a MongoDB connection"; protected argsShape = {}; + protected outputShape = ListDatabasesToolOutputShape; static operationType: OperationType = "metadata"; - protected async execute(): Promise { + protected async execute(): Promise> { const provider = await this.ensureConnected(); - const dbs = (await provider.listDatabases("")).databases as { name: string; sizeOnDisk: bson.Long }[]; + const dbs = ((await provider.listDatabases("")).databases as { name: string; sizeOnDisk: bson.Long }[]).map( + (db) => ({ name: db.name, sizeOnDisk: db.sizeOnDisk.toString(), sizeUnit: "bytes" as const }) + ); return { - content: formatUntrustedData( - `Found ${dbs.length} databases`, - ...dbs.map((db) => `Name: ${db.name}, Size: ${db.sizeOnDisk.toString()} bytes`) - ), + content: formatUntrustedData(`Found ${dbs.length} databases`, JSON.stringify(dbs)), + structuredContent: { dbs }, }; } } diff --git a/src/tools/tool.ts b/src/tools/tool.ts index 5c00cfab..e550faa0 100644 --- a/src/tools/tool.ts +++ b/src/tools/tool.ts @@ -1,5 +1,4 @@ -import type { z } from "zod"; -import { type ZodRawShape, type ZodNever } from "zod"; +import type { z, ZodTypeAny, ZodRawShape, ZodNever } from "zod"; import type { RegisteredTool, ToolCallback } from "@modelcontextprotocol/sdk/server/mcp.js"; import type { CallToolResult, ToolAnnotations } from "@modelcontextprotocol/sdk/types.js"; import type { Session } from "../common/session.js"; @@ -14,6 +13,12 @@ import type { PreviewFeature } from "../common/schemas.js"; export type ToolArgs = z.objectOutputType; export type ToolCallbackArgs = Parameters>; +export type ToolResult = { + content: { type: "text"; text: string }[]; + structuredContent: OutputSchema extends ZodRawShape ? z.objectOutputType : never; + isError?: boolean; +}; + export type ToolExecutionContext = Parameters>[1]; /** @@ -274,6 +279,8 @@ export abstract class ToolBase { */ protected abstract argsShape: ZodRawShape; + protected outputShape?: ZodRawShape; + private registeredTool: RegisteredTool | undefined; protected get annotations(): ToolAnnotations { @@ -462,11 +469,14 @@ export abstract class ToolBase { } }; - this.registeredTool = server.mcpServer.tool( + this.registeredTool = server.mcpServer.registerTool( this.name, - this.description, - this.argsShape, - this.annotations, + { + description: this.description, + inputSchema: this.argsShape, + annotations: this.annotations, + outputSchema: this.outputShape, + }, callback ); diff --git a/tests/integration/helpers.ts b/tests/integration/helpers.ts index 5bd44511..b0df445e 100644 --- a/tests/integration/helpers.ts +++ b/tests/integration/helpers.ts @@ -184,7 +184,7 @@ export function setupIntegrationTest( } // eslint-disable-next-line @typescript-eslint/no-redundant-type-constituents -export function getResponseContent(content: unknown | { content: unknown }): string { +export function getResponseContent(content: unknown | { content: unknown; structuredContent: unknown }): string { return getResponseElements(content) .map((item) => item.text) .join("\n"); diff --git a/tests/integration/tools/mongodb/create/insertMany.test.ts b/tests/integration/tools/mongodb/create/insertMany.test.ts index 2170efb9..2941e834 100644 --- a/tests/integration/tools/mongodb/create/insertMany.test.ts +++ b/tests/integration/tools/mongodb/create/insertMany.test.ts @@ -72,10 +72,11 @@ describeWithMongoDB("insertMany tool when search is disabled", (integration) => }, }); - const content = getResponseContent(response.content); + const content = getResponseContent(response); expect(content).toContain(`Inserted \`1\` document(s) into ${integration.randomDbName()}.coll1.`); await validateDocuments("coll1", [{ prop1: "value1" }]); + validateStructuredContent(response.structuredContent, extractInsertedIds(content)); }); it("returns an error when inserting duplicates", async () => { @@ -95,7 +96,7 @@ describeWithMongoDB("insertMany tool when search is disabled", (integration) => }, }); - const content = getResponseContent(response.content); + const content = getResponseContent(response); expect(content).toContain("Error running insert-many"); expect(content).toContain("duplicate key error"); expect(content).toContain(insertedIds[0]?.toString()); @@ -174,12 +175,14 @@ describeWithMongoDB( }, }); - const content = getResponseContent(response.content); + const content = getResponseContent(response); const insertedIds = extractInsertedIds(content); expect(insertedIds).toHaveLength(1); const docCount = await collection.countDocuments({ _id: insertedIds[0] }); expect(docCount).toBe(1); + + validateStructuredContent(response.structuredContent, insertedIds); }); it("returns an error when there is a search index and embeddings parameter are wrong", async () => { @@ -214,7 +217,7 @@ describeWithMongoDB( }, }); - const content = getResponseContent(response.content); + const content = getResponseContent(response); expect(content).toContain("Error running insert-many"); const untrustedContent = getDataFromUntrustedContent(content); expect(untrustedContent).toContain( @@ -263,10 +266,11 @@ describeWithMongoDB( }, }); - const content = getResponseContent(response.content); + const content = getResponseContent(response); expect(content).toContain("Documents were inserted successfully."); const insertedIds = extractInsertedIds(content); expect(insertedIds).toHaveLength(1); + validateStructuredContent(response.structuredContent, insertedIds); const doc = await collection.findOne({ _id: insertedIds[0] }); expect(doc).toBeDefined(); @@ -316,10 +320,11 @@ describeWithMongoDB( }, }); - const content = getResponseContent(response.content); + const content = getResponseContent(response); expect(content).toContain("Documents were inserted successfully."); const insertedIds = extractInsertedIds(content); expect(insertedIds).toHaveLength(2); + validateStructuredContent(response.structuredContent, insertedIds); const doc1 = await collection.findOne({ _id: insertedIds[0] }); expect(doc1?.title).toBe("The Matrix"); @@ -369,10 +374,11 @@ describeWithMongoDB( }, }); - const content = getResponseContent(response.content); + const content = getResponseContent(response); expect(content).toContain("Documents were inserted successfully."); const insertedIds = extractInsertedIds(content); expect(insertedIds).toHaveLength(1); + validateStructuredContent(response.structuredContent, insertedIds); const doc = await collection.findOne({ _id: insertedIds[0] }); expect(doc?.info).toBeDefined(); @@ -417,10 +423,11 @@ describeWithMongoDB( }, }); - const content = getResponseContent(response.content); + const content = getResponseContent(response); expect(content).toContain("Documents were inserted successfully."); const insertedIds = extractInsertedIds(content); expect(insertedIds).toHaveLength(1); + validateStructuredContent(response.structuredContent, insertedIds); const doc = await collection.findOne({ _id: insertedIds[0] }); expect(doc?.title).toBe("The Matrix"); @@ -452,10 +459,11 @@ describeWithMongoDB( }, }, }); - const content = getResponseContent(response.content); + const content = getResponseContent(response); expect(content).toContain("Documents were inserted successfully."); const insertedIds = extractInsertedIds(content); expect(insertedIds).toHaveLength(1); + validateStructuredContent(response.structuredContent, insertedIds); const doc = await collection.findOne({ _id: insertedIds[0] }); expect((doc?.title as Record)?.text).toBe("The Matrix"); @@ -495,7 +503,7 @@ describeWithMongoDB( }, }); - const content = getResponseContent(response.content); + const content = getResponseContent(response); expect(content).toContain("Error running insert-many"); expect(content).toContain("Field 'nonExistentField' does not have a vector search index in collection"); expect(content).toContain("Only fields with vector search indexes can have embeddings generated"); @@ -529,10 +537,11 @@ describeWithMongoDB( }, }); - const content = getResponseContent(response.content); + const content = getResponseContent(response); expect(content).toContain("Documents were inserted successfully."); const insertedIds = extractInsertedIds(content); expect(insertedIds).toHaveLength(1); + validateStructuredContent(response.structuredContent, insertedIds); const doc = await collection.findOne({ _id: insertedIds[0] }); expect(doc?.title).toBe("The Matrix"); @@ -564,9 +573,10 @@ describeWithMongoDB( }, }); - const content = getResponseContent(response.content); + const content = getResponseContent(response); expect(content).toContain("Documents were inserted successfully."); const insertedIds = extractInsertedIds(content); + validateStructuredContent(response.structuredContent, insertedIds); const doc = await collection.findOne({ _id: insertedIds[0] }); expect(Array.isArray(doc?.titleEmbeddings)).toBe(true); @@ -614,9 +624,10 @@ describeWithMongoDB( }, }); - const content = getResponseContent(response.content); + const content = getResponseContent(response); expect(content).toContain("Documents were inserted successfully."); const insertedIds = extractInsertedIds(content); + validateStructuredContent(response.structuredContent, insertedIds); const doc = await collection.findOne({ _id: insertedIds[0] }); expect(doc?.title).toBe("The Matrix"); @@ -692,3 +703,11 @@ function extractInsertedIds(content: string): ObjectId[] { .map((e) => ObjectId.createFromHexString(e)) ?? [] ); } + +function validateStructuredContent(structuredContent: unknown, expectedIds: ObjectId[]): void { + expect(structuredContent).toEqual({ + success: true, + insertedCount: expectedIds.length, + insertedIds: expectedIds, + }); +} diff --git a/tests/integration/tools/mongodb/metadata/listDatabases.test.ts b/tests/integration/tools/mongodb/metadata/listDatabases.test.ts index 6caa016b..f654b759 100644 --- a/tests/integration/tools/mongodb/metadata/listDatabases.test.ts +++ b/tests/integration/tools/mongodb/metadata/listDatabases.test.ts @@ -1,6 +1,7 @@ import { describeWithMongoDB, validateAutoConnectBehavior } from "../mongodbHelpers.js"; import { getResponseElements, getParameters, expectDefined, getDataFromUntrustedContent } from "../../../helpers.js"; import { describe, expect, it } from "vitest"; +import type { ListDatabasesToolOutput } from "../../../../../src/tools/mongodb/metadata/listDatabases.js"; describeWithMongoDB("listDatabases tool", (integration) => { const defaultDatabases = ["admin", "config", "local"]; @@ -22,6 +23,9 @@ describeWithMongoDB("listDatabases tool", (integration) => { const dbNames = getDbNames(response.content); expect(dbNames).toIncludeSameMembers(defaultDatabases); + + const structuredContent = response.structuredContent as ListDatabasesToolOutput; + expect(structuredContent.dbs.map((db) => db.name)).toIncludeSameMembers(defaultDatabases); }); }); @@ -36,6 +40,13 @@ describeWithMongoDB("listDatabases tool", (integration) => { const response = await integration.mcpClient().callTool({ name: "list-databases", arguments: {} }); const dbNames = getDbNames(response.content); expect(dbNames).toIncludeSameMembers([...defaultDatabases, "foo", "baz"]); + + const structuredContent = response.structuredContent as ListDatabasesToolOutput; + expect(structuredContent.dbs.map((db) => db.name)).toIncludeSameMembers([ + ...defaultDatabases, + "foo", + "baz", + ]); }); }); @@ -68,11 +79,6 @@ function getDbNames(content: unknown): (string | null)[] { const responseItems = getResponseElements(content); expect(responseItems).toHaveLength(2); const data = getDataFromUntrustedContent(responseItems[1]?.text ?? "{}"); - return data - .split("\n") - .map((item) => { - const match = item.match(/Name: ([^,]+), Size: \d+ bytes/); - return match ? match[1] : null; - }) - .filter((item): item is string | null => item !== undefined); + + return (JSON.parse(data) as ListDatabasesToolOutput["dbs"]).map((db) => db.name); } diff --git a/tests/unit/toolBase.test.ts b/tests/unit/toolBase.test.ts index b5dc928c..8f0a6340 100644 --- a/tests/unit/toolBase.test.ts +++ b/tests/unit/toolBase.test.ts @@ -141,15 +141,13 @@ describe("ToolBase", () => { beforeEach(() => { const mockServer = { mcpServer: { - tool: ( + registerTool: ( name: string, - description: string, - paramsSchema: unknown, - annotations: ToolAnnotations, + config: { description: string; inputSchema: unknown; annotations: ToolAnnotations }, cb: ToolCallback ): void => { expect(name).toBe(testTool.name); - expect(description).toBe(testTool["description"]); + expect(config.description).toBe(testTool["description"]); mockCallback = cb; }, },