Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 66 additions & 33 deletions core/src/agents/llm_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1716,26 +1716,35 @@ export class LlmAgent extends BaseAgent {
ctx,
this,
async function* () {
for await (const llmResponse of this.callLlmAsync(
invocationContext,
llmRequest,
modelResponseEvent,
)) {
// ======================================================================
// Postprocess after calling the LLM
// ======================================================================
for await (const event of this.postprocess(
const responsesGenerator = async function* (this: LlmAgent) {
for await (const llmResponse of this.callLlmAsync(
invocationContext,
llmRequest,
llmResponse,
modelResponseEvent,
)) {
// Update the mutable event id to avoid conflict
modelResponseEvent.id = createNewEventId();
modelResponseEvent.timestamp = new Date().getTime();
yield event;
// ======================================================================
// Postprocess after calling the LLM
// ======================================================================
for await (const event of this.postprocess(
invocationContext,
llmRequest,
llmResponse,
modelResponseEvent,
)) {
// Update the mutable event id to avoid conflict
modelResponseEvent.id = createNewEventId();
modelResponseEvent.timestamp = new Date().getTime();
yield event;
}
}
}
};

yield* this.runAndHandleError(
responsesGenerator.call(this),
invocationContext,
llmRequest,
modelResponseEvent,
);
},
);
span.end();
Expand Down Expand Up @@ -1798,6 +1807,11 @@ export class LlmAgent extends BaseAgent {
return;
}

if (invocationContext.runConfig?.pauseOnToolCalls) {
invocationContext.endInvocation = true;
return;
}

// Call functions
// TODO - b/425992518: bloated funciton input, fix.
// Tool callback passed to get rid of cyclic dependency.
Expand Down Expand Up @@ -1831,6 +1845,8 @@ export class LlmAgent extends BaseAgent {
});
if (toolConfirmationEvent) {
yield toolConfirmationEvent;
invocationContext.endInvocation = true;
return;
}

// Yields the function response event.
Expand Down Expand Up @@ -1910,12 +1926,7 @@ export class LlmAgent extends BaseAgent {
StreamingMode.SSE,
);

for await (const llmResponse of this.runAndHandleError(
responsesGenerator,
invocationContext,
llmRequest,
modelResponseEvent,
)) {
for await (const llmResponse of responsesGenerator) {
traceCallLlm({
invocationContext,
eventId: modelResponseEvent.id,
Expand Down Expand Up @@ -2001,12 +2012,12 @@ export class LlmAgent extends BaseAgent {
return undefined;
}

private async *runAndHandleError(
responseGenerator: AsyncGenerator<LlmResponse, void, void>,
private async *runAndHandleError<T extends LlmResponse | Event>(
responseGenerator: AsyncGenerator<T, void, void>,
invocationContext: InvocationContext,
llmRequest: LlmRequest,
modelResponseEvent: Event,
): AsyncGenerator<LlmResponse, void, void> {
): AsyncGenerator<T, void, void> {
try {
for await (const response of responseGenerator) {
yield response;
Expand All @@ -2030,17 +2041,39 @@ export class LlmAgent extends BaseAgent {
});

if (onModelErrorCallbackResponse) {
yield onModelErrorCallbackResponse;
yield onModelErrorCallbackResponse as T;
} else {
// If no plugins, just return the message.
const errorResponse = JSON.parse(modelError.message) as {
error: {code: number; message: string};
};

yield {
errorCode: String(errorResponse.error.code),
errorMessage: errorResponse.error.message,
};
let errorCode = 'UNKNOWN_ERROR';
let errorMessage = modelError.message;

try {
const errorResponse = JSON.parse(modelError.message) as {
error: {code: number; message: string};
};
if (errorResponse?.error) {
errorCode = String(errorResponse.error.code || 'UNKNOWN_ERROR');
errorMessage = errorResponse.error.message || errorMessage;
}
} catch {
// Ignore JSON parse error, use original message.
}

if (modelResponseEvent.actions) {
// We are yielding an Event
yield createEvent({
invocationId: invocationContext.invocationId,
author: this.name,
errorCode,
errorMessage,
}) as T;
} else {
// We are yielding an LlmResponse
yield {
errorCode,
errorMessage,
} as T;
}
}
} else {
logger.error('Unknown error during response generation', modelError);
Expand Down
10 changes: 9 additions & 1 deletion core/test/agents/llm_agent_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,15 @@ describe('LlmAgent.callLlm', () => {
async function callLlmUnderTest(): Promise<LlmResponse[]> {
const responses: LlmResponse[] = [];
// eslint-disable-next-line @typescript-eslint/no-explicit-any
for await (const response of (agent as any).callLlmAsync(
const responseGenerator = (agent as any).callLlmAsync(
invocationContext,
llmRequest,
modelResponseEvent,
);

// eslint-disable-next-line @typescript-eslint/no-explicit-any
for await (const response of (agent as any).runAndHandleError(
responseGenerator,
invocationContext,
llmRequest,
modelResponseEvent,
Expand Down
Loading