Skip to content

Commit abf06bb

Browse files
fix(agent): 明确 pending 消息消费决策 (#810)
* fix(agent): 明确 pending 消息消费决策 * fix(agent): 改为基于 direct tool output 决定终止 * fix(agent): 用 direct tool output 结束伪装回复 * fix(core,adapter-gemini): stabilize pending requests and model capability checks * fix(core): narrow agent observation typing * fix(core,agent): normalize chat stop and mcp server validation --------- Co-authored-by: dingyi <dingyi222666@foxmail.com>
1 parent b7bb563 commit abf06bb

21 files changed

Lines changed: 424 additions & 353 deletions

File tree

packages/adapter-gemini/src/utils.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ function processImageParts(
212212
!(
213213
(model.includes('vision') ||
214214
model.includes('gemini') ||
215-
model.includes('gemma')) &&
215+
model.includes('gemma2')) &&
216216
!model.includes('gemini-1.0')
217217
)
218218
) {

packages/core/src/llm-core/agent/legacy-executor.ts

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { CallbackManagerForChainRun } from '@langchain/core/callbacks/manager'
22
import { AIMessage, AIMessageChunk } from '@langchain/core/messages'
3+
import { isDirectToolOutput } from '@langchain/core/messages/tool'
34
import { OutputParserException } from '@langchain/core/output_parsers'
45
import {
56
patchConfig,
@@ -257,12 +258,8 @@ export async function* runAgent(
257258
}
258259

259260
if (output.length > 0) {
260-
const last = output[output.length - 1]
261-
const tool = toolMap[last.tool?.toLowerCase()]
262-
263261
yield {
264-
type: 'round-decision',
265-
canContinue: !tool?.returnDirect
262+
type: 'round-decision'
266263
}
267264

268265
yield {
@@ -293,7 +290,15 @@ export async function* runAgent(
293290
const last = newSteps[newSteps.length - 1]
294291
const tool = last ? toolMap[last.action.tool?.toLowerCase()] : undefined
295292

296-
if (tool?.returnDirect && last != null) {
293+
if (
294+
last != null &&
295+
(tool?.returnDirect || isDirectToolOutput(last.observation))
296+
) {
297+
yield {
298+
type: 'round-decision',
299+
canContinue: false
300+
}
301+
297302
const pending = queue?.drain() ?? []
298303
if (pending.length > 0) {
299304
yield {
@@ -304,14 +309,24 @@ export async function* runAgent(
304309

305310
yield {
306311
type: 'done',
307-
output: toOutput(last.observation),
312+
output:
313+
// TODO: remove this property
314+
last.observation['replyEmitted'] === true
315+
? ''
316+
: toOutput(last.observation),
308317
log: last.action.log,
309-
steps
318+
steps,
319+
replyEmitted: last.observation['replyEmitted'] === true
310320
}
311321

312322
return
313323
}
314324

325+
yield {
326+
type: 'round-decision',
327+
canContinue: true
328+
}
329+
315330
iterations += 1
316331
}
317332

packages/core/src/llm-core/agent/sub-agent.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,7 @@ async function onTaskEvent(
788788
at: Date.now(),
789789
title: '最终输出',
790790
text:
791+
(event.replyEmitted ? '最终回复已由工具发送。' : '') ||
791792
getMessageContent(event.message?.content ?? '') ||
792793
event.output ||
793794
event.log

packages/core/src/llm-core/agent/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ export type AgentEvent =
252252
log: string
253253
steps: AgentStep[]
254254
message?: AIMessage
255+
replyEmitted?: boolean
255256
}
256257

257258
export interface AgentRuntimeConfigurable {

packages/core/src/llm-core/platform/service.ts

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { Context, Dict } from 'koishi'
1+
import { Awaitable, Context, Dict, Session } from 'koishi'
22
import {
33
BasePlatformClient,
44
PlatformEmbeddingsClient,
@@ -28,6 +28,7 @@ import { computed, ComputedRef, reactive } from '@vue/reactivity'
2828
import { randomUUID } from 'crypto'
2929
import { RunnableConfig } from '@langchain/core/runnables'
3030
import { ToolMask } from '../agent'
31+
import type { ConversationRecord } from '../../services/conversation_types'
3132

3233
export class PlatformService {
3334
private _platformClients: Record<string, BasePlatformClient> = reactive({})
@@ -36,6 +37,7 @@ export class PlatformService {
3637

3738
private _tools: Record<string, ChatLunaTool> = reactive({})
3839
private _tmpTools: Record<string, StructuredTool> = reactive({})
40+
private _toolMaskResolvers: Record<string, ToolMaskResolver> = {}
3941
private _models: Record<string, ModelInfo[]> = reactive({})
4042
private _chatChains: Record<string, ChatLunaChainInfo> = reactive({})
4143
private _vectorStore: Record<string, CreateVectorStoreFunction> = reactive(
@@ -218,6 +220,23 @@ export class PlatformService {
218220
return allNames.filter((name) => !mask.deny.includes(name))
219221
}
220222

223+
registerToolMaskResolver(name: string, resolver: ToolMaskResolver) {
224+
this._toolMaskResolvers[name] = resolver
225+
226+
return () => {
227+
delete this._toolMaskResolvers[name]
228+
}
229+
}
230+
231+
async resolveToolMask(arg: ToolMaskArg) {
232+
for (const name in this._toolMaskResolvers) {
233+
const mask = await this._toolMaskResolvers[name](arg)
234+
if (mask) {
235+
return mask
236+
}
237+
}
238+
}
239+
221240
static buildToolMask(rule: {
222241
mode?: 'inherit' | 'all' | 'allow' | 'deny'
223242
allow?: string[]
@@ -453,3 +472,13 @@ declare module 'koishi' {
453472
'chatluna/tool-updated': (service: PlatformService) => void
454473
}
455474
}
475+
476+
export interface ToolMaskArg {
477+
session: Session
478+
conversation?: ConversationRecord
479+
bindingKey?: string
480+
}
481+
482+
export type ToolMaskResolver = (
483+
arg: ToolMaskArg
484+
) => Awaitable<ToolMask | undefined>

packages/core/src/middlewares/chat/rollback_chat.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ import type { Context, Session } from 'koishi'
22
import { gzipDecode } from 'koishi-plugin-chatluna/utils/string'
33
import { Config } from '../../config'
44
import {
5-
ChainMiddlewareRunStatus,
65
type ChainMiddlewareContext,
6+
ChainMiddlewareRunStatus,
77
type ChatChain
88
} from '../../chains/chain'
99
import { MessageRecord } from '../../services/conversation_types'

packages/core/src/middlewares/chat/stop_chat.ts

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import { Context } from 'koishi'
22
import { Config } from '../../config'
33
import { ChainMiddlewareRunStatus, ChatChain } from '../../chains/chain'
4-
import { getRequestId } from '../../utils/chat_request'
54
import { checkAdmin } from 'koishi-plugin-chatluna/utils/koishi'
65

76
export function apply(ctx: Context, config: Config, chain: ChatChain) {
@@ -82,20 +81,13 @@ export function apply(ctx: Context, config: Config, chain: ChatChain) {
8281
}
8382

8483
context.options.conversationId = conversation.id
85-
const requestId = getRequestId(session, conversation.id)
86-
87-
if (requestId == null) {
88-
context.message = session.text('.no_active_chat')
89-
return ChainMiddlewareRunStatus.STOP
90-
}
91-
9284
const status =
93-
await ctx.chatluna.conversationRuntime.stopRequest(requestId)
85+
ctx.chatluna.conversationRuntime.stopConversationRequest(
86+
conversation.id
87+
)
9488

95-
if (status === null) {
89+
if (!status) {
9690
context.message = session.text('.no_active_chat')
97-
} else if (!status) {
98-
context.message = session.text('.stop_failed')
9991
}
10092

10193
return ChainMiddlewareRunStatus.STOP

packages/core/src/middlewares/conversation/request_conversation.ts

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ import {
3131
MessageContent,
3232
MessageContentComplex
3333
} from '@langchain/core/messages'
34-
import { createRequestId } from '../../utils/chat_request'
3534
import { AgentAction } from 'koishi-plugin-chatluna/llm-core/agent'
3635

3736
let logger: Logger
@@ -127,11 +126,7 @@ export function apply(ctx: Context, config: Config, chain: ChatChain) {
127126
inputMessage.name =
128127
session.author?.name ?? session.author?.id ?? session.username
129128

130-
const requestId = createRequestId(
131-
session,
132-
conversation.id,
133-
context.options.messageId
134-
)
129+
const requestId = context.options.messageId
135130

136131
const chatCallbacks = createChatCallbacks(
137132
context,

packages/core/src/services/chat.ts

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@ import {
3434
ChatLunaBaseEmbeddings,
3535
ChatLunaChatModel
3636
} from 'koishi-plugin-chatluna/llm-core/platform/model'
37-
import { PlatformService } from 'koishi-plugin-chatluna/llm-core/platform/service'
37+
import {
38+
PlatformService,
39+
ToolMaskArg,
40+
ToolMaskResolver
41+
} from 'koishi-plugin-chatluna/llm-core/platform/service'
3842
import {
3943
ChatLunaTool,
4044
CreateChatLunaLLMChainParams,
@@ -49,7 +53,7 @@ import {
4953
ChatLunaErrorCode
5054
} from 'koishi-plugin-chatluna/utils/error'
5155
import { MessageTransformer } from './message_transform'
52-
import { ChatEvents, ToolMaskArg, ToolMaskResolver } from './types'
56+
import { ChatEvents } from './types'
5357
import { ConversationService } from './conversation'
5458
import { ConversationRuntime } from './conversation_runtime'
5559
import { ConstraintRecord, ConversationRecord } from './conversation_types'
@@ -83,8 +87,6 @@ export class ChatLunaService extends Service<Config> {
8387
private readonly _contextManager: ChatLunaContextManagerService
8488
private readonly _conversation: ConversationService
8589
private readonly _conversationRuntime: ConversationRuntime
86-
private _toolMaskResolvers: Record<string, ToolMaskResolver> = {}
87-
8890
declare public config: Config
8991

9092
declare public currentConfig: Config
@@ -205,20 +207,11 @@ export class ChatLunaService extends Service<Config> {
205207
}
206208

207209
registerToolMaskResolver(name: string, resolver: ToolMaskResolver) {
208-
this._toolMaskResolvers[name] = resolver
209-
210-
return () => {
211-
delete this._toolMaskResolvers[name]
212-
}
210+
return this._platformService.registerToolMaskResolver(name, resolver)
213211
}
214212

215213
async resolveToolMask(arg: ToolMaskArg) {
216-
for (const name in this._toolMaskResolvers) {
217-
const mask = await this._toolMaskResolvers[name](arg)
218-
if (mask) {
219-
return mask
220-
}
221-
}
214+
return this._platformService.resolveToolMask(arg)
222215
}
223216

224217
getPlugin(platformName: string) {

0 commit comments

Comments
 (0)