diff --git a/package-lock.json b/package-lock.json index da4b5a0cba9..1dcc3e34a63 100644 --- a/package-lock.json +++ b/package-lock.json @@ -16768,6 +16768,15 @@ } } }, + "node_modules/@vercel/oidc": { + "version": "3.0.5", + "resolved": "https://registry.npmjs.org/@vercel/oidc/-/oidc-3.0.5.tgz", + "integrity": "sha512-fnYhv671l+eTTp48gB4zEsTW/YtRgRPnkI2nT7x6qw5rkI1Lq2hTmQIpHPgyThI0znLK+vX2n9XxKdXZ7BUbbw==", + "license": "Apache-2.0", + "engines": { + "node": ">= 20" + } + }, "node_modules/@vue/compiler-core": { "version": "3.5.6", "resolved": "https://registry.npmjs.org/@vue/compiler-core/-/compiler-core-3.5.6.tgz", @@ -25619,6 +25628,15 @@ "node": ">=0.4.x" } }, + "node_modules/eventsource-parser": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-3.0.6.tgz", + "integrity": "sha512-Vo1ab+QXPzZ4tCa8SwIHJFaSzy4R6SHf7BY79rFBDf0idraZWAkYrDjDj8uWaSm3S2TK+hJ7/t1CEmZ7jXw+pg==", + "license": "MIT", + "engines": { + "node": ">=18.0.0" + } + }, "node_modules/evp_bytestokey": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/evp_bytestokey/-/evp_bytestokey-1.0.3.tgz", @@ -49661,6 +49679,7 @@ "version": "0.68.0", "license": "SSPL", "dependencies": { + "@ai-sdk/openai": "^2.0.4", "@mongodb-js/atlas-service": "^0.73.0", "@mongodb-js/compass-app-registry": "^9.4.29", "@mongodb-js/compass-components": "^1.59.2", @@ -49669,9 +49688,11 @@ "@mongodb-js/compass-telemetry": "^1.19.5", "@mongodb-js/compass-utils": "^0.9.23", "@mongodb-js/connection-info": "^0.24.0", + "ai": "^5.0.26", "bson": "^6.10.4", "compass-preferences-model": "^2.66.3", "mongodb": "^6.19.0", + "mongodb-query-parser": "^4.5.0", "mongodb-schema": "^12.6.3", "react": "^17.0.2", "react-redux": "^8.1.3", @@ -49694,7 +49715,6 @@ "depcheck": "^1.4.1", "electron-mocha": "^12.2.0", "mocha": "^10.2.0", - "mongodb-query-parser": "^4.5.0", "nyc": "^15.1.0", "p-queue": "^7.4.1", "sinon": "^9.2.3", @@ -49702,6 +49722,74 @@ "xvfb-maybe": "^0.2.1" } }, + "packages/compass-generative-ai/node_modules/@ai-sdk/gateway": { + "version": "2.0.17", + "resolved": "https://registry.npmjs.org/@ai-sdk/gateway/-/gateway-2.0.17.tgz", + "integrity": "sha512-oVAG6q72KsjKlrYdLhWjRO7rcqAR8CjokAbYuyVZoCO4Uh2PH/VzZoxZav71w2ipwlXhHCNaInGYWNs889MMDA==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "2.0.0", + "@ai-sdk/provider-utils": "3.0.18", + "@vercel/oidc": "3.0.5" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "packages/compass-generative-ai/node_modules/@ai-sdk/openai": { + "version": "2.0.75", + "resolved": "https://registry.npmjs.org/@ai-sdk/openai/-/openai-2.0.75.tgz", + "integrity": "sha512-ThDHg1+Jes7S0AOXa01EyLBSzZiZwzB5do9vAlufNkoiRHGTH1BmoShrCyci/TUsg4ky1HwbK4hPK+Z0isiE6g==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "2.0.0", + "@ai-sdk/provider-utils": "3.0.18" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "packages/compass-generative-ai/node_modules/@ai-sdk/provider-utils": { + "version": "3.0.18", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-3.0.18.tgz", + "integrity": "sha512-ypv1xXMsgGcNKUP+hglKqtdDuMg68nWHucPPAhIENrbFAI+xCHiqPVN8Zllxyv1TNZwGWUghPxJXU+Mqps0YRQ==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "2.0.0", + "@standard-schema/spec": "^1.0.0", + "eventsource-parser": "^3.0.6" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "packages/compass-generative-ai/node_modules/ai": { + "version": "5.0.104", + "resolved": "https://registry.npmjs.org/ai/-/ai-5.0.104.tgz", + "integrity": "sha512-MZOkL9++nY5PfkpWKBR3Rv+Oygxpb9S16ctv8h91GvrSif7UnNEdPMVZe3bUyMd2djxf0AtBk/csBixP0WwWZQ==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/gateway": "2.0.17", + "@ai-sdk/provider": "2.0.0", + "@ai-sdk/provider-utils": "3.0.18", + "@opentelemetry/api": "1.9.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, "packages/compass-generative-ai/node_modules/diff": { "version": "4.0.2", "resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz", @@ -62375,6 +62463,7 @@ "@mongodb-js/compass-generative-ai": { "version": "file:packages/compass-generative-ai", "requires": { + "@ai-sdk/openai": "^2.0.4", "@mongodb-js/atlas-service": "^0.73.0", "@mongodb-js/compass-app-registry": "^9.4.29", "@mongodb-js/compass-components": "^1.59.2", @@ -62393,6 +62482,7 @@ "@types/mocha": "^9.0.0", "@types/react": "^17.0.5", "@types/sinon-chai": "^3.2.5", + "ai": "^5.0.26", "bson": "^6.10.4", "chai": "^4.3.6", "compass-preferences-model": "^2.66.3", @@ -62414,6 +62504,46 @@ "zod": "^3.25.76" }, "dependencies": { + "@ai-sdk/gateway": { + "version": "2.0.17", + "resolved": "https://registry.npmjs.org/@ai-sdk/gateway/-/gateway-2.0.17.tgz", + "integrity": "sha512-oVAG6q72KsjKlrYdLhWjRO7rcqAR8CjokAbYuyVZoCO4Uh2PH/VzZoxZav71w2ipwlXhHCNaInGYWNs889MMDA==", + "requires": { + "@ai-sdk/provider": "2.0.0", + "@ai-sdk/provider-utils": "3.0.18", + "@vercel/oidc": "3.0.5" + } + }, + "@ai-sdk/openai": { + "version": "2.0.75", + "resolved": "https://registry.npmjs.org/@ai-sdk/openai/-/openai-2.0.75.tgz", + "integrity": "sha512-ThDHg1+Jes7S0AOXa01EyLBSzZiZwzB5do9vAlufNkoiRHGTH1BmoShrCyci/TUsg4ky1HwbK4hPK+Z0isiE6g==", + "requires": { + "@ai-sdk/provider": "2.0.0", + "@ai-sdk/provider-utils": "3.0.18" + } + }, + "@ai-sdk/provider-utils": { + "version": "3.0.18", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-3.0.18.tgz", + "integrity": "sha512-ypv1xXMsgGcNKUP+hglKqtdDuMg68nWHucPPAhIENrbFAI+xCHiqPVN8Zllxyv1TNZwGWUghPxJXU+Mqps0YRQ==", + "requires": { + "@ai-sdk/provider": "2.0.0", + "@standard-schema/spec": "^1.0.0", + "eventsource-parser": "^3.0.6" + } + }, + "ai": { + "version": "5.0.104", + "resolved": "https://registry.npmjs.org/ai/-/ai-5.0.104.tgz", + "integrity": "sha512-MZOkL9++nY5PfkpWKBR3Rv+Oygxpb9S16ctv8h91GvrSif7UnNEdPMVZe3bUyMd2djxf0AtBk/csBixP0WwWZQ==", + "requires": { + "@ai-sdk/gateway": "2.0.17", + "@ai-sdk/provider": "2.0.0", + "@ai-sdk/provider-utils": "3.0.18", + "@opentelemetry/api": "1.9.0" + } + }, "diff": { "version": "4.0.2", "resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz", @@ -70978,6 +71108,11 @@ "dev": true, "requires": {} }, + "@vercel/oidc": { + "version": "3.0.5", + "resolved": "https://registry.npmjs.org/@vercel/oidc/-/oidc-3.0.5.tgz", + "integrity": "sha512-fnYhv671l+eTTp48gB4zEsTW/YtRgRPnkI2nT7x6qw5rkI1Lq2hTmQIpHPgyThI0znLK+vX2n9XxKdXZ7BUbbw==" + }, "@vue/compiler-core": { "version": "3.5.6", "resolved": "https://registry.npmjs.org/@vue/compiler-core/-/compiler-core-3.5.6.tgz", @@ -78133,6 +78268,11 @@ "resolved": "https://registry.npmjs.org/events/-/events-1.1.1.tgz", "integrity": "sha512-kEcvvCBByWXGnZy6JUlgAp2gBIUjfCAV6P6TgT1/aaQKcmuAEC4OZTV1I4EWQLz2gxZw76atuVyvHhTxvi0Flw==" }, + "eventsource-parser": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-3.0.6.tgz", + "integrity": "sha512-Vo1ab+QXPzZ4tCa8SwIHJFaSzy4R6SHf7BY79rFBDf0idraZWAkYrDjDj8uWaSm3S2TK+hJ7/t1CEmZ7jXw+pg==" + }, "evp_bytestokey": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/evp_bytestokey/-/evp_bytestokey-1.0.3.tgz", diff --git a/packages/compass-e2e-tests/helpers/assistant-service.ts b/packages/compass-e2e-tests/helpers/assistant-service.ts index c36a3d9ae3d..970f8edd6dd 100644 --- a/packages/compass-e2e-tests/helpers/assistant-service.ts +++ b/packages/compass-e2e-tests/helpers/assistant-service.ts @@ -170,12 +170,13 @@ export async function startMockAssistantServer( let response = _response; const server = http .createServer((req, res) => { - res.setHeader('Access-Control-Allow-Origin', '*'); + res.setHeader('Access-Control-Allow-Origin', req.headers.origin || '*'); res.setHeader('Access-Control-Allow-Methods', 'POST, OPTIONS'); res.setHeader( 'Access-Control-Allow-Headers', - 'Content-Type, Authorization, X-Request-Origin, User-Agent' + 'Content-Type, Authorization, X-Request-Origin, User-Agent, X-CSRF-Token, X-CSRF-Time' ); + res.setHeader('Access-Control-Allow-Credentials', 'true'); // Handle preflight requests if (req.method === 'OPTIONS') { @@ -212,8 +213,8 @@ export async function startMockAssistantServer( }); if (response.status !== 200) { - res.writeHead(response.status); res.setHeader('Content-Type', 'application/json'); + res.writeHead(response.status); return res.end(JSON.stringify({ error: response.body })); } diff --git a/packages/compass-e2e-tests/tests/collection-ai-query.test.ts b/packages/compass-e2e-tests/tests/collection-ai-query.test.ts index 919aa54490b..ac09eaae4c9 100644 --- a/packages/compass-e2e-tests/tests/collection-ai-query.test.ts +++ b/packages/compass-e2e-tests/tests/collection-ai-query.test.ts @@ -8,12 +8,14 @@ import { cleanup, screenshotIfFailed, DEFAULT_CONNECTION_NAME_1, + screenshotPathName, } from '../helpers/compass'; import type { Compass } from '../helpers/compass'; import * as Selectors from '../helpers/selectors'; import { createNumbersCollection } from '../helpers/insert-data'; import { startMockAtlasServiceServer } from '../helpers/mock-atlas-service'; import type { MockAtlasServerResponse } from '../helpers/mock-atlas-service'; +import { startMockAssistantServer } from '../helpers/assistant-service'; describe('Collection ai query (with mocked backend)', function () { let compass: Compass; @@ -171,3 +173,145 @@ describe('Collection ai query (with mocked backend)', function () { }); }); }); + +async function setup( + browser: CompassBrowser, + dbName: string, + collName: string +) { + await createNumbersCollection(); + await browser.setupDefaultConnections(); + await browser.connectToDefaults(); + await browser.navigateToCollectionTab( + DEFAULT_CONNECTION_NAME_1, + dbName, + collName, + 'Documents' + ); + + await browser.setFeature('enableChatbotEndpointForGenAI', true); + await browser.setFeature('enableGenAIFeatures', true); + await browser.setFeature('enableGenAISampleDocumentPassing', true); + await browser.setFeature('optInGenAIFeatures', true); +} + +describe('Collection ai query with chatbot (with mocked backend)', function () { + const dbName = 'test'; + const collName = 'numbers'; + let compass: Compass; + let browser: CompassBrowser; + + let mockAssistantServer: Awaited>; + + before(async function () { + mockAssistantServer = await startMockAssistantServer(); + compass = await init(this.test?.fullTitle()); + browser = compass.browser; + + await browser.setEnv( + 'COMPASS_ASSISTANT_BASE_URL_OVERRIDE', + mockAssistantServer.endpoint + ); + }); + + after(async function () { + await mockAssistantServer.stop(); + await cleanup(compass); + }); + + afterEach(async function () { + await screenshotIfFailed(compass, this.currentTest); + try { + mockAssistantServer.clearRequests(); + } catch (err) { + await browser.screenshot(screenshotPathName('afterEach-GenAi-Query')); + throw err; + } + }); + + describe('when the ai model response is valid', function () { + beforeEach(async function () { + await setup(browser, dbName, collName); + mockAssistantServer.setResponse({ + status: 200, + body: '{i: {$gt: 50}}', + }); + }); + + it('makes request to the server and updates the query bar with the response', async function () { + // Click the ai entry button. + await browser.clickVisible(Selectors.GenAIEntryButton); + + // Enter the ai prompt. + await browser.clickVisible(Selectors.GenAITextInput); + + const testUserInput = 'find all documents where i is greater than 50'; + await browser.setValueVisible(Selectors.GenAITextInput, testUserInput); + + // Click generate. + await browser.clickVisible(Selectors.GenAIGenerateQueryButton); + + // Wait for the ipc events to succeed. + await browser.waitUntil(async function () { + // Make sure the query bar was updated. + const queryBarFilterContent = await browser.getCodemirrorEditorText( + Selectors.queryBarOptionInputFilter('Documents') + ); + return queryBarFilterContent === '{i:{$gt:50}}'; + }); + + // Check that the request was made with the correct parameters. + const requests = mockAssistantServer.getRequests(); + expect(requests.length).to.equal(1); + + const queryRequest = requests[0]; + // TODO(COMPASS-10125): Switch the model to `mongodb-slim-latest` when + // enabling this feature. + expect(queryRequest.content.model).to.equal('mongodb-chat-latest'); + expect(queryRequest.content.instructions).to.be.string; + expect(queryRequest.content.input).to.be.an('array').of.length(1); + + const message = queryRequest.content.input[0]; + expect(message.role).to.equal('user'); + expect(message.content).to.be.an('array').of.length(1); + expect(message.content[0]).to.have.property('type'); + expect(message.content[0]).to.have.property('text'); + + // Run it and check that the correct documents are shown. + await browser.runFind('Documents', true); + const modifiedResult = await browser.getFirstListDocument(); + expect(modifiedResult.i).to.be.equal('51'); + }); + }); + + describe('when the chatbot api request errors', function () { + beforeEach(async function () { + await setup(browser, dbName, collName); + mockAssistantServer.setResponse({ + status: 500, + body: '', + }); + }); + + it('the error is shown to the user', async function () { + // Click the ai entry button. + await browser.clickVisible(Selectors.GenAIEntryButton); + + // Enter the ai prompt. + await browser.clickVisible(Selectors.GenAITextInput); + + const testUserInput = 'find all documents where i is greater than 50'; + await browser.setValueVisible(Selectors.GenAITextInput, testUserInput); + + // Click generate. + await browser.clickVisible(Selectors.GenAIGenerateQueryButton); + + // Check that the error is shown. + const errorBanner = browser.$(Selectors.GenAIErrorMessageBanner); + await errorBanner.waitForDisplayed(); + expect(await errorBanner.getText()).to.equal( + 'Sorry, we were unable to generate the query, please try again. If the error persists, try changing your prompt.' + ); + }); + }); +}); diff --git a/packages/compass-generative-ai/package.json b/packages/compass-generative-ai/package.json index 69a9776dd94..1ba664073fb 100644 --- a/packages/compass-generative-ai/package.json +++ b/packages/compass-generative-ai/package.json @@ -63,12 +63,15 @@ "bson": "^6.10.4", "compass-preferences-model": "^2.66.3", "mongodb": "^6.19.0", + "mongodb-query-parser": "^4.5.0", "mongodb-schema": "^12.6.3", "react": "^17.0.2", "react-redux": "^8.1.3", "redux": "^4.2.1", "redux-thunk": "^2.4.2", - "zod": "^3.25.76" + "zod": "^3.25.76", + "@ai-sdk/openai": "^2.0.4", + "ai": "^5.0.26" }, "devDependencies": { "@mongodb-js/eslint-config-compass": "^1.4.12", @@ -85,7 +88,6 @@ "depcheck": "^1.4.1", "electron-mocha": "^12.2.0", "mocha": "^10.2.0", - "mongodb-query-parser": "^4.5.0", "nyc": "^15.1.0", "p-queue": "^7.4.1", "sinon": "^9.2.3", diff --git a/packages/compass-generative-ai/src/atlas-ai-service.spec.ts b/packages/compass-generative-ai/src/atlas-ai-service.spec.ts index 1767fef6dd5..e56ffbcef2f 100644 --- a/packages/compass-generative-ai/src/atlas-ai-service.spec.ts +++ b/packages/compass-generative-ai/src/atlas-ai-service.spec.ts @@ -51,6 +51,7 @@ class MockAtlasService { getCurrentUser = () => Promise.resolve(ATLAS_USER); cloudEndpoint = (url: string) => `${['/cloud', url].join('/')}`; adminApiEndpoint = (url: string) => `${[BASE_URL, url].join('/')}`; + assistantApiEndpoint = (url: string) => `${[BASE_URL, url].join('/')}`; authenticatedFetch = (url: string, init: RequestInit) => { return fetch(url, init); }; @@ -736,4 +737,357 @@ describe('AtlasAiService', function () { }); }); } + + describe('with chatbot api', function () { + describe('getQueryFromUserInput and getAggregationFromUserInput', function () { + type Chunk = { type: 'text' | 'error'; content: string }; + let atlasAiService: AtlasAiService; + const mockConnectionInfo = getMockConnectionInfo(); + + function streamChunkResponse( + readableStreamController: ReadableStreamController, + chunks: Chunk[] + ) { + const responseId = `resp_${Date.now()}`; + const itemId = `item_${Date.now()}`; + let sequenceNumber = 0; + + const encoder = new TextEncoder(); + + // openai response format: + // https://github.com/vercel/ai/blob/811119c1808d7b62a4857bcad42353808cdba17c/packages/openai/src/responses/openai-responses-api.ts#L322 + + // Send response.created event + readableStreamController.enqueue( + encoder.encode( + `data: ${JSON.stringify({ + type: 'response.created', + response: { + id: responseId, + object: 'realtime.response', + status: 'in_progress', + output: [], + usage: { + input_tokens: 0, + output_tokens: 0, + total_tokens: 0, + }, + }, + sequence_number: sequenceNumber++, + })}\n\n` + ) + ); + + // Send output_item.added event + readableStreamController.enqueue( + encoder.encode( + `data: ${JSON.stringify({ + type: 'response.output_item.added', + response_id: responseId, + output_index: 0, + item: { + id: itemId, + object: 'realtime.item', + type: 'message', + role: 'assistant', + content: [], + }, + sequence_number: sequenceNumber++, + })}\n\n` + ) + ); + + for (const chunk of chunks) { + if (chunk.type === 'error') { + readableStreamController.enqueue( + encoder.encode( + `data: ${JSON.stringify({ + type: `error`, + response_id: responseId, + item_id: itemId, + output_index: 0, + error: { + type: 'model_error', + code: 'model_error', + message: chunk.content, + }, + sequence_number: sequenceNumber++, + })}\n\n` + ) + ); + } else { + readableStreamController.enqueue( + encoder.encode( + `data: ${JSON.stringify({ + type: 'response.output_text.delta', + response_id: responseId, + item_id: itemId, + output_index: 0, + delta: chunk.content, + sequence_number: sequenceNumber++, + })}\n\n` + ) + ); + } + } + + const content = chunks + .filter((c) => c.type === 'text') + .map((c) => c.content) + .join(''); + + // Send output_item.done event + readableStreamController.enqueue( + encoder.encode( + `data: ${JSON.stringify({ + type: 'response.output_item.done', + response_id: responseId, + output_index: 0, + item: { + id: itemId, + object: 'realtime.item', + type: 'message', + role: 'assistant', + content: [ + { + type: 'text', + text: content, + }, + ], + }, + sequence_number: sequenceNumber++, + })}\n\n` + ) + ); + + // Send response.completed event + const tokenCount = Math.ceil(content.length / 4); // assume 4 chars per token + readableStreamController.enqueue( + encoder.encode( + `data: ${JSON.stringify({ + type: 'response.completed', + response: { + id: responseId, + object: 'realtime.response', + status: 'completed', + output: [ + { + id: itemId, + object: 'realtime.item', + type: 'message', + role: 'assistant', + content: [ + { + type: 'text', + text: content, + }, + ], + }, + ], + usage: { + input_tokens: 10, + output_tokens: tokenCount, + total_tokens: 10 + tokenCount, + }, + }, + sequence_number: sequenceNumber++, + })}\n\n` + ) + ); + } + + function streamableFetchMock(chunks: Chunk[]) { + const readableStream = new ReadableStream({ + start(controller) { + streamChunkResponse(controller, chunks); + controller.close(); + }, + }); + return new Response(readableStream, { + headers: { 'Content-Type': 'text/event-stream' }, + }); + } + + beforeEach(async function () { + const mockAtlasService = new MockAtlasService(); + await preferences.savePreferences({ + enableChatbotEndpointForGenAI: true, + }); + atlasAiService = new AtlasAiService({ + apiURLPreset: 'cloud', + atlasService: mockAtlasService as any, + preferences, + logger: createNoopLogger(), + }); + // Enable the AI feature + const fetchStub = sandbox.stub().resolves( + makeResponse({ + features: { + GEN_AI_COMPASS: { + enabled: true, + }, + }, + }) + ); + global.fetch = fetchStub; + await atlasAiService['setupAIAccess'](); + }); + + after(function () { + global.fetch = initialFetch; + }); + + const testCases = [ + { + functionName: 'getQueryFromUserInput', + successResponse: { + request: [ + { type: 'text', content: 'Hello' }, + { type: 'text', content: ' world' }, + { + type: 'text', + content: '. This is some non relevant text in the output', + }, + { type: 'text', content: '{test: ' }, + { type: 'text', content: '"pineapple"' }, + { type: 'text', content: '}' }, + ] as Chunk[], + response: { + content: { + query: { + filter: "{test:'pineapple'}", + project: null, + sort: null, + skip: null, + limit: null, + }, + }, + }, + }, + invalidModelResponse: { + request: [ + { type: 'text', content: 'Hello' }, + { type: 'text', content: ' world.' }, + { type: 'text', content: '{test: ' }, + { type: 'text', content: '"pineapple"' }, + { type: 'text', content: '}' }, + { type: 'error', content: 'Model crashed!' }, + ] as Chunk[], + errorMessage: 'Model crashed!', + }, + }, + { + functionName: 'getAggregationFromUserInput', + successResponse: { + request: [ + { type: 'text', content: 'Hello' }, + { type: 'text', content: ' world' }, + { + type: 'text', + content: '. This is some non relevant text in the output', + }, + { type: 'text', content: '[{$count: ' }, + { type: 'text', content: '"pineapple"' }, + { type: 'text', content: '}]' }, + ] as Chunk[], + response: { + content: { + aggregation: { + pipeline: "[{$count:'pineapple'}]", + }, + }, + }, + }, + invalidModelResponse: { + request: [ + { type: 'text', content: 'Hello' }, + { type: 'text', content: ' world.' }, + { type: 'text', content: '[{test: ' }, + { type: 'text', content: '"pineapple"' }, + { type: 'text', content: '}]' }, + { type: 'error', content: 'Model crashed!' }, + ] as Chunk[], + errorMessage: 'Model crashed!', + }, + }, + ] as const; + + for (const { + functionName, + successResponse, + invalidModelResponse, + } of testCases) { + describe(functionName, function () { + it('makes a post request with the user input to the endpoint in the environment', async function () { + const fetchStub = sandbox + .stub() + .resolves(streamableFetchMock(successResponse.request)); + global.fetch = fetchStub; + + const input = { + userInput: 'test', + signal: new AbortController().signal, + collectionName: 'jam', + databaseName: 'peanut', + schema: { _id: { types: [{ bsonType: 'ObjectId' }] } }, + sampleDocuments: [ + { _id: new ObjectId('642d766b7300158b1f22e972') }, + ], + requestId: 'abc', + }; + + const res = await atlasAiService[functionName]( + input as any, + mockConnectionInfo + ); + + expect(fetchStub).to.have.been.calledOnce; + + const { args } = fetchStub.firstCall; + const requestBody = JSON.parse(args[1].body as string); + + expect(requestBody.model).to.equal('mongodb-chat-latest'); + expect(requestBody.store).to.equal(false); + expect(requestBody.instructions).to.be.a('string'); + expect(requestBody.input).to.be.an('array'); + + const { role, content } = requestBody.input[0]; + expect(role).to.equal('user'); + expect(content[0].text).to.include( + `Database name: "${input.databaseName}"` + ); + expect(content[0].text).to.include( + `Collection name: "${input.collectionName}"` + ); + expect(res).to.deep.eq(successResponse.response); + }); + + it('should throw an error when the stream contains an error chunk', async function () { + const fetchStub = sandbox + .stub() + .resolves(streamableFetchMock(invalidModelResponse.request)); + global.fetch = fetchStub; + + try { + await atlasAiService[functionName]( + { + userInput: 'test', + collectionName: 'test', + databaseName: 'peanut', + requestId: 'abc', + signal: new AbortController().signal, + }, + mockConnectionInfo + ); + expect.fail(`Expected ${functionName} to throw`); + } catch (err) { + expect((err as Error).message).to.match( + new RegExp(invalidModelResponse.errorMessage, 'i') + ); + } + }); + }); + } + }); + }); }); diff --git a/packages/compass-generative-ai/src/atlas-ai-service.ts b/packages/compass-generative-ai/src/atlas-ai-service.ts index 04f75a90161..5baf8c62b60 100644 --- a/packages/compass-generative-ai/src/atlas-ai-service.ts +++ b/packages/compass-generative-ai/src/atlas-ai-service.ts @@ -16,6 +16,15 @@ import { AtlasAiServiceInvalidInputError, AtlasAiServiceApiResponseParseError, } from './atlas-ai-errors'; +import { createOpenAI } from '@ai-sdk/openai'; +import { type LanguageModel } from 'ai'; +import type { AiQueryPrompt } from './utils/gen-ai-prompt'; +import { + buildAggregateQueryPrompt, + buildFindQueryPrompt, +} from './utils/gen-ai-prompt'; +import { parseXmlToJsonResponse } from './utils/parse-xml-response'; +import { getAiQueryResponse } from './utils/gen-ai-response'; type GenerativeAiInput = { userInput: string; @@ -40,14 +49,6 @@ type AIAggregation = { }; }; -type AIFeatureEnablement = { - features: { - [featureName: string]: { - enabled: boolean; - }; - }; -}; - type AIQuery = { content: { query: Record< @@ -271,6 +272,8 @@ export class AtlasAiService { private preferences: PreferencesAccess; private logger: Logger; + private aiModel: LanguageModel; + constructor({ apiURLPreset, atlasService, @@ -286,8 +289,26 @@ export class AtlasAiService { this.atlasService = atlasService; this.preferences = preferences; this.logger = logger; - this.initPromise = this.setupAIAccess(); + + const PLACEHOLDER_BASE_URL = + 'http://PLACEHOLDER_BASE_URL_TO_BE_REPLACED.invalid'; + this.aiModel = createOpenAI({ + apiKey: '', + baseURL: PLACEHOLDER_BASE_URL, + fetch: (url, init) => { + // The `baseUrl` can be dynamically changed, but `createOpenAI` + // doesn't allow us to change it after initial call. Instead + // we're going to update it every time the fetch call happens + const uri = String(url).replace( + PLACEHOLDER_BASE_URL, + this.atlasService.assistantApiEndpoint() + ); + return this.atlasService.authenticatedFetch(uri, init); + }, + // TODO(COMPASS-10125): Switch the model to `mongodb-slim-latest` when + // enabling this feature (to use edu-chatbot for GenAI). + }).responses('mongodb-chat-latest'); } /** @@ -423,6 +444,14 @@ export class AtlasAiService { input: GenerativeAiInput, connectionInfo: ConnectionInfo ) { + if (this.preferences.getPreferences().enableChatbotEndpointForGenAI) { + const message = buildAggregateQueryPrompt(input); + return this.generateQueryUsingChatbot( + message, + validateAIAggregationResponse, + { signal: input.signal } + ); + } return this.getQueryOrAggregationFromUserInput( { connectionInfo, @@ -437,6 +466,12 @@ export class AtlasAiService { input: GenerativeAiInput, connectionInfo: ConnectionInfo ) { + if (this.preferences.getPreferences().enableChatbotEndpointForGenAI) { + const message = buildFindQueryPrompt(input); + return this.generateQueryUsingChatbot(message, validateAIQueryResponse, { + signal: input.signal, + }); + } return this.getQueryOrAggregationFromUserInput( { urlId: 'query', @@ -527,12 +562,19 @@ export class AtlasAiService { }); } - private validateAIFeatureEnablementResponse( - response: any - ): asserts response is AIFeatureEnablement { - const { features } = response; - if (typeof features !== 'object') { - throw new Error('Unexpected response: expected features to be an object'); - } + private async generateQueryUsingChatbot( + message: AiQueryPrompt, + validateFn: (res: any) => asserts res is T, + options: { signal: AbortSignal } + ): Promise { + this.throwIfAINotEnabled(); + const response = await getAiQueryResponse( + this.aiModel, + message, + options.signal + ); + const parsedResponse = parseXmlToJsonResponse(response, this.logger); + validateFn(parsedResponse); + return parsedResponse; } } diff --git a/packages/compass-generative-ai/src/chatbot-errors.ts b/packages/compass-generative-ai/src/chatbot-errors.ts new file mode 100644 index 00000000000..f115482b1b0 --- /dev/null +++ b/packages/compass-generative-ai/src/chatbot-errors.ts @@ -0,0 +1,8 @@ +import { AtlasServiceError } from '@mongodb-js/atlas-service/renderer'; + +export class AiChatbotInvalidResponseError extends AtlasServiceError { + constructor(message: string) { + super('ServerError', 500, message, 'INVALID_RESPONSE'); + this.name = 'AiChatbotInvalidResponseError'; + } +} diff --git a/packages/compass-generative-ai/src/utils/gen-ai-prompt.spec.ts b/packages/compass-generative-ai/src/utils/gen-ai-prompt.spec.ts index 526861c71af..f1aa217c21d 100644 --- a/packages/compass-generative-ai/src/utils/gen-ai-prompt.spec.ts +++ b/packages/compass-generative-ai/src/utils/gen-ai-prompt.spec.ts @@ -2,13 +2,13 @@ import { expect } from 'chai'; import { buildFindQueryPrompt, buildAggregateQueryPrompt, - type UserPromptForQueryOptions, + type PromptContextOptions, } from './gen-ai-prompt'; import { toJSString } from 'mongodb-query-parser'; import { ObjectId } from 'bson'; -const OPTIONS: UserPromptForQueryOptions = { - userPrompt: 'Find all users older than 30', +const OPTIONS: PromptContextOptions = { + userInput: 'Find all users older than 30', databaseName: 'airbnb', collectionName: 'listings', schema: { @@ -50,7 +50,7 @@ describe('GenAI Prompts', function () { expect(prompt).to.be.a('string'); expect(prompt).to.include( - `Write a query that does the following: "${OPTIONS.userPrompt}"`, + `Write a query that does the following: "${OPTIONS.userInput}"`, 'includes user prompt' ); expect(prompt).to.include( @@ -93,7 +93,7 @@ describe('GenAI Prompts', function () { expect(prompt).to.be.a('string'); expect(prompt).to.include( - `Generate an aggregation that does the following: "${OPTIONS.userPrompt}"`, + `Generate an aggregation that does the following: "${OPTIONS.userInput}"`, 'includes user prompt' ); expect(prompt).to.include( @@ -121,4 +121,63 @@ describe('GenAI Prompts', function () { 'includes actual sample documents' ); }); + + it('throws if user prompt exceeds the max size', function () { + try { + buildFindQueryPrompt({ + ...OPTIONS, + userInput: 'a'.repeat(512001), + }); + expect.fail('Expected buildFindQueryPrompt to throw'); + } catch (err) { + expect(err).to.have.property( + 'message', + 'Sorry, your request is too large. Please use a smaller prompt or try using this feature on a collection with smaller documents.' + ); + } + }); + + context('handles large sample documents', function () { + it('sends all the sample docs if within limits', function () { + const sampleDocuments = [ + { a: '1' }, + { a: '2' }, + { a: '3' }, + { a: '4'.repeat(5120) }, + ]; + const prompt = buildFindQueryPrompt({ + ...OPTIONS, + sampleDocuments, + }).prompt; + + expect(prompt).to.include(toJSString(sampleDocuments)); + }); + it('sends only one sample doc if all exceed limits', function () { + const sampleDocuments = [ + { a: '1'.repeat(5120) }, + { a: '2'.repeat(5120001) }, + { a: '3'.repeat(5120001) }, + { a: '4'.repeat(5120001) }, + ]; + const prompt = buildFindQueryPrompt({ + ...OPTIONS, + sampleDocuments, + }).prompt; + expect(prompt).to.include(toJSString([sampleDocuments[0]])); + }); + it('should not send sample docs if even one exceeds limits', function () { + const sampleDocuments = [ + { a: '1'.repeat(5120001) }, + { a: '2'.repeat(5120001) }, + { a: '3'.repeat(5120001) }, + { a: '4'.repeat(5120001) }, + ]; + const prompt = buildFindQueryPrompt({ + ...OPTIONS, + sampleDocuments, + }).prompt; + expect(prompt).to.not.include('Sample document from the collection:'); + expect(prompt).to.not.include('Sample documents from the collection:'); + }); + }); }); diff --git a/packages/compass-generative-ai/src/utils/gen-ai-prompt.ts b/packages/compass-generative-ai/src/utils/gen-ai-prompt.ts index c235feaabf0..ac45fdf4eaf 100644 --- a/packages/compass-generative-ai/src/utils/gen-ai-prompt.ts +++ b/packages/compass-generative-ai/src/utils/gen-ai-prompt.ts @@ -57,8 +57,8 @@ function buildInstructionsForAggregateQuery() { ].join('\n'); } -export type UserPromptForQueryOptions = { - userPrompt: string; +export type PromptContextOptions = { + userInput: string; databaseName?: string; collectionName?: string; schema?: unknown; @@ -76,18 +76,18 @@ function withCodeFence(code: string): string { function buildUserPromptForQuery({ type, - userPrompt, + userInput, databaseName, collectionName, schema, sampleDocuments, -}: UserPromptForQueryOptions & { type: 'find' | 'aggregate' }): string { +}: PromptContextOptions & { type: 'find' | 'aggregate' }): string { const messages = []; const queryPrompt = [ type === 'find' ? 'Write a query' : 'Generate an aggregation', 'that does the following:', - `"${userPrompt}"`, + `"${userInput}"`, ].join(' '); if (databaseName) { @@ -137,7 +137,16 @@ function buildUserPromptForQuery({ } } messages.push(queryPrompt); - return messages.join('\n'); + + const prompt = messages.join('\n'); + + // If at this point we have exceeded the limit, throw an error. + if (prompt.length > MAX_TOTAL_PROMPT_LENGTH) { + throw new Error( + 'Sorry, your request is too large. Please use a smaller prompt or try using this feature on a collection with smaller documents.' + ); + } + return prompt; } export type AiQueryPrompt = { @@ -148,15 +157,15 @@ export type AiQueryPrompt = { }; export function buildFindQueryPrompt({ - userPrompt, + userInput, databaseName, collectionName, schema, sampleDocuments, -}: UserPromptForQueryOptions): AiQueryPrompt { +}: PromptContextOptions): AiQueryPrompt { const prompt = buildUserPromptForQuery({ type: 'find', - userPrompt, + userInput, databaseName, collectionName, schema, @@ -172,15 +181,15 @@ export function buildFindQueryPrompt({ } export function buildAggregateQueryPrompt({ - userPrompt, + userInput, databaseName, collectionName, schema, sampleDocuments, -}: UserPromptForQueryOptions): AiQueryPrompt { +}: PromptContextOptions): AiQueryPrompt { const prompt = buildUserPromptForQuery({ type: 'aggregate', - userPrompt, + userInput, databaseName, collectionName, schema, diff --git a/packages/compass-generative-ai/src/utils/gen-ai-response.ts b/packages/compass-generative-ai/src/utils/gen-ai-response.ts new file mode 100644 index 00000000000..823921f0079 --- /dev/null +++ b/packages/compass-generative-ai/src/utils/gen-ai-response.ts @@ -0,0 +1,32 @@ +import { AiChatbotInvalidResponseError } from '../chatbot-errors'; +import { type AiQueryPrompt } from './gen-ai-prompt'; +import type { LanguageModel } from 'ai'; +import { streamText } from 'ai'; + +export async function getAiQueryResponse( + model: LanguageModel, + message: AiQueryPrompt, + abortSignal: AbortSignal +): Promise { + const response = streamText({ + model, + messages: [{ role: 'user', content: message.prompt }], + providerOptions: { + openai: { + store: false, + instructions: message.metadata.instructions, + }, + }, + abortSignal, + }).toUIMessageStream(); + const chunks: string[] = []; + for await (const value of response) { + if (value.type === 'text-delta') { + chunks.push(value.delta); + } + if (value.type === 'error') { + throw new AiChatbotInvalidResponseError(value.errorText); + } + } + return chunks.join(''); +} diff --git a/packages/compass-generative-ai/src/utils/parse-xml-response.spec.ts b/packages/compass-generative-ai/src/utils/parse-xml-response.spec.ts new file mode 100644 index 00000000000..87fa314ac13 --- /dev/null +++ b/packages/compass-generative-ai/src/utils/parse-xml-response.spec.ts @@ -0,0 +1,120 @@ +import { expect } from 'chai'; +import { parseXmlToJsonResponse } from './parse-xml-response'; +import { createNoopLogger } from '@mongodb-js/compass-logging/provider'; + +describe('parseXmlToJsonResponse', function () { + it('should return prioritize aggregation over query when available and valid', function () { + const xmlString = ` + { age: { $gt: 25 } } + [{ $match: { status: "A" } }] + `; + + const result = parseXmlToJsonResponse(xmlString, createNoopLogger()); + + expect(result).to.deep.equal({ + content: { + aggregation: { + pipeline: "[{$match:{status:'A'}}]", + }, + query: { + filter: null, + project: null, + sort: null, + skip: null, + limit: null, + }, + }, + }); + }); + + it('should not return aggregation if its not available in the response', function () { + const xmlString = ` + { age: { $gt: 25 } } + `; + + const result = parseXmlToJsonResponse(xmlString, createNoopLogger()); + expect(result).to.deep.equal({ + content: { + query: { + filter: '{age:{$gt:25}}', + project: null, + sort: null, + skip: null, + limit: null, + }, + }, + }); + }); + + it('should not return query if its not available in the response', function () { + const xmlString = ` + [{ $match: { status: "A" } }] + `; + + const result = parseXmlToJsonResponse(xmlString, createNoopLogger()); + + expect(result).to.deep.equal({ + content: { + aggregation: { + pipeline: "[{$match:{status:'A'}}]", + }, + }, + }); + }); + + it('should return all the query fields if provided', function () { + const xmlString = ` + { age: { $gt: 25 } } + { name: 1, age: 1 } + { age: -1 } + 5 + 10 + + `; + + const result = parseXmlToJsonResponse(xmlString, createNoopLogger()); + + expect(result).to.deep.equal({ + content: { + query: { + filter: '{age:{$gt:25}}', + project: '{name:1,age:1}', + sort: '{age:-1}', + skip: '5', + limit: '10', + }, + }, + }); + }); + + context('it should handle invalid data', function () { + it('invalid json', function () { + const result = parseXmlToJsonResponse( + `{ age: { $gt: 25 `, + createNoopLogger() + ); + expect(result.content).to.not.have.property('query'); + }); + it('empty object', function () { + const result = parseXmlToJsonResponse( + `{}`, + createNoopLogger() + ); + expect(result.content).to.not.have.property('query'); + }); + it('empty array', function () { + const result = parseXmlToJsonResponse( + `[]`, + createNoopLogger() + ); + expect(result.content).to.not.have.property('aggregation'); + }); + it('zero value', function () { + const result = parseXmlToJsonResponse( + `0`, + createNoopLogger() + ); + expect(result.content).to.not.have.property('query'); + }); + }); +}); diff --git a/packages/compass-generative-ai/src/utils/parse-xml-response.ts b/packages/compass-generative-ai/src/utils/parse-xml-response.ts new file mode 100644 index 00000000000..44e34a6c971 --- /dev/null +++ b/packages/compass-generative-ai/src/utils/parse-xml-response.ts @@ -0,0 +1,105 @@ +import type { Logger } from '@mongodb-js/compass-logging'; +import parse, { toJSString } from 'mongodb-query-parser'; + +type ParsedXmlJsonResponse = { + content: { + query?: { + filter: string | null; + project: string | null; + sort: string | null; + skip: string | null; + limit: string | null; + }; + aggregation?: { + pipeline: string; + }; + }; +}; + +export function parseXmlToJsonResponse( + xmlString: string, + logger: Logger +): ParsedXmlJsonResponse { + const expectedTags = [ + 'filter', + 'project', + 'sort', + 'skip', + 'limit', + 'aggregation', + ] as const; + + const parser = new DOMParser(); + const xmlDoc = parser.parseFromString( + `${xmlString}`, + 'text/xml' + ); + + // Currently the prompt forces LLM to return xml-styled data + const result: Record<(typeof expectedTags)[number], string | null> = { + filter: null, + project: null, + sort: null, + skip: null, + limit: null, + aggregation: null, + }; + for (const tag of expectedTags) { + const value = xmlDoc.querySelector(tag)?.textContent?.trim(); + if (value) { + try { + const tagValue = parse(value); + if ( + !tagValue || + (typeof tagValue === 'object' && Object.keys(tagValue).length === 0) + ) { + result[tag] = null; + } else { + // No indentation + result[tag] = toJSString(tagValue, 0) ?? null; + } + } catch (e) { + logger.log.warn( + logger.mongoLogId(1_001_000_384), + 'AtlasAiService', + `Failed to parse value for tag <${tag}>: ${value}`, + { error: e } + ); + result[tag] = null; + } + } + } + + const { aggregation, ...query } = result; + const isQueryEmpty = Object.values(query).every((v) => v === null); + + // It prioritizes aggregation over query if both are present + if (aggregation && !isQueryEmpty) { + return { + content: { + aggregation: { + pipeline: aggregation, + }, + query: { + filter: null, + project: null, + sort: null, + skip: null, + limit: null, + }, + }, + }; + } + return { + content: { + ...(aggregation + ? { + aggregation: { + pipeline: aggregation, + }, + } + : {}), + ...(isQueryEmpty ? {} : { query }), + }, + }; +} diff --git a/packages/compass-web/src/entrypoint.tsx b/packages/compass-web/src/entrypoint.tsx index 599319b205c..9d424f3a1e9 100644 --- a/packages/compass-web/src/entrypoint.tsx +++ b/packages/compass-web/src/entrypoint.tsx @@ -110,7 +110,13 @@ const WithAtlasProviders: React.FC<{ children: React.ReactNode }> = ({ return ( - + {children} diff --git a/packages/compass/src/app/components/entrypoint.tsx b/packages/compass/src/app/components/entrypoint.tsx index b69b01f5853..a4ba8050933 100644 --- a/packages/compass/src/app/components/entrypoint.tsx +++ b/packages/compass/src/app/components/entrypoint.tsx @@ -62,6 +62,7 @@ export const WithAtlasProviders: React.FC = ({ children }) => { options={{ defaultHeaders: { 'User-Agent': `${getAppName()}/${getAppVersion()}`, + 'X-Request-Origin': 'mongodb-compass', }, }} >