Skip to content

Commit a4d970e

Browse files
committed
better compaction support in our other chat variants
1 parent dfdab6b commit a4d970e

File tree

2 files changed

+185
-54
lines changed
  • packages/trigger-sdk/src/v3
  • references/ai-chat/src/trigger

2 files changed

+185
-54
lines changed

packages/trigger-sdk/src/v3/ai.ts

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2908,6 +2908,11 @@ async function pipeChatAndCapture(
29082908
class ChatMessageAccumulator {
29092909
modelMessages: ModelMessage[] = [];
29102910
uiMessages: UIMessage[] = [];
2911+
private _compaction?: ChatTaskCompactionOptions;
2912+
2913+
constructor(options?: { compaction?: ChatTaskCompactionOptions }) {
2914+
this._compaction = options?.compaction;
2915+
}
29112916

29122917
/**
29132918
* Add incoming messages from the transport payload.
@@ -2958,6 +2963,84 @@ class ChatMessageAccumulator {
29582963
// Conversion failed — skip model message accumulation for this response
29592964
}
29602965
}
2966+
2967+
/**
2968+
* Returns a `prepareStep` function for inner-loop compaction.
2969+
* Only available when `compaction` was provided to the constructor.
2970+
* Pass the result to `streamText({ prepareStep: conversation.prepareStep() })`.
2971+
*/
2972+
prepareStep(): ((args: { messages: ModelMessage[]; steps: CompactionStep[] }) => Promise<{ messages: ModelMessage[] } | undefined>) | undefined {
2973+
if (!this._compaction) return undefined;
2974+
const comp = this._compaction;
2975+
return async ({ messages, steps }) => {
2976+
const result = await chatCompact(messages, steps, {
2977+
shouldCompact: comp.shouldCompact,
2978+
summarize: (msgs) => comp.summarize({ messages: msgs, source: "inner" }),
2979+
});
2980+
return result.type === "skipped" ? undefined : result;
2981+
};
2982+
}
2983+
2984+
/**
2985+
* Run outer-loop compaction if needed. Call after adding the response
2986+
* and capturing usage. Applies `compactModelMessages` and `compactUIMessages`
2987+
* callbacks if configured.
2988+
*
2989+
* @returns `true` if compaction was performed, `false` otherwise.
2990+
*/
2991+
async compactIfNeeded(usage: LanguageModelUsage | undefined, context?: {
2992+
chatId?: string;
2993+
turn?: number;
2994+
clientData?: unknown;
2995+
totalUsage?: LanguageModelUsage;
2996+
}): Promise<boolean> {
2997+
if (!this._compaction || !usage) return false;
2998+
2999+
const shouldTrigger = await this._compaction.shouldCompact({
3000+
messages: this.modelMessages,
3001+
totalTokens: usage.totalTokens,
3002+
inputTokens: usage.inputTokens,
3003+
outputTokens: usage.outputTokens,
3004+
usage,
3005+
totalUsage: context?.totalUsage,
3006+
chatId: context?.chatId,
3007+
turn: context?.turn,
3008+
clientData: context?.clientData,
3009+
source: "outer",
3010+
});
3011+
3012+
if (!shouldTrigger) return false;
3013+
3014+
const summary = await this._compaction.summarize({
3015+
messages: this.modelMessages,
3016+
usage,
3017+
totalUsage: context?.totalUsage,
3018+
chatId: context?.chatId,
3019+
turn: context?.turn,
3020+
clientData: context?.clientData,
3021+
source: "outer",
3022+
});
3023+
3024+
const compactEvent: CompactMessagesEvent = {
3025+
summary,
3026+
uiMessages: this.uiMessages,
3027+
modelMessages: this.modelMessages,
3028+
chatId: context?.chatId ?? "",
3029+
turn: context?.turn ?? 0,
3030+
clientData: context?.clientData,
3031+
source: "outer",
3032+
};
3033+
3034+
this.modelMessages = this._compaction.compactModelMessages
3035+
? await this._compaction.compactModelMessages(compactEvent)
3036+
: [{ role: "assistant" as const, content: [{ type: "text" as const, text: `[Conversation summary]\n\n${summary}` }] }];
3037+
3038+
if (this._compaction.compactUIMessages) {
3039+
this.uiMessages = await this._compaction.compactUIMessages(compactEvent);
3040+
}
3041+
3042+
return true;
3043+
}
29613044
}
29623045

29633046
// ---------------------------------------------------------------------------
@@ -2973,6 +3056,8 @@ export type ChatSessionOptions = {
29733056
timeout?: string;
29743057
/** Max turns before ending. @default 100 */
29753058
maxTurns?: number;
3059+
/** Automatic context compaction — same options as `chat.task({ compaction })`. */
3060+
compaction?: ChatTaskCompactionOptions;
29763061
};
29773062

29783063
export type ChatTurn = {
@@ -3065,6 +3150,7 @@ function createChatSession(
30653150
idleTimeoutInSeconds = 30,
30663151
timeout = "1h",
30673152
maxTurns = 100,
3153+
compaction: sessionCompaction,
30683154
} = options;
30693155

30703156
return {
@@ -3168,14 +3254,62 @@ function createChatSession(
31683254
}
31693255

31703256
// Capture token usage from the streamText result
3257+
let turnUsage: LanguageModelUsage | undefined;
31713258
if (typeof (source as any).totalUsage?.then === "function") {
31723259
try {
31733260
const usage: LanguageModelUsage = await (source as any).totalUsage;
3261+
turnUsage = usage;
31743262
previousTurnUsage = usage;
31753263
cumulativeUsage = addUsage(cumulativeUsage, usage);
31763264
} catch { /* non-fatal */ }
31773265
}
31783266

3267+
// Outer-loop compaction (same logic as chat.task)
3268+
if (sessionCompaction && turnUsage && !turnObj.stopped) {
3269+
const shouldTrigger = await sessionCompaction.shouldCompact({
3270+
messages: accumulator.modelMessages,
3271+
totalTokens: turnUsage.totalTokens,
3272+
inputTokens: turnUsage.inputTokens,
3273+
outputTokens: turnUsage.outputTokens,
3274+
usage: turnUsage,
3275+
totalUsage: cumulativeUsage,
3276+
chatId: currentPayload.chatId,
3277+
turn,
3278+
clientData: currentPayload.metadata,
3279+
source: "outer",
3280+
});
3281+
3282+
if (shouldTrigger) {
3283+
const summary = await sessionCompaction.summarize({
3284+
messages: accumulator.modelMessages,
3285+
usage: turnUsage,
3286+
totalUsage: cumulativeUsage,
3287+
chatId: currentPayload.chatId,
3288+
turn,
3289+
clientData: currentPayload.metadata,
3290+
source: "outer",
3291+
});
3292+
3293+
const compactEvent: CompactMessagesEvent = {
3294+
summary,
3295+
uiMessages: accumulator.uiMessages,
3296+
modelMessages: accumulator.modelMessages,
3297+
chatId: currentPayload.chatId,
3298+
turn,
3299+
clientData: currentPayload.metadata,
3300+
source: "outer",
3301+
};
3302+
3303+
accumulator.modelMessages = sessionCompaction.compactModelMessages
3304+
? await sessionCompaction.compactModelMessages(compactEvent)
3305+
: [{ role: "assistant" as const, content: [{ type: "text" as const, text: `[Conversation summary]\n\n${summary}` }] }];
3306+
3307+
if (sessionCompaction.compactUIMessages) {
3308+
accumulator.uiMessages = await sessionCompaction.compactUIMessages(compactEvent);
3309+
}
3310+
}
3311+
}
3312+
31793313
await chatWriteTurnComplete();
31803314
return response;
31813315
},

references/ai-chat/src/trigger/chat.ts

Lines changed: 51 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { chat, type ChatTaskWirePayload } from "@trigger.dev/sdk/ai";
22
import { logger, task, prompts } from "@trigger.dev/sdk";
33
import { streamText, generateText, tool, dynamicTool, stepCountIs, generateId, createProviderRegistry } from "ai";
4-
import type { LanguageModel, Tool as AITool, UIMessage } from "ai";
4+
import type { LanguageModel, LanguageModelUsage, Tool as AITool, UIMessage } from "ai";
55
import { openai } from "@ai-sdk/openai";
66
import { anthropic } from "@ai-sdk/anthropic";
77
import { z } from "zod";
@@ -565,7 +565,27 @@ export const aiChatRaw = task({
565565
}
566566

567567
const stop = chat.createStopSignal();
568-
const conversation = new chat.MessageAccumulator();
568+
const conversation = new chat.MessageAccumulator({
569+
compaction: {
570+
shouldCompact: ({ totalTokens }) => (totalTokens ?? 0) > COMPACT_AFTER_TOKENS,
571+
summarize: async ({ messages: msgs }) => {
572+
const resolved = await compactionPrompt.resolve({});
573+
return generateText({
574+
model: registry.languageModel(resolved.model ?? "openai:gpt-4o-mini"),
575+
...resolved.toAISDKTelemetry(),
576+
messages: [...msgs, { role: "user" as const, content: resolved.text }],
577+
}).then((r) => r.text);
578+
},
579+
// Flatten to summary only in the raw task variant
580+
compactUIMessages: ({ summary }) => [
581+
{
582+
id: generateId(),
583+
role: "assistant" as const,
584+
parts: [{ type: "text" as const, text: `[Summary]\n\n${summary}` }],
585+
},
586+
],
587+
},
588+
});
569589

570590
for (let turn = 0; turn < 100; turn++) {
571591
stop.reset();
@@ -622,33 +642,7 @@ export const aiChatRaw = task({
622642
...(useReasoning ? { thinking: { type: "enabled", budgetTokens: 10000 } } : {}),
623643
},
624644
},
625-
// Low-level compaction using chat.compact() — gives full control
626-
// while chat.compact handles the decision tree + stream chunks
627-
prepareStep: async ({ messages: stepMessages, steps }) => {
628-
// Custom logic before/around compaction
629-
const lastStep = steps.at(-1);
630-
if (lastStep?.usage.totalTokens) {
631-
logger.info("Raw task: step usage", { totalTokens: lastStep.usage.totalTokens, turn });
632-
}
633-
634-
const result = await chat.compact(stepMessages, steps, {
635-
threshold: COMPACT_AFTER_TOKENS,
636-
summarize: async (msgs) => {
637-
const resolved = await compactionPrompt.resolve({});
638-
return generateText({
639-
model: registry.languageModel(resolved.model ?? "openai:gpt-4o-mini"),
640-
...resolved.toAISDKTelemetry(),
641-
messages: [...msgs, { role: "user" as const, content: resolved.text }],
642-
}).then((r) => r.text);
643-
},
644-
});
645-
646-
if (result.type === "compacted") {
647-
logger.info("Raw task: compacted", { summary: result.summary.slice(0, 100) });
648-
}
649-
650-
return result.type === "skipped" ? undefined : result;
651-
},
645+
prepareStep: conversation.prepareStep(),
652646
});
653647

654648
let response: UIMessage | undefined;
@@ -673,6 +667,14 @@ export const aiChatRaw = task({
673667

674668
if (runSignal.aborted) break;
675669

670+
// Outer-loop compaction — runs if token threshold exceeded
671+
let turnUsage: LanguageModelUsage | undefined;
672+
try { turnUsage = await result.totalUsage; } catch { /* non-fatal */ }
673+
await conversation.compactIfNeeded(turnUsage, {
674+
chatId: currentPayload.chatId,
675+
turn,
676+
});
677+
676678
// Persist messages
677679
await prisma.chat.update({
678680
where: { id: currentPayload.chatId },
@@ -722,6 +724,26 @@ export const aiChatSession = task({
722724
signal,
723725
idleTimeoutInSeconds: payload.idleTimeoutInSeconds ?? 60,
724726
timeout: "1h",
727+
compaction: {
728+
shouldCompact: ({ totalTokens }) => (totalTokens ?? 0) > COMPACT_AFTER_TOKENS,
729+
summarize: async ({ messages: msgs }) => {
730+
const resolved = await compactionPrompt.resolve({});
731+
return generateText({
732+
model: registry.languageModel(resolved.model ?? "openai:gpt-4o-mini"),
733+
...resolved.toAISDKTelemetry(),
734+
messages: [...msgs, { role: "user" as const, content: resolved.text }],
735+
}).then((r) => r.text);
736+
},
737+
// Keep summary + last 4 messages in the session variant
738+
compactUIMessages: ({ uiMessages, summary }) => [
739+
{
740+
id: generateId(),
741+
role: "assistant" as const,
742+
parts: [{ type: "text" as const, text: `[Conversation summary]\n\n${summary}` }],
743+
},
744+
...uiMessages.slice(-4),
745+
],
746+
},
725747
});
726748

727749
for await (const turn of session) {
@@ -754,31 +776,6 @@ export const aiChatSession = task({
754776
...(useReasoning ? { thinking: { type: "enabled", budgetTokens: 10000 } } : {}),
755777
},
756778
},
757-
// Low-level compaction — same pattern as raw task
758-
prepareStep: async ({ messages: stepMessages, steps }) => {
759-
const lastStep = steps.at(-1);
760-
if (lastStep?.usage.totalTokens) {
761-
logger.info("Session: step usage", { totalTokens: lastStep.usage.totalTokens, turn: turn.number });
762-
}
763-
764-
const result = await chat.compact(stepMessages, steps, {
765-
threshold: COMPACT_AFTER_TOKENS,
766-
summarize: async (msgs) => {
767-
const resolved = await compactionPrompt.resolve({});
768-
return generateText({
769-
model: registry.languageModel(resolved.model ?? "openai:gpt-4o-mini"),
770-
...resolved.toAISDKTelemetry(),
771-
messages: [...msgs, { role: "user" as const, content: resolved.text }],
772-
}).then((r) => r.text);
773-
},
774-
});
775-
776-
if (result.type === "compacted") {
777-
logger.info("Session: compacted", { summary: result.summary.slice(0, 100) });
778-
}
779-
780-
return result.type === "skipped" ? undefined : result;
781-
},
782779
});
783780

784781
await turn.complete(result);

0 commit comments

Comments
 (0)