Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 21 additions & 23 deletions packages/core/src/llm-core/prompt/chat_history.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import {
PromptContextRuntime,
PromptPipelineMiddleware
} from './context_manager'
import { countMessagesTokens, countMessageTokens } from './system_prompts'
import { countMessageTokens } from './system_prompts'
import { logger } from 'koishi-plugin-chatluna'
import { isChatLunaUserMessage } from 'koishi-plugin-chatluna/utils/langchain'

Expand All @@ -24,17 +24,8 @@ export function createChatHistoryMiddleware(): PromptPipelineMiddleware {

// Pre-account input tokens
if (runtime.input) {
const input = runtime.input
const inputMessageForCount = {
...input,
content:
typeof input.content === 'string'
? input.content
: JSON.stringify(input.content),
getType: () => input.getType()
} as BaseMessage
const inputTokens = await countMessageTokens(
inputMessageForCount,
runtime.input,
runtime.tokenCounter
)
runtime.usedTokens += inputTokens
Expand All @@ -43,10 +34,12 @@ export function createChatHistoryMiddleware(): PromptPipelineMiddleware {
// Pre-account scratchpad tokens
if (runtime.agentScratchpad) {
if (Array.isArray(runtime.agentScratchpad)) {
runtime.usedTokens += await countMessagesTokens(
runtime.agentScratchpad,
runtime.tokenCounter
)
for (const msg of runtime.agentScratchpad) {
runtime.usedTokens += await countMessageTokens(
msg,
runtime.tokenCounter
Comment thread
dingyi222666 marked this conversation as resolved.
)
}
Comment thread
dingyi222666 marked this conversation as resolved.
} else {
runtime.usedTokens += await countMessageTokens(
runtime.agentScratchpad as BaseMessage,
Expand All @@ -66,10 +59,13 @@ export function createChatHistoryMiddleware(): PromptPipelineMiddleware {

for (let i = rounds.length - 1; i >= 0; i--) {
const round = rounds[i]
const roundTokens = await countMessagesTokens(
round,
runtime.tokenCounter
)
let roundTokens = 0
for (const msg of round) {
roundTokens += await countMessageTokens(
msg,
runtime.tokenCounter
)
}
Comment thread
dingyi222666 marked this conversation as resolved.
const exceedsLimit = hasValidLimit
? usedTokens + roundTokens > availableLimit
: false
Expand All @@ -91,10 +87,12 @@ export function createChatHistoryMiddleware(): PromptPipelineMiddleware {
// Ensure at least one round
if (rounds.length > 0 && selectedRounds.length === 0) {
const lastRound = rounds[rounds.length - 1]
usedTokens += await countMessagesTokens(
lastRound,
runtime.tokenCounter
)
for (const msg of lastRound) {
usedTokens += await countMessageTokens(
msg,
runtime.tokenCounter
)
}
Comment thread
dingyi222666 marked this conversation as resolved.
selectedRounds.unshift(lastRound)
truncated = hasValidLimit
}
Expand Down
13 changes: 11 additions & 2 deletions packages/core/src/llm-core/utils/count_tokens.ts
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,20 @@ export async function countMessageTokens(
content = content.replaceAll(/!\[.*?\]\(.*?\)/g, '')
}

return (
let tokens =
(await tokenCounter(content)) +
(await tokenCounter(messageTypeToOpenAIRole(message.getType()))) +
(message.name ? await tokenCounter(message.name) : 0)
)

// Account for tool_calls payload on AI messages
const toolCalls =
(message as AIMessage).tool_calls ??
(message.additional_kwargs?.tool_calls as unknown[] | undefined)
if (Array.isArray(toolCalls) && toolCalls.length > 0) {
tokens += await tokenCounter(JSON.stringify(toolCalls))
}

return tokens
}

/**
Expand Down
Loading