-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathsimple_mcp_client_stream.py
More file actions
405 lines (345 loc) · 14.7 KB
/
simple_mcp_client_stream.py
File metadata and controls
405 lines (345 loc) · 14.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
import asyncio
import re
import os
import sys
import json
from contextlib import AsyncExitStack
from typing import Optional, List, Dict, Any
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from openai import AsyncOpenAI
from dotenv import load_dotenv, dotenv_values
load_dotenv(
dotenv_path=".env",
override=True
)
class MCPClient:
def __init__(
self,
model_type: str,
api_key: str,
base_url: str = "https://api.deepseek.com",
model_name: str = "deepseek-chat",
):
self.model_type = model_type
self.model_name = model_name
self.api_key = api_key
self.base_url = base_url
self.session: Optional[ClientSession] = None
self.stdio_transport = None
self.exit_stack = AsyncExitStack()
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.available_tools: List[Dict[str, Any]] = []
# {resources_name: resources}
self.resources_dict = {}
# {promts_name, description}
self.prompts_dict = {}
self.llm_client = AsyncOpenAI(
api_key=api_key,
base_url=None if model_type == "openai" else base_url,
)
@staticmethod
def parse_arguments(args: List[str]) -> StdioServerParameters:
if len(args) == 1:
server_script_path = args[0]
if not server_script_path.endswith(('.py', '.js')):
raise ValueError("[ERR] 服务器脚本必须是 .py 或 .js 文件")
command = "python" if server_script_path.endswith('.py') else "node"
return StdioServerParameters(
command=command,
args=[server_script_path],
env=None
)
elif len(args) == 2:
server_identifier, config_path = args[0], args[1]
try:
with open(config_path, 'r') as f:
config = json.load(f)
except Exception as e:
raise ValueError(f"配置文件读取失败: {str(e)}")
mcp_servers = config.get('mcpServers', {})
server_config = mcp_servers.get(server_identifier)
if not server_config:
raise ValueError(f"未找到服务器标识符: {server_identifier}")
if not all(key in server_config for key in ['command', 'args']):
raise ValueError("服务器配置缺少必要字段(command/args)")
return StdioServerParameters(
command=server_config['command'],
args=server_config['args'],
env=None
)
else:
raise ValueError("参数数量错误")
async def connect_to_server(self, server_params: StdioServerParameters):
print(f"[SYS]: 正在链接服务器...")
self.stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
print(f"[SYS]: 链接成功,正在初始化...")
stdio_reader, stdio_writer = self.stdio_transport
print(f"[SYS]: 服务器初始化中...")
self.session = await self.exit_stack.enter_async_context(
ClientSession(stdio_reader, stdio_writer)
)
print(f"[SYS]: 服务器初始化完成,正在连接...")
await self.session.initialize()
print(f"[SYS]: 服务器链接成功 !!!\n")
# 获取可用工具列表
tools_response = await self.session.list_tools()
self.available_tools = [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description or "",
"parameters": tool.inputSchema
}
}
for tool in tools_response.tools
]
print(f"[SYS]: 可用工具: {[t['function']['name'] for t in self.available_tools]}")
# 获取资源列表
resources_response = await self.session.list_resources()
resources_names = [resource.name for resource in resources_response.resources]
for resource_name in resources_names:
resource = await self.session.read_resource(resource_name)
self.resources_dict[resource_name] = resource.contents[0].text
print(f"[SYS]: 可用资源: {resources_names}")
prompts_response = await self.session.list_prompts()
prompts_names = []
for prompt in prompts_response.prompts:
prompt_name = prompt.name
prompts_names.append(prompt_name)
self.prompts_dict[prompt_name] = prompt.description
# print(f"[SYS]: 可用 Prompt: {prompts_names}")
print(f"[SYS]: 可用 Prompt: {self.prompts_dict}")
async def selcect_prompt_template(self, user_question: str) -> str:
""" 根据用户问题选择 prompt 模板
"""
# 需要详细回答的指示词
detailed_indicators = [
"解释", "说明", "详细", "具体", "详尽", "深入", "全面", "彻底",
"分析", "为什么", "怎么样", "如何", "原因", "机制", "过程",
"explain", "detail", "elaborate", "comprehensive", "thorough",
"in-depth", "analysis", "why", "how does", "reasons",
"背景", "历史", "发展", "比较", "区别", "联系", "影响", "意义",
"优缺点", "利弊", "方法", "步骤", "案例", "举例", "证明",
"理论", "原理", "依据", "论证", "详解", "指南", "教程",
"细节", "要点", "关键", "系统", "完整", "清晰", "请详细"
]
# 判断问题类型
question_lower = user_question.lower()
is_brief_question = len(question_lower.split()) < 10
wants_details = any(
indicator in question_lower for indicator in detailed_indicators
)
# 返回模板类型, 和service对应
return (
"detailed_response"
if (wants_details or not is_brief_question)
else "simply_replay"
)
async def add_relevant_resources(self, user_question: str) -> str:
""" 根据用户问题添加资源
"""
keywords_map = {
"MCP规范协议": ["mcp-doc://4.MCP规范协议.md"],
"MCP交互流程": ["mcp-doc://6.MCP核心交互流程.md"],
"MCP": ["mcp-doc://4.MCP规范协议.md", "mcp-doc://6.MCP核心交互流程.md"],
}
# 关键字匹配查找
matched_resources = []
for keyword, resources in keywords_map.items():
if keyword in user_question:
for resource in resources:
if (
resource in self.resources_dict
and resource not in matched_resources
):
matched_resources.append(resource)
# 没有匹配则返回原问题
if not matched_resources:
return user_question
# 构建增强的问题
context_parts = []
for resource in matched_resources:
context_parts.append(f"--- {resource} ---\n{self.resources_dict[resource]}")
return (
user_question + "\n\n相关信息:\n\n" + "\n\n".join(context_parts)
)
async def process_query(self, query: str, messages: List[dict] = None, depth: int = 0) -> str:
if depth >= 5:
return "[ERR] 超过最大递归深度,请检查工具调用逻辑"
# messages = []
# messages = messages.copy() if messages else [{"role": "user", "content": query}]
if messages:
messages = messages.copy()
else:
user_text = query.strip()
# 1.选择 prompt
if self.prompts_dict:
template_name = await self.selcect_prompt_template(user_text)
prompt_response = await self.session.get_prompt(
template_name,
arguments={"question": user_text}
)
user_text = prompt_response.messages[0].content.text
print(f"[LOG]: 选择的提示模板: {template_name} \n")
# 2.添加相关资源
if self.resources_dict:
user_text = await self.add_relevant_resources(user_text)
messages = [{"role": "user", "content": user_text}]
full_response = ""
tool_calls_cache = {}
print("[LOG] Call LLM Messages:", messages)
print("[LOG] Call LLM Tools:", self.available_tools)
# 发起流式请求
stream = await self.llm_client.chat.completions.create(
model=self.model_name,
messages=messages,
tools=self.available_tools,
tool_choice="auto",
stream=True,
)
sys.stdout.write("[LLM]: ")
sys.stdout.flush()
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
# 处理自然语言回答内容
if delta.content:
sys.stdout.write(delta.content)
sys.stdout.flush()
full_response += delta.content
# 处理工具调用增量参数
if delta.tool_calls:
for tool_call in delta.tool_calls:
index = tool_call.index
if index not in tool_calls_cache:
tool_calls_cache[index] = {
"id": "",
"name": "",
"arguments": ""
}
cached = tool_calls_cache[index]
cached["id"] = tool_call.id or cached["id"]
cached["name"] = tool_call.function.name or cached["name"]
cached["arguments"] += tool_call.function.arguments or ""
print("\n") # 流式输出结束后换行
# 处理工具调用逻辑
if tool_calls_cache:
# 构造完整的工具调用参数日志
tool_calls_log = [
{
"name": call["name"],
"arguments": json.loads(call["arguments"]) if call["arguments"] else {}
}
for call in tool_calls_cache.values()
]
print(f"[LOG]: 完整工具调用参数: {json.dumps(tool_calls_log, ensure_ascii=False)}")
# 将工具调用信息添加到messages
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [
{
"type": "function",
"id": call["id"],
"function": {
"name": call["name"],
"arguments": call["arguments"]
}
}
for call in tool_calls_cache.values()
]
})
# 执行工具调用并递归处理
for tool_call in tool_calls_cache.values():
tool_name = tool_call["name"]
try:
tool_args = json.loads(tool_call["arguments"])
except json.JSONDecodeError:
tool_args = {"input": tool_call["arguments"]}
result = await self.session.call_tool(tool_name, tool_args)
print(f"[LOG]: 调用结果: {result.model_dump()}")
print(f"[LOG]: 调用工具 [{tool_name}] 参数: {tool_args}")
print(f"[LOG]: 工具响应: {result.content}\n")
messages.append({
"role": "tool",
# "content": getattr(result.content, 'text', str(result.content)),
"content": result.model_dump()["content"],
"tool_call_id": tool_call["id"],
"name": tool_name
})
return await self.process_query(query, messages, depth + 1)
return full_response
async def chat_loop(self):
print("[SYS]: MCP客户端已启动!")
print("[SYS]: 输入自然语言查询开始交互(输入 'quit' 退出)")
loop = asyncio.get_event_loop()
while True:
try:
query = await loop.run_in_executor(
None,
lambda: input("[USR]: ").strip()
)
print()
if not query:
continue
if query.lower() == 'quit':
break
response = await self.process_query(query)
# process_query 中流式打印了,这儿可以不用打印
# print(f"[LLM]: {response}")
except (KeyboardInterrupt, EOFError):
print("\n[SYS]: 检测到退出信号,正在关闭...")
break
except Exception as e:
print(f"\n[SYS]: 错误发生:{str(e)}")
async def cleanup(self):
async with self._cleanup_lock:
try:
await self.exit_stack.aclose()
self.session = None
self.stdio_context = None
except Exception as e:
print(f"Error during cleanup of server: {e}")
async def main():
try:
server_params = MCPClient.parse_arguments(sys.argv[1:])
except ValueError as e:
print(f"[ERR]: 参数错误: {str(e)}")
print("使用方法:")
print("方式 1: python mcp_client.py <服务器脚本路径>")
print("方式 2: python mcp_client.py <服务器标识符> <配置文件路径>")
sys.exit(1)
model_type = os.getenv("LLM_MODEL_TYPE", "")
api_key = os.getenv("LLM_API_KEY", "")
base_url = os.getenv("LLM_API_URL", "")
model_name = os.getenv("LLM_MODEL_NAME", "")
print("[SYS]: LLM_MODEL_TYPE: ", model_type)
print("[SYS]: LLM_API_URL: ", base_url)
print("[SYS]: LLM_MODEL_NAME: ", model_name)
client = MCPClient(
model_type=model_type,
api_key=api_key,
base_url=base_url,
model_name=model_name,
)
try:
await client.connect_to_server(server_params)
await client.chat_loop()
except ValueError as e:
print(f"[ERR] 参数错误: {str(e)}")
except Exception as e:
print(f"\n[ERR] 运行时错误: {str(e)}")
finally:
await client.cleanup()
if __name__ == "__main__":
import platform
if platform.system().lower() == 'windows':
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
else:
asyncio.run(main())