diff --git a/src/ax/dsp/generate.ts b/src/ax/dsp/generate.ts index 2fc36e67d..f972abd18 100644 --- a/src/ax/dsp/generate.ts +++ b/src/ax/dsp/generate.ts @@ -180,6 +180,10 @@ export class AxGen this.promptTemplate.setInstruction(instruction); } + public getInstruction(): string { + return this.promptTemplate.getInstruction(); + } + private getSignatureName(): string { return this.signature.getDescription() || 'unknown_signature'; } diff --git a/src/ax/dsp/optimizers/gepa.test.ts b/src/ax/dsp/optimizers/gepa.test.ts new file mode 100644 index 000000000..b762ff859 --- /dev/null +++ b/src/ax/dsp/optimizers/gepa.test.ts @@ -0,0 +1,65 @@ +import { AxGEPA } from './gepa.js'; +import { ax } from '../template.js'; +import type { AxAIService } from '../../ai/types.js'; +import { vi, describe, it, expect } from 'vitest'; + +describe('AxGEPA Optimizer', () => { + it('should use the instruction from the program', async () => { + const ai: AxAIService = { + name: 'mockAI', + chat: vi.fn().mockResolvedValue({ + results: [{ content: JSON.stringify({ answer: '4' }) }], + }), + getOptions: vi.fn().mockReturnValue({}), + getLogger: vi.fn().mockReturnValue(undefined), + clone: vi.fn().mockReturnThis(), + }; + + const program = ax('question:string -> answer:string'); + const customInstruction = 'This is a custom instruction.'; + program.setInstruction(customInstruction); + + const examples = [ + { question: 'What is 2+2?', answer: '4' }, + { question: 'What is 3+3?', answer: '6' }, + ]; + + const metricFn = () => 1; + + const optimizer = new AxGEPA({ + studentAI: ai, + teacherAI: ai, + numTrials: 1, // Run only one trial for a predictable test + }); + + // Spy on getBaseInstruction to confirm it's called and what it returns. + const getBaseInstructionSpy = vi.spyOn( + optimizer, + 'getBaseInstruction' as any + ); + + // Mock the reflectInstruction to prevent it from running and making real AI calls + const reflectSpy = vi + .spyOn(optimizer, 'reflectInstruction' as any) + .mockResolvedValue('a new evolved instruction'); + + await optimizer.compile(program, examples, metricFn, { + maxMetricCalls: 10, + }); + + // 1. Verify that our patched getBaseInstruction is working + expect(getBaseInstructionSpy).toHaveBeenCalled(); + const baseInstruction = await getBaseInstructionSpy.mock.results[0].value; + expect(baseInstruction).toBe(customInstruction); + + // 2. Verify that this base instruction is passed to the first reflection call + expect(reflectSpy).toHaveBeenCalled(); + expect(reflectSpy).toHaveBeenCalledWith( + customInstruction, + expect.anything(), + expect.anything(), + expect.anything(), + expect.anything() + ); + }); +}); diff --git a/src/ax/dsp/optimizers/gepa.ts b/src/ax/dsp/optimizers/gepa.ts index 59efcd964..543a0b06b 100644 --- a/src/ax/dsp/optimizers/gepa.ts +++ b/src/ax/dsp/optimizers/gepa.ts @@ -836,17 +836,10 @@ export class AxGEPA extends AxBaseOptimizer { private async getBaseInstruction( program: Readonly> ): Promise { - try { - // If program exposes instruction via signature, prefer it - const sig: any = program.getSignature?.(); - if ( - sig && - typeof sig.instruction === 'string' && - sig.instruction.length > 0 - ) { - return sig.instruction as string; - } - } catch {} + const instruction = program.getInstruction(); + if (instruction && instruction.length > 0) { + return instruction; + } return 'Follow the task precisely. Be concise, correct, and consistent.'; } diff --git a/src/ax/dsp/prompt.ts b/src/ax/dsp/prompt.ts index 9b2072fa3..e663e1f5e 100644 --- a/src/ax/dsp/prompt.ts +++ b/src/ax/dsp/prompt.ts @@ -50,6 +50,10 @@ export class AxPromptTemplate { public setInstruction(instruction: string): void { this.task = { type: 'text', text: instruction }; } + + public getInstruction(): string { + return this.task.text; + } private readonly thoughtFieldName: string; private readonly functions?: Readonly; private readonly cacheSystemPrompt?: boolean;