Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/adapter-gemini/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ function processImageParts(
!(
(model.includes('vision') ||
model.includes('gemini') ||
model.includes('gemma')) &&
model.includes('gemma2')) &&
!model.includes('gemini-1.0')
)
) {
Expand Down
31 changes: 23 additions & 8 deletions packages/core/src/llm-core/agent/legacy-executor.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { CallbackManagerForChainRun } from '@langchain/core/callbacks/manager'
import { AIMessage, AIMessageChunk } from '@langchain/core/messages'
import { isDirectToolOutput } from '@langchain/core/messages/tool'
import { OutputParserException } from '@langchain/core/output_parsers'
import {
patchConfig,
Expand Down Expand Up @@ -257,12 +258,8 @@ export async function* runAgent(
}

if (output.length > 0) {
const last = output[output.length - 1]
const tool = toolMap[last.tool?.toLowerCase()]

yield {
type: 'round-decision',
canContinue: !tool?.returnDirect
type: 'round-decision'
}

yield {
Expand Down Expand Up @@ -293,7 +290,15 @@ export async function* runAgent(
const last = newSteps[newSteps.length - 1]
const tool = last ? toolMap[last.action.tool?.toLowerCase()] : undefined

if (tool?.returnDirect && last != null) {
if (
last != null &&
(tool?.returnDirect || isDirectToolOutput(last.observation))
) {
yield {
type: 'round-decision',
canContinue: false
}

const pending = queue?.drain() ?? []
if (pending.length > 0) {
yield {
Expand All @@ -304,14 +309,24 @@ export async function* runAgent(

yield {
type: 'done',
output: toOutput(last.observation),
output:
// TODO: remove this property
last.observation['replyEmitted'] === true
? ''
: toOutput(last.observation),
log: last.action.log,
steps
steps,
replyEmitted: last.observation['replyEmitted'] === true
}

return
}

yield {
type: 'round-decision',
canContinue: true
}

iterations += 1
}

Expand Down
1 change: 1 addition & 0 deletions packages/core/src/llm-core/agent/sub-agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,7 @@ async function onTaskEvent(
at: Date.now(),
title: '最终输出',
text:
(event.replyEmitted ? '最终回复已由工具发送。' : '') ||
getMessageContent(event.message?.content ?? '') ||
event.output ||
event.log
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/llm-core/agent/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ export type AgentEvent =
log: string
steps: AgentStep[]
message?: AIMessage
replyEmitted?: boolean
}

export interface AgentRuntimeConfigurable {
Expand Down
31 changes: 30 additions & 1 deletion packages/core/src/llm-core/platform/service.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Context, Dict } from 'koishi'
import { Awaitable, Context, Dict, Session } from 'koishi'
import {
BasePlatformClient,
PlatformEmbeddingsClient,
Expand Down Expand Up @@ -28,6 +28,7 @@ import { computed, ComputedRef, reactive } from '@vue/reactivity'
import { randomUUID } from 'crypto'
import { RunnableConfig } from '@langchain/core/runnables'
import { ToolMask } from '../agent'
import type { ConversationRecord } from '../../services/conversation_types'

export class PlatformService {
private _platformClients: Record<string, BasePlatformClient> = reactive({})
Expand All @@ -36,6 +37,7 @@ export class PlatformService {

private _tools: Record<string, ChatLunaTool> = reactive({})
private _tmpTools: Record<string, StructuredTool> = reactive({})
private _toolMaskResolvers: Record<string, ToolMaskResolver> = {}
private _models: Record<string, ModelInfo[]> = reactive({})
private _chatChains: Record<string, ChatLunaChainInfo> = reactive({})
private _vectorStore: Record<string, CreateVectorStoreFunction> = reactive(
Expand Down Expand Up @@ -218,6 +220,23 @@ export class PlatformService {
return allNames.filter((name) => !mask.deny.includes(name))
}

registerToolMaskResolver(name: string, resolver: ToolMaskResolver) {
this._toolMaskResolvers[name] = resolver

return () => {
delete this._toolMaskResolvers[name]
}
}

async resolveToolMask(arg: ToolMaskArg) {
for (const name in this._toolMaskResolvers) {
const mask = await this._toolMaskResolvers[name](arg)
if (mask) {
return mask
}
}
}

static buildToolMask(rule: {
mode?: 'inherit' | 'all' | 'allow' | 'deny'
allow?: string[]
Expand Down Expand Up @@ -453,3 +472,13 @@ declare module 'koishi' {
'chatluna/tool-updated': (service: PlatformService) => void
}
}

export interface ToolMaskArg {
session: Session
conversation?: ConversationRecord
bindingKey?: string
}

export type ToolMaskResolver = (
arg: ToolMaskArg
) => Awaitable<ToolMask | undefined>
2 changes: 1 addition & 1 deletion packages/core/src/middlewares/chat/rollback_chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import { gzipDecode } from 'koishi-plugin-chatluna/utils/string'
import { Config } from '../../config'
import {
ChainMiddlewareRunStatus,
type ChainMiddlewareContext,
ChainMiddlewareRunStatus,
type ChatChain
} from '../../chains/chain'
import { MessageRecord } from '../../services/conversation_types'
Expand Down Expand Up @@ -126,135 +126,135 @@
}
}

async function rollbackConversation(
ctx: Context,
config: Config,
session: Session,
context: ChainMiddlewareContext,
conversation: { id: string },
rollbackRound: number
) {
const current = await ctx.chatluna.conversation.getConversation(
conversation.id
)

if (current == null) {
return {
status: ChainMiddlewareRunStatus.STOP,
msg: session.text('.conversation_not_exist')
}
}

const resolved = await ctx.chatluna.conversation.resolveContext(session, {
conversationId: current.id,
presetLane: context.options.presetLane,
bindingKey: current.bindingKey
})

if (
resolved.constraint.manageMode === 'admin' &&
!(await checkAdmin(session))
) {
return {
status: ChainMiddlewareRunStatus.STOP,
msg: session.text('.conversation_not_exist')
}
}

if (resolved.constraint.lockConversation) {
return {
status: ChainMiddlewareRunStatus.STOP,
msg: session.text('.conversation_not_exist')
}
}

await ctx.chatluna.conversationRuntime.clearConversationInterfaceLocked(
current
)

let parentId = current.latestMessageId
const messages: MessageRecord[] = []
let humanMessage: MessageRecord | undefined
let humanCount = 0
const seen = new Set<string>()

while (parentId != null) {
if (seen.has(parentId)) {
ctx.logger.warn(`rollback cycle detected: ${parentId}`)
break
}

if (seen.size >= MAX_ROLLBACK_HOPS) {
ctx.logger.warn(`rollback hop limit reached: ${current.id}`)
break
}

seen.add(parentId)

const message = await ctx.database.get('chatluna_message', {
conversationId: current.id,
id: parentId
})
const currentMessage = message[0]

if (currentMessage == null) {
break
}

parentId = currentMessage.parentId
messages.unshift(currentMessage)

if (currentMessage.role === 'human') {
humanMessage = currentMessage
humanCount += 1

if (humanCount >= rollbackRound) {
break
}
}
}

if (humanCount < rollbackRound || humanMessage == null) {
return {
status: ChainMiddlewareRunStatus.STOP,
msg: session.text('.no_chat_history')
}
}

let inputMessage = context.options.inputMessage

if ((context.options.message?.length ?? 0) < 1) {
const humanContent = await decodeMessageContent(humanMessage)

inputMessage = await ctx.chatluna.messageTransformer.transform(
session,
transformMessageContentToElements(humanContent),
resolved.effectiveModel ?? current.model,
undefined,
{
quote: false,
includeQuoteReply: config.includeQuoteReply
}
)
}

await ctx.database.remove('chatluna_message', {
id: messages.map((message) => message.id)
})

await ctx.database.upsert('chatluna_conversation', [
{
id: current.id,
latestMessageId: humanMessage.parentId ?? null,
updatedAt: new Date()
}
])

return {
status: ChainMiddlewareRunStatus.CONTINUE,
inputMessage
}
}

Check notice on line 257 in packages/core/src/middlewares/chat/rollback_chat.ts

View check run for this annotation

codefactor.io / CodeFactor

packages/core/src/middlewares/chat/rollback_chat.ts#L129-L257

Complex Method

declare module '../../chains/chain' {
interface ChainMiddlewareName {
Expand Down
16 changes: 4 additions & 12 deletions packages/core/src/middlewares/chat/stop_chat.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { Context } from 'koishi'
import { Config } from '../../config'
import { ChainMiddlewareRunStatus, ChatChain } from '../../chains/chain'
import { getRequestId } from '../../utils/chat_request'
import { checkAdmin } from 'koishi-plugin-chatluna/utils/koishi'

export function apply(ctx: Context, config: Config, chain: ChatChain) {
Expand Down Expand Up @@ -82,20 +81,13 @@ export function apply(ctx: Context, config: Config, chain: ChatChain) {
}

context.options.conversationId = conversation.id
const requestId = getRequestId(session, conversation.id)

if (requestId == null) {
context.message = session.text('.no_active_chat')
return ChainMiddlewareRunStatus.STOP
}

const status =
await ctx.chatluna.conversationRuntime.stopRequest(requestId)
ctx.chatluna.conversationRuntime.stopConversationRequest(
conversation.id
)

if (status === null) {
if (!status) {
context.message = session.text('.no_active_chat')
} else if (!status) {
context.message = session.text('.stop_failed')
}

return ChainMiddlewareRunStatus.STOP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
MessageContent,
MessageContentComplex
} from '@langchain/core/messages'
import { createRequestId } from '../../utils/chat_request'
import { AgentAction } from 'koishi-plugin-chatluna/llm-core/agent'

let logger: Logger
Expand All @@ -39,150 +38,146 @@
export function apply(ctx: Context, config: Config, chain: ChatChain) {
logger = createLogger(ctx)
chain
.middleware('request_conversation', async (session, context) => {
const { inputMessage } = context.options
const useRoutePresetLane =
context.options.presetLane == null &&
context.options.conversationId == null &&
(context.command == null || context.command.length === 0)
const resolved =
await ctx.chatluna.conversation.ensureActiveConversation(
session,
{
conversationId: context.options.conversationId,
presetLane: context.options.presetLane,
useRoutePresetLane
}
)
const conversation = resolved.conversation

context.options.conversationId = conversation.id
context.options.resolvedConversation = conversation
context.options.resolvedConversationContext = resolved

const presetTemplate = ctx.chatluna.preset.getPreset(
conversation.preset
).value

if (presetTemplate == null) {
throw new ChatLunaError(
ChatLunaErrorCode.PRESET_NOT_FOUND,
new Error(`Preset ${conversation.preset} not found`)
)
}

const originContent = inputMessage.content

if (presetTemplate.formatUserPromptString != null) {
inputMessage.content = await processUserPrompt(
config,
presetTemplate,
session,
inputMessage.content,
conversation
)
}

const bufferText = new StreamingBufferText(
3,
presetTemplate.config?.postHandler?.prefix,
presetTemplate.config?.postHandler?.postfix
)

const postHandler = presetTemplate.config?.postHandler
? new PresetPostHandler(
ctx,
config,
presetTemplate.config?.postHandler
)
: undefined

let streamPromise: Promise<void> = Promise.resolve()
if (config.streamResponse) {
const isEditMessage =
session.bot.editMessage != null &&
session.bot.platform !== 'onebot'

if (isEditMessage) {
streamPromise = setupEditMessageStream(
context,
session,
config,
bufferText
)
} else {
streamPromise = setupRegularMessageStream(
context,
config,
config.splitMessage
? bufferText.splitByPunctuations()
: bufferText.splitByMarkdown()
)
}
}

let responseMessage: Message

inputMessage.conversationId = conversation.id
inputMessage.name =
session.author?.name ?? session.author?.id ?? session.username

const requestId = createRequestId(
session,
conversation.id,
context.options.messageId
)
const requestId = context.options.messageId

const chatCallbacks = createChatCallbacks(
context,
config,
bufferText
)

try {
;[responseMessage] = await Promise.all([
ctx.chatluna.conversationRuntime.chat(
session,
conversation,
inputMessage,
chatCallbacks,
config.streamResponse,
{
prompt: getMessageContent(originContent),
...getSystemPromptVariables(
session,
config,
conversation
)
},
postHandler,
requestId
),
streamPromise
])
} catch (e) {
if (e?.message?.includes('output values have 1 keys')) {
throw new ChatLunaError(
ChatLunaErrorCode.MODEL_RESPONSE_IS_EMPTY
)
} else {
throw e
}
}

if (!config.streamResponse) {
context.options.responseMessage = responseMessage
} else {
context.options.responseMessage = null
context.message = null
}

await ctx.chatluna.conversation.touchConversation(conversation.id, {
lastChatAt: new Date()
})

return ChainMiddlewareRunStatus.CONTINUE
})

Check notice on line 180 in packages/core/src/middlewares/conversation/request_conversation.ts

View check run for this annotation

codefactor.io / CodeFactor

packages/core/src/middlewares/conversation/request_conversation.ts#L41-L180

Complex Method
.after('lifecycle-request_conversation')
}

Expand Down
23 changes: 8 additions & 15 deletions packages/core/src/services/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ import {
ChatLunaBaseEmbeddings,
ChatLunaChatModel
} from 'koishi-plugin-chatluna/llm-core/platform/model'
import { PlatformService } from 'koishi-plugin-chatluna/llm-core/platform/service'
import {
PlatformService,
ToolMaskArg,
ToolMaskResolver
} from 'koishi-plugin-chatluna/llm-core/platform/service'
import {
ChatLunaTool,
CreateChatLunaLLMChainParams,
Expand All @@ -49,7 +53,7 @@ import {
ChatLunaErrorCode
} from 'koishi-plugin-chatluna/utils/error'
import { MessageTransformer } from './message_transform'
import { ChatEvents, ToolMaskArg, ToolMaskResolver } from './types'
import { ChatEvents } from './types'
import { ConversationService } from './conversation'
import { ConversationRuntime } from './conversation_runtime'
import { ConstraintRecord, ConversationRecord } from './conversation_types'
Expand Down Expand Up @@ -83,8 +87,6 @@ export class ChatLunaService extends Service<Config> {
private readonly _contextManager: ChatLunaContextManagerService
private readonly _conversation: ConversationService
private readonly _conversationRuntime: ConversationRuntime
private _toolMaskResolvers: Record<string, ToolMaskResolver> = {}

declare public config: Config

declare public currentConfig: Config
Expand Down Expand Up @@ -205,20 +207,11 @@ export class ChatLunaService extends Service<Config> {
}

registerToolMaskResolver(name: string, resolver: ToolMaskResolver) {
this._toolMaskResolvers[name] = resolver

return () => {
delete this._toolMaskResolvers[name]
}
return this._platformService.registerToolMaskResolver(name, resolver)
}

async resolveToolMask(arg: ToolMaskArg) {
for (const name in this._toolMaskResolvers) {
const mask = await this._toolMaskResolvers[name](arg)
if (mask) {
return mask
}
}
return this._platformService.resolveToolMask(arg)
}

getPlugin(platformName: string) {
Expand Down
Loading
Loading