diff --git a/packages/core/src/llm-core/prompt/chat_history.ts b/packages/core/src/llm-core/prompt/chat_history.ts index 43a1d69cc..4d611861c 100644 --- a/packages/core/src/llm-core/prompt/chat_history.ts +++ b/packages/core/src/llm-core/prompt/chat_history.ts @@ -4,7 +4,7 @@ import { PromptContextRuntime, PromptPipelineMiddleware } from './context_manager' -import { countMessagesTokens, countMessageTokens } from './system_prompts' +import { countMessageTokens } from './system_prompts' import { logger } from 'koishi-plugin-chatluna' import { isChatLunaUserMessage } from 'koishi-plugin-chatluna/utils/langchain' @@ -24,17 +24,8 @@ export function createChatHistoryMiddleware(): PromptPipelineMiddleware { // Pre-account input tokens if (runtime.input) { - const input = runtime.input - const inputMessageForCount = { - ...input, - content: - typeof input.content === 'string' - ? input.content - : JSON.stringify(input.content), - getType: () => input.getType() - } as BaseMessage const inputTokens = await countMessageTokens( - inputMessageForCount, + runtime.input, runtime.tokenCounter ) runtime.usedTokens += inputTokens @@ -43,10 +34,12 @@ export function createChatHistoryMiddleware(): PromptPipelineMiddleware { // Pre-account scratchpad tokens if (runtime.agentScratchpad) { if (Array.isArray(runtime.agentScratchpad)) { - runtime.usedTokens += await countMessagesTokens( - runtime.agentScratchpad, - runtime.tokenCounter - ) + for (const msg of runtime.agentScratchpad) { + runtime.usedTokens += await countMessageTokens( + msg, + runtime.tokenCounter + ) + } } else { runtime.usedTokens += await countMessageTokens( runtime.agentScratchpad as BaseMessage, @@ -66,10 +59,13 @@ export function createChatHistoryMiddleware(): PromptPipelineMiddleware { for (let i = rounds.length - 1; i >= 0; i--) { const round = rounds[i] - const roundTokens = await countMessagesTokens( - round, - runtime.tokenCounter - ) + let roundTokens = 0 + for (const msg of round) { + roundTokens += await countMessageTokens( + msg, + runtime.tokenCounter + ) + } const exceedsLimit = hasValidLimit ? usedTokens + roundTokens > availableLimit : false @@ -91,10 +87,12 @@ export function createChatHistoryMiddleware(): PromptPipelineMiddleware { // Ensure at least one round if (rounds.length > 0 && selectedRounds.length === 0) { const lastRound = rounds[rounds.length - 1] - usedTokens += await countMessagesTokens( - lastRound, - runtime.tokenCounter - ) + for (const msg of lastRound) { + usedTokens += await countMessageTokens( + msg, + runtime.tokenCounter + ) + } selectedRounds.unshift(lastRound) truncated = hasValidLimit } diff --git a/packages/core/src/llm-core/utils/count_tokens.ts b/packages/core/src/llm-core/utils/count_tokens.ts index 1813589c2..afd364792 100644 --- a/packages/core/src/llm-core/utils/count_tokens.ts +++ b/packages/core/src/llm-core/utils/count_tokens.ts @@ -227,11 +227,20 @@ export async function countMessageTokens( content = content.replaceAll(/!\[.*?\]\(.*?\)/g, '') } - return ( + let tokens = (await tokenCounter(content)) + (await tokenCounter(messageTypeToOpenAIRole(message.getType()))) + (message.name ? await tokenCounter(message.name) : 0) - ) + + // Account for tool_calls payload on AI messages + const toolCalls = + (message as AIMessage).tool_calls ?? + (message.additional_kwargs?.tool_calls as unknown[] | undefined) + if (Array.isArray(toolCalls) && toolCalls.length > 0) { + tokens += await tokenCounter(JSON.stringify(toolCalls)) + } + + return tokens } /**