Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 20 additions & 27 deletions core/src/agents/llm_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import {
import {context, trace} from '@opentelemetry/api';
import {cloneDeep} from 'lodash-es';
import {z} from 'zod';
import {z as z3} from 'zod/v3';
import {z as z4} from 'zod/v4';

import {
BaseCodeExecutor,
Expand Down Expand Up @@ -62,6 +64,7 @@ import {
traceCallLlm,
tracer,
} from '../telemetry/tracing.js';
import {isZodObject, zodObjectToSchema} from '../utils/simple_zod_to_json.js';
import {BaseAgent, BaseAgentConfig} from './base_agent.js';
import {
BaseLlmRequestProcessor,
Expand All @@ -86,6 +89,14 @@ import {InvocationContext} from './invocation_context.js';
import {ReadonlyContext} from './readonly_context.js';
import {StreamingMode} from './run_config.js';

/**
* Input/output schema type for agent.
*/
export type LlmAgentSchema =
| z3.ZodObject<z3.ZodRawShape>
| z4.ZodObject<z4.ZodRawShape>
| Schema;

/** An object that can provide an instruction string. */
export type InstructionProvider = (
context: ReadonlyContext,
Expand Down Expand Up @@ -264,16 +275,10 @@ export interface LlmAgentConfig extends BaseAgentConfig {
includeContents?: 'default' | 'none';

/** The input schema when agent is used as a tool. */
inputSchema?: Schema;
inputSchema?: LlmAgentSchema;

/**
* The output schema when agent replies.
*
* NOTE:
* When this is set, agent can ONLY reply and CANNOT use any tools, such as
* function tools, RAGs, agent transfer, etc.
*/
outputSchema?: Schema;
/** The output schema when agent replies. */
outputSchema?: LlmAgentSchema;

/**
* The key in session state to store the output of the agent.
Expand Down Expand Up @@ -1338,8 +1343,12 @@ export class LlmAgent extends BaseAgent {
this.disallowTransferToParent = config.disallowTransferToParent ?? false;
this.disallowTransferToPeers = config.disallowTransferToPeers ?? false;
this.includeContents = config.includeContents ?? 'default';
this.inputSchema = config.inputSchema;
this.outputSchema = config.outputSchema;
this.inputSchema = isZodObject(config.inputSchema)
? zodObjectToSchema(config.inputSchema)
: config.inputSchema;
this.outputSchema = isZodObject(config.outputSchema)
? zodObjectToSchema(config.outputSchema)
: config.outputSchema;
this.outputKey = config.outputKey;
this.beforeModelCallback = config.beforeModelCallback;
this.afterModelCallback = config.afterModelCallback;
Expand Down Expand Up @@ -1398,22 +1407,6 @@ export class LlmAgent extends BaseAgent {
this.disallowTransferToParent = true;
this.disallowTransferToPeers = true;
}

if (this.subAgents && this.subAgents.length > 0) {
throw new Error(
`Invalid config for agent ${
this.name
}: if outputSchema is set, subAgents must be empty to disable agent transfer.`,
);
}

if (this.tools && this.tools.length > 0) {
throw new Error(
`Invalid config for agent ${
this.name
}: if outputSchema is set, tools must be empty`,
);
}
}
}

Expand Down
1 change: 1 addition & 0 deletions core/src/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export type {
BeforeToolCallback,
InstructionProvider,
LlmAgentConfig,
LlmAgentSchema,
SingleAfterModelCallback,
SingleAfterToolCallback,
SingleBeforeModelCallback,
Expand Down
156 changes: 155 additions & 1 deletion core/test/agents/llm_agent_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ import {
PluginManager,
Session,
} from '@google/adk';
import {Content} from '@google/genai';
import {Content, Schema, Type} from '@google/genai';
import {z as z3} from 'zod/v3';
import {z as z4} from 'zod/v4';

class MockLlmConnection implements BaseLlmConnection {
sendHistory(_history: Content[]): Promise<void> {
Expand Down Expand Up @@ -230,3 +232,155 @@ describe('LlmAgent.callLlm', () => {
expect(result).toEqual([{errorCode: '500', errorMessage: 'LLM error'}]);
});
});

describe('LlmAgent Schema Initialization', () => {
it('should initialize inputSchema from Schema object', () => {
const inputSchema: Schema = {
type: Type.OBJECT,
properties: {foo: {type: Type.STRING}},
};
const agent = new LlmAgent({name: 'test', inputSchema});
expect(agent.inputSchema).toEqual(inputSchema);
});

it('should initialize inputSchema from Zod v4 object', () => {
const zodSchema = z4.object({foo: z4.string()});
const agent = new LlmAgent({
name: 'test',
inputSchema: zodSchema,
});
expect(agent.inputSchema).toBeDefined();
expect((agent.inputSchema as Schema).type).toBe('OBJECT');
expect((agent.inputSchema as Schema).properties?.foo?.type).toBe('STRING');
});

it('should initialize inputSchema from Zod v3 object', () => {
const zodSchema = z3.object({
foo: z3.string(),
});
const agent = new LlmAgent({
name: 'test',
inputSchema: zodSchema,
});
expect(agent.inputSchema).toBeDefined();
expect((agent.inputSchema as Schema).type).toBe('OBJECT');
expect((agent.inputSchema as Schema).properties?.foo?.type).toBe('STRING');
});

it('should initialize outputSchema from Schema object', () => {
const outputSchema: Schema = {
type: Type.OBJECT,
properties: {bar: {type: Type.NUMBER}},
};
const agent = new LlmAgent({name: 'test', outputSchema});
expect(agent.outputSchema).toEqual(outputSchema);
});

it('should initialize outputSchema from Zod z4 object', () => {
const zodSchema = z4.object({bar: z4.number()});
const agent = new LlmAgent({
name: 'test',
outputSchema: zodSchema,
});
expect(agent.outputSchema).toBeDefined();
expect((agent.outputSchema as Schema).type).toBe('OBJECT');
expect((agent.outputSchema as Schema).properties?.bar?.type).toBe('NUMBER');
});

it('should initialize outputSchema from Zod v3 object', () => {
const zodSchema = z3.object({
bar: z3.number(),
});
const agent = new LlmAgent({
name: 'test',
outputSchema: zodSchema,
});
expect(agent.outputSchema).toBeDefined();
expect((agent.outputSchema as Schema).type).toBe('OBJECT');
expect((agent.outputSchema as Schema).properties?.bar?.type).toBe('NUMBER');
});

it('should enforce transfer restrictions when outputSchema is present', () => {
const outputSchema: Schema = {type: Type.OBJECT};
const agent = new LlmAgent({
name: 'test',
outputSchema,
disallowTransferToParent: false,
disallowTransferToPeers: false,
});
expect(agent.disallowTransferToParent).toBe(true);
expect(agent.disallowTransferToPeers).toBe(true);
});
});

describe('LlmAgent Output Processing', () => {
let agent: LlmAgent;
let invocationContext: InvocationContext;
let validationSchema: Schema;

beforeEach(() => {
validationSchema = {
type: Type.OBJECT,
properties: {
answer: {type: Type.STRING},
},
};
agent = new LlmAgent({
name: 'test_agent',
outputSchema: validationSchema,
outputKey: 'result',
});
const mockState = {
hasDelta: () => false,
get: () => undefined,
set: () => {},
};
invocationContext = new InvocationContext({
invocationId: 'inv_123',
session: {
id: 'sess_123',
state: mockState,
events: [],
} as unknown as Session,
agent: agent,
pluginManager: new PluginManager(),
});
});

it('should save parsed JSON output to state based on outputKey', async () => {
const jsonOutput = JSON.stringify({answer: '42'});
const response: LlmResponse = {
content: {parts: [{text: jsonOutput}]},
};
agent.model = new MockLlm(response);

const generator = agent.runAsync(invocationContext);
const events: Event[] = [];
for await (const event of generator) {
events.push(event);
}

const lastEvent = events[events.length - 1];
expect(lastEvent).toBeDefined();
expect(lastEvent.content?.parts?.[0].text).toEqual(jsonOutput);
expect(lastEvent.actions?.stateDelta).toBeDefined();
expect(lastEvent.actions?.stateDelta?.['result']).toEqual({answer: '42'});
});

it('should not save output if invalid JSON', async () => {
const invalidJson = '{answer: 42'; // Missing closing brace
const response: LlmResponse = {
content: {parts: [{text: invalidJson}]},
};
agent.model = new MockLlm(response);

const generator = agent.runAsync(invocationContext);
const events: Event[] = [];
for await (const event of generator) {
events.push(event);
}

const lastEvent = events[events.length - 1];
expect(lastEvent.actions?.stateDelta?.['result']).toEqual(invalidJson);
});
});
Loading