diff --git a/packages/_example/src/forest/agent.ts b/packages/_example/src/forest/agent.ts index e3130780ee..a64852843d 100644 --- a/packages/_example/src/forest/agent.ts +++ b/packages/_example/src/forest/agent.ts @@ -2,6 +2,7 @@ import type { Schema } from './typings'; import type { AgentOptions } from '@forestadmin/agent'; import { createAgent } from '@forestadmin/agent'; +import { createAiProvider } from '@forestadmin/ai-proxy'; import { createMongoDataSource } from '@forestadmin/datasource-mongo'; import { createMongooseDataSource } from '@forestadmin/datasource-mongoose'; import { createSequelizeDataSource } from '@forestadmin/datasource-sequelize'; @@ -94,5 +95,11 @@ export default function makeAgent() { .customizeCollection('post', customizePost) .customizeCollection('comment', customizeComment) .customizeCollection('review', customizeReview) - .customizeCollection('sales', customizeSales); + .customizeCollection('sales', customizeSales) + .addAi(createAiProvider({ + model: 'gpt-4o', + provider: 'openai', + name: 'test', + apiKey: process.env.OPENAI_API_KEY, + })); } diff --git a/packages/agent/package.json b/packages/agent/package.json index 8e9d40e4a5..cac83afb24 100644 --- a/packages/agent/package.json +++ b/packages/agent/package.json @@ -13,7 +13,6 @@ }, "dependencies": { "@fast-csv/format": "^4.3.5", - "@forestadmin/ai-proxy": "1.4.1", "@forestadmin/datasource-customizer": "1.67.3", "@forestadmin/datasource-toolkit": "1.50.1", "@forestadmin/forestadmin-client": "1.37.10", diff --git a/packages/agent/src/agent.ts b/packages/agent/src/agent.ts index 6a611aeb0d..5dd455f691 100644 --- a/packages/agent/src/agent.ts +++ b/packages/agent/src/agent.ts @@ -1,11 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import type { ForestAdminHttpDriverServices } from './services'; -import type { - AgentOptions, - AgentOptionsWithDefaults, - AiConfiguration, - HttpCallback, -} from './types'; +import type { AgentOptions, AgentOptionsWithDefaults, HttpCallback } from './types'; import type { CollectionCustomizer, DataSourceChartDefinition, @@ -14,7 +9,11 @@ import type { TCollectionName, TSchema, } from '@forestadmin/datasource-customizer'; -import type { DataSource, DataSourceFactory } from '@forestadmin/datasource-toolkit'; +import type { + AiProviderDefinition, + DataSource, + DataSourceFactory, +} from '@forestadmin/datasource-toolkit'; import type { ForestSchema } from '@forestadmin/forestadmin-client'; import { DataSourceCustomizer } from '@forestadmin/datasource-customizer'; @@ -47,7 +46,7 @@ export default class Agent extends FrameworkMounter protected nocodeCustomizer: DataSourceCustomizer; protected customizationService: CustomizationService; protected schemaGenerator: SchemaGenerator; - protected aiConfigurations: AiConfiguration[] = []; + protected aiProvider: AiProviderDefinition | null = null; /** Whether MCP server should be mounted */ private mcpEnabled = false; @@ -222,42 +221,49 @@ export default class Agent extends FrameworkMounter * All AI requests from Forest Admin are forwarded to your agent and processed locally. * Your data and API keys never transit through Forest Admin servers, ensuring full privacy. * - * @param configuration - The AI provider configuration - * @param configuration.name - A unique name to identify this AI configuration - * @param configuration.provider - The AI provider to use ('openai') - * @param configuration.apiKey - Your API key for the chosen provider - * @param configuration.model - The model to use (e.g., 'gpt-4o') + * Requires the `@forestadmin/ai-proxy` package to be installed: + * ```bash + * npm install @forestadmin/ai-proxy + * ``` + * + * @param provider - An AI provider definition created via `createAiProvider` from `@forestadmin/ai-proxy` * @returns The agent instance for chaining * @throws Error if addAi is called more than once * * @example - * agent.addAi({ + * import { createAiProvider } from '@forestadmin/ai-proxy'; + * + * agent.addAi(createAiProvider({ * name: 'assistant', * provider: 'openai', * apiKey: process.env.OPENAI_API_KEY, * model: 'gpt-4o', - * }); + * })); */ - addAi(configuration: AiConfiguration): this { - if (this.aiConfigurations.length > 0) { + addAi(provider: AiProviderDefinition): this { + if (this.aiProvider) { throw new Error( 'addAi can only be called once. Multiple AI configurations are not supported yet.', ); } - this.options.logger( - 'Warn', - `AI configuration added with model '${configuration.model}'. ` + - 'Make sure to test Forest Admin AI features thoroughly to ensure compatibility.', - ); + this.aiProvider = provider; - this.aiConfigurations.push(configuration); + for (const p of provider.providers) { + this.options.logger( + 'Warn', + `AI configuration added with model '${p.model}'. ` + + 'Make sure to test Forest Admin AI features thoroughly to ensure compatibility.', + ); + } return this; } protected getRoutes(dataSource: DataSource, services: ForestAdminHttpDriverServices) { - return makeRoutes(dataSource, this.options, services, this.aiConfigurations); + const aiRouter = this.aiProvider?.init(this.options.logger) ?? null; + + return makeRoutes(dataSource, this.options, services, aiRouter); } /** @@ -380,9 +386,10 @@ export default class Agent extends FrameworkMounter let schema: Pick; // Get the AI configurations for schema metadata + const aiMeta = this.aiProvider?.providers ?? []; const { meta } = SchemaGenerator.buildMetadata( this.customizationService.buildFeatures(), - this.aiConfigurations, + aiMeta, ); // When using experimental no-code features even in production we need to build a new schema diff --git a/packages/agent/src/index.ts b/packages/agent/src/index.ts index 49ade7e5dd..dafa1613c5 100644 --- a/packages/agent/src/index.ts +++ b/packages/agent/src/index.ts @@ -9,6 +9,7 @@ export function createAgent(options: AgentOptions): export { Agent }; export { AgentOptions } from './types'; +export type { AiProviderDefinition } from './types'; export * from '@forestadmin/datasource-customizer'; // export is necessary for the agent-generator package diff --git a/packages/agent/src/routes/ai/ai-proxy.ts b/packages/agent/src/routes/ai/ai-proxy.ts index c36b308c89..71844c6b9b 100644 --- a/packages/agent/src/routes/ai/ai-proxy.ts +++ b/packages/agent/src/routes/ai/ai-proxy.ts @@ -1,40 +1,23 @@ import type { ForestAdminHttpDriverServices } from '../../services'; -import type { AgentOptionsWithDefaults, AiConfiguration } from '../../types'; +import type { AgentOptionsWithDefaults } from '../../types'; +import type { AiRouter } from '@forestadmin/datasource-toolkit'; import type KoaRouter from '@koa/router'; import type { Context } from 'koa'; -import { - AIBadRequestError, - AIError, - AINotConfiguredError, - AINotFoundError, - Router as AiProxyRouter, - extractMcpOauthTokensFromHeaders, - injectOauthTokens, -} from '@forestadmin/ai-proxy'; -import { - BadRequestError, - NotFoundError, - UnprocessableError, -} from '@forestadmin/datasource-toolkit'; - import { HttpCode, RouteType } from '../../types'; import BaseRoute from '../base-route'; export default class AiProxyRoute extends BaseRoute { readonly type = RouteType.PrivateRoute; - private readonly aiProxyRouter: AiProxyRouter; + private readonly aiRouter: AiRouter; constructor( services: ForestAdminHttpDriverServices, options: AgentOptionsWithDefaults, - aiConfigurations: AiConfiguration[], + aiRouter: AiRouter, ) { super(services, options); - this.aiProxyRouter = new AiProxyRouter({ - aiConfigurations, - logger: this.options.logger, - }); + this.aiRouter = aiRouter; } setupRoutes(router: KoaRouter): void { @@ -42,33 +25,16 @@ export default class AiProxyRoute extends BaseRoute { } private async handleAiProxy(context: Context): Promise { - try { - const tokensByMcpServerName = extractMcpOauthTokensFromHeaders(context.request.headers); - - const mcpConfigs = - await this.options.forestAdminClient.mcpServerConfigService.getConfiguration(); - - context.response.body = await this.aiProxyRouter.route({ - route: context.params.route, - body: context.request.body, - query: context.query, - mcpConfigs: injectOauthTokens({ mcpConfigs, tokensByMcpServerName }), - }); - context.response.status = HttpCode.Ok; - } catch (error) { - if (error instanceof AIError) { - this.options.logger('Error', `AI proxy error: ${error.message}`, error); - - if (error instanceof AINotConfiguredError) { - throw new UnprocessableError('AI is not configured. Please call addAi() on your agent.'); - } - - if (error instanceof AIBadRequestError) throw new BadRequestError(error.message); - if (error instanceof AINotFoundError) throw new NotFoundError(error.message); - throw new UnprocessableError(error.message); - } - - throw error; - } + const mcpServerConfigs = + await this.options.forestAdminClient.mcpServerConfigService.getConfiguration(); + + context.response.body = await this.aiRouter.route({ + route: context.params.route, + body: context.request.body, + query: context.query, + mcpServerConfigs, + requestHeaders: context.request.headers, + }); + context.response.status = HttpCode.Ok; } } diff --git a/packages/agent/src/routes/index.ts b/packages/agent/src/routes/index.ts index cf9d4ab6d6..8673b45d6a 100644 --- a/packages/agent/src/routes/index.ts +++ b/packages/agent/src/routes/index.ts @@ -1,7 +1,7 @@ import type { ForestAdminHttpDriverServices as Services } from '../services'; -import type { AiConfiguration, AgentOptionsWithDefaults as Options } from '../types'; +import type { AgentOptionsWithDefaults as Options } from '../types'; import type BaseRoute from './base-route'; -import type { DataSource } from '@forestadmin/datasource-toolkit'; +import type { AiRouter, DataSource } from '@forestadmin/datasource-toolkit'; import CollectionApiChartRoute from './access/api-chart-collection'; import DataSourceApiChartRoute from './access/api-chart-datasource'; @@ -165,21 +165,17 @@ function getActionRoutes( return routes; } -function getAiRoutes( - options: Options, - services: Services, - aiConfigurations: AiConfiguration[], -): BaseRoute[] { - if (aiConfigurations.length === 0) return []; +function getAiRoutes(options: Options, services: Services, aiRouter: AiRouter | null): BaseRoute[] { + if (!aiRouter) return []; - return [new AiProxyRoute(services, options, aiConfigurations)]; + return [new AiProxyRoute(services, options, aiRouter)]; } export default function makeRoutes( dataSource: DataSource, options: Options, services: Services, - aiConfigurations: AiConfiguration[] = [], + aiRouter: AiRouter | null = null, ): BaseRoute[] { const routes = [ ...getRootRoutes(options, services), @@ -189,7 +185,7 @@ export default function makeRoutes( ...getApiChartRoutes(dataSource, options, services), ...getRelatedRoutes(dataSource, options, services), ...getActionRoutes(dataSource, options, services), - ...getAiRoutes(options, services, aiConfigurations), + ...getAiRoutes(options, services, aiRouter), ]; // Ensure routes and middlewares are loaded in the right order. diff --git a/packages/agent/src/types.ts b/packages/agent/src/types.ts index ec3c823b29..86d3355726 100644 --- a/packages/agent/src/types.ts +++ b/packages/agent/src/types.ts @@ -1,9 +1,13 @@ -import type { AiConfiguration, AiProvider } from '@forestadmin/ai-proxy'; -import type { CompositeId, Logger, LoggerLevel } from '@forestadmin/datasource-toolkit'; +import type { + AiProviderDefinition, + CompositeId, + Logger, + LoggerLevel, +} from '@forestadmin/datasource-toolkit'; import type { ForestAdminClient } from '@forestadmin/forestadmin-client'; import type { IncomingMessage, ServerResponse } from 'http'; -export type { AiConfiguration, AiProvider }; +export type { AiProviderDefinition }; /** Options to configure behavior of an agent's forestadmin driver */ export type AgentOptions = { diff --git a/packages/agent/src/utils/forest-schema/generator.ts b/packages/agent/src/utils/forest-schema/generator.ts index 28047768e9..6bfa1cb7ec 100644 --- a/packages/agent/src/utils/forest-schema/generator.ts +++ b/packages/agent/src/utils/forest-schema/generator.ts @@ -1,5 +1,5 @@ -import type { AgentOptionsWithDefaults, AiConfiguration } from '../../types'; -import type { DataSource } from '@forestadmin/datasource-toolkit'; +import type { AgentOptionsWithDefaults } from '../../types'; +import type { AiProviderMeta, DataSource } from '@forestadmin/datasource-toolkit'; import type { ForestSchema } from '@forestadmin/forestadmin-client'; import SchemaGeneratorCollection from './generator-collection'; @@ -23,7 +23,7 @@ export default class SchemaGenerator { static buildMetadata( features: Record | null, - aiConfigurations: AiConfiguration[] = [], + aiProviders: AiProviderMeta[] = [], ): Pick { const { version } = require('../../../package.json'); // eslint-disable-line @typescript-eslint/no-var-requires,global-require @@ -33,8 +33,8 @@ export default class SchemaGenerator { liana_version: version, liana_features: features, ai_llms: - aiConfigurations.length > 0 - ? aiConfigurations.map(c => ({ name: c.name, provider: c.provider })) + aiProviders.length > 0 + ? aiProviders.map(p => ({ name: p.name, provider: p.provider })) : null, stack: { engine: 'nodejs', diff --git a/packages/agent/test/agent.test.ts b/packages/agent/test/agent.test.ts index 5874f26866..c955b00eba 100644 --- a/packages/agent/test/agent.test.ts +++ b/packages/agent/test/agent.test.ts @@ -1,7 +1,7 @@ /* eslint-disable max-classes-per-file */ /* eslint-disable @typescript-eslint/no-explicit-any */ -import type { DataSourceFactory } from '@forestadmin/datasource-toolkit'; +import type { AiProviderDefinition, DataSourceFactory } from '@forestadmin/datasource-toolkit'; import { DataSourceCustomizer } from '@forestadmin/datasource-customizer'; import * as McpServer from '@forestadmin/mcp-server'; @@ -34,6 +34,14 @@ beforeEach(() => { .mockResolvedValue(factories.dataSource.build()); }); +function createMockAiProvider(overrides: Partial = {}): AiProviderDefinition { + return { + providers: [{ name: 'gpt4o', provider: 'openai', model: 'gpt-4o' }], + init: jest.fn().mockReturnValue({ route: jest.fn() }), + ...overrides, + }; +} + describe('Agent', () => { describe('Development', () => { const options = factories.forestAdminHttpDriverOptions.build({ @@ -407,14 +415,10 @@ describe('Agent', () => { forestAdminClient: factories.forestAdminClient.build({ postSchema: mockPostSchema }), }); - test('should store the AI configuration', () => { + test('should store the AI provider and return agent for chaining', () => { const agent = new Agent(options); - const result = agent.addAi({ - name: 'gpt4o', - provider: 'openai', - apiKey: 'test-key', - model: 'gpt-4o', - }); + const provider = createMockAiProvider(); + const result = agent.addAi(provider); expect(result).toBe(agent); }); @@ -422,50 +426,59 @@ describe('Agent', () => { test('should throw an error when addAi is called more than once', () => { const agent = new Agent(options); - agent.addAi({ - name: 'gpt4o', - provider: 'openai', - apiKey: 'test-key', - model: 'gpt-4o', - }); + agent.addAi( + createMockAiProvider({ + providers: [{ name: 'gpt4o', provider: 'openai', model: 'gpt-4o' }], + }), + ); expect(() => - agent.addAi({ - name: 'gpt4o-mini', - provider: 'openai', - apiKey: 'another-key', - model: 'gpt-4o-mini', - }), + agent.addAi( + createMockAiProvider({ + providers: [{ name: 'gpt4o-mini', provider: 'openai', model: 'gpt-4o-mini' }], + }), + ), ).toThrow('addAi can only be called once. Multiple AI configurations are not supported yet.'); }); - test('should throw an error on start when model does not support tools', async () => { - // Use the real makeRoutes to trigger validation in AiProxyRouter + test('should log a warning with model name when addAi is called', () => { + const mockLogger = jest.fn(); + const agentOptions = factories.forestAdminHttpDriverOptions.build({ + isProduction: false, + logger: mockLogger, + forestAdminClient: factories.forestAdminClient.build({ postSchema: mockPostSchema }), + }); + + const agent = new Agent(agentOptions); + agent.addAi( + createMockAiProvider({ + providers: [{ name: 'gpt4o', provider: 'openai', model: 'gpt-4o' }], + }), + ); + + expect(mockLogger).toHaveBeenCalledWith('Warn', expect.stringContaining("model 'gpt-4o'")); + }); + + test('should call init with logger on start to create AI router', async () => { const realMakeRoutes = jest.requireActual('../src/routes').default; mockMakeRoutes.mockImplementation(realMakeRoutes); + const provider = createMockAiProvider(); const agent = new Agent(options); + agent.addAi(provider); - agent.addAi({ - name: 'gpt4-base', - provider: 'openai', - apiKey: 'test-key', - model: 'gpt-4', - }); + await agent.start(); - await expect(agent.start()).rejects.toThrow( - "Model 'gpt-4' does not support tools. Please use a model that supports function calling.", - ); + expect(provider.init).toHaveBeenCalledWith(options.logger); }); test('should include ai_llms in schema meta when AI is configured', async () => { const agent = new Agent(options); - agent.addAi({ - name: 'gpt4o', - provider: 'openai', - apiKey: 'test-key', - model: 'gpt-4o', - }); + agent.addAi( + createMockAiProvider({ + providers: [{ name: 'gpt4o', provider: 'openai', model: 'gpt-4o' }], + }), + ); await agent.start(); diff --git a/packages/agent/test/routes/ai/ai-proxy.test.ts b/packages/agent/test/routes/ai/ai-proxy.test.ts index fd4eb83793..53b9ac8fc6 100644 --- a/packages/agent/test/routes/ai/ai-proxy.test.ts +++ b/packages/agent/test/routes/ai/ai-proxy.test.ts @@ -1,56 +1,28 @@ -// eslint-disable-next-line import/no-extraneous-dependencies -import { - AIBadRequestError, - AIError, - AINotConfiguredError, - AINotFoundError, - AIToolNotFoundError, - AIUnprocessableError, -} from '@forestadmin/ai-proxy'; -import { - BadRequestError, - NotFoundError, - UnprocessableError, -} from '@forestadmin/datasource-toolkit'; +import type { AiRouter } from '@forestadmin/datasource-toolkit'; + import { createMockContext } from '@shopify/jest-koa-mocks'; import AiProxyRoute from '../../../src/routes/ai/ai-proxy'; import { HttpCode, RouteType } from '../../../src/types'; import * as factories from '../../__factories__'; -const mockRoute = jest.fn(); - -jest.mock('@forestadmin/ai-proxy', () => { - const actual = jest.requireActual('@forestadmin/ai-proxy'); - - return { - ...actual, - Router: jest.fn().mockImplementation(() => ({ - route: mockRoute, - })), - }; -}); - describe('AiProxyRoute', () => { const options = factories.forestAdminHttpDriverOptions.build(); const services = factories.forestAdminHttpDriverServices.build(); const router = factories.router.mockAllMethods().build(); - const aiConfigurations = [ - { - name: 'gpt4', - provider: 'openai' as const, - apiKey: 'test-key', - model: 'gpt-4o', - }, - ]; + + let mockRoute: jest.Mock; + let aiRouter: AiRouter; beforeEach(() => { jest.clearAllMocks(); + mockRoute = jest.fn(); + aiRouter = { route: mockRoute }; }); describe('constructor', () => { test('should have RouteType.PrivateRoute', () => { - const route = new AiProxyRoute(services, options, aiConfigurations); + const route = new AiProxyRoute(services, options, aiRouter); expect(route.type).toBe(RouteType.PrivateRoute); }); @@ -58,7 +30,7 @@ describe('AiProxyRoute', () => { describe('setupRoutes', () => { test('should register POST route at /_internal/ai-proxy/:route', () => { - const route = new AiProxyRoute(services, options, aiConfigurations); + const route = new AiProxyRoute(services, options, aiRouter); route.setupRoutes(router); expect(router.post).toHaveBeenCalledWith('/_internal/ai-proxy/:route', expect.any(Function)); @@ -67,7 +39,7 @@ describe('AiProxyRoute', () => { describe('handleAiProxy', () => { test('should return 200 with response body on successful request', async () => { - const route = new AiProxyRoute(services, options, aiConfigurations); + const route = new AiProxyRoute(services, options, aiRouter); const expectedResponse = { result: 'success' }; mockRoute.mockResolvedValueOnce(expectedResponse); @@ -85,8 +57,8 @@ describe('AiProxyRoute', () => { expect(context.response.body).toEqual(expectedResponse); }); - test('should pass route, body, query, mcpConfigs and tokensByMcpServerName to router', async () => { - const route = new AiProxyRoute(services, options, aiConfigurations); + test('should pass route, body, query, mcpServerConfigs and requestHeaders to router', async () => { + const route = new AiProxyRoute(services, options, aiRouter); mockRoute.mockResolvedValueOnce({}); const context = createMockContext({ @@ -95,7 +67,6 @@ describe('AiProxyRoute', () => { }, requestBody: { messages: [{ role: 'user', content: 'Hello' }] }, }); - // Set query directly on context as createMockContext doesn't handle it properly context.query = { 'ai-name': 'gpt4' }; await (route as any).handleAiProxy(context); @@ -104,31 +75,29 @@ describe('AiProxyRoute', () => { route: 'ai-query', body: { messages: [{ role: 'user', content: 'Hello' }] }, query: { 'ai-name': 'gpt4' }, - mcpConfigs: undefined, // mcpServerConfigService.getConfiguration returns undefined in test + mcpServerConfigs: undefined, + requestHeaders: context.request.headers, }); }); - test('should inject oauth tokens into mcpConfigs when header is provided', async () => { - const route = new AiProxyRoute(services, options, aiConfigurations); + test('should pass mcpServerConfigs from forestAdminClient to router', async () => { + const route = new AiProxyRoute(services, options, aiRouter); mockRoute.mockResolvedValueOnce({}); const mcpConfigs = { configs: { server1: { type: 'http' as const, url: 'https://server1.com' }, - server2: { type: 'http' as const, url: 'https://server2.com' }, }, }; jest .spyOn(options.forestAdminClient.mcpServerConfigService, 'getConfiguration') .mockResolvedValueOnce(mcpConfigs); - const tokens = { server1: 'Bearer token1', server2: 'Bearer token2' }; const context = createMockContext({ customProperties: { params: { route: 'ai-query' }, }, requestBody: { messages: [] }, - headers: { 'x-mcp-oauth-tokens': JSON.stringify(tokens) }, }); context.query = {}; @@ -136,154 +105,26 @@ describe('AiProxyRoute', () => { expect(mockRoute).toHaveBeenCalledWith( expect.objectContaining({ - mcpConfigs: { - configs: { - server1: { - type: 'http', - url: 'https://server1.com', - headers: { Authorization: 'Bearer token1' }, - }, - server2: { - type: 'http', - url: 'https://server2.com', - headers: { Authorization: 'Bearer token2' }, - }, - }, - }, + mcpServerConfigs: mcpConfigs, + requestHeaders: context.request.headers, }), ); }); - test('should throw BadRequestError when x-mcp-oauth-tokens header contains invalid JSON', async () => { - const route = new AiProxyRoute(services, options, aiConfigurations); + test('should let errors from aiRouter propagate unchanged', async () => { + const route = new AiProxyRoute(services, options, aiRouter); + const error = new Error('AI error'); + mockRoute.mockRejectedValueOnce(error); const context = createMockContext({ customProperties: { params: { route: 'ai-query' }, + query: {}, }, - requestBody: { messages: [] }, - headers: { 'x-mcp-oauth-tokens': '{ invalid json }' }, - }); - context.query = {}; - - await expect((route as any).handleAiProxy(context)).rejects.toThrow(BadRequestError); - await expect((route as any).handleAiProxy(context)).rejects.toThrow( - 'Invalid JSON in x-mcp-oauth-tokens header', - ); - }); - - describe('error handling', () => { - test('should convert AINotConfiguredError to UnprocessableError with agent-specific message', async () => { - const route = new AiProxyRoute(services, options, aiConfigurations); - mockRoute.mockRejectedValueOnce(new AINotConfiguredError()); - - const context = createMockContext({ - customProperties: { - params: { route: 'ai-query' }, - query: {}, - }, - requestBody: {}, - }); - - await expect((route as any).handleAiProxy(context)).rejects.toMatchObject({ - name: 'UnprocessableError', - message: 'AI is not configured. Please call addAi() on your agent.', - }); - }); - - test('should convert AIToolNotFoundError to NotFoundError', async () => { - const route = new AiProxyRoute(services, options, aiConfigurations); - mockRoute.mockRejectedValueOnce(new AIToolNotFoundError('tool-name')); - - const context = createMockContext({ - customProperties: { - params: { route: 'invoke-remote-tool' }, - query: { 'tool-name': 'unknown-tool' }, - }, - requestBody: {}, - }); - - await expect((route as any).handleAiProxy(context)).rejects.toThrow(NotFoundError); + requestBody: {}, }); - test('should convert AINotFoundError to NotFoundError', async () => { - const route = new AiProxyRoute(services, options, aiConfigurations); - mockRoute.mockRejectedValueOnce(new AINotFoundError('Resource not found')); - - const context = createMockContext({ - customProperties: { - params: { route: 'ai-query' }, - query: {}, - }, - requestBody: {}, - }); - - await expect((route as any).handleAiProxy(context)).rejects.toThrow(NotFoundError); - }); - - test('should convert AIBadRequestError to BadRequestError', async () => { - const route = new AiProxyRoute(services, options, aiConfigurations); - mockRoute.mockRejectedValueOnce(new AIBadRequestError('Invalid input')); - - const context = createMockContext({ - customProperties: { - params: { route: 'ai-query' }, - query: {}, - }, - requestBody: {}, - }); - - await expect((route as any).handleAiProxy(context)).rejects.toThrow(BadRequestError); - }); - - test('should convert AIUnprocessableError to UnprocessableError', async () => { - const route = new AiProxyRoute(services, options, aiConfigurations); - mockRoute.mockRejectedValueOnce(new AIUnprocessableError('Invalid input')); - - const context = createMockContext({ - customProperties: { - params: { route: 'ai-query' }, - query: {}, - }, - requestBody: {}, - }); - - await expect((route as any).handleAiProxy(context)).rejects.toThrow(UnprocessableError); - }); - - test('should convert generic AIError to UnprocessableError', async () => { - const route = new AiProxyRoute(services, options, aiConfigurations); - mockRoute.mockRejectedValueOnce(new AIError('Generic AI error')); - - const context = createMockContext({ - customProperties: { - params: { route: 'ai-query' }, - query: {}, - }, - requestBody: {}, - }); - - await expect((route as any).handleAiProxy(context)).rejects.toThrow(UnprocessableError); - }); - - test('should re-throw unknown errors unchanged', async () => { - const route = new AiProxyRoute(services, options, aiConfigurations); - const unknownError = new Error('Unknown error'); - mockRoute.mockRejectedValueOnce(unknownError); - - const context = createMockContext({ - customProperties: { - params: { route: 'ai-query' }, - }, - requestBody: {}, - }); - context.query = {}; - - const promise = (route as any).handleAiProxy(context); - - await expect(promise).rejects.toBe(unknownError); - expect(unknownError).not.toBeInstanceOf(UnprocessableError); - }); + await expect((route as any).handleAiProxy(context)).rejects.toBe(error); }); }); }); diff --git a/packages/agent/test/routes/index.test.ts b/packages/agent/test/routes/index.test.ts index d2dc94aac8..d2e56b8ddb 100644 --- a/packages/agent/test/routes/index.test.ts +++ b/packages/agent/test/routes/index.test.ts @@ -300,8 +300,8 @@ describe('Route index', () => { }); }); - describe('with AI configurations', () => { - test('should not include AI routes when aiConfigurations is empty', () => { + describe('with AI router', () => { + test('should not include AI routes when aiRouter is null', () => { const dataSource = factories.dataSource.buildWithCollections([ factories.collection.build({ name: 'books' }), ]); @@ -310,69 +310,31 @@ describe('Route index', () => { dataSource, factories.forestAdminHttpDriverOptions.build(), factories.forestAdminHttpDriverServices.build(), - [], + null, ); const aiRoute = routes.find(route => route instanceof AiProxyRoute); expect(aiRoute).toBeUndefined(); }); - test('should include AiProxyRoute when AI configurations are provided', () => { + test('should include AiProxyRoute when an AI router is provided', () => { const dataSource = factories.dataSource.buildWithCollections([ factories.collection.build({ name: 'books' }), ]); - const aiConfigurations = [ - { - name: 'gpt4', - provider: 'openai' as const, - apiKey: 'test-key', - model: 'gpt-4o', - }, - ]; + const aiRouter = { route: jest.fn() }; const routes = makeRoutes( dataSource, factories.forestAdminHttpDriverOptions.build(), factories.forestAdminHttpDriverServices.build(), - aiConfigurations, + aiRouter, ); const aiRoute = routes.find(route => route instanceof AiProxyRoute); expect(aiRoute).toBeTruthy(); expect(aiRoute).toBeInstanceOf(AiProxyRoute); }); - - test('should include only one AiProxyRoute even with multiple AI configurations', () => { - const dataSource = factories.dataSource.buildWithCollections([ - factories.collection.build({ name: 'books' }), - ]); - - const aiConfigurations = [ - { - name: 'gpt4', - provider: 'openai' as const, - apiKey: 'test-key', - model: 'gpt-4o', - }, - { - name: 'gpt3', - provider: 'openai' as const, - apiKey: 'test-key-2', - model: 'gpt-3.5-turbo', - }, - ]; - - const routes = makeRoutes( - dataSource, - factories.forestAdminHttpDriverOptions.build(), - factories.forestAdminHttpDriverServices.build(), - aiConfigurations, - ); - - const aiRoutes = routes.filter(route => route instanceof AiProxyRoute); - expect(aiRoutes).toHaveLength(1); - }); }); }); }); diff --git a/packages/agent/test/utils/forest-schema/generator.test.ts b/packages/agent/test/utils/forest-schema/generator.test.ts index 4183b60473..67f90f9a2d 100644 --- a/packages/agent/test/utils/forest-schema/generator.test.ts +++ b/packages/agent/test/utils/forest-schema/generator.test.ts @@ -72,19 +72,19 @@ describe('SchemaGenerator', () => { }); }); - test('it should serialize ai_llms when AI configurations are provided', async () => { - const aiConfigurations = [ - { name: 'gpt4', provider: 'openai' as const, apiKey: 'key1', model: 'gpt-4o' }, - { name: 'claude', provider: 'openai' as const, apiKey: 'key2', model: 'claude-3' }, + test('it should serialize ai_llms when AI providers are provided', async () => { + const aiProviders = [ + { name: 'gpt4', provider: 'openai', model: 'gpt-4o' }, + { name: 'claude', provider: 'anthropic', model: 'claude-sonnet-4-5-20250929' }, ]; - const schema = await SchemaGenerator.buildMetadata(null, aiConfigurations); + const schema = await SchemaGenerator.buildMetadata(null, aiProviders); expect(schema).toStrictEqual({ meta: { ai_llms: [ { name: 'gpt4', provider: 'openai' }, - { name: 'claude', provider: 'openai' }, + { name: 'claude', provider: 'anthropic' }, ], liana: 'forest-nodejs-agent', liana_version: expect.any(String), diff --git a/packages/ai-proxy/src/create-ai-provider.ts b/packages/ai-proxy/src/create-ai-provider.ts new file mode 100644 index 0000000000..13bf3a5d40 --- /dev/null +++ b/packages/ai-proxy/src/create-ai-provider.ts @@ -0,0 +1,38 @@ +import type { McpConfiguration } from './mcp-client'; +import type { AiConfiguration } from './provider'; +import type { RouterRouteArgs } from './schemas/route'; +import type { AiProviderDefinition, AiRouter } from '@forestadmin/datasource-toolkit'; + +import { extractMcpOauthTokensFromHeaders, injectOauthTokens } from './oauth-token-injector'; +import { Router } from './router'; + +function resolveMcpConfigs(args: Parameters[0]): McpConfiguration | undefined { + const tokensByMcpServerName = args.requestHeaders + ? extractMcpOauthTokensFromHeaders(args.requestHeaders) + : undefined; + + return injectOauthTokens({ + mcpConfigs: args.mcpServerConfigs as McpConfiguration | undefined, + tokensByMcpServerName, + }); +} + +// eslint-disable-next-line import/prefer-default-export +export function createAiProvider(config: AiConfiguration): AiProviderDefinition { + return { + providers: [{ name: config.name, provider: config.provider, model: config.model }], + init(logger) { + const router = new Router({ aiConfigurations: [config], logger }); + + return { + route: args => + router.route({ + route: args.route, + body: args.body, + query: args.query, + mcpConfigs: resolveMcpConfigs(args), + } as RouterRouteArgs), + }; + }, + }; +} diff --git a/packages/ai-proxy/src/errors.ts b/packages/ai-proxy/src/errors.ts index fc1e5e2507..f817a4e7c6 100644 --- a/packages/ai-proxy/src/errors.ts +++ b/packages/ai-proxy/src/errors.ts @@ -1,33 +1,44 @@ /** - * ------------------------------------- - * ------------------------------------- - * ------------------------------------- - * All custom errors must extend the AIError class. - * This inheritance is crucial for proper error translation - * and consistent handling throughout the system. - * ------------------------------------- - * ------------------------------------- - * ------------------------------------- + * All custom AI errors extend HTTP-status error classes (BadRequestError, NotFoundError, + * UnprocessableError) from datasource-toolkit. This allows the agent's error middleware + * to map them to their natural HTTP status codes automatically. + * + * Hierarchy: + * + * UnprocessableError (422) + * ├── AIError (general AI errors) + * │ ├── AINotConfiguredError + * │ └── McpError + * │ ├── McpConnectionError, McpConflictError, McpConfigError + * └── AIUnprocessableError (provider/tool input errors) + * ├── OpenAIUnprocessableError, AIToolUnprocessableError + * + * BadRequestError (400) + * └── AIBadRequestError + * └── AIModelNotSupportedError + * + * NotFoundError (404) + * └── AINotFoundError + * └── AIToolNotFoundError */ // eslint-disable-next-line max-classes-per-file -export class AIError extends Error { - readonly status: number; - - constructor(message: string, status = 422) { - if (status < 100 || status > 599) { - throw new RangeError(`Invalid HTTP status code: ${status}`); - } +import { + BadRequestError, + NotFoundError, + UnprocessableError, +} from '@forestadmin/datasource-toolkit'; +export class AIError extends UnprocessableError { + constructor(message: string) { super(message); this.name = 'AIError'; - this.status = status; } } -export class AIBadRequestError extends AIError { +export class AIBadRequestError extends BadRequestError { constructor(message: string) { - super(message, 400); + super(message); this.name = 'AIBadRequestError'; } } @@ -41,23 +52,23 @@ export class AIModelNotSupportedError extends AIBadRequestError { } } -export class AINotFoundError extends AIError { +export class AINotFoundError extends NotFoundError { constructor(message: string) { - super(message, 404); + super(message); this.name = 'AINotFoundError'; } } -export class AIUnprocessableError extends AIError { +export class AIUnprocessableError extends UnprocessableError { constructor(message: string) { - super(message, 422); + super(message); this.name = 'AIUnprocessableError'; } } export class AINotConfiguredError extends AIError { constructor(message = 'AI is not configured') { - super(message, 422); + super(message); this.name = 'AINotConfiguredError'; } } diff --git a/packages/ai-proxy/src/index.ts b/packages/ai-proxy/src/index.ts index bee805b44d..dfa50e46ed 100644 --- a/packages/ai-proxy/src/index.ts +++ b/packages/ai-proxy/src/index.ts @@ -2,6 +2,7 @@ import type { McpConfiguration } from './mcp-client'; import McpConfigChecker from './mcp-config-checker'; +export { createAiProvider } from './create-ai-provider'; export * from './provider-dispatcher'; export * from './remote-tools'; export * from './router'; diff --git a/packages/ai-proxy/src/schemas/route.ts b/packages/ai-proxy/src/schemas/route.ts index 3e8c75aba4..2baa234fad 100644 --- a/packages/ai-proxy/src/schemas/route.ts +++ b/packages/ai-proxy/src/schemas/route.ts @@ -1,3 +1,5 @@ +import type { McpConfiguration } from '../mcp-client'; + import { z } from 'zod'; // Base query schema with common optional parameters @@ -83,6 +85,7 @@ export type RemoteToolsArgs = z.infer; // Derived types for consumers export type DispatchBody = AiQueryArgs['body']; +export type RouterRouteArgs = RouteArgs & { mcpConfigs?: McpConfiguration }; // Backward compatibility types export type InvokeRemoteToolBody = InvokeRemoteToolArgs['body']; diff --git a/packages/ai-proxy/test/create-ai-provider.test.ts b/packages/ai-proxy/test/create-ai-provider.test.ts new file mode 100644 index 0000000000..43e90bad5d --- /dev/null +++ b/packages/ai-proxy/test/create-ai-provider.test.ts @@ -0,0 +1,131 @@ +import type { AiConfiguration } from '../src/provider'; + +import { createAiProvider } from '../src/create-ai-provider'; +import { Router } from '../src/router'; + +jest.mock('../src/router'); + +const routeMock = jest.fn(); +jest.mocked(Router).mockImplementation(() => ({ route: routeMock } as any)); + +describe('createAiProvider', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + const config: AiConfiguration = { + name: 'my-ai', + provider: 'openai', + model: 'gpt-4o', + apiKey: 'test-key', + }; + + test('should return providers array from config', () => { + const result = createAiProvider(config); + + expect(result.providers).toEqual([{ name: 'my-ai', provider: 'openai', model: 'gpt-4o' }]); + }); + + test('init should create a Router with the config and logger', () => { + const provider = createAiProvider(config); + const mockLogger = jest.fn(); + provider.init(mockLogger); + + expect(Router).toHaveBeenCalledWith({ + aiConfigurations: [config], + logger: mockLogger, + }); + }); + + test('init should return an AiRouter with a route method', () => { + const provider = createAiProvider(config); + const result = provider.init(jest.fn()); + + expect(typeof result.route).toBe('function'); + }); + + describe('route wrapper', () => { + test('should pass route, body, query to underlying Router', async () => { + routeMock.mockResolvedValue({ result: 'ok' }); + const provider = createAiProvider(config); + const aiRouter = provider.init(jest.fn()); + + const result = await aiRouter.route({ + route: 'ai-query', + body: { messages: [] }, + query: { 'ai-name': 'my-ai' }, + }); + + expect(routeMock).toHaveBeenCalledWith({ + route: 'ai-query', + body: { messages: [] }, + query: { 'ai-name': 'my-ai' }, + mcpConfigs: undefined, + }); + expect(result).toEqual({ result: 'ok' }); + }); + + test('should pass mcpServerConfigs as mcpConfigs to Router when no requestHeaders', async () => { + routeMock.mockResolvedValue({}); + const provider = createAiProvider(config); + const aiRouter = provider.init(jest.fn()); + + await aiRouter.route({ + route: 'remote-tools', + mcpServerConfigs: { configs: { server1: { command: 'test', args: [] } } }, + }); + + expect(routeMock).toHaveBeenCalledWith({ + route: 'remote-tools', + body: undefined, + query: undefined, + mcpConfigs: { configs: { server1: { command: 'test', args: [] } } }, + }); + }); + + test('should inject OAuth tokens from requestHeaders into mcpConfigs', async () => { + routeMock.mockResolvedValue({}); + const provider = createAiProvider(config); + const aiRouter = provider.init(jest.fn()); + const oauthTokens = JSON.stringify({ server1: 'Bearer token123' }); + + await aiRouter.route({ + route: 'remote-tools', + mcpServerConfigs: { + configs: { server1: { type: 'http', url: 'https://server1.com' } }, + }, + requestHeaders: { 'x-mcp-oauth-tokens': oauthTokens }, + }); + + expect(routeMock).toHaveBeenCalledWith({ + route: 'remote-tools', + body: undefined, + query: undefined, + mcpConfigs: { + configs: { + server1: { + type: 'http', + url: 'https://server1.com', + headers: { Authorization: 'Bearer token123' }, + }, + }, + }, + }); + }); + + test('should pass mcpConfigs as undefined when no mcpServerConfigs provided', async () => { + routeMock.mockResolvedValue({}); + const provider = createAiProvider(config); + const aiRouter = provider.init(jest.fn()); + + await aiRouter.route({ route: 'remote-tools' }); + + expect(routeMock).toHaveBeenCalledWith({ + route: 'remote-tools', + body: undefined, + query: undefined, + mcpConfigs: undefined, + }); + }); + }); +}); diff --git a/packages/ai-proxy/test/errors.test.ts b/packages/ai-proxy/test/errors.test.ts new file mode 100644 index 0000000000..527f3bab65 --- /dev/null +++ b/packages/ai-proxy/test/errors.test.ts @@ -0,0 +1,103 @@ +import { + BadRequestError, + NotFoundError, + UnprocessableError, +} from '@forestadmin/datasource-toolkit'; + +import { + AIBadRequestError, + AIError, + AIModelNotSupportedError, + AINotConfiguredError, + AINotFoundError, + AIToolNotFoundError, + AIToolUnprocessableError, + AIUnprocessableError, + McpConfigError, + McpConflictError, + McpConnectionError, + McpError, + OpenAIUnprocessableError, +} from '../src/errors'; + +describe('AI Error Hierarchy', () => { + describe('UnprocessableError branch (422)', () => { + test('AIError extends UnprocessableError', () => { + const error = new AIError('test'); + expect(error).toBeInstanceOf(UnprocessableError); + }); + + test('AINotConfiguredError extends UnprocessableError via AIError', () => { + const error = new AINotConfiguredError(); + expect(error).toBeInstanceOf(AIError); + expect(error).toBeInstanceOf(UnprocessableError); + }); + + test('McpError extends UnprocessableError via AIError', () => { + const error = new McpError('test'); + expect(error).toBeInstanceOf(AIError); + expect(error).toBeInstanceOf(UnprocessableError); + }); + + test('McpConnectionError extends UnprocessableError via McpError', () => { + const error = new McpConnectionError('test'); + expect(error).toBeInstanceOf(McpError); + expect(error).toBeInstanceOf(UnprocessableError); + }); + + test('McpConflictError extends UnprocessableError via McpError', () => { + const error = new McpConflictError('entity'); + expect(error).toBeInstanceOf(McpError); + expect(error).toBeInstanceOf(UnprocessableError); + }); + + test('McpConfigError extends UnprocessableError via McpError', () => { + const error = new McpConfigError('test'); + expect(error).toBeInstanceOf(McpError); + expect(error).toBeInstanceOf(UnprocessableError); + }); + + test('AIUnprocessableError extends UnprocessableError', () => { + const error = new AIUnprocessableError('test'); + expect(error).toBeInstanceOf(UnprocessableError); + }); + + test('OpenAIUnprocessableError extends UnprocessableError via AIUnprocessableError', () => { + const error = new OpenAIUnprocessableError('test'); + expect(error).toBeInstanceOf(AIUnprocessableError); + expect(error).toBeInstanceOf(UnprocessableError); + }); + + test('AIToolUnprocessableError extends UnprocessableError via AIUnprocessableError', () => { + const error = new AIToolUnprocessableError('test'); + expect(error).toBeInstanceOf(AIUnprocessableError); + expect(error).toBeInstanceOf(UnprocessableError); + }); + }); + + describe('BadRequestError branch (400)', () => { + test('AIBadRequestError extends BadRequestError', () => { + const error = new AIBadRequestError('test'); + expect(error).toBeInstanceOf(BadRequestError); + }); + + test('AIModelNotSupportedError extends BadRequestError via AIBadRequestError', () => { + const error = new AIModelNotSupportedError('gpt-4'); + expect(error).toBeInstanceOf(AIBadRequestError); + expect(error).toBeInstanceOf(BadRequestError); + }); + }); + + describe('NotFoundError branch (404)', () => { + test('AINotFoundError extends NotFoundError', () => { + const error = new AINotFoundError('test'); + expect(error).toBeInstanceOf(NotFoundError); + }); + + test('AIToolNotFoundError extends NotFoundError via AINotFoundError', () => { + const error = new AIToolNotFoundError('test'); + expect(error).toBeInstanceOf(AINotFoundError); + expect(error).toBeInstanceOf(NotFoundError); + }); + }); +}); diff --git a/packages/ai-proxy/test/router.test.ts b/packages/ai-proxy/test/router.test.ts index 82abb1dbf1..8a64759b26 100644 --- a/packages/ai-proxy/test/router.test.ts +++ b/packages/ai-proxy/test/router.test.ts @@ -1,7 +1,7 @@ import type { DispatchBody, InvokeRemoteToolArgs, Route } from '../src'; import type { Logger } from '@forestadmin/datasource-toolkit'; -import { AIBadRequestError, AIModelNotSupportedError, Router } from '../src'; +import { AIModelNotSupportedError, Router } from '../src'; import McpClient from '../src/mcp-client'; const invokeToolMock = jest.fn(); diff --git a/packages/datasource-toolkit/src/index.ts b/packages/datasource-toolkit/src/index.ts index 5a60731a29..b00d534e01 100644 --- a/packages/datasource-toolkit/src/index.ts +++ b/packages/datasource-toolkit/src/index.ts @@ -1,6 +1,7 @@ // Misc export * from './errors'; export * from './factory'; +export type { AiProviderDefinition, AiProviderMeta, AiRouter } from './interfaces/ai'; export { MAP_ALLOWED_OPERATORS_FOR_COLUMN_TYPE as allowedOperatorsForColumnType } from './validation/rules'; // Base Collection & DataSource diff --git a/packages/datasource-toolkit/src/interfaces/ai.ts b/packages/datasource-toolkit/src/interfaces/ai.ts new file mode 100644 index 0000000000..781bb4fe66 --- /dev/null +++ b/packages/datasource-toolkit/src/interfaces/ai.ts @@ -0,0 +1,29 @@ +import type { Logger } from '../factory'; + +/** Metadata describing a configured AI provider, used in schema reporting and logging. */ +export interface AiProviderMeta { + name: string; + provider: string; + model: string; +} + +export interface AiRouter { + /** + * Route a request to the AI proxy. + * + * Implementations should throw BusinessError subclasses (BadRequestError, NotFoundError, + * UnprocessableError) for proper HTTP status mapping by the agent's error middleware. + */ + route(args: { + route: string; + body?: unknown; + query?: Record; + mcpServerConfigs?: unknown; + requestHeaders?: Record; + }): Promise; +} + +export interface AiProviderDefinition { + providers: AiProviderMeta[]; + init(logger: Logger): AiRouter; +} diff --git a/packages/forestadmin-client/src/schema/types.ts b/packages/forestadmin-client/src/schema/types.ts index e86d2b296c..b65e2a5a7e 100644 --- a/packages/forestadmin-client/src/schema/types.ts +++ b/packages/forestadmin-client/src/schema/types.ts @@ -6,7 +6,7 @@ export type ForestSchema = { liana: string; liana_version: string; liana_features: Record | null; - ai_llms?: Array<{ provider: string }> | null; + ai_llms?: Array<{ name: string; provider: string }> | null; stack: { engine: string; engine_version: string;