diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 914f20b524e..0b5d8af6852 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -384,10 +384,11 @@ def _handle_mcp_request( if tool is None or tool.is_active is False: continue executor = ToolExecutor() + init_params_default_value = {i["field"]: i.get('default_value') for i in tool.init_field_list} if tool.init_params is not None: - tool_init_params = json.loads(rsa_long_decrypt(tool.init_params)) + tool_init_params = init_params_default_value | json.loads(rsa_long_decrypt(tool.init_params)) else: - tool_init_params = {i["field"]: i.get("default_value") for i in tool.init_field_list} + tool_init_params = init_params_default_value tool_config = executor.get_tool_mcp_config(tool, tool_init_params) mcp_servers_config[str(tool.id)] = tool_config diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index 5e2b94ade0c..885b63bd9dd 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -27,7 +27,7 @@ from common.utils.rsa_util import rsa_long_decrypt from common.utils.shared_resource_auth import filter_authorized_ids from common.utils.tool_code import ToolExecutor -from knowledge.models import File +from common.utils.logger import maxkb_logger from models_provider.models import Model from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id from tools.models import Tool, ToolType @@ -280,10 +280,12 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids if tool is None: continue executor = ToolExecutor() + init_params_default_value = {i["field"]: i.get('default_value') for i in tool.init_field_list} if tool.init_params is not None: - tool_init_params = json.loads(rsa_long_decrypt(tool.init_params)) + tool_init_params = init_params_default_value | json.loads(rsa_long_decrypt(tool.init_params)) else: - tool_init_params = {i["field"]: i.get('default_value') for i in tool.init_field_list} + tool_init_params = init_params_default_value + tool_config = executor.get_tool_mcp_config(tool, tool_init_params) mcp_servers_config[str(tool.id)] = tool_config @@ -482,6 +484,21 @@ def reset_message_list(message_list: List[BaseMessage], answer_text): return result def get_details(self, index: int, **kwargs): + tool_call_list = [] + answer = self.context.get('answer', '') + if answer: + for match in re.finditer( + r'(.*?)', answer, re.DOTALL + ): + try: + tool_data = json.loads(match.group(1)) + tool_call_list.append({ + 'name': tool_data.get('title', ''), + 'input': tool_data.get('content', {}).get('input', ''), + 'icon': tool_data.get('icon', ''), + }) + except (json.JSONDecodeError, Exception) as e: + maxkb_logger.error(f"get_details error {e}") return { 'name': self.node.properties.get('stepName'), "index": index, @@ -490,6 +507,7 @@ def get_details(self, index: int, **kwargs): 'history_message': self.context.get('history_message'), 'question': self.context.get('question'), 'answer': self.context.get('answer'), + 'tool_call_list': tool_call_list, 'reasoning_content': self.context.get('reasoning_content'), 'enableException': self.node.properties.get('enableException'), 'type': self.node.type, diff --git a/apps/application/flow/step_node/tool_node/impl/base_tool_node.py b/apps/application/flow/step_node/tool_node/impl/base_tool_node.py index c5595bc805e..e269a2b2ba7 100644 --- a/apps/application/flow/step_node/tool_node/impl/base_tool_node.py +++ b/apps/application/flow/step_node/tool_node/impl/base_tool_node.py @@ -100,8 +100,19 @@ def execute(self, input_field_list, code, **kwargs) -> NodeResult: params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'), field.get('is_required'), field.get('source'), self) for field in input_field_list} - result = function_executor.exec_code(code, params) - self.context['params'] = params + # 合并启动参数默认值(如果有 init_field_list 定义) + init_field_list = self.node_params.get('init_field_list', []) + if init_field_list: + init_params_default_value = {i["field"]: i.get('default_value') for i in init_field_list} + init_params = kwargs.get('init_params') + if init_params is not None: + all_params = init_params_default_value | init_params | params + else: + all_params = init_params_default_value | params + else: + all_params = params + result = function_executor.exec_code(code, all_params) + self.context['params'] = all_params return NodeResult({'result': result}, {}, _write_context=write_context) def get_details(self, index: int, **kwargs): diff --git a/apps/tools/serializers/tool.py b/apps/tools/serializers/tool.py index 0d1cadd7444..99a8ced368f 100644 --- a/apps/tools/serializers/tool.py +++ b/apps/tools/serializers/tool.py @@ -569,6 +569,7 @@ def debug(self, debug_instance): input_field_list = debug_instance.get("input_field_list") code = debug_instance.get("code") debug_field_list = debug_instance.get("debug_field_list") + init_field_list = debug_instance.get("init_field_list", []) init_params = debug_instance.get("init_params") params = { field.get("name"): self.convert_value( @@ -582,11 +583,13 @@ def debug(self, debug_instance): for field in input_field_list ] } + # 合并初始化参数(默认值 → 已保存的启动参数 → 运行时入参) + init_params_default_value = {i["field"]: i.get('default_value') for i in init_field_list} # 合并初始化参数 if init_params is not None: - all_params = init_params | params + all_params = init_params_default_value | init_params | params else: - all_params = params + all_params = init_params_default_value | params return tool_executor.exec_code(code, all_params) @staticmethod