diff --git a/packages/core/src/tracing/ai/messageTruncation.ts b/packages/core/src/tracing/ai/messageTruncation.ts index 945761f6220c..1acec0bdfc35 100644 --- a/packages/core/src/tracing/ai/messageTruncation.ts +++ b/packages/core/src/tracing/ai/messageTruncation.ts @@ -12,13 +12,49 @@ type ContentMessage = { content: string; }; +/** + * Message format used by OpenAI and Anthropic APIs for media. + */ +type ContentArrayMessage = { + [key: string]: unknown; + content: { + [key: string]: unknown; + type: string; + }[]; +}; + +/** + * Inline media content source, with a potentially very large base64 + * blob or data: uri. + */ +type ContentMedia = Record & + ( + | { + media_type: string; + data: string; + } + | { + image_url: `data:${string}`; + } + | { + type: 'blob' | 'base64'; + content: string; + } + | { + b64_json: string; + } + | { + uri: `data:${string}`; + } + ); + /** * Message format used by Google GenAI API. * Parts can be strings or objects with a text property. */ type PartsMessage = { [key: string]: unknown; - parts: Array; + parts: Array; }; /** @@ -26,6 +62,14 @@ type PartsMessage = { */ type TextPart = string | { text: string }; +/** + * A part in a Google GenAI that contains media. + */ +type MediaPart = { + type: string; + content: string; +}; + /** * Calculate the UTF-8 byte length of a string. */ @@ -79,11 +123,12 @@ function truncateTextByBytes(text: string, maxBytes: number): string { * * @returns The text content */ -function getPartText(part: TextPart): string { +function getPartText(part: TextPart | MediaPart): string { if (typeof part === 'string') { return part; } - return part.text; + if ('text' in part) return part.text; + return ''; } /** @@ -93,7 +138,7 @@ function getPartText(part: TextPart): string { * @param text - New text content * @returns New part with updated text */ -function withPartText(part: TextPart, text: string): TextPart { +function withPartText(part: TextPart | MediaPart, text: string): TextPart { if (typeof part === 'string') { return text; } @@ -112,6 +157,33 @@ function isContentMessage(message: unknown): message is ContentMessage { ); } +/** + * Check if a message has the OpenAI/Anthropic content array format. + */ +function isContentArrayMessage(message: unknown): message is ContentArrayMessage { + return message !== null && typeof message === 'object' && 'content' in message && Array.isArray(message.content); +} + +/** + * Check if a content part is an OpenAI/Anthropic media source + */ +function isContentMedia(part: unknown): part is ContentMedia { + if (!part || typeof part !== 'object') return false; + + return ( + isContentMediaSource(part) || + ('media_type' in part && typeof part.media_type === 'string' && 'data' in part) || + ('image_url' in part && typeof part.image_url === 'string' && part.image_url.startsWith('data:')) || + ('type' in part && (part.type === 'blob' || part.type === 'base64')) || + 'b64_json' in part || + ('type' in part && 'result' in part && part.type === 'image_generation') || + ('uri' in part && typeof part.uri === 'string' && part.uri.startsWith('data:')) + ); +} +function isContentMediaSource(part: NonNullable): boolean { + return 'type' in part && typeof part.type === 'string' && 'source' in part && isContentMedia(part.source); +} + /** * Check if a message has the Google GenAI parts format. */ @@ -167,7 +239,7 @@ function truncatePartsMessage(message: PartsMessage, maxBytes: number): unknown[ } // Include parts until we run out of space - const includedParts: TextPart[] = []; + const includedParts: (TextPart | MediaPart)[] = []; for (const part of parts) { const text = getPartText(part); @@ -190,7 +262,10 @@ function truncatePartsMessage(message: PartsMessage, maxBytes: number): unknown[ } } + /* c8 ignore start + * for type safety only, algorithm guarantees SOME text included */ return includedParts.length > 0 ? [{ ...message, parts: includedParts }] : []; + /* c8 ignore stop */ } /** @@ -205,9 +280,11 @@ function truncatePartsMessage(message: PartsMessage, maxBytes: number): unknown[ * @returns Array containing the truncated message, or empty array if truncation fails */ function truncateSingleMessage(message: unknown, maxBytes: number): unknown[] { + /* c8 ignore start - unreachable */ if (!message || typeof message !== 'object') { return []; } + /* c8 ignore start - unreachable */ if (isContentMessage(message)) { return truncateContentMessage(message, maxBytes); @@ -221,6 +298,59 @@ function truncateSingleMessage(message: unknown, maxBytes: number): unknown[] { return []; } +const REMOVED_STRING = ''; + +const MEDIA_FIELDS = ['image_url', 'data', 'content', 'b64_json', 'result', 'uri'] as const; + +function stripInlineMediaFromSingleMessage(part: ContentMedia): ContentMedia { + const strip = { ...part }; + if (isContentMedia(strip.source)) { + strip.source = stripInlineMediaFromSingleMessage(strip.source); + } + for (const field of MEDIA_FIELDS) { + if (strip[field]) strip[field] = REMOVED_STRING; + } + return strip; +} + +/** + * Strip the inline media from message arrays. + * + * This returns a stripped message. We do NOT want to mutate the data in place, + * because of course we still want the actual API/client to handle the media. + */ +export function stripInlineMediaFromMessages(messages: unknown[]): unknown[] { + return messages.map(message => { + if (!!message && typeof message === 'object') { + if (isContentArrayMessage(message)) { + // eslint-disable-next-line no-param-reassign + message = { + ...message, + content: stripInlineMediaFromMessages(message.content), + }; + } else if ('content' in message && isContentMedia(message.content)) { + // eslint-disable-next-line no-param-reassign + message = { + ...message, + content: stripInlineMediaFromSingleMessage(message.content), + }; + } + if (isPartsMessage(message)) { + // eslint-disable-next-line no-param-reassign + message = { + ...message, + parts: stripInlineMediaFromMessages(message.parts), + }; + } + if (isContentMedia(message)) { + // eslint-disable-next-line no-param-reassign + message = stripInlineMediaFromSingleMessage(message); + } + } + return message; + }); +} + /** * Truncate an array of messages to fit within a byte limit. * @@ -246,6 +376,11 @@ export function truncateMessagesByBytes(messages: unknown[], maxBytes: number): return messages; } + // strip inline media first. This will often get us below the threshold, + // while preserving human-readable information about messages sent. + // eslint-disable-next-line no-param-reassign + messages = stripInlineMediaFromMessages(messages); + // Fast path: if all messages fit, return as-is const totalBytes = jsonBytes(messages); if (totalBytes <= maxBytes) { diff --git a/packages/core/test/lib/tracing/ai-message-truncation.test.ts b/packages/core/test/lib/tracing/ai-message-truncation.test.ts new file mode 100644 index 000000000000..8c49ad8c551f --- /dev/null +++ b/packages/core/test/lib/tracing/ai-message-truncation.test.ts @@ -0,0 +1,294 @@ +import { describe, expect, it } from 'vitest'; +import { truncateGenAiMessages, truncateGenAiStringInput } from '../../../src/tracing/ai/messageTruncation'; + +describe('message truncation utilities', () => { + describe('truncateGenAiMessages', () => { + it('leaves empty/non-array/small messages alone', () => { + // @ts-expect-error - exercising invalid type code path + expect(truncateGenAiMessages(null)).toBe(null); + expect(truncateGenAiMessages([])).toStrictEqual([]); + expect(truncateGenAiMessages([{ text: 'hello' }])).toStrictEqual([{ text: 'hello' }]); + expect(truncateGenAiStringInput('hello')).toBe('hello'); + }); + + it('strips inline media from messages', () => { + const b64 = Buffer.from('lots of data\n').toString('base64'); + const removed = ''; + const messages = [ + { + role: 'user', + content: [ + { + type: 'image', + source: { + type: 'base64', + media_type: 'image/png', + data: b64, + }, + }, + ], + }, + { + role: 'user', + content: { + image_url: `data:image/png;base64,${b64}`, + }, + }, + { + role: 'agent', + type: 'image', + content: { + b64_json: b64, + }, + }, + { + role: 'system', + parts: [ + { + image_url: `data:image/png;base64,${b64}`, + }, + { + type: 'image_generation', + result: b64, + }, + { + uri: `data:image/png;base64,${b64}`, + mediaType: 'image/png', + }, + { + type: 'blob', + mediaType: 'image/png', + content: b64, + }, + { + type: 'text', + text: 'just some text!', + }, + 'unadorned text', + ], + }, + ]; + + // indented json makes for better diffs in test output + const messagesJson = JSON.stringify(messages, null, 2); + const result = truncateGenAiMessages(messages); + + // original messages objects must not be mutated + expect(JSON.stringify(messages, null, 2)).toBe(messagesJson); + expect(result).toStrictEqual([ + { + role: 'user', + content: [ + { + type: 'image', + source: { + type: 'base64', + media_type: 'image/png', + data: removed, + }, + }, + ], + }, + { + role: 'user', + content: { + image_url: removed, + }, + }, + { + role: 'agent', + type: 'image', + content: { + b64_json: removed, + }, + }, + { + role: 'system', + parts: [ + { + image_url: removed, + }, + { + type: 'image_generation', + result: removed, + }, + { + uri: removed, + mediaType: 'image/png', + }, + { + type: 'blob', + mediaType: 'image/png', + content: removed, + }, + { + type: 'text', + text: 'just some text!', + }, + 'unadorned text', + ], + }, + ]); + }); + + const humongous = 'this is a long string '.repeat(10_000); + const giant = 'this is a long string '.repeat(1_000); + const big = 'this is a long string '.repeat(100); + + it('drops older messages to fit in the limit', () => { + const messages = [ + `0 ${giant}`, + { type: 'text', content: `1 ${big}` }, + { type: 'text', content: `2 ${big}` }, + { type: 'text', content: `3 ${giant}` }, + { type: 'text', content: `4 ${big}` }, + `5 ${big}`, + { type: 'text', content: `6 ${big}` }, + { type: 'text', content: `7 ${big}` }, + { type: 'text', content: `8 ${big}` }, + { type: 'text', content: `9 ${big}` }, + { type: 'text', content: `10 ${big}` }, + { type: 'text', content: `11 ${big}` }, + { type: 'text', content: `12 ${big}` }, + ]; + + const messagesJson = JSON.stringify(messages, null, 2); + const result = truncateGenAiMessages(messages); + // should not mutate original messages list + expect(JSON.stringify(messages, null, 2)).toBe(messagesJson); + + // just retain the messages that fit in the budget + expect(result).toStrictEqual([ + `5 ${big}`, + { type: 'text', content: `6 ${big}` }, + { type: 'text', content: `7 ${big}` }, + { type: 'text', content: `8 ${big}` }, + { type: 'text', content: `9 ${big}` }, + { type: 'text', content: `10 ${big}` }, + { type: 'text', content: `11 ${big}` }, + { type: 'text', content: `12 ${big}` }, + ]); + }); + + it('fully drops message if content cannot be made to fit', () => { + const messages = [{ some_other_field: humongous, content: 'hello' }]; + expect(truncateGenAiMessages(messages)).toStrictEqual([]); + }); + + it('truncates if the message content string will not fit', () => { + const messages = [{ content: `2 ${humongous}` }]; + const result = truncateGenAiMessages(messages); + const truncLen = 20_000 - JSON.stringify({ content: '' }).length; + expect(result).toStrictEqual([{ content: `2 ${humongous}`.substring(0, truncLen) }]); + }); + + it('fully drops message if first part overhead does not fit', () => { + const messages = [ + { + parts: [{ some_other_field: humongous }], + }, + ]; + expect(truncateGenAiMessages(messages)).toStrictEqual([]); + }); + + it('fully drops message if overhead too large', () => { + const messages = [ + { + some_other_field: humongous, + parts: [], + }, + ]; + expect(truncateGenAiMessages(messages)).toStrictEqual([]); + }); + + it('truncates if the first message part will not fit', () => { + const messages = [ + { + parts: [`2 ${humongous}`, { some_other_field: 'no text here' }], + }, + ]; + + const result = truncateGenAiMessages(messages); + + // interesting (unexpected?) edge case effect of this truncation. + // subsequent messages count towards truncation overhead limit, + // but are not included, even without their text. This is an edge + // case that seems unlikely in normal usage. + const truncLen = + 20_000 - + JSON.stringify({ + parts: ['', { some_other_field: 'no text here', text: '' }], + }).length; + + expect(result).toStrictEqual([ + { + parts: [`2 ${humongous}`.substring(0, truncLen)], + }, + ]); + }); + + it('truncates if the first message part will not fit, text object', () => { + const messages = [ + { + parts: [{ text: `2 ${humongous}` }], + }, + ]; + const result = truncateGenAiMessages(messages); + const truncLen = + 20_000 - + JSON.stringify({ + parts: [{ text: '' }], + }).length; + expect(result).toStrictEqual([ + { + parts: [ + { + text: `2 ${humongous}`.substring(0, truncLen), + }, + ], + }, + ]); + }); + + it('drops if subsequent message part will not fit, text object', () => { + const messages = [ + { + parts: [ + { text: `1 ${big}` }, + { some_other_field: 'ok' }, + { text: `2 ${big}` }, + { text: `3 ${big}` }, + { text: `4 ${giant}` }, + { text: `5 ${giant}` }, + { text: `6 ${big}` }, + { text: `7 ${big}` }, + { text: `8 ${big}` }, + ], + }, + ]; + const result = truncateGenAiMessages(messages); + expect(result).toStrictEqual([ + { + parts: [{ text: `1 ${big}` }, { some_other_field: 'ok' }, { text: `2 ${big}` }, { text: `3 ${big}` }], + }, + ]); + }); + + it('truncates first message if none fit', () => { + const messages = [{ content: `1 ${humongous}` }, { content: `2 ${humongous}` }, { content: `3 ${humongous}` }]; + const result = truncateGenAiMessages(messages); + const truncLen = 20_000 - JSON.stringify({ content: '' }).length; + expect(result).toStrictEqual([{ content: `3 ${humongous}`.substring(0, truncLen) }]); + }); + + it('drops if first message cannot be safely truncated', () => { + const messages = [ + { content: `1 ${humongous}` }, + { content: `2 ${humongous}` }, + { what_even_is_this: `? ${humongous}` }, + ]; + const result = truncateGenAiMessages(messages); + expect(result).toStrictEqual([]); + }); + }); +});