diff --git a/backend/apps/config_app.py b/backend/apps/config_app.py index 0cfc962ea..3119b38b6 100644 --- a/backend/apps/config_app.py +++ b/backend/apps/config_app.py @@ -16,6 +16,7 @@ from apps.model_managment_app import router as model_manager_router from apps.oauth_app import router as oauth_router from apps.prompt_app import router as prompt_router +from apps.prompt_template_app import router as prompt_template_router from apps.remote_mcp_app import router as remote_mcp_router from apps.skill_app import router as skill_router from apps.tenant_config_app import router as tenant_config_router @@ -30,6 +31,7 @@ from apps.monitoring_app import router as monitoring_router from apps.a2a_server_app import router as a2a_server_router from consts.const import IS_SPEED_MODE +from services.prompt_template_service import sync_system_default_prompt_template # Create logger instance logger = logging.getLogger("base_app") @@ -37,6 +39,16 @@ # Create FastAPI app with common configurations app = create_app(title="Nexent Config API", description="Configuration APIs") + +@app.on_event("startup") +async def sync_default_prompt_template_on_startup(): + """Sync the YAML-backed system default prompt template into the database on startup.""" + try: + sync_system_default_prompt_template() + logger.info("System default prompt template synced successfully.") + except Exception as exc: + logger.error(f"Failed to sync system default prompt template: {str(exc)}") + app.include_router(model_manager_router) app.include_router(config_sync_router) app.include_router(agent_router) @@ -62,6 +74,7 @@ app.include_router(summary_router) app.include_router(prompt_router) +app.include_router(prompt_template_router) app.include_router(skill_router) app.include_router(tenant_config_router) app.include_router(remote_mcp_router) diff --git a/backend/apps/prompt_app.py b/backend/apps/prompt_app.py index a9bd8d3a6..a7e7b736a 100644 --- a/backend/apps/prompt_app.py +++ b/backend/apps/prompt_app.py @@ -25,6 +25,7 @@ async def generate_and_save_system_prompt_api( agent_id=prompt_request.agent_id, model_id=prompt_request.model_id, task_description=prompt_request.task_description, + prompt_template_id=prompt_request.prompt_template_id, user_id=user_id, tenant_id=tenant_id, language=language, diff --git a/backend/apps/prompt_template_app.py b/backend/apps/prompt_template_app.py new file mode 100644 index 000000000..0f12bd614 --- /dev/null +++ b/backend/apps/prompt_template_app.py @@ -0,0 +1,143 @@ +import logging +from http import HTTPStatus +from typing import Optional + +from fastapi import APIRouter, Header, HTTPException +from starlette.responses import JSONResponse + +from consts.exceptions import DuplicateError, NotFoundException, ValidationError +from consts.model import PromptTemplateRequest +from services.prompt_template_service import ( + create_prompt_template_impl, + delete_prompt_template_impl, + get_prompt_template_detail_impl, + list_prompt_templates_impl, + update_prompt_template_impl, +) +from utils.auth_utils import get_current_user_id + +router = APIRouter(prefix="/prompt_templates") +logger = logging.getLogger("prompt_template_app") + + +@router.get("") +async def list_prompt_templates_api( + authorization: Optional[str] = Header(None), +): + """List prompt templates for the current user.""" + try: + user_id, tenant_id = get_current_user_id(authorization) + result = list_prompt_templates_impl(tenant_id=tenant_id, user_id=user_id) + return JSONResponse(status_code=HTTPStatus.OK, content=result) + except Exception as exc: + logger.error(f"Prompt template list error: {str(exc)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Prompt template list error.", + ) + + +@router.get("/{template_id}") +async def get_prompt_template_api( + template_id: int, + authorization: Optional[str] = Header(None), +): + """Get prompt template detail.""" + try: + user_id, tenant_id = get_current_user_id(authorization) + result = get_prompt_template_detail_impl( + template_id=template_id, + tenant_id=tenant_id, + user_id=user_id, + ) + return JSONResponse(status_code=HTTPStatus.OK, content=result) + except NotFoundException as exc: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(exc)) + except Exception as exc: + logger.error(f"Prompt template detail error: {str(exc)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Prompt template detail error.", + ) + + +@router.post("") +async def create_prompt_template_api( + request: PromptTemplateRequest, + authorization: Optional[str] = Header(None), +): + """Create a prompt template.""" + try: + user_id, tenant_id = get_current_user_id(authorization) + result = create_prompt_template_impl( + request=request, + tenant_id=tenant_id, + user_id=user_id, + ) + return JSONResponse(status_code=HTTPStatus.OK, content=result) + except DuplicateError as exc: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(exc)) + except ValidationError as exc: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(exc)) + except Exception as exc: + logger.error(f"Prompt template create error: {str(exc)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Prompt template create error.", + ) + + +@router.put("/{template_id}") +async def update_prompt_template_api( + template_id: int, + request: PromptTemplateRequest, + authorization: Optional[str] = Header(None), +): + """Update a prompt template.""" + try: + user_id, tenant_id = get_current_user_id(authorization) + result = update_prompt_template_impl( + template_id=template_id, + request=request, + tenant_id=tenant_id, + user_id=user_id, + ) + return JSONResponse(status_code=HTTPStatus.OK, content=result) + except NotFoundException as exc: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(exc)) + except DuplicateError as exc: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(exc)) + except ValidationError as exc: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(exc)) + except Exception as exc: + logger.error(f"Prompt template update error: {str(exc)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Prompt template update error.", + ) + + +@router.delete("/{template_id}") +async def delete_prompt_template_api( + template_id: int, + authorization: Optional[str] = Header(None), +): + """Delete a prompt template.""" + try: + user_id, tenant_id = get_current_user_id(authorization) + result = delete_prompt_template_impl( + template_id=template_id, + tenant_id=tenant_id, + user_id=user_id, + ) + return JSONResponse(status_code=HTTPStatus.OK, content=result) + except NotFoundException as exc: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(exc)) + except ValidationError as exc: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(exc)) + except Exception as exc: + logger.error(f"Prompt template delete error: {str(exc)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Prompt template delete error.", + ) diff --git a/backend/consts/model.py b/backend/consts/model.py index 2f1d7aae3..9a36f8f17 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -1,9 +1,11 @@ from enum import Enum from typing import Optional, Any, List, Dict -from pydantic import BaseModel, Field, EmailStr +from pydantic import BaseModel, Field, EmailStr, ConfigDict from nexent.core.agents.agent_model import ToolConfig +from consts.prompt_template import PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP + class ModelConnectStatusEnum(Enum): """Enum class for model connection status""" @@ -312,6 +314,7 @@ class GeneratePromptRequest(BaseModel): task_description: str agent_id: int model_id: int + prompt_template_id: Optional[int] = None tool_ids: Optional[List[int]] = Field( None, description="Optional: tool IDs from frontend (takes precedence over database query)") sub_agent_ids: Optional[List[int]] = Field( @@ -320,6 +323,52 @@ class GeneratePromptRequest(BaseModel): None, description="Optional: knowledge base display names from frontend (takes precedence over database query)") +class PromptTemplateContentRequest(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + duty_system_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["duty_system_prompt"] + ) + constraint_system_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["constraint_system_prompt"] + ) + few_shots_system_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["few_shots_system_prompt"] + ) + agent_variable_name_system_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["agent_variable_name_system_prompt"] + ) + agent_display_name_system_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["agent_display_name_system_prompt"] + ) + agent_description_system_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["agent_description_system_prompt"] + ) + user_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["user_prompt"] + ) + agent_name_regenerate_system_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["agent_name_regenerate_system_prompt"] + ) + agent_name_regenerate_user_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["agent_name_regenerate_user_prompt"] + ) + agent_display_name_regenerate_system_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["agent_display_name_regenerate_system_prompt"] + ) + agent_display_name_regenerate_user_prompt: str = Field( + alias=PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP["agent_display_name_regenerate_user_prompt"] + ) + + +class PromptTemplateRequest(BaseModel): + template_name: str + description: Optional[str] = None + template_type: str = "agent_generate" + template_content_zh: PromptTemplateContentRequest + template_content_en: Optional[PromptTemplateContentRequest] = None + + class GenerateTitleRequest(BaseModel): conversation_id: int question: str @@ -343,6 +392,8 @@ class AgentInfoRequest(BaseModel): enabled: Optional[bool] = None business_logic_model_name: Optional[str] = None business_logic_model_id: Optional[int] = None + prompt_template_id: Optional[int] = None + prompt_template_name: Optional[str] = None enabled_tool_ids: Optional[List[int]] = None enabled_skill_ids: Optional[List[int]] = None related_agent_ids: Optional[List[int]] = None @@ -431,6 +482,8 @@ class ExportAndImportAgentInfo(BaseModel): model_name: Optional[str] = None business_logic_model_id: Optional[int] = None business_logic_model_name: Optional[str] = None + prompt_template_id: Optional[int] = None + prompt_template_name: Optional[str] = None class Config: arbitrary_types_allowed = True diff --git a/backend/consts/prompt_template.py b/backend/consts/prompt_template.py new file mode 100644 index 000000000..febcaeca5 --- /dev/null +++ b/backend/consts/prompt_template.py @@ -0,0 +1,15 @@ +PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP = { + "duty_system_prompt": "DUTY_SYSTEM_PROMPT", + "constraint_system_prompt": "CONSTRAINT_SYSTEM_PROMPT", + "few_shots_system_prompt": "FEW_SHOTS_SYSTEM_PROMPT", + "agent_variable_name_system_prompt": "AGENT_VARIABLE_NAME_SYSTEM_PROMPT", + "agent_display_name_system_prompt": "AGENT_DISPLAY_NAME_SYSTEM_PROMPT", + "agent_description_system_prompt": "AGENT_DESCRIPTION_SYSTEM_PROMPT", + "user_prompt": "USER_PROMPT", + "agent_name_regenerate_system_prompt": "AGENT_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_name_regenerate_user_prompt": "AGENT_NAME_REGENERATE_USER_PROMPT", + "agent_display_name_regenerate_system_prompt": "AGENT_DISPLAY_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_display_name_regenerate_user_prompt": "AGENT_DISPLAY_NAME_REGENERATE_USER_PROMPT", +} + +PROMPT_GENERATE_TEMPLATE_FIELDS = tuple(PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP.keys()) diff --git a/backend/database/agent_db.py b/backend/database/agent_db.py index 7d14d7b8e..90de64ca9 100644 --- a/backend/database/agent_db.py +++ b/backend/database/agent_db.py @@ -192,6 +192,8 @@ def create_agent(agent_info, tenant_id: str, user_id: str): "business_description": new_agent.business_description, "business_logic_model_id": new_agent.business_logic_model_id, "business_logic_model_name": new_agent.business_logic_model_name, + "prompt_template_id": new_agent.prompt_template_id, + "prompt_template_name": new_agent.prompt_template_name, "group_ids": new_agent.group_ids, "is_new": new_agent.is_new, "enable_context_manager": new_agent.enable_context_manager, diff --git a/backend/database/db_models.py b/backend/database/db_models.py index baa8e903e..b6bdf3b57 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -1,3 +1,4 @@ +from sqlalchemy import BigInteger, Boolean, Column, ForeignKey, ForeignKeyConstraint, Integer, JSON, Numeric, PrimaryKeyConstraint, Sequence, String, Text, TIMESTAMP, UniqueConstraint, Index, Float, text from sqlalchemy import BigInteger, Boolean, Column, Integer, JSON, Numeric, Sequence, String, Text, TIMESTAMP, UniqueConstraint, Index, Float from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import DeclarativeBase @@ -313,6 +314,8 @@ class AgentInfo(TableBase): Text, doc="Manually entered by the user to describe the entire business process") business_logic_model_name = Column(String(100), doc="Model name used for business logic prompt generation") business_logic_model_id = Column(Integer, doc="Model ID used for business logic prompt generation, foreign key reference to model_record_t.model_id") + prompt_template_id = Column(Integer, doc="Prompt template ID used for business logic prompt generation") + prompt_template_name = Column(String(100), doc="Prompt template name used for business logic prompt generation") group_ids = Column(String, doc="Agent group IDs list") is_new = Column(Boolean, default=False, doc="Whether this agent is marked as new for the user") current_version_no = Column(Integer, nullable=True, doc="Current published version number. NULL means no version published yet") @@ -320,6 +323,41 @@ class AgentInfo(TableBase): enable_context_manager = Column(Boolean, default=False, doc="Whether to enable context management (compression) for this agent") +class PromptTemplate(TableBase): + """ + Prompt template table for user-defined prompt generation templates. + """ + __tablename__ = "ag_prompt_template_t" + __table_args__ = ( + Index( + "uq_prompt_template_user_name_active", + "tenant_id", + "user_id", + "template_name", + unique=True, + postgresql_where=text("delete_flag = 'N'"), + ), + Index( + "idx_ag_prompt_template_t_user", + "tenant_id", + "user_id", + "template_type", + postgresql_where=text("delete_flag = 'N'"), + ), + {"schema": SCHEMA}, + ) + + template_id = Column(Integer, Sequence( + "ag_prompt_template_t_template_id_seq", schema=SCHEMA), primary_key=True, nullable=False, autoincrement=True, doc="Prompt template ID") + template_name = Column(String(100), nullable=False, doc="Prompt template name") + description = Column(String(500), doc="Prompt template description") + template_type = Column(String(50), nullable=False, default="agent_generate", doc="Prompt template type") + tenant_id = Column(String(100), nullable=False, doc="Tenant ID") + user_id = Column(String(100), nullable=False, doc="User ID") + template_content_zh = Column(JSONB, nullable=False, doc="Chinese prompt template content") + template_content_en = Column(JSONB, doc="English prompt template content") + + class ToolInstance(TableBase): """ Information table for tenant tool configuration. diff --git a/backend/database/prompt_template_db.py b/backend/database/prompt_template_db.py new file mode 100644 index 000000000..fbc286cf9 --- /dev/null +++ b/backend/database/prompt_template_db.py @@ -0,0 +1,165 @@ +import logging +from typing import Optional + +from sqlalchemy import select, update + +from database.client import as_dict, filter_property, get_db_session +from database.db_models import PromptTemplate + +logger = logging.getLogger("prompt_template_db") + + +def create_prompt_template(template_data: dict) -> dict: + """Create a prompt template.""" + with get_db_session() as session: + prompt_template = PromptTemplate( + **filter_property(template_data, PromptTemplate) + ) + prompt_template.delete_flag = "N" + session.add(prompt_template) + session.flush() + return as_dict(prompt_template) + + +def upsert_prompt_template_by_id(template_id: int, template_data: dict, user_id: str) -> dict: + """Create or update a prompt template with a fixed template ID.""" + with get_db_session() as session: + prompt_template = session.query(PromptTemplate).filter( + PromptTemplate.template_id == template_id, + ).first() + + filtered_data = filter_property(template_data, PromptTemplate) + if prompt_template: + for key, value in filtered_data.items(): + setattr(prompt_template, key, value) + prompt_template.updated_by = user_id + else: + prompt_template = PromptTemplate(**filtered_data) + prompt_template.template_id = template_id + prompt_template.delete_flag = filtered_data.get("delete_flag", "N") + session.add(prompt_template) + + session.flush() + return as_dict(prompt_template) + + +def update_prompt_template(template_id: int, template_data: dict, user_id: str) -> dict: + """Update a prompt template.""" + with get_db_session() as session: + prompt_template = session.query(PromptTemplate).filter( + PromptTemplate.template_id == template_id, + PromptTemplate.delete_flag == "N", + ).first() + + if not prompt_template: + raise ValueError("prompt template not found") + + for key, value in filter_property(template_data, PromptTemplate).items(): + if value is None: + continue + setattr(prompt_template, key, value) + + prompt_template.updated_by = user_id + session.flush() + return as_dict(prompt_template) + + +def delete_prompt_template(template_id: int, user_id: str) -> int: + """Soft-delete a prompt template.""" + with get_db_session() as session: + result = session.execute( + update(PromptTemplate) + .where( + PromptTemplate.template_id == template_id, + PromptTemplate.delete_flag == "N", + ) + .values(delete_flag="Y", updated_by=user_id) + ) + return result.rowcount + + +def query_prompt_templates_by_user( + tenant_id: str, + user_id: str, + template_type: str = "agent_generate", +) -> list[dict]: + """Query prompt templates by tenant and user.""" + with get_db_session() as session: + templates = session.query(PromptTemplate).filter( + PromptTemplate.tenant_id == tenant_id, + PromptTemplate.user_id == user_id, + PromptTemplate.template_type == template_type, + PromptTemplate.delete_flag == "N", + ).order_by(PromptTemplate.update_time.desc(), PromptTemplate.template_id.desc()).all() + return [as_dict(template) for template in templates] + + +def get_prompt_template_by_id( + template_id: int, + tenant_id: str, + user_id: str, + template_type: str = "agent_generate", +) -> Optional[dict]: + """Get a prompt template by ID.""" + with get_db_session() as session: + template = session.query(PromptTemplate).filter( + PromptTemplate.template_id == template_id, + PromptTemplate.tenant_id == tenant_id, + PromptTemplate.user_id == user_id, + PromptTemplate.template_type == template_type, + PromptTemplate.delete_flag == "N", + ).first() + return as_dict(template) if template else None + + +def get_prompt_template_by_name( + template_name: str, + tenant_id: str, + user_id: str, + template_type: str = "agent_generate", +) -> Optional[dict]: + """Get a prompt template by name.""" + with get_db_session() as session: + template = session.query(PromptTemplate).filter( + PromptTemplate.template_name == template_name, + PromptTemplate.tenant_id == tenant_id, + PromptTemplate.user_id == user_id, + PromptTemplate.template_type == template_type, + PromptTemplate.delete_flag == "N", + ).first() + return as_dict(template) if template else None + + +def get_prompt_template_by_template_id( + template_id: int, + template_type: str = "agent_generate", + include_deleted: bool = False, +) -> Optional[dict]: + """Get a prompt template by template ID regardless of owner.""" + with get_db_session() as session: + query = session.query(PromptTemplate).filter( + PromptTemplate.template_id == template_id, + PromptTemplate.template_type == template_type, + ) + if not include_deleted: + query = query.filter(PromptTemplate.delete_flag == "N") + template = query.first() + return as_dict(template) if template else None + + +def query_prompt_template_names( + tenant_id: str, + user_id: str, + template_type: str = "agent_generate", +) -> set[str]: + """Query all active prompt template names for the current user.""" + with get_db_session() as session: + rows = session.execute( + select(PromptTemplate.template_name).where( + PromptTemplate.tenant_id == tenant_id, + PromptTemplate.user_id == user_id, + PromptTemplate.template_type == template_type, + PromptTemplate.delete_flag == "N", + ) + ).all() + return {row[0] for row in rows if row and row[0]} diff --git a/backend/services/agent_service.py b/backend/services/agent_service.py index 73c6a4640..32e879606 100644 --- a/backend/services/agent_service.py +++ b/backend/services/agent_service.py @@ -16,6 +16,7 @@ from agents.create_agent_info import create_agent_run_info, create_tool_config_list from agents.preprocess_manager import preprocess_manager from services.agent_version_service import publish_version_impl +from utils.prompt_template_utils import normalize_prompt_generate_template_content from consts.const import MEMORY_SEARCH_START_MSG, MEMORY_SEARCH_DONE_MSG, MEMORY_SEARCH_FAIL_MSG, TOOL_TYPE_MAPPING, \ LANGUAGE, MESSAGE_ROLE, MODEL_CONFIG_MAPPING, CAN_EDIT_ALL_USER_ROLES, PERMISSION_EDIT, PERMISSION_READ, PERMISSION_PRIVATE from consts.exceptions import MemoryPreparationException @@ -63,6 +64,11 @@ from database.group_db import query_group_ids_by_user from database.user_tenant_db import get_user_tenant_by_user_id from database.a2a_agent_db import get_server_agent_ids +from services.prompt_template_service import ( + SYSTEM_PROMPT_TEMPLATE_ID, + SYSTEM_PROMPT_TEMPLATE_NAME, + get_prompt_template_summary, +) from utils.str_utils import convert_list_to_string, convert_string_to_list from services.conversation_management_service import save_conversation_assistant, save_conversation_user from services.memory_config_service import build_memory_context @@ -311,12 +317,25 @@ def _regenerate_agent_value_with_llm( user_prompt_key: str, default_system_prompt: str, default_user_prompt_builder: Callable[[dict], str], - fallback_fn: Callable[[str], str] + fallback_fn: Callable[[str], str], + prompt_template_id: Optional[int] = None, + user_id: Optional[str] = None, ) -> str: """ Shared helper to regenerate agent-related values with an LLM. """ - prompt_template = get_prompt_generate_prompt_template(language) + if user_id is not None: + from services.prompt_template_service import resolve_prompt_generate_template + prompt_template = resolve_prompt_generate_template( + tenant_id=tenant_id, + user_id=user_id, + language=language, + prompt_template_id=prompt_template_id, + ) + else: + prompt_template = normalize_prompt_generate_template_content( + get_prompt_generate_prompt_template(language) + ) system_prompt = _render_prompt_template( prompt_template.get(system_prompt_key, ""), original_value=original_value @@ -373,7 +392,9 @@ def _regenerate_agent_name_with_llm( tenant_id: str, language: str = LANGUAGE["ZH"], agents_cache: list[dict] | None = None, - exclude_agent_id: int | None = None + exclude_agent_id: int | None = None, + prompt_template_id: Optional[int] = None, + user_id: Optional[str] = None, ) -> str: return _regenerate_agent_value_with_llm( original_value=original_name, @@ -382,8 +403,8 @@ def _regenerate_agent_name_with_llm( model_id=model_id, tenant_id=tenant_id, language=language, - system_prompt_key="AGENT_NAME_REGENERATE_SYSTEM_PROMPT", - user_prompt_key="AGENT_NAME_REGENERATE_USER_PROMPT", + system_prompt_key="agent_name_regenerate_system_prompt", + user_prompt_key="agent_name_regenerate_user_prompt", default_system_prompt=( "You refine agent variable names so that they stay close to the " "original meaning and remain unique within the tenant." @@ -401,7 +422,9 @@ def _regenerate_agent_name_with_llm( tenant_id=tenant_id, agents_cache=agents_cache, exclude_agent_id=exclude_agent_id - ) + ), + prompt_template_id=prompt_template_id, + user_id=user_id, ) @@ -414,7 +437,9 @@ def _regenerate_agent_display_name_with_llm( tenant_id: str, language: str = LANGUAGE["ZH"], agents_cache: list[dict] | None = None, - exclude_agent_id: int | None = None + exclude_agent_id: int | None = None, + prompt_template_id: Optional[int] = None, + user_id: Optional[str] = None, ) -> str: return _regenerate_agent_value_with_llm( original_value=original_display_name, @@ -423,8 +448,8 @@ def _regenerate_agent_display_name_with_llm( model_id=model_id, tenant_id=tenant_id, language=language, - system_prompt_key="AGENT_DISPLAY_NAME_REGENERATE_SYSTEM_PROMPT", - user_prompt_key="AGENT_DISPLAY_NAME_REGENERATE_USER_PROMPT", + system_prompt_key="agent_display_name_regenerate_system_prompt", + user_prompt_key="agent_display_name_regenerate_user_prompt", default_system_prompt=( "You refine agent display names so they remain unique, concise, " "and aligned with the agent's capability." @@ -441,7 +466,9 @@ def _regenerate_agent_display_name_with_llm( tenant_id=tenant_id, agents_cache=agents_cache, exclude_agent_id=exclude_agent_id - ) + ), + prompt_template_id=prompt_template_id, + user_id=user_id, ) @@ -748,6 +775,11 @@ async def get_agent_info_impl(agent_id: int, tenant_id: str, version_no: int = 0 elif "business_logic_model_name" not in agent_info: agent_info["business_logic_model_name"] = None + if not agent_info.get("prompt_template_id"): + agent_info["prompt_template_id"] = SYSTEM_PROMPT_TEMPLATE_ID + if not agent_info.get("prompt_template_name"): + agent_info["prompt_template_name"] = SYSTEM_PROMPT_TEMPLATE_NAME + if agent_info.get("group_ids") is not None: agent_info["group_ids"] = convert_string_to_list(agent_info.get("group_ids")) @@ -804,6 +836,11 @@ async def get_creating_sub_agent_info_impl(authorization: str = Header(None)): async def update_agent_info_impl(request: AgentInfoRequest, authorization: str = Header(None)): user_id, tenant_id, _ = get_current_user_info(authorization) + prompt_template_id, prompt_template_name = get_prompt_template_summary( + template_id=request.prompt_template_id, + tenant_id=tenant_id, + user_id=user_id, + ) # If agent_id is None, create a new agent; otherwise, update existing agent_id: Optional[int] = request.agent_id @@ -821,6 +858,8 @@ async def update_agent_info_impl(request: AgentInfoRequest, authorization: str = "model_name": request.model_name, "business_logic_model_id": request.business_logic_model_id, "business_logic_model_name": request.business_logic_model_name, + "prompt_template_id": prompt_template_id, + "prompt_template_name": prompt_template_name, "max_steps": request.max_steps, "provide_run_summary": request.provide_run_summary, "duty_prompt": request.duty_prompt, @@ -833,6 +872,8 @@ async def update_agent_info_impl(request: AgentInfoRequest, authorization: str = agent_id = created["agent_id"] else: # Update agent + request.prompt_template_id = prompt_template_id + request.prompt_template_name = prompt_template_name update_agent(agent_id, request, user_id) except Exception as e: logger.error(f"Failed to update agent info: {str(e)}") @@ -1145,7 +1186,9 @@ async def export_agent_by_agent_id(agent_id: int, tenant_id: str, user_id: str) model_id=model_id, model_name=model_display_name, business_logic_model_id=business_logic_model_id, - business_logic_model_name=business_logic_model_display_name) + business_logic_model_name=business_logic_model_display_name, + prompt_template_id=agent_info.get("prompt_template_id"), + prompt_template_name=agent_info.get("prompt_template_name")) return agent_info @@ -1278,6 +1321,8 @@ async def import_agent_by_agent_id( "model_name": import_agent_info.model_name, "business_logic_model_id": business_logic_model_id, "business_logic_model_name": import_agent_info.business_logic_model_name, + "prompt_template_id": import_agent_info.prompt_template_id or SYSTEM_PROMPT_TEMPLATE_ID, + "prompt_template_name": import_agent_info.prompt_template_name or SYSTEM_PROMPT_TEMPLATE_NAME, "max_steps": import_agent_info.max_steps, "provide_run_summary": import_agent_info.provide_run_summary, "duty_prompt": import_agent_info.duty_prompt, diff --git a/backend/services/agent_version_service.py b/backend/services/agent_version_service.py index 067fd0e1c..647d02740 100644 --- a/backend/services/agent_version_service.py +++ b/backend/services/agent_version_service.py @@ -387,6 +387,11 @@ def rollback_version_impl( if not target_agent: raise ValueError(f"Agent snapshot for version {target_version_no} not found") + # Ensure the draft still exists before attempting an in-place restore. + draft_agent, _, _ = query_agent_draft(agent_id, tenant_id) + if not draft_agent: + raise ValueError("Agent draft not found") + # Get skill snapshots for target version from database import skill_db as skill_db_module target_skills = skill_db_module.query_skill_instances_by_agent_id( diff --git a/backend/services/prompt_service.py b/backend/services/prompt_service.py index aa4d420d5..54d431241 100644 --- a/backend/services/prompt_service.py +++ b/backend/services/prompt_service.py @@ -23,14 +23,14 @@ _generate_unique_agent_name_with_suffix, _generate_unique_display_name_with_suffix ) +from services.prompt_template_service import resolve_prompt_generate_template from utils.llm_utils import call_llm_for_system_prompt -from utils.prompt_template_utils import get_prompt_generate_prompt_template # Configure logging logger = logging.getLogger("prompt_service") -def gen_system_prompt_streamable(agent_id: int, model_id: int, task_description: str, user_id: str, tenant_id: str, language: str, tool_ids: Optional[List[int]] = None, sub_agent_ids: Optional[List[int]] = None, knowledge_base_display_names: Optional[List[str]] = None): +def gen_system_prompt_streamable(agent_id: int, model_id: int, task_description: str, user_id: str, tenant_id: str, language: str, prompt_template_id: Optional[int] = None, tool_ids: Optional[List[int]] = None, sub_agent_ids: Optional[List[int]] = None, knowledge_base_display_names: Optional[List[str]] = None): try: for system_prompt in generate_and_save_system_prompt_impl( agent_id=agent_id, @@ -39,6 +39,7 @@ def gen_system_prompt_streamable(agent_id: int, model_id: int, task_description: user_id=user_id, tenant_id=tenant_id, language=language, + prompt_template_id=prompt_template_id, tool_ids=tool_ids, sub_agent_ids=sub_agent_ids, knowledge_base_display_names=knowledge_base_display_names @@ -64,6 +65,7 @@ def generate_and_save_system_prompt_impl(agent_id: int, user_id: str, tenant_id: str, language: str, + prompt_template_id: Optional[int] = None, tool_ids: Optional[List[int]] = None, sub_agent_ids: Optional[List[int]] = None, knowledge_base_display_names: Optional[List[str]] = None): @@ -128,8 +130,17 @@ def generate_and_save_system_prompt_impl(agent_id: int, ] # Collect results and yield non-name fields immediately, but hold name fields for duplicate checking - for result_data in generate_system_prompt(sub_agent_info_list, task_description, tool_info_list, tenant_id, - model_id, language, knowledge_base_display_names): + for result_data in generate_system_prompt( + sub_agent_info_list, + task_description, + tool_info_list, + tenant_id, + user_id, + model_id, + language, + prompt_template_id, + knowledge_base_display_names, + ): result_type = result_data["type"] final_results[result_type] = result_data["content"] @@ -158,7 +169,9 @@ def generate_and_save_system_prompt_impl(agent_id: int, tenant_id=tenant_id, language=language, agents_cache=all_agents, - exclude_agent_id=agent_id + exclude_agent_id=agent_id, + prompt_template_id=prompt_template_id, + user_id=user_id, ) logger.info(f"Regenerated agent name: '{agent_name}'") final_results["agent_var_name"] = agent_name @@ -199,7 +212,9 @@ def generate_and_save_system_prompt_impl(agent_id: int, tenant_id=tenant_id, language=language, agents_cache=all_agents, - exclude_agent_id=agent_id + exclude_agent_id=agent_id, + prompt_template_id=prompt_template_id, + user_id=user_id, ) logger.info(f"Regenerated agent display_name: '{agent_display_name}'") final_results["agent_display_name"] = agent_display_name @@ -238,9 +253,14 @@ def generate_and_save_system_prompt_impl(agent_id: int, raise Exception("Failed to generate prompt content.") -def generate_system_prompt(sub_agent_info_list, task_description, tool_info_list, tenant_id: str, model_id: int, language: str = LANGUAGE["ZH"], knowledge_base_display_names: Optional[List[str]] = None): +def generate_system_prompt(sub_agent_info_list, task_description, tool_info_list, tenant_id: str, user_id: str, model_id: int, language: str = LANGUAGE["ZH"], prompt_template_id: Optional[int] = None, knowledge_base_display_names: Optional[List[str]] = None): """Main function for generating system prompts""" - prompt_for_generate = get_prompt_generate_prompt_template(language) + prompt_for_generate = resolve_prompt_generate_template( + tenant_id=tenant_id, + user_id=user_id, + language=language, + prompt_template_id=prompt_template_id, + ) # Prepare content for generating system prompts content = join_info_for_generate_system_prompt( @@ -292,15 +312,15 @@ def run_and_flag(tag, sys_prompt): logger.info("Generating system prompt") prompt_configs = [ - ("duty", prompt_for_generate["DUTY_SYSTEM_PROMPT"]), - ("constraint", prompt_for_generate["CONSTRAINT_SYSTEM_PROMPT"]), - ("few_shots", prompt_for_generate["FEW_SHOTS_SYSTEM_PROMPT"]), + ("duty", prompt_for_generate["duty_system_prompt"]), + ("constraint", prompt_for_generate["constraint_system_prompt"]), + ("few_shots", prompt_for_generate["few_shots_system_prompt"]), ("agent_var_name", - prompt_for_generate["AGENT_VARIABLE_NAME_SYSTEM_PROMPT"]), + prompt_for_generate["agent_variable_name_system_prompt"]), ("agent_display_name", - prompt_for_generate["AGENT_DISPLAY_NAME_SYSTEM_PROMPT"]), + prompt_for_generate["agent_display_name_system_prompt"]), ("agent_description", - prompt_for_generate["AGENT_DESCRIPTION_SYSTEM_PROMPT"]) + prompt_for_generate["agent_description_system_prompt"]) ] for tag, sys_prompt in prompt_configs: @@ -398,7 +418,7 @@ def join_info_for_generate_system_prompt(prompt_for_generate, sub_agent_info_lis template_context["knowledge_base_names"] = kb_names_str # Generate content using template - content = Template(prompt_for_generate["USER_PROMPT"], undefined=StrictUndefined).render(template_context) + content = Template(prompt_for_generate["user_prompt"], undefined=StrictUndefined).render(template_context) return content diff --git a/backend/services/prompt_template_service.py b/backend/services/prompt_template_service.py new file mode 100644 index 000000000..14224a099 --- /dev/null +++ b/backend/services/prompt_template_service.py @@ -0,0 +1,322 @@ +import logging +from typing import Optional + +from consts.const import DEFAULT_TENANT_ID, DEFAULT_USER_ID +from consts.const import LANGUAGE +from consts.exceptions import DuplicateError, NotFoundException, ValidationError +from consts.model import PromptTemplateRequest +from database.prompt_template_db import ( + create_prompt_template, + delete_prompt_template, + get_prompt_template_by_id, + get_prompt_template_by_name, + get_prompt_template_by_template_id, + query_prompt_templates_by_user, + upsert_prompt_template_by_id, + update_prompt_template, +) +from utils.prompt_template_utils import ( + get_prompt_generate_prompt_template, + merge_prompt_generate_templates, + normalize_prompt_generate_template_content, +) + +logger = logging.getLogger("prompt_template_service") + +SYSTEM_PROMPT_TEMPLATE_ID = 0 +SYSTEM_PROMPT_TEMPLATE_NAME = "system_default" +PROMPT_TEMPLATE_TYPE_AGENT_GENERATE = "agent_generate" +SYSTEM_PROMPT_TEMPLATE_DESCRIPTION = "System default prompt template" +SYSTEM_PROMPT_TEMPLATE_TENANT_ID = DEFAULT_TENANT_ID +SYSTEM_PROMPT_TEMPLATE_USER_ID = DEFAULT_USER_ID + + +def _normalize_prompt_template_entity(template: Optional[dict]) -> Optional[dict]: + """Normalize prompt template entity content keys to lowercase.""" + if not template: + return template + + normalized_template = dict(template) + normalized_template["template_content_zh"] = normalize_prompt_generate_template_content( + normalized_template.get("template_content_zh") + ) + template_content_en = normalize_prompt_generate_template_content( + normalized_template.get("template_content_en") + ) + normalized_template["template_content_en"] = template_content_en or None + return normalized_template + + +def build_system_default_prompt_template_payload() -> dict: + """Build the canonical system default prompt template payload from YAML files.""" + system_template_zh = normalize_prompt_generate_template_content( + get_prompt_generate_prompt_template(LANGUAGE["ZH"]) + ) + system_template_en = normalize_prompt_generate_template_content( + get_prompt_generate_prompt_template(LANGUAGE["EN"]) + ) + return { + "template_id": SYSTEM_PROMPT_TEMPLATE_ID, + "template_name": SYSTEM_PROMPT_TEMPLATE_NAME, + "description": SYSTEM_PROMPT_TEMPLATE_DESCRIPTION, + "template_type": PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + "tenant_id": SYSTEM_PROMPT_TEMPLATE_TENANT_ID, + "user_id": SYSTEM_PROMPT_TEMPLATE_USER_ID, + "template_content_zh": system_template_zh, + "template_content_en": system_template_en, + "created_by": SYSTEM_PROMPT_TEMPLATE_USER_ID, + "updated_by": SYSTEM_PROMPT_TEMPLATE_USER_ID, + "delete_flag": "N", + } + + +def sync_system_default_prompt_template() -> dict: + """Sync the YAML-backed system default prompt template into the database.""" + payload = build_system_default_prompt_template_payload() + prompt_template = upsert_prompt_template_by_id( + template_id=SYSTEM_PROMPT_TEMPLATE_ID, + template_data=payload, + user_id=SYSTEM_PROMPT_TEMPLATE_USER_ID, + ) + prompt_template["is_system_default"] = True + return _normalize_prompt_template_entity(prompt_template) + + +def get_system_default_prompt_template() -> dict: + """Return the system default prompt generation template from the database.""" + prompt_template = get_prompt_template_by_template_id( + template_id=SYSTEM_PROMPT_TEMPLATE_ID, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + if not prompt_template: + prompt_template = sync_system_default_prompt_template() + else: + prompt_template["is_system_default"] = True + return _normalize_prompt_template_entity({ + **prompt_template, + "is_system_default": True, + }) + + +def _normalize_template_request(request: PromptTemplateRequest) -> dict: + """Normalize prompt template request payload.""" + template_name = (request.template_name or "").strip() + if not template_name: + raise ValidationError("template_name is required") + + if request.template_type != PROMPT_TEMPLATE_TYPE_AGENT_GENERATE: + raise ValidationError("Unsupported template type") + + zh_content = normalize_prompt_generate_template_content( + request.template_content_zh.model_dump() + ) + if len(zh_content) == 0: + raise ValidationError("template_content_zh is required") + + en_content = None + if request.template_content_en is not None: + en_content = normalize_prompt_generate_template_content( + request.template_content_en.model_dump() + ) + if len(en_content) == 0: + en_content = None + + return { + "template_name": template_name, + "description": (request.description or "").strip() or None, + "template_type": request.template_type, + "template_content_zh": zh_content, + "template_content_en": en_content, + } + + +def list_prompt_templates_impl(tenant_id: str, user_id: str) -> list[dict]: + """List all prompt templates for the current user.""" + system_default_template = sync_system_default_prompt_template() + templates = query_prompt_templates_by_user( + tenant_id=tenant_id, + user_id=user_id, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + return [system_default_template, *[ + _normalize_prompt_template_entity({ + **template, + "is_system_default": False, + }) + for template in templates + if template.get("template_id") != SYSTEM_PROMPT_TEMPLATE_ID + ]] + + +def get_prompt_template_detail_impl(template_id: int, tenant_id: str, user_id: str) -> dict: + """Get prompt template detail.""" + if template_id == SYSTEM_PROMPT_TEMPLATE_ID: + return get_system_default_prompt_template() + + template = get_prompt_template_by_id( + template_id=template_id, + tenant_id=tenant_id, + user_id=user_id, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + if not template: + raise NotFoundException("Prompt template not found") + + template["is_system_default"] = False + return _normalize_prompt_template_entity(template) + + +def create_prompt_template_impl( + request: PromptTemplateRequest, + tenant_id: str, + user_id: str, +) -> dict: + """Create a prompt template.""" + normalized_request = _normalize_template_request(request) + existing_template = get_prompt_template_by_name( + template_name=normalized_request["template_name"], + tenant_id=tenant_id, + user_id=user_id, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + if existing_template: + raise DuplicateError("Prompt template name already exists") + + created_template = create_prompt_template({ + **normalized_request, + "tenant_id": tenant_id, + "user_id": user_id, + "created_by": user_id, + "updated_by": user_id, + }) + created_template["is_system_default"] = False + return _normalize_prompt_template_entity(created_template) + + +def update_prompt_template_impl( + template_id: int, + request: PromptTemplateRequest, + tenant_id: str, + user_id: str, +) -> dict: + """Update a prompt template.""" + if template_id == SYSTEM_PROMPT_TEMPLATE_ID: + raise ValidationError("System default prompt template cannot be updated") + + existing_template = get_prompt_template_by_id( + template_id=template_id, + tenant_id=tenant_id, + user_id=user_id, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + if not existing_template: + raise NotFoundException("Prompt template not found") + + normalized_request = _normalize_template_request(request) + duplicate_template = get_prompt_template_by_name( + template_name=normalized_request["template_name"], + tenant_id=tenant_id, + user_id=user_id, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + if duplicate_template and duplicate_template["template_id"] != template_id: + raise DuplicateError("Prompt template name already exists") + + updated_template = update_prompt_template( + template_id=template_id, + template_data=normalized_request, + user_id=user_id, + ) + updated_template["is_system_default"] = False + return _normalize_prompt_template_entity(updated_template) + + +def delete_prompt_template_impl(template_id: int, tenant_id: str, user_id: str) -> dict: + """Delete a prompt template.""" + if template_id == SYSTEM_PROMPT_TEMPLATE_ID: + raise ValidationError("System default prompt template cannot be deleted") + + existing_template = get_prompt_template_by_id( + template_id=template_id, + tenant_id=tenant_id, + user_id=user_id, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + if not existing_template: + raise NotFoundException("Prompt template not found") + + deleted_count = delete_prompt_template(template_id=template_id, user_id=user_id) + return { + "template_id": template_id, + "deleted": deleted_count > 0, + } + + +def resolve_prompt_generate_template( + tenant_id: str, + user_id: str, + language: str, + prompt_template_id: Optional[int] = None, +) -> dict: + """Resolve prompt generation template for the current user and language.""" + system_default_template = sync_system_default_prompt_template() + system_template = ( + system_default_template.get("template_content_en") + if language == LANGUAGE["EN"] + else system_default_template.get("template_content_zh") + ) + fallback_system_template = system_default_template.get("template_content_zh") + + if not prompt_template_id or prompt_template_id == SYSTEM_PROMPT_TEMPLATE_ID: + return merge_prompt_generate_templates(system_template, fallback_system_template) + + prompt_template = get_prompt_template_by_id( + template_id=prompt_template_id, + tenant_id=tenant_id, + user_id=user_id, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + if not prompt_template: + logger.warning( + "Prompt template %s not found for tenant %s user %s, falling back to system default", + prompt_template_id, + tenant_id, + user_id, + ) + return merge_prompt_generate_templates(system_template, fallback_system_template) + + custom_language_template = ( + prompt_template.get("template_content_en") + if language == LANGUAGE["EN"] + else prompt_template.get("template_content_zh") + ) + return merge_prompt_generate_templates( + custom_language_template, + prompt_template.get("template_content_zh"), + system_template, + fallback_system_template, + ) + + +def get_prompt_template_summary( + template_id: Optional[int], + tenant_id: str, + user_id: str, +) -> tuple[Optional[int], Optional[str]]: + """Resolve prompt template identity for saving on agent.""" + if template_id is None: + return None, None + + if template_id == SYSTEM_PROMPT_TEMPLATE_ID: + return SYSTEM_PROMPT_TEMPLATE_ID, SYSTEM_PROMPT_TEMPLATE_NAME + + prompt_template = get_prompt_template_by_id( + template_id=template_id, + tenant_id=tenant_id, + user_id=user_id, + template_type=PROMPT_TEMPLATE_TYPE_AGENT_GENERATE, + ) + if not prompt_template: + raise NotFoundException("Prompt template not found") + + return prompt_template["template_id"], prompt_template["template_name"] diff --git a/backend/utils/llm_utils.py b/backend/utils/llm_utils.py index e99b9f384..3add98952 100644 --- a/backend/utils/llm_utils.py +++ b/backend/utils/llm_utils.py @@ -100,9 +100,18 @@ def call_llm_for_system_prompt( reasoning_content_seen = False content_tokens_seen = 0 for chunk in current_request: - delta = chunk.choices[0].delta + choices = getattr(chunk, "choices", None) or [] + if len(choices) == 0: + logger.debug("Skipping LLM stream chunk without choices") + continue + + delta = getattr(choices[0], "delta", None) + if delta is None: + logger.debug("Skipping LLM stream chunk without delta") + continue + reasoning_content = getattr(delta, "reasoning_content", None) - new_token = delta.content + new_token = getattr(delta, "content", None) # Note: reasoning_content is separate metadata and doesn't affect content filtering # We only filter content based on tags in delta.content diff --git a/backend/utils/prompt_template_utils.py b/backend/utils/prompt_template_utils.py index cf83bfa60..b298ef777 100644 --- a/backend/utils/prompt_template_utils.py +++ b/backend/utils/prompt_template_utils.py @@ -5,9 +5,56 @@ import yaml from consts.const import LANGUAGE +from consts.prompt_template import ( + PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP, + PROMPT_GENERATE_TEMPLATE_FIELDS, +) logger = logging.getLogger("prompt_template_utils") +PROMPT_GENERATE_TEMPLATE_KEY_MAP = PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP +PROMPT_GENERATE_TEMPLATE_KEYS = PROMPT_GENERATE_TEMPLATE_FIELDS + + +def get_prompt_generate_template_keys() -> list[str]: + """Return the supported prompt generation template keys.""" + return list(PROMPT_GENERATE_TEMPLATE_FIELDS) + + +def normalize_prompt_generate_template_content( + template_content: Optional[Dict[str, Any]] +) -> Dict[str, str]: + """Normalize prompt generation template content and keep non-empty fields only.""" + normalized: Dict[str, str] = {} + if not isinstance(template_content, dict): + return normalized + + for key in PROMPT_GENERATE_TEMPLATE_FIELDS: + legacy_key = PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP[key] + value = template_content.get(key) + if value is None: + value = template_content.get(legacy_key) + if isinstance(value, str) and value.strip(): + normalized[key] = value + + return normalized + + +def merge_prompt_generate_templates( + *template_contents: Optional[Dict[str, Any]] +) -> Dict[str, str]: + """Merge multiple prompt generation templates with first-non-empty priority.""" + merged: Dict[str, str] = {} + + for template_content in template_contents: + normalized = normalize_prompt_generate_template_content(template_content) + for key in PROMPT_GENERATE_TEMPLATE_FIELDS: + value = normalized.get(key) + if value and key not in merged: + merged[key] = value + + return merged + def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kwargs) -> Dict[str, Any]: """ diff --git a/docker/init.sql b/docker/init.sql index adfe65019..246ddf79a 100644 --- a/docker/init.sql +++ b/docker/init.sql @@ -316,6 +316,8 @@ CREATE TABLE IF NOT EXISTS nexent.ag_tenant_agent_t ( model_id INTEGER, business_logic_model_name VARCHAR(100), business_logic_model_id INTEGER, + prompt_template_id INTEGER, + prompt_template_name VARCHAR(100), max_steps INTEGER, duty_prompt TEXT, constraint_prompt TEXT, @@ -366,6 +368,8 @@ COMMENT ON COLUMN nexent.ag_tenant_agent_t.model_name IS '[DEPRECATED] Name of t COMMENT ON COLUMN nexent.ag_tenant_agent_t.model_id IS 'Model ID, foreign key reference to model_record_t.model_id'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.business_logic_model_name IS 'Model name used for business logic prompt generation'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.business_logic_model_id IS 'Model ID used for business logic prompt generation, foreign key reference to model_record_t.model_id'; +COMMENT ON COLUMN nexent.ag_tenant_agent_t.prompt_template_id IS 'Prompt template ID used for business logic prompt generation'; +COMMENT ON COLUMN nexent.ag_tenant_agent_t.prompt_template_name IS 'Prompt template name used for business logic prompt generation'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.max_steps IS 'Maximum number of steps'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.duty_prompt IS 'Duty prompt'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.constraint_prompt IS 'Constraint prompt'; @@ -389,8 +393,98 @@ COMMENT ON COLUMN nexent.ag_tenant_agent_t.enable_context_manager IS 'Whether to -- Create index for is_new queries CREATE INDEX IF NOT EXISTS idx_ag_tenant_agent_t_is_new ON nexent.ag_tenant_agent_t (tenant_id, is_new) + +CREATE TABLE IF NOT EXISTS nexent.ag_prompt_template_t ( + template_id SERIAL PRIMARY KEY, + template_name VARCHAR(100) NOT NULL, + description VARCHAR(500), + template_type VARCHAR(50) NOT NULL DEFAULT 'agent_generate', + tenant_id VARCHAR(100) NOT NULL, + user_id VARCHAR(100) NOT NULL, + template_content_zh JSONB NOT NULL, + template_content_en JSONB, + create_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, + created_by VARCHAR(100), + updated_by VARCHAR(100), + delete_flag VARCHAR(1) DEFAULT 'N' +); + +ALTER TABLE nexent.ag_prompt_template_t OWNER TO "root"; + +CREATE OR REPLACE FUNCTION update_ag_prompt_template_update_time() +RETURNS TRIGGER AS $$ +BEGIN + NEW.update_time = CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER update_ag_prompt_template_update_time_trigger +BEFORE UPDATE ON nexent.ag_prompt_template_t +FOR EACH ROW +EXECUTE FUNCTION update_ag_prompt_template_update_time(); + +COMMENT ON TABLE nexent.ag_prompt_template_t IS 'Prompt template table for user-defined business logic generation prompts'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_id IS 'Prompt template ID'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_name IS 'Prompt template name'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.description IS 'Prompt template description'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_type IS 'Prompt template type'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.tenant_id IS 'Tenant ID'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.user_id IS 'User ID'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_content_zh IS 'Chinese prompt template content'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_content_en IS 'English prompt template content'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.create_time IS 'Creation time'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.update_time IS 'Update time'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.created_by IS 'Creator'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.updated_by IS 'Updater'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.delete_flag IS 'Whether it is deleted. Optional values: Y/N'; + +CREATE UNIQUE INDEX IF NOT EXISTS uq_prompt_template_user_name_active +ON nexent.ag_prompt_template_t (tenant_id, user_id, template_name) WHERE delete_flag = 'N'; +CREATE INDEX IF NOT EXISTS idx_ag_prompt_template_t_user +ON nexent.ag_prompt_template_t (tenant_id, user_id, template_type) +WHERE delete_flag = 'N'; + +INSERT INTO nexent.ag_prompt_template_t ( + template_id, + template_name, + description, + template_type, + tenant_id, + user_id, + template_content_zh, + template_content_en, + created_by, + updated_by, + delete_flag +) +VALUES ( + 0, + 'system_default', + 'System default prompt template', + 'agent_generate', + 'tenant_id', + 'user_id', + '{}'::jsonb, + '{}'::jsonb, + 'user_id', + 'user_id', + 'N' +) +ON CONFLICT (template_id) DO UPDATE SET + template_name = EXCLUDED.template_name, + description = EXCLUDED.description, + template_type = EXCLUDED.template_type, + tenant_id = EXCLUDED.tenant_id, + user_id = EXCLUDED.user_id, + template_content_zh = EXCLUDED.template_content_zh, + template_content_en = EXCLUDED.template_content_en, + updated_by = EXCLUDED.updated_by, + delete_flag = 'N'; + -- Create the ag_tool_instance_t table in the nexent schema CREATE TABLE IF NOT EXISTS nexent.ag_tool_instance_t ( diff --git a/docker/sql/v2.1.0_0503_add_prompt_template_t.sql b/docker/sql/v2.1.0_0503_add_prompt_template_t.sql new file mode 100644 index 000000000..3db9a9701 --- /dev/null +++ b/docker/sql/v2.1.0_0503_add_prompt_template_t.sql @@ -0,0 +1,115 @@ +-- Migration: Add prompt template table and agent prompt template fields +-- Date: 2026-05-03 +-- Description: Add user-scoped prompt template storage and bind selected prompt template to agents + +ALTER TABLE nexent.ag_tenant_agent_t +ADD COLUMN IF NOT EXISTS prompt_template_id INTEGER; + +ALTER TABLE nexent.ag_tenant_agent_t +ADD COLUMN IF NOT EXISTS prompt_template_name VARCHAR(100); + +COMMENT ON COLUMN nexent.ag_tenant_agent_t.prompt_template_id IS 'Prompt template ID used for business logic prompt generation'; +COMMENT ON COLUMN nexent.ag_tenant_agent_t.prompt_template_name IS 'Prompt template name used for business logic prompt generation'; + +UPDATE nexent.ag_tenant_agent_t +SET prompt_template_id = 0, + prompt_template_name = 'system_default' +WHERE delete_flag = 'N' + AND (prompt_template_id IS NULL OR prompt_template_name IS NULL); + +CREATE TABLE IF NOT EXISTS nexent.ag_prompt_template_t ( + template_id SERIAL PRIMARY KEY, + template_name VARCHAR(100) NOT NULL, + description VARCHAR(500), + template_type VARCHAR(50) NOT NULL DEFAULT 'agent_generate', + tenant_id VARCHAR(100) NOT NULL, + user_id VARCHAR(100) NOT NULL, + template_content_zh JSONB NOT NULL, + template_content_en JSONB, + create_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, + created_by VARCHAR(100), + updated_by VARCHAR(100), + delete_flag VARCHAR(1) DEFAULT 'N' +); + +ALTER TABLE nexent.ag_prompt_template_t OWNER TO "root"; + +CREATE OR REPLACE FUNCTION update_ag_prompt_template_update_time() +RETURNS TRIGGER AS $$ +BEGIN + NEW.update_time = CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +DROP TRIGGER IF EXISTS update_ag_prompt_template_update_time_trigger ON nexent.ag_prompt_template_t; + +CREATE TRIGGER update_ag_prompt_template_update_time_trigger +BEFORE UPDATE ON nexent.ag_prompt_template_t +FOR EACH ROW +EXECUTE FUNCTION update_ag_prompt_template_update_time(); + +ALTER TABLE nexent.ag_prompt_template_t +DROP CONSTRAINT IF EXISTS uq_prompt_template_user_name; + +COMMENT ON TABLE nexent.ag_prompt_template_t IS 'Prompt template table for user-defined business logic generation prompts'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_id IS 'Prompt template ID'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_name IS 'Prompt template name'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.description IS 'Prompt template description'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_type IS 'Prompt template type'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.tenant_id IS 'Tenant ID'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.user_id IS 'User ID'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_content_zh IS 'Chinese prompt template content'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.template_content_en IS 'English prompt template content'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.create_time IS 'Creation time'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.update_time IS 'Update time'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.created_by IS 'Creator'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.updated_by IS 'Updater'; +COMMENT ON COLUMN nexent.ag_prompt_template_t.delete_flag IS 'Whether it is deleted. Optional values: Y/N'; + +DROP INDEX IF EXISTS nexent.uq_prompt_template_user_name_active; +CREATE UNIQUE INDEX IF NOT EXISTS uq_prompt_template_user_name_active +ON nexent.ag_prompt_template_t (tenant_id, user_id, template_name) +WHERE delete_flag = 'N'; + +CREATE INDEX IF NOT EXISTS idx_ag_prompt_template_t_user +ON nexent.ag_prompt_template_t (tenant_id, user_id, template_type) +WHERE delete_flag = 'N'; + +INSERT INTO nexent.ag_prompt_template_t ( + template_id, + template_name, + description, + template_type, + tenant_id, + user_id, + template_content_zh, + template_content_en, + created_by, + updated_by, + delete_flag +) +VALUES ( + 0, + 'system_default', + 'System default prompt template', + 'agent_generate', + 'tenant_id', + 'user_id', + '{}'::jsonb, + '{}'::jsonb, + 'user_id', + 'user_id', + 'N' +) +ON CONFLICT (template_id) DO UPDATE SET + template_name = EXCLUDED.template_name, + description = EXCLUDED.description, + template_type = EXCLUDED.template_type, + tenant_id = EXCLUDED.tenant_id, + user_id = EXCLUDED.user_id, + template_content_zh = EXCLUDED.template_content_zh, + template_content_en = EXCLUDED.template_content_en, + updated_by = EXCLUDED.updated_by, + delete_flag = 'N'; diff --git a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx index 1dd8422fa..7c6a166cb 100644 --- a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx +++ b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx @@ -17,10 +17,14 @@ import { App, } from "antd"; import type { TabsProps } from "antd"; -import { Zap, Maximize2 } from "lucide-react"; +import { Zap, Maximize2, Settings2 } from "lucide-react"; import log from "@/lib/logger"; -import { AgentProfileInfo, AgentBusinessInfo } from "@/types/agentConfig"; +import { + AgentProfileInfo, + AgentBusinessInfo, + PromptTemplate, +} from "@/types/agentConfig"; import { getAgentGenerationCache, setAgentGenerationStatus, @@ -39,10 +43,12 @@ import { useModelList } from "@/hooks/model/useModelList"; import { useConfig } from "@/hooks/useConfig"; import { useTenantList } from "@/hooks/tenant/useTenantList"; import { useGroupList } from "@/hooks/group/useGroupList"; +import { usePromptTemplateList } from "@/hooks/agent/usePromptTemplateList"; import { USER_ROLES } from "@/const/auth"; import { Can } from "@/components/permission/Can"; import { useAgentConfigStore } from "@/stores/agentConfigStore"; import ExpandEditModal from "./ExpandEditModal"; +import PromptTemplateManagerModal from "./PromptTemplateManagerModal"; const { TextArea } = Input; @@ -74,6 +80,11 @@ export default function AgentGenerateDetail({ // Model data: default LLM name from config, resolve to full model from model list const { defaultLlmModelName } = useConfig(); const { availableLlmModels, models, isLoading: loadingModels } = useModelList(); + const { + templates: promptTemplates, + isLoading: loadingPromptTemplates, + invalidate: invalidatePromptTemplates, + } = usePromptTemplateList(); const defaultLlmModel = useMemo(() => { if (defaultLlmModelName) { const found = availableLlmModels.find( @@ -115,6 +126,7 @@ export default function AgentGenerateDetail({ // Modal states const [expandModalOpen, setExpandModalOpen] = useState(false); const [expandModalType, setExpandModalType] = useState<'duty' | 'constraint' | 'few-shots' | null>(null); + const [promptTemplateManagerOpen, setPromptTemplateManagerOpen] = useState(false); // Use ref to track generation initiator - this doesn't trigger re-renders // but is accessible in closures @@ -133,14 +145,24 @@ export default function AgentGenerateDetail({ useEffect(() => { if (editedAgent.business_description !== businessInfo.businessDescription || editedAgent.business_logic_model_name !== businessInfo.businessLogicModelName || - editedAgent.business_logic_model_id !== businessInfo.businessLogicModelId) { + editedAgent.business_logic_model_id !== businessInfo.businessLogicModelId || + (editedAgent.prompt_template_id ?? 0) !== businessInfo.promptTemplateId || + (editedAgent.prompt_template_name || "system_default") !== businessInfo.promptTemplateName) { setBusinessInfo({ businessDescription: editedAgent.business_description || "", businessLogicModelName: editedAgent.business_logic_model_name || "", businessLogicModelId: editedAgent.business_logic_model_id || 0, + promptTemplateId: editedAgent.prompt_template_id ?? 0, + promptTemplateName: editedAgent.prompt_template_name || "system_default", }); } - }, [editedAgent.business_description, editedAgent.business_logic_model_name, editedAgent.business_logic_model_id]); + }, [ + editedAgent.business_description, + editedAgent.business_logic_model_name, + editedAgent.business_logic_model_id, + editedAgent.prompt_template_id, + editedAgent.prompt_template_name, + ]); // Only show "no edit permission" tooltip when the panel is active and agent is read-only. // Note: when no agent is selected, AgentInfoComp shows an overlay and we should not show @@ -194,6 +216,8 @@ export default function AgentGenerateDetail({ businessDescription: "", businessLogicModelName: "", businessLogicModelId: 0, + promptTemplateId: 0, + promptTemplateName: "system_default", }); const normalizeNumberArray = (value: unknown): number[] => { @@ -286,6 +310,8 @@ export default function AgentGenerateDetail({ "", businessLogicModelId: editedAgent.business_logic_model_id || defaultLlmModel?.id || 0, + promptTemplateId: editedAgent.prompt_template_id ?? 0, + promptTemplateName: editedAgent.prompt_template_name || "system_default", }; // Initialize local business description state setBusinessInfo(initialBusinessInfo); @@ -400,6 +426,8 @@ export default function AgentGenerateDetail({ business_description: value, business_logic_model_id: businessInfo.businessLogicModelId, business_logic_model_name: businessInfo.businessLogicModelName, + prompt_template_id: businessInfo.promptTemplateId, + prompt_template_name: businessInfo.promptTemplateName, }); }; @@ -418,6 +446,34 @@ export default function AgentGenerateDetail({ business_description: businessInfo.businessDescription || "", business_logic_model_id: selectedModel?.id || 0, business_logic_model_name: modelName, + prompt_template_id: businessInfo.promptTemplateId, + prompt_template_name: businessInfo.promptTemplateName, + }); + }; + + const handlePromptTemplateChange = (templateId: number) => { + const selectedTemplate = promptTemplates.find( + (template) => template.template_id === templateId + ); + if (!selectedTemplate) { + return; + } + handleSelectPromptTemplate(selectedTemplate); + }; + + const handleSelectPromptTemplate = (template: PromptTemplate) => { + setBusinessInfo((prev) => ({ + ...prev, + promptTemplateId: template.template_id, + promptTemplateName: template.template_name, + })); + + updateBusinessInfo({ + business_description: businessInfo.businessDescription || "", + business_logic_model_id: businessInfo.businessLogicModelId, + business_logic_model_name: businessInfo.businessLogicModelName, + prompt_template_id: template.template_id, + prompt_template_name: template.template_name, }); }; @@ -610,6 +666,7 @@ export default function AgentGenerateDetail({ agent_id: effectiveAgentId, task_description: businessInfo.businessDescription, model_id: businessInfo.businessLogicModelId.toString(), + prompt_template_id: businessInfo.promptTemplateId, sub_agent_ids: editedAgent.sub_agent_id_list, tool_ids: Array.isArray(editedAgent.tools) ? editedAgent.tools.map((tool: any) => @@ -808,6 +865,47 @@ export default function AgentGenerateDetail({ disabled: model.connect_status !== "available", })); + const promptTemplateSelectOptions = useMemo(() => { + const options = promptTemplates.map((template) => ({ + value: template.template_id, + label: template.is_system_default + ? t("businessLogic.config.template.systemDefault") + : template.template_name, + })); + + if ( + businessInfo.promptTemplateId && + !options.some((option) => option.value === businessInfo.promptTemplateId) + ) { + options.unshift({ + value: businessInfo.promptTemplateId, + label: businessInfo.promptTemplateName || t("businessLogic.config.template.label"), + }); + } + + return options; + }, [ + businessInfo.promptTemplateId, + businessInfo.promptTemplateName, + promptTemplates, + t, + ]); + + const generationControlSelectStyle = { + width: "min(300px, 100%)", + minWidth: "220px", + maxWidth: "300px", + overflow: "hidden", + textOverflow: "ellipsis", + whiteSpace: "nowrap", + }; + + const generationControlLabelStyle = { + width: 84, + minWidth: 84, + flexShrink: 0, + }; + // Tab items configuration const tabItems = [ { @@ -1165,46 +1263,77 @@ export default function AgentGenerateDetail({ )} {/* Control area */} - -
- - {t("model.type.llm")}: - - } + disabled={!editable || isGenerating} + style={generationControlSelectStyle} + /> +
+
+ {wrapNoEditTooltipInline( + + )} +
+
+ + +
+ - - {isGenerating - ? t("businessLogic.config.button.generating") - : t("businessLogic.config.button.generatePrompt")} - - - )} -
+ {t("model.type.llm")}: + + option.value === selectedTemplateId)?.label + } + disabled + style={{ flex: 1, minWidth: 220 }} + /> +
+ + { + const isSelected = selectedTemplateId === template.template_id; + const isSystemDefault = template.is_system_default; + return ( + + + + + + + {isSystemDefault + ? t("businessLogic.config.template.systemDefault") + : template.template_name} + + {isSystemDefault ? ( + + {t("businessLogic.config.template.system")} + + ) : null} + {isSelected ? ( + + {t("businessLogic.config.template.current")} + + ) : null} + + + {template.description || t("businessLogic.config.template.noDescription")} + + + + + + + + + + + + + ); + }} + /> + + + + + + + + {t("businessLogic.config.template.manageDescription")} + + + +
+ + + + + + + + + + +
+
+ + ); +} diff --git a/frontend/app/[locale]/agents/components/agentManage/AgentList.tsx b/frontend/app/[locale]/agents/components/agentManage/AgentList.tsx index edfeff559..a1d809b50 100644 --- a/frontend/app/[locale]/agents/components/agentManage/AgentList.tsx +++ b/frontend/app/[locale]/agents/components/agentManage/AgentList.tsx @@ -259,6 +259,8 @@ export default function AgentList({ few_shots_prompt: detail.few_shots_prompt, business_logic_model_name: detail.business_logic_model_name ?? undefined, business_logic_model_id: detail.business_logic_model_id ?? undefined, + prompt_template_id: detail.prompt_template_id ?? 0, + prompt_template_name: detail.prompt_template_name ?? "system_default", enabled_tool_ids: enabledToolIds, related_agent_ids: subAgentIds, }); diff --git a/frontend/const/promptTemplate.ts b/frontend/const/promptTemplate.ts new file mode 100644 index 000000000..aada2371e --- /dev/null +++ b/frontend/const/promptTemplate.ts @@ -0,0 +1,82 @@ +export const PROMPT_TEMPLATE_FIELD_CONFIG = [ + { + key: "duty_system_prompt", + labelKey: "systemPrompt.card.duty.title", + section: "basic", + }, + { + key: "constraint_system_prompt", + labelKey: "systemPrompt.card.constraint.title", + section: "basic", + }, + { + key: "few_shots_system_prompt", + labelKey: "systemPrompt.card.fewShots.title", + section: "basic", + }, + { + key: "user_prompt", + labelKey: "businessLogic.config.template.field.userPrompt", + section: "basic", + }, + { + key: "agent_variable_name_system_prompt", + labelKey: "businessLogic.config.template.field.agentVariableName", + section: "advanced", + }, + { + key: "agent_display_name_system_prompt", + labelKey: "businessLogic.config.template.field.agentDisplayName", + section: "advanced", + }, + { + key: "agent_description_system_prompt", + labelKey: "businessLogic.config.template.field.agentDescription", + section: "advanced", + }, + { + key: "agent_name_regenerate_system_prompt", + labelKey: "businessLogic.config.template.field.agentNameRegenerateSystem", + section: "advanced", + }, + { + key: "agent_name_regenerate_user_prompt", + labelKey: "businessLogic.config.template.field.agentNameRegenerateUser", + section: "advanced", + }, + { + key: "agent_display_name_regenerate_system_prompt", + labelKey: "businessLogic.config.template.field.agentDisplayNameRegenerateSystem", + section: "advanced", + }, + { + key: "agent_display_name_regenerate_user_prompt", + labelKey: "businessLogic.config.template.field.agentDisplayNameRegenerateUser", + section: "advanced", + }, +] as const; + +export type PromptTemplateFieldConfig = (typeof PROMPT_TEMPLATE_FIELD_CONFIG)[number]; +export type PromptTemplateFieldKey = PromptTemplateFieldConfig["key"]; + +export const PROMPT_TEMPLATE_FIELD_KEYS = PROMPT_TEMPLATE_FIELD_CONFIG.map( + (field) => field.key +) as PromptTemplateFieldKey[]; + +export const BASIC_PROMPT_TEMPLATE_FIELDS = PROMPT_TEMPLATE_FIELD_CONFIG.filter( + (field) => field.section === "basic" +); + +export const ADVANCED_PROMPT_TEMPLATE_FIELDS = PROMPT_TEMPLATE_FIELD_CONFIG.filter( + (field) => field.section === "advanced" +); + +export function createEmptyPromptTemplateContent(): Record { + return PROMPT_TEMPLATE_FIELD_KEYS.reduce( + (content, key) => { + content[key] = ""; + return content; + }, + {} as Record + ); +} diff --git a/frontend/hooks/agent/usePromptTemplateList.ts b/frontend/hooks/agent/usePromptTemplateList.ts new file mode 100644 index 000000000..592776b7c --- /dev/null +++ b/frontend/hooks/agent/usePromptTemplateList.ts @@ -0,0 +1,22 @@ +import { useQuery, useQueryClient } from "@tanstack/react-query"; + +import { promptTemplateService } from "@/services/promptTemplateService"; +import { PromptTemplate } from "@/types/agentConfig"; + +export function usePromptTemplateList() { + const queryClient = useQueryClient(); + + const query = useQuery({ + queryKey: ["promptTemplates"], + queryFn: async (): Promise => { + return promptTemplateService.list(); + }, + staleTime: 60_000, + }); + + return { + ...query, + templates: query.data ?? [], + invalidate: () => queryClient.invalidateQueries({ queryKey: ["promptTemplates"] }), + }; +} diff --git a/frontend/hooks/agent/useSaveGuard.ts b/frontend/hooks/agent/useSaveGuard.ts index 131e1aa59..38e56b3df 100644 --- a/frontend/hooks/agent/useSaveGuard.ts +++ b/frontend/hooks/agent/useSaveGuard.ts @@ -138,6 +138,8 @@ export const useSaveGuard = () => { few_shots_prompt: currentEditedAgent.few_shots_prompt, business_logic_model_name: currentEditedAgent.business_logic_model_name ?? undefined, business_logic_model_id: currentEditedAgent.business_logic_model_id ?? undefined, + prompt_template_id: currentEditedAgent.prompt_template_id ?? 0, + prompt_template_name: currentEditedAgent.prompt_template_name ?? "system_default", enabled_tool_ids: enabledToolIds, enabled_skill_ids: enabledSkillIds, related_agent_ids: relatedAgentIds, diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 22c17c2ca..e3bbc1b93 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -937,6 +937,43 @@ "businessLogic.config.message.agentDeleteSuccess": "Agent delete success", "businessLogic.config.message.agentDeleteFailed": "Agent delete failed", "businessLogic.config.message.agentSaveSuccess": "Agent save success", + "businessLogic.config.template.label": "Prompt Template", + "businessLogic.config.template.manage": "Manage Templates", + "businessLogic.config.template.manageDescription": "Choose a prompt template for generation, or create your own private templates.", + "businessLogic.config.template.create": "New Template", + "businessLogic.config.template.use": "Use", + "businessLogic.config.template.current": "Current", + "businessLogic.config.template.system": "System", + "businessLogic.config.template.systemDefault": "System Default", + "businessLogic.config.template.empty": "No prompt templates", + "businessLogic.config.template.noDescription": "No description", + "businessLogic.config.template.name": "Template Name", + "businessLogic.config.template.nameRequired": "Please enter a template name", + "businessLogic.config.template.description": "Description", + "businessLogic.config.template.language.zh": "Chinese Template", + "businessLogic.config.template.language.en": "English Template", + "businessLogic.config.template.contentRequired": "This field is required", + "businessLogic.config.template.basicSection": "Basic Configuration", + "businessLogic.config.template.basicDescription": "Configure the core prompts users most often care about. The remaining prompt segments can be adjusted in Advanced Configuration.", + "businessLogic.config.template.englishOptionalDescription": "English content is optional. Leave it blank to fall back to the Chinese template during generation.", + "businessLogic.config.template.advancedSection": "Advanced Configuration", + "businessLogic.config.template.advancedDescription": "These fields are still stored with the template and are suitable for fine-grained control of naming and regeneration behavior.", + "businessLogic.config.template.createTitle": "Create Prompt Template", + "businessLogic.config.template.editTitle": "Edit Prompt Template", + "businessLogic.config.template.saveSuccess": "Prompt template saved successfully", + "businessLogic.config.template.saveError": "Failed to save prompt template", + "businessLogic.config.template.deleteSuccess": "Prompt template deleted successfully", + "businessLogic.config.template.deleteError": "Failed to delete prompt template", + "businessLogic.config.template.deleteConfirm": "Are you sure you want to delete prompt template {{name}}?", + "businessLogic.config.template.loadError": "Failed to load prompt template", + "businessLogic.config.template.field.agentVariableName": "Agent Variable Name Prompt", + "businessLogic.config.template.field.agentDisplayName": "Agent Display Name Prompt", + "businessLogic.config.template.field.agentDescription": "Agent Description Prompt", + "businessLogic.config.template.field.userPrompt": "User Prompt", + "businessLogic.config.template.field.agentNameRegenerateSystem": "Agent Name Regenerate System Prompt", + "businessLogic.config.template.field.agentNameRegenerateUser": "Agent Name Regenerate User Prompt", + "businessLogic.config.template.field.agentDisplayNameRegenerateSystem": "Agent Display Name Regenerate System Prompt", + "businessLogic.config.template.field.agentDisplayNameRegenerateUser": "Agent Display Name Regenerate User Prompt", "businessLogic.config.import.duplicateTitle": "Duplicate Agent detected", "businessLogic.config.import.duplicateDescription": "The imported Agent name or display name conflicts with an existing Agent. You can choose to import directly or call the LLM to regenerate a unique name before importing.", "businessLogic.config.import.duplicateConfirm": "Regenerate and import", @@ -1858,6 +1895,7 @@ "common.loading": "Loading", "common.save": "Save", "common.cancel": "Cancel", + "common.close": "Close", "common.confirm": "Confirm", "common.skip": "Skip", "common.saving": "Saving...", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 1cc83a802..798c0285e 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -938,6 +938,43 @@ "businessLogic.config.message.agentDeleteSuccess": "智能体删除成功", "businessLogic.config.message.agentDeleteFailed": "智能体删除失败", "businessLogic.config.message.agentSaveSuccess": "智能体保存成功", + "businessLogic.config.template.label": "提示词模板", + "businessLogic.config.template.manage": "管理模板", + "businessLogic.config.template.manageDescription": "选择用于生成的提示词模板,或创建仅自己可见的私有模板。", + "businessLogic.config.template.create": "新建模板", + "businessLogic.config.template.use": "使用", + "businessLogic.config.template.current": "当前使用", + "businessLogic.config.template.system": "系统", + "businessLogic.config.template.systemDefault": "系统默认", + "businessLogic.config.template.empty": "暂无提示词模板", + "businessLogic.config.template.noDescription": "暂无描述", + "businessLogic.config.template.name": "模板名称", + "businessLogic.config.template.nameRequired": "请输入模板名称", + "businessLogic.config.template.description": "模板描述", + "businessLogic.config.template.language.zh": "中文模板", + "businessLogic.config.template.language.en": "英文模板", + "businessLogic.config.template.contentRequired": "该字段不能为空", + "businessLogic.config.template.basicSection": "基础配置", + "businessLogic.config.template.basicDescription": "默认展示用户最常调整的核心提示词,其余提示词片段可在高级配置中继续编辑。", + "businessLogic.config.template.englishOptionalDescription": "英文内容为选填,留空时生成阶段会回退使用中文模板。", + "businessLogic.config.template.advancedSection": "高级配置", + "businessLogic.config.template.advancedDescription": "这些字段也会随模板一并入库,适合精细控制名称生成和重生成行为。", + "businessLogic.config.template.createTitle": "新建提示词模板", + "businessLogic.config.template.editTitle": "编辑提示词模板", + "businessLogic.config.template.saveSuccess": "提示词模板保存成功", + "businessLogic.config.template.saveError": "提示词模板保存失败", + "businessLogic.config.template.deleteSuccess": "提示词模板删除成功", + "businessLogic.config.template.deleteError": "提示词模板删除失败", + "businessLogic.config.template.deleteConfirm": "确定要删除提示词模板 {{name}} 吗?", + "businessLogic.config.template.loadError": "加载提示词模板失败", + "businessLogic.config.template.field.agentVariableName": "智能体变量名提示词", + "businessLogic.config.template.field.agentDisplayName": "智能体展示名提示词", + "businessLogic.config.template.field.agentDescription": "智能体描述提示词", + "businessLogic.config.template.field.userPrompt": "用户提示词", + "businessLogic.config.template.field.agentNameRegenerateSystem": "变量名重生成系统提示词", + "businessLogic.config.template.field.agentNameRegenerateUser": "变量名重生成用户提示词", + "businessLogic.config.template.field.agentDisplayNameRegenerateSystem": "展示名重生成系统提示词", + "businessLogic.config.template.field.agentDisplayNameRegenerateUser": "展示名重生成用户提示词", "businessLogic.config.import.duplicateTitle": "检测到重名智能体", "businessLogic.config.import.duplicateDescription": "导入的智能体名称或展示名称与已有智能体重复。您可以选择直接导入或调用 LLM 重新生成唯一名称后导入。", "businessLogic.config.import.duplicateConfirm": "重新生成并导入", @@ -1915,6 +1952,7 @@ "common.loading": "加载中", "common.save": "保存", "common.cancel": "取消", + "common.close": "关闭", "common.confirm": "确定", "common.skip": "跳过", "common.saving": "保存中...", diff --git a/frontend/services/agentConfigService.ts b/frontend/services/agentConfigService.ts index 37f621e95..aceabcd1e 100644 --- a/frontend/services/agentConfigService.ts +++ b/frontend/services/agentConfigService.ts @@ -401,6 +401,8 @@ export interface UpdateAgentInfoPayload { business_description?: string; business_logic_model_name?: string; business_logic_model_id?: number; + prompt_template_id?: number; + prompt_template_name?: string; enabled_tool_ids?: number[]; enabled_skill_ids?: number[]; related_agent_ids?: number[]; @@ -697,6 +699,8 @@ export const searchAgentInfo = async (agentId: number, tenantId?: string, versio business_description: data.business_description, business_logic_model_name: data.business_logic_model_name, business_logic_model_id: data.business_logic_model_id, + prompt_template_id: data.prompt_template_id ?? 0, + prompt_template_name: data.prompt_template_name ?? "system_default", provide_run_summary: data.provide_run_summary, enabled: data.enabled, is_available: data.is_available, diff --git a/frontend/services/api.ts b/frontend/services/api.ts index 34d359d0c..074f0f69b 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -86,6 +86,13 @@ export const API_ENDPOINTS = { prompt: { generate: `${API_BASE_URL}/prompt/generate`, }, + promptTemplates: { + list: `${API_BASE_URL}/prompt_templates`, + detail: (templateId: number) => `${API_BASE_URL}/prompt_templates/${templateId}`, + create: `${API_BASE_URL}/prompt_templates`, + update: (templateId: number) => `${API_BASE_URL}/prompt_templates/${templateId}`, + delete: (templateId: number) => `${API_BASE_URL}/prompt_templates/${templateId}`, + }, stt: { ws: `/api/voice/stt/ws`, }, diff --git a/frontend/services/promptTemplateService.ts b/frontend/services/promptTemplateService.ts new file mode 100644 index 000000000..c88275ae1 --- /dev/null +++ b/frontend/services/promptTemplateService.ts @@ -0,0 +1,90 @@ +import { API_ENDPOINTS, fetchWithErrorHandling } from "./api"; + +import { getAuthHeaders } from "@/lib/auth"; +import log from "@/lib/logger"; +import { + PromptTemplate, + PromptTemplatePayload, +} from "@/types/agentConfig"; + +async function requestJson(url: string, options: RequestInit = {}): Promise { + const response = await fetchWithErrorHandling(url, { + ...options, + headers: { + ...getAuthHeaders(), + ...(options.headers || {}), + }, + }); + return response.json(); +} + +export const promptTemplateService = { + async list(): Promise { + try { + const data = await requestJson(API_ENDPOINTS.promptTemplates.list, { + method: "GET", + }); + return data || []; + } catch (error) { + log.error("Failed to list prompt templates:", error); + return []; + } + }, + + async detail(templateId: number): Promise { + try { + const data = await requestJson( + API_ENDPOINTS.promptTemplates.detail(templateId), + { method: "GET" } + ); + return data; + } catch (error) { + log.error("Failed to get prompt template detail:", error); + return null; + } + }, + + async create(payload: PromptTemplatePayload): Promise { + try { + const data = await requestJson( + API_ENDPOINTS.promptTemplates.create, + { + method: "POST", + body: JSON.stringify(payload), + } + ); + return data; + } catch (error) { + log.error("Failed to create prompt template:", error); + throw error; + } + }, + + async update(templateId: number, payload: PromptTemplatePayload): Promise { + try { + const data = await requestJson( + API_ENDPOINTS.promptTemplates.update(templateId), + { + method: "PUT", + body: JSON.stringify(payload), + } + ); + return data; + } catch (error) { + log.error("Failed to update prompt template:", error); + throw error; + } + }, + + async remove(templateId: number): Promise { + try { + await requestJson(API_ENDPOINTS.promptTemplates.delete(templateId), { + method: "DELETE", + }); + return true; + } catch (error) { + log.error("Failed to delete prompt template:", error); + throw error; + } + }, +}; diff --git a/frontend/stores/agentConfigStore.ts b/frontend/stores/agentConfigStore.ts index e0840acf3..83fbef586 100644 --- a/frontend/stores/agentConfigStore.ts +++ b/frontend/stores/agentConfigStore.ts @@ -35,6 +35,8 @@ export type EditableAgent = Pick< | "business_description" | "business_logic_model_name" | "business_logic_model_id" + | "prompt_template_id" + | "prompt_template_name" | "sub_agent_id_list" | "group_ids" | "ingroup_permission" @@ -159,6 +161,8 @@ const emptyEditableAgent: EditableAgent = { business_description: "", business_logic_model_name: "", business_logic_model_id: 0, + prompt_template_id: 0, + prompt_template_name: "system_default", sub_agent_id_list: [], group_ids: [], ingroup_permission: "READ_ONLY", @@ -183,6 +187,8 @@ const toEditable = (agent: Agent | null): EditableAgent => business_description: agent.business_description || "", business_logic_model_name: agent.business_logic_model_name || "", business_logic_model_id: agent.business_logic_model_id || 0, + prompt_template_id: agent.prompt_template_id ?? 0, + prompt_template_name: agent.prompt_template_name || "system_default", sub_agent_id_list: agent.sub_agent_id_list || [], group_ids: agent.group_ids || [], ingroup_permission: agent.ingroup_permission || "READ_ONLY", @@ -200,13 +206,17 @@ const isBusinessInfoDirty = (baselineAgent: EditableAgent | null, editedAgent: E return ( editedAgent.business_description !== "" || editedAgent.business_logic_model_name !== "" || - editedAgent.business_logic_model_id !== 0 + editedAgent.business_logic_model_id !== 0 || + (editedAgent.prompt_template_id ?? 0) !== 0 || + (editedAgent.prompt_template_name || "system_default") !== "system_default" ); } return ( baselineAgent.business_description !== editedAgent.business_description || baselineAgent.business_logic_model_name !== editedAgent.business_logic_model_name || - baselineAgent.business_logic_model_id !== editedAgent.business_logic_model_id + baselineAgent.business_logic_model_id !== editedAgent.business_logic_model_id || + (baselineAgent.prompt_template_id ?? 0) !== (editedAgent.prompt_template_id ?? 0) || + (baselineAgent.prompt_template_name || "system_default") !== (editedAgent.prompt_template_name || "system_default") ); }; diff --git a/frontend/types/agentConfig.ts b/frontend/types/agentConfig.ts index e6d36daaf..69f382a1b 100644 --- a/frontend/types/agentConfig.ts +++ b/frontend/types/agentConfig.ts @@ -4,10 +4,15 @@ import type { Dispatch, SetStateAction } from "react"; import { ChatMessageType } from "./chat"; import { ModelOption } from "@/types/modelConfig"; import { GENERATE_PROMPT_STREAM_TYPES } from "../const/agentConfig"; +import type { PromptTemplateFieldKey } from "../const/promptTemplate"; export type AgentBusinessInfo = Partial>; export type AgentProfileInfo = Partial< @@ -26,6 +31,8 @@ export type AgentProfileInfo = Partial< | "few_shots_prompt" | "group_ids" | "ingroup_permission" + | "prompt_template_id" + | "prompt_template_name" > >; @@ -50,6 +57,8 @@ export interface Agent { business_description?: string; business_logic_model_name?: string; business_logic_model_id?: number; + prompt_template_id?: number; + prompt_template_name?: string; is_available?: boolean; is_new?: boolean; sub_agent_id_list?: number[]; @@ -408,6 +417,7 @@ export interface GeneratePromptParams { agent_id: number; task_description: string; model_id: string; + prompt_template_id?: number; tool_ids?: number[]; // Optional: tool IDs selected in frontend (takes precedence over database query) sub_agent_ids?: number[]; // Optional: sub-agent IDs selected in frontend (takes precedence over database query) /** @@ -427,3 +437,25 @@ export interface StreamResponseData { content: string; is_complete: boolean; } + +export type PromptTemplateContent = Record; + +export interface PromptTemplate { + template_id: number; + template_name: string; + description?: string | null; + template_type: string; + template_content_zh: PromptTemplateContent; + template_content_en?: PromptTemplateContent | null; + is_system_default?: boolean; + create_time?: string; + update_time?: string; +} + +export interface PromptTemplatePayload { + template_name: string; + description?: string; + template_type?: string; + template_content_zh: PromptTemplateContent; + template_content_en?: PromptTemplateContent | null; +} diff --git a/test/backend/app/test_prompt_template_app.py b/test/backend/app/test_prompt_template_app.py new file mode 100644 index 000000000..8cd78cf1d --- /dev/null +++ b/test/backend/app/test_prompt_template_app.py @@ -0,0 +1,397 @@ +import importlib +import os +import sys +import types +from http import HTTPStatus + +import pytest + + +BACKEND_PATH = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../../backend") +) + + +@pytest.fixture(autouse=True) +def _reset_prompt_template_app_modules(): + yield + sys.modules.pop("apps.prompt_template_app", None) + sys.modules.pop("services.prompt_template_service", None) + sys.modules.pop("utils.auth_utils", None) + + +@pytest.fixture +def prompt_template_app_module(monkeypatch): + if BACKEND_PATH not in sys.path: + sys.path.insert(0, BACKEND_PATH) + + service_module = types.ModuleType("services.prompt_template_service") + for name in [ + "create_prompt_template_impl", + "delete_prompt_template_impl", + "get_prompt_template_detail_impl", + "list_prompt_templates_impl", + "update_prompt_template_impl", + ]: + setattr(service_module, name, lambda *args, **kwargs: None) + monkeypatch.setitem(sys.modules, "services.prompt_template_service", service_module) + + auth_module = types.ModuleType("utils.auth_utils") + auth_module.get_current_user_id = lambda authorization: ("user-1", "tenant-1") + monkeypatch.setitem(sys.modules, "utils.auth_utils", auth_module) + + sys.modules.pop("apps.prompt_template_app", None) + module = importlib.import_module("apps.prompt_template_app") + return importlib.reload(module) + + +@pytest.fixture +def prompt_template_exceptions(): + if BACKEND_PATH not in sys.path: + sys.path.insert(0, BACKEND_PATH) + return importlib.import_module("consts.exceptions") + + +@pytest.fixture +def prompt_template_client(prompt_template_app_module): + from fastapi import FastAPI + from fastapi.testclient import TestClient + + app = FastAPI() + app.include_router(prompt_template_app_module.router) + return TestClient(app) + + +@pytest.fixture +def prompt_template_payload(): + return { + "template_name": "template-a", + "description": "template description", + "template_type": "agent_generate", + "template_content_zh": { + "duty_system_prompt": "zh-duty", + "constraint_system_prompt": "zh-constraint", + "few_shots_system_prompt": "zh-few-shots", + "agent_variable_name_system_prompt": "zh-agent-name", + "agent_display_name_system_prompt": "zh-display-name", + "agent_description_system_prompt": "zh-description", + "user_prompt": "zh-user", + "agent_name_regenerate_system_prompt": "zh-regen-name-system", + "agent_name_regenerate_user_prompt": "zh-regen-name-user", + "agent_display_name_regenerate_system_prompt": "zh-regen-display-system", + "agent_display_name_regenerate_user_prompt": "zh-regen-display-user", + }, + "template_content_en": { + "duty_system_prompt": "en-duty", + "constraint_system_prompt": "en-constraint", + "few_shots_system_prompt": "en-few-shots", + "agent_variable_name_system_prompt": "en-agent-name", + "agent_display_name_system_prompt": "en-display-name", + "agent_description_system_prompt": "en-description", + "user_prompt": "en-user", + "agent_name_regenerate_system_prompt": "en-regen-name-system", + "agent_name_regenerate_user_prompt": "en-regen-name-user", + "agent_display_name_regenerate_system_prompt": "en-regen-display-system", + "agent_display_name_regenerate_user_prompt": "en-regen-display-user", + }, + } + + +def test_list_prompt_templates_api_success( + mocker, prompt_template_app_module, prompt_template_client +): + auth_mock = mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + list_mock = mocker.patch.object( + prompt_template_app_module, + "list_prompt_templates_impl", + return_value=[{"template_id": 0, "template_name": "system_default"}], + ) + + response = prompt_template_client.get( + "/prompt_templates", + headers={"Authorization": "Bearer token"}, + ) + + assert response.status_code == HTTPStatus.OK + assert response.json() == [{"template_id": 0, "template_name": "system_default"}] + auth_mock.assert_called_once_with("Bearer token") + list_mock.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + + +def test_list_prompt_templates_api_returns_internal_error_on_unexpected_exception( + mocker, prompt_template_app_module, prompt_template_client +): + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + mocker.patch.object( + prompt_template_app_module, + "list_prompt_templates_impl", + side_effect=Exception("db error"), + ) + + response = prompt_template_client.get("/prompt_templates") + + assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + assert response.json()["detail"] == "Prompt template list error." + + +def test_get_prompt_template_api_success( + mocker, prompt_template_app_module, prompt_template_client +): + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + detail_mock = mocker.patch.object( + prompt_template_app_module, + "get_prompt_template_detail_impl", + return_value={"template_id": 1, "template_name": "template-a"}, + ) + + response = prompt_template_client.get("/prompt_templates/1") + + assert response.status_code == HTTPStatus.OK + assert response.json() == {"template_id": 1, "template_name": "template-a"} + detail_mock.assert_called_once_with(template_id=1, tenant_id="tenant-1", user_id="user-1") + + +@pytest.mark.parametrize( + ("side_effect", "expected_status", "expected_detail"), + [ + pytest.param("not_found", HTTPStatus.NOT_FOUND, "Prompt template not found", id="not-found"), + (Exception("unexpected"), HTTPStatus.INTERNAL_SERVER_ERROR, "Prompt template detail error."), + ], +) +def test_get_prompt_template_api_error_mapping( + mocker, + prompt_template_app_module, + prompt_template_client, + prompt_template_exceptions, + side_effect, + expected_status, + expected_detail, +): + if side_effect == "not_found": + side_effect = prompt_template_exceptions.NotFoundException( + "Prompt template not found" + ) + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + mocker.patch.object( + prompt_template_app_module, + "get_prompt_template_detail_impl", + side_effect=side_effect, + ) + + response = prompt_template_client.get("/prompt_templates/3") + + assert response.status_code == expected_status + assert response.json()["detail"] == expected_detail + + +def test_create_prompt_template_api_success( + mocker, prompt_template_app_module, prompt_template_client, prompt_template_payload +): + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + create_mock = mocker.patch.object( + prompt_template_app_module, + "create_prompt_template_impl", + return_value={"template_id": 9, "template_name": "template-a"}, + ) + + response = prompt_template_client.post("/prompt_templates", json=prompt_template_payload) + + assert response.status_code == HTTPStatus.OK + assert response.json() == {"template_id": 9, "template_name": "template-a"} + assert create_mock.call_args.kwargs["tenant_id"] == "tenant-1" + assert create_mock.call_args.kwargs["user_id"] == "user-1" + + +@pytest.mark.parametrize( + ("side_effect", "expected_status", "expected_detail"), + [ + pytest.param("duplicate", HTTPStatus.BAD_REQUEST, "Prompt template name already exists", id="duplicate"), + pytest.param("validation", HTTPStatus.BAD_REQUEST, "template_content_zh is required", id="validation"), + (Exception("unexpected"), HTTPStatus.INTERNAL_SERVER_ERROR, "Prompt template create error."), + ], +) +def test_create_prompt_template_api_error_mapping( + mocker, + prompt_template_app_module, + prompt_template_client, + prompt_template_exceptions, + prompt_template_payload, + side_effect, + expected_status, + expected_detail, +): + if side_effect == "duplicate": + side_effect = prompt_template_exceptions.DuplicateError( + "Prompt template name already exists" + ) + elif side_effect == "validation": + side_effect = prompt_template_exceptions.ValidationError( + "template_content_zh is required" + ) + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + mocker.patch.object( + prompt_template_app_module, + "create_prompt_template_impl", + side_effect=side_effect, + ) + + response = prompt_template_client.post("/prompt_templates", json=prompt_template_payload) + + assert response.status_code == expected_status + assert response.json()["detail"] == expected_detail + + +def test_update_prompt_template_api_success( + mocker, prompt_template_app_module, prompt_template_client, prompt_template_payload +): + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + update_mock = mocker.patch.object( + prompt_template_app_module, + "update_prompt_template_impl", + return_value={"template_id": 4, "template_name": "template-a"}, + ) + + response = prompt_template_client.put("/prompt_templates/4", json=prompt_template_payload) + + assert response.status_code == HTTPStatus.OK + assert response.json() == {"template_id": 4, "template_name": "template-a"} + assert update_mock.call_args.kwargs["template_id"] == 4 + + +@pytest.mark.parametrize( + ("side_effect", "expected_status", "expected_detail"), + [ + pytest.param("not_found", HTTPStatus.NOT_FOUND, "Prompt template not found", id="not-found"), + pytest.param("duplicate", HTTPStatus.BAD_REQUEST, "Prompt template name already exists", id="duplicate"), + pytest.param("validation", HTTPStatus.BAD_REQUEST, "System default prompt template cannot be updated", id="validation"), + (Exception("unexpected"), HTTPStatus.INTERNAL_SERVER_ERROR, "Prompt template update error."), + ], +) +def test_update_prompt_template_api_error_mapping( + mocker, + prompt_template_app_module, + prompt_template_client, + prompt_template_exceptions, + prompt_template_payload, + side_effect, + expected_status, + expected_detail, +): + if side_effect == "not_found": + side_effect = prompt_template_exceptions.NotFoundException( + "Prompt template not found" + ) + elif side_effect == "duplicate": + side_effect = prompt_template_exceptions.DuplicateError( + "Prompt template name already exists" + ) + elif side_effect == "validation": + side_effect = prompt_template_exceptions.ValidationError( + "System default prompt template cannot be updated" + ) + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + mocker.patch.object( + prompt_template_app_module, + "update_prompt_template_impl", + side_effect=side_effect, + ) + + response = prompt_template_client.put("/prompt_templates/7", json=prompt_template_payload) + + assert response.status_code == expected_status + assert response.json()["detail"] == expected_detail + + +def test_delete_prompt_template_api_success( + mocker, prompt_template_app_module, prompt_template_client +): + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + delete_mock = mocker.patch.object( + prompt_template_app_module, + "delete_prompt_template_impl", + return_value={"template_id": 8, "deleted": True}, + ) + + response = prompt_template_client.delete("/prompt_templates/8") + + assert response.status_code == HTTPStatus.OK + assert response.json() == {"template_id": 8, "deleted": True} + delete_mock.assert_called_once_with(template_id=8, tenant_id="tenant-1", user_id="user-1") + + +@pytest.mark.parametrize( + ("side_effect", "expected_status", "expected_detail"), + [ + pytest.param("not_found", HTTPStatus.NOT_FOUND, "Prompt template not found", id="not-found"), + pytest.param("validation", HTTPStatus.BAD_REQUEST, "System default prompt template cannot be deleted", id="validation"), + (Exception("unexpected"), HTTPStatus.INTERNAL_SERVER_ERROR, "Prompt template delete error."), + ], +) +def test_delete_prompt_template_api_error_mapping( + mocker, + prompt_template_app_module, + prompt_template_client, + prompt_template_exceptions, + side_effect, + expected_status, + expected_detail, +): + if side_effect == "not_found": + side_effect = prompt_template_exceptions.NotFoundException( + "Prompt template not found" + ) + elif side_effect == "validation": + side_effect = prompt_template_exceptions.ValidationError( + "System default prompt template cannot be deleted" + ) + mocker.patch.object( + prompt_template_app_module, + "get_current_user_id", + return_value=("user-1", "tenant-1"), + ) + mocker.patch.object( + prompt_template_app_module, + "delete_prompt_template_impl", + side_effect=side_effect, + ) + + response = prompt_template_client.delete("/prompt_templates/11") + + assert response.status_code == expected_status + assert response.json()["detail"] == expected_detail diff --git a/test/backend/database/test_agent_db.py b/test/backend/database/test_agent_db.py index 6f2c780e5..de2ed8864 100644 --- a/test/backend/database/test_agent_db.py +++ b/test/backend/database/test_agent_db.py @@ -119,6 +119,8 @@ def __init__(self): self.parent_agent_id = None self.provide_run_summary = None self.business_description = None + self.prompt_template_id = None + self.prompt_template_name = None self.group_ids = None self.is_new = True self.enable_context_manager = False diff --git a/test/backend/services/test_agent_service.py b/test/backend/services/test_agent_service.py index 27298f25f..393695c09 100644 --- a/test/backend/services/test_agent_service.py +++ b/test/backend/services/test_agent_service.py @@ -1,6 +1,7 @@ import sys import asyncio import json +import types from contextlib import contextmanager from unittest.mock import patch, MagicMock, mock_open, call, Mock, AsyncMock import os @@ -62,10 +63,23 @@ def model_dump(self, **kwargs): sys.modules['database.a2a_agent_db'] = MagicMock() # Mock services submodules -sys.modules['services'] = MagicMock() -sys.modules['services.conversation_management_service'] = MagicMock() -sys.modules['services.memory_config_service'] = MagicMock() -sys.modules['services.agent_version_service'] = MagicMock() +services_module = types.ModuleType("services") +services_module.__path__ = [] +sys.modules['services'] = services_module + +conversation_management_service_mock = MagicMock() +memory_config_service_mock = MagicMock() +agent_version_service_mock = MagicMock() +prompt_template_service_mock = MagicMock() +prompt_template_service_mock.SYSTEM_PROMPT_TEMPLATE_ID = 0 +prompt_template_service_mock.SYSTEM_PROMPT_TEMPLATE_NAME = "system_default" +prompt_template_service_mock.get_prompt_template_summary = MagicMock(return_value=(None, None)) +prompt_template_service_mock.resolve_prompt_generate_template = MagicMock(return_value={}) + +sys.modules['services.conversation_management_service'] = conversation_management_service_mock +sys.modules['services.memory_config_service'] = memory_config_service_mock +sys.modules['services.agent_version_service'] = agent_version_service_mock +sys.modules['services.prompt_template_service'] = prompt_template_service_mock # Mock agents submodules sys.modules['agents'] = MagicMock() @@ -282,6 +296,18 @@ def reset_mocks(): yield +def apply_default_prompt_template_request_fields(request, prompt_template_id=None): + """Populate default request fields needed by prompt template aware service logic.""" + request.prompt_template_id = prompt_template_id + request.prompt_template_name = None + request.enabled_skill_ids = None + if not hasattr(request, "related_agent_ids"): + request.related_agent_ids = None + if not hasattr(request, "enabled_tool_ids"): + request.enabled_tool_ids = None + return request + + @pytest.mark.asyncio async def test_get_enable_tool_id_by_agent_id(): """ @@ -421,6 +447,8 @@ async def test_get_agent_info_impl_success(mock_search_agent_info, mock_search_t "sub_agent_id_list": mock_sub_agent_ids, "model_name": None, "business_logic_model_name": None, + "prompt_template_id": 0, + "prompt_template_name": "system_default", "is_available": True, "unavailable_reasons": [] } @@ -479,6 +507,8 @@ async def test_get_agent_info_impl_with_version_no(mock_search_agent_info, mock_ "sub_agent_id_list": mock_sub_agent_ids, "model_name": None, "business_logic_model_name": None, + "prompt_template_id": 0, + "prompt_template_name": "system_default", "is_available": True, "unavailable_reasons": [] } @@ -584,6 +614,7 @@ async def test_update_agent_info_impl_success(mock_get_current_user_info, mock_u request.business_description = "Updated agent" request.display_name = "Updated Display Name" request.enabled_tool_ids = None # Explicitly set to None to avoid tool handling path + apply_default_prompt_template_request_fields(request) # Execute await update_agent_info_impl(request, authorization="Bearer token") @@ -662,6 +693,7 @@ async def test_update_agent_info_impl_exception_handling(mock_get_current_user_i request.display_name = "Test Display Name" request.enabled_tool_ids = None request.related_agent_ids = None + apply_default_prompt_template_request_fields(request) # Execute & Assert with pytest.raises(ValueError) as context: @@ -701,6 +733,7 @@ async def test_update_agent_info_impl_with_enabled_tool_ids( request.agent_id = 123 request.enabled_tool_ids = [1, 2] # Enable tools 1 and 2 request.related_agent_ids = None + apply_default_prompt_template_request_fields(request) # Execute result = await update_agent_info_impl(request, authorization="Bearer token") @@ -758,6 +791,7 @@ async def test_update_agent_info_impl_with_enabled_tool_ids_instance_having_null request.agent_id = 123 request.enabled_tool_ids = [1] # Enable only tool 1 request.related_agent_ids = None + apply_default_prompt_template_request_fields(request) # Execute result = await update_agent_info_impl(request, authorization="Bearer token") @@ -805,6 +839,7 @@ async def test_update_agent_info_impl_with_enabled_tool_ids_disabled_existing_to request.enabled_tool_ids = [2] # Only enable tool 2 (new tool) # Tool 1 exists but is NOT in enabled_tool_ids, so it should be disabled request.related_agent_ids = None + apply_default_prompt_template_request_fields(request) # Execute result = await update_agent_info_impl(request, authorization="Bearer token") @@ -858,6 +893,7 @@ async def test_update_agent_info_impl_with_related_agent_ids( request.agent_id = 123 request.enabled_tool_ids = None request.related_agent_ids = [456, 789] + apply_default_prompt_template_request_fields(request) # Execute result = await update_agent_info_impl(request, authorization="Bearer token") @@ -896,6 +932,7 @@ async def test_update_agent_info_impl_circular_dependency_detection( request.agent_id = 123 request.enabled_tool_ids = None request.related_agent_ids = [123] # Agent tries to relate to itself + apply_default_prompt_template_request_fields(request) # Execute & Assert - self-reference should raise ValueError with pytest.raises(ValueError, match="Circular dependency detected"): @@ -941,6 +978,7 @@ async def test_update_agent_info_impl_with_both_tool_and_related_agents( request.agent_id = 123 request.enabled_tool_ids = [1] request.related_agent_ids = [456] + apply_default_prompt_template_request_fields(request) # Execute result = await update_agent_info_impl(request, authorization="Bearer token") @@ -983,6 +1021,7 @@ async def test_update_agent_info_impl_tool_update_exception( request.agent_id = 123 request.enabled_tool_ids = [1] request.related_agent_ids = None + apply_default_prompt_template_request_fields(request) # Execute & Assert with pytest.raises(ValueError, match="Failed to update agent tools"): @@ -1015,6 +1054,7 @@ async def test_update_agent_info_impl_related_agent_update_exception( request.agent_id = 123 request.enabled_tool_ids = None request.related_agent_ids = [456] + apply_default_prompt_template_request_fields(request) # Execute & Assert with pytest.raises(ValueError, match="Failed to update related agents"): @@ -1216,6 +1256,7 @@ async def test_update_agent_info_impl_create_agent_auto_group_ids(mock_get_curre request.enabled_tool_ids = None request.related_agent_ids = None request.group_ids = None + apply_default_prompt_template_request_fields(request) # Execute result = await update_agent_info_impl(request, authorization="Bearer token") @@ -1563,6 +1604,8 @@ async def test_get_agent_info_impl_with_model_id_success(mock_search_agent_info, "sub_agent_id_list": mock_sub_agent_ids, "model_name": "GPT-4", "business_logic_model_name": None, + "prompt_template_id": 0, + "prompt_template_name": "system_default", "is_available": True, "unavailable_reasons": [] } @@ -1651,6 +1694,8 @@ async def test_get_agent_info_impl_with_model_id_no_display_name(mock_search_age "sub_agent_id_list": mock_sub_agent_ids, "model_name": None, "business_logic_model_name": None, + "prompt_template_id": 0, + "prompt_template_name": "system_default", "is_available": True, "unavailable_reasons": [] } @@ -1702,6 +1747,8 @@ async def test_get_agent_info_impl_with_model_id_none_model_info(mock_search_age "sub_agent_id_list": mock_sub_agent_ids, "model_name": None, "business_logic_model_name": None, + "prompt_template_id": 0, + "prompt_template_name": "system_default", "is_available": True, "unavailable_reasons": [] } @@ -1777,6 +1824,8 @@ def mock_get_model(model_id): "sub_agent_id_list": mock_sub_agent_ids, "model_name": "GPT-4", "business_logic_model_name": "Claude-3.5", + "prompt_template_id": 0, + "prompt_template_name": "system_default", "is_available": True, "unavailable_reasons": [] } @@ -1848,6 +1897,8 @@ def mock_get_model(model_id): "sub_agent_id_list": mock_sub_agent_ids, "model_name": "GPT-4", "business_logic_model_name": None, # Should be None when model info is not found + "prompt_template_id": 0, + "prompt_template_name": "system_default", "is_available": True, "unavailable_reasons": [] } @@ -1926,6 +1977,8 @@ def mock_get_model(model_id): "sub_agent_id_list": mock_sub_agent_ids, "model_name": "GPT-4", "business_logic_model_name": None, # Should be None when display_name is not in model_info + "prompt_template_id": 0, + "prompt_template_name": "system_default", "is_available": True, "unavailable_reasons": [] } @@ -8015,6 +8068,7 @@ async def test_update_agent_info_impl_create_agent_with_ingroup_permission( request.related_agent_ids = None request.group_ids = [1, 2] request.ingroup_permission = PERMISSION_READ + apply_default_prompt_template_request_fields(request) result = await update_agent_info_impl(request, authorization="Bearer token") @@ -8065,6 +8119,7 @@ async def test_update_agent_info_impl_create_agent_with_ingroup_permission_none( request.related_agent_ids = None request.group_ids = None request.ingroup_permission = None + apply_default_prompt_template_request_fields(request) result = await update_agent_info_impl(request, authorization="Bearer token") @@ -8766,6 +8821,8 @@ async def test_update_agent_info_impl_skill_update_exception( mock_request.related_agent_ids = None mock_request.group_ids = None mock_request.ingroup_permission = None + mock_request.prompt_template_id = None + mock_request.prompt_template_name = None mock_query_skills.return_value = [] mock_create_skill.side_effect = Exception("Skill update failed") diff --git a/test/backend/services/test_agent_version_service.py b/test/backend/services/test_agent_version_service.py index d44ae737c..0db70fb14 100644 --- a/test/backend/services/test_agent_version_service.py +++ b/test/backend/services/test_agent_version_service.py @@ -601,23 +601,19 @@ def test_rollback_version_impl_success(monkeypatch): } mock_search = MagicMock(return_value=mock_version) monkeypatch.setattr(agent_version_service_module, "search_version_by_version_no", mock_search) - - # Mock query_agent_snapshot - mock_agent_snapshot = {"agent_id": 1, "name": "test"} - mock_tools_snapshot = [] - mock_relations_snapshot = [] - mock_query_snapshot = MagicMock(return_value=(mock_agent_snapshot, mock_tools_snapshot, mock_relations_snapshot)) - monkeypatch.setattr(agent_version_service_module, "query_agent_snapshot", mock_query_snapshot) - - # Mock restore_agent_draft - mock_restore = MagicMock() - monkeypatch.setattr(agent_version_service_module, "restore_agent_draft", mock_restore) - mock_query_snapshot = MagicMock(return_value=({"agent_id": 1}, [], [])) + mock_query_snapshot = MagicMock( + return_value=( + {"agent_id": 1, "version_no": 1, "name": "Test Agent"}, + [{"tool_id": 1, "version_no": 1}], + [{"relation_id": 1, "version_no": 1}], + ) + ) monkeypatch.setattr(agent_version_service_module, "query_agent_snapshot", mock_query_snapshot) + mock_query_draft = MagicMock(return_value=({"agent_id": 1, "version_no": 0}, [], [])) + monkeypatch.setattr(agent_version_service_module, "query_agent_draft", mock_query_draft) + mock_restore_draft = MagicMock() + monkeypatch.setattr(agent_version_service_module, "restore_agent_draft", mock_restore_draft) monkeypatch.setattr(skill_db_mock, "query_skill_instances_by_agent_id", MagicMock(return_value=[])) - monkeypatch.setattr(agent_version_db_mock, "restore_agent_draft", MagicMock(return_value=True)) - mock_update_current = MagicMock(return_value=1) - monkeypatch.setattr(agent_version_service_module, "update_agent_current_version", mock_update_current) result = rollback_version_impl( agent_id=1, @@ -630,7 +626,7 @@ def test_rollback_version_impl_success(monkeypatch): assert "Successfully rolled back" in result["message"] mock_search.assert_called_once_with(1, "tenant1", 1) mock_query_snapshot.assert_called_once_with(1, "tenant1", 1) - mock_restore.assert_called_once() + mock_restore_draft.assert_called_once() def test_rollback_version_impl_version_not_found(monkeypatch): @@ -647,18 +643,22 @@ def test_rollback_version_impl_version_not_found(monkeypatch): def test_rollback_version_impl_draft_not_found(monkeypatch): - """Test rolling back when snapshot is not found""" + """Test rolling back when draft doesn't exist""" mock_version = {"version_no": 1} mock_search = MagicMock(return_value=mock_version) monkeypatch.setattr(agent_version_service_module, "search_version_by_version_no", mock_search) - mock_query_snapshot = MagicMock(return_value=(None, [], [])) - monkeypatch.setattr(agent_version_service_module, "query_agent_snapshot", mock_query_snapshot) - - # Mock query_agent_snapshot to return empty agent (falsy) - mock_query_snapshot = MagicMock(return_value=(None, [], [])) + mock_query_snapshot = MagicMock( + return_value=( + {"agent_id": 1, "version_no": 1, "name": "Test Agent"}, + [], + [], + ) + ) monkeypatch.setattr(agent_version_service_module, "query_agent_snapshot", mock_query_snapshot) + mock_query_draft = MagicMock(return_value=(None, [], [])) + monkeypatch.setattr(agent_version_service_module, "query_agent_draft", mock_query_draft) - with pytest.raises(ValueError, match="Agent snapshot for version 1 not found"): + with pytest.raises(ValueError, match="Agent draft not found"): rollback_version_impl( agent_id=1, tenant_id="tenant1", diff --git a/test/backend/services/test_prompt_service.py b/test/backend/services/test_prompt_service.py index 601e6a934..7c85011b4 100644 --- a/test/backend/services/test_prompt_service.py +++ b/test/backend/services/test_prompt_service.py @@ -172,9 +172,11 @@ def mock_generator(*args, **kwargs): "Test task", enabled_tools, # tool_info_list from helper "tenant456", + "user123", self.test_model_id, "zh", - None # knowledge_base_display_names + None, + None, ) @patch('backend.services.prompt_service._regenerate_agent_display_name_with_llm') @@ -567,6 +569,7 @@ def test_gen_system_prompt_streamable(self, mock_generate_impl): user_id="user123", tenant_id="tenant456", language="zh", + prompt_template_id=None, tool_ids=None, sub_agent_ids=None, knowledge_base_display_names=None, @@ -580,19 +583,19 @@ def test_gen_system_prompt_streamable(self, mock_generate_impl): @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') - @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') - def test_generate_system_prompt(self, mock_get_prompt_template, mock_join_info, mock_call_llm): + @patch('backend.services.prompt_service.resolve_prompt_generate_template') + def test_generate_system_prompt(self, mock_resolve_prompt_template, mock_join_info, mock_call_llm): # Setup mock_prompt_config = { - "USER_PROMPT": "Test user prompt template", - "DUTY_SYSTEM_PROMPT": "Generate duty prompt", - "CONSTRAINT_SYSTEM_PROMPT": "Generate constraint prompt", - "FEW_SHOTS_SYSTEM_PROMPT": "Generate few shots prompt", - "AGENT_VARIABLE_NAME_SYSTEM_PROMPT": "Generate agent var name", - "AGENT_DISPLAY_NAME_SYSTEM_PROMPT": "Generate agent display name", - "AGENT_DESCRIPTION_SYSTEM_PROMPT": "Generate agent description" + "user_prompt": "Test user prompt template", + "duty_system_prompt": "Generate duty prompt", + "constraint_system_prompt": "Generate constraint prompt", + "few_shots_system_prompt": "Generate few shots prompt", + "agent_variable_name_system_prompt": "Generate agent var name", + "agent_display_name_system_prompt": "Generate agent display name", + "agent_description_system_prompt": "Generate agent description" } - mock_get_prompt_template.return_value = mock_prompt_config + mock_resolve_prompt_template.return_value = mock_prompt_config mock_join_info.return_value = "Joined template content" @@ -644,6 +647,7 @@ def mock_llm_call(model_id, content, sys_prompt, callback, tenant_id): mock_task_description, mock_tools, mock_tenant_id, + "test_user", self.test_model_id, mock_language ): @@ -651,7 +655,12 @@ def mock_llm_call(model_id, content, sys_prompt, callback, tenant_id): # Assert # Verify template loading - mock_get_prompt_template.assert_called_once_with(mock_language) + mock_resolve_prompt_template.assert_called_once_with( + tenant_id=mock_tenant_id, + user_id="test_user", + language=mock_language, + prompt_template_id=None, + ) # Verify template joining - now includes knowledge_base_display_names parameter mock_join_info.assert_called_once_with( @@ -697,19 +706,19 @@ def mock_llm_call(model_id, content, sys_prompt, callback, tenant_id): @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') - @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') - def test_generate_system_prompt_with_exception(self, mock_get_prompt_template, mock_join_info, mock_call_llm): + @patch('backend.services.prompt_service.resolve_prompt_generate_template') + def test_generate_system_prompt_with_exception(self, mock_resolve_prompt_template, mock_join_info, mock_call_llm): # Setup mock_prompt_config = { - "USER_PROMPT": "Test user prompt template", - "DUTY_SYSTEM_PROMPT": "Generate duty prompt", - "CONSTRAINT_SYSTEM_PROMPT": "Generate constraint prompt", - "FEW_SHOTS_SYSTEM_PROMPT": "Generate few shots prompt", - "AGENT_VARIABLE_NAME_SYSTEM_PROMPT": "Generate agent var name", - "AGENT_DISPLAY_NAME_SYSTEM_PROMPT": "Generate agent display name", - "AGENT_DESCRIPTION_SYSTEM_PROMPT": "Generate agent description" + "user_prompt": "Test user prompt template", + "duty_system_prompt": "Generate duty prompt", + "constraint_system_prompt": "Generate constraint prompt", + "few_shots_system_prompt": "Generate few shots prompt", + "agent_variable_name_system_prompt": "Generate agent var name", + "agent_display_name_system_prompt": "Generate agent display name", + "agent_description_system_prompt": "Generate agent description" } - mock_get_prompt_template.return_value = mock_prompt_config + mock_resolve_prompt_template.return_value = mock_prompt_config mock_join_info.return_value = "Joined template content" # Mock call_llm_for_system_prompt to raise exception for one prompt type @@ -741,6 +750,7 @@ def mock_llm_call_with_exception(model_id, content, sys_prompt, callback, tenant mock_task_description, mock_tools, mock_tenant_id, + "test_user", self.test_model_id, mock_language ): @@ -752,7 +762,7 @@ def mock_llm_call_with_exception(model_id, content, sys_prompt, callback, tenant @patch('backend.services.prompt_service.Template') def test_join_info_for_generate_system_prompt(self, mock_template): # Setup - mock_prompt_for_generate = {"USER_PROMPT": "Test User Prompt"} + mock_prompt_for_generate = {"user_prompt": "Test User Prompt"} mock_sub_agents = [ {"name": "agent1", "description": "Agent 1 desc"}, {"name": "agent2", "description": "Agent 2 desc"} @@ -777,7 +787,7 @@ def test_join_info_for_generate_system_prompt(self, mock_template): # Assert self.assertEqual(result, "Rendered content") mock_template.assert_called_once_with( - mock_prompt_for_generate["USER_PROMPT"], undefined=StrictUndefined) + mock_prompt_for_generate["user_prompt"], undefined=StrictUndefined) mock_template_instance.render.assert_called_once() # Check template variables template_vars = mock_template_instance.render.call_args[0][0] @@ -994,25 +1004,25 @@ def mock_gen(*args, **kwargs): @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') - @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') + @patch('backend.services.prompt_service.resolve_prompt_generate_template') def test_generate_system_prompt_error_before_streaming( self, - mock_get_prompt_template, + mock_resolve_prompt_template, mock_join_info, mock_call_llm, ): """Test generate_system_prompt handles error that occurs before streaming (line 307-311)""" # Setup mock_prompt_config = { - "USER_PROMPT": "Test user prompt template", - "DUTY_SYSTEM_PROMPT": "Generate duty prompt", - "CONSTRAINT_SYSTEM_PROMPT": "Generate constraint prompt", - "FEW_SHOTS_SYSTEM_PROMPT": "Generate few shots prompt", - "AGENT_VARIABLE_NAME_SYSTEM_PROMPT": "Generate agent var name", - "AGENT_DISPLAY_NAME_SYSTEM_PROMPT": "Generate agent display name", - "AGENT_DESCRIPTION_SYSTEM_PROMPT": "Generate agent description" + "user_prompt": "Test user prompt template", + "duty_system_prompt": "Generate duty prompt", + "constraint_system_prompt": "Generate constraint prompt", + "few_shots_system_prompt": "Generate few shots prompt", + "agent_variable_name_system_prompt": "Generate agent var name", + "agent_display_name_system_prompt": "Generate agent display name", + "agent_description_system_prompt": "Generate agent description" } - mock_get_prompt_template.return_value = mock_prompt_config + mock_resolve_prompt_template.return_value = mock_prompt_config mock_join_info.return_value = "Joined template content" # Mock call_llm_for_system_prompt to raise exception immediately @@ -1034,6 +1044,7 @@ def mock_llm_call_error(model_id, content, sys_prompt, callback, tenant_id): "Test task", [{"name": "tool1"}], "tenant123", + "test_user", self.test_model_id, "zh" ): @@ -1043,25 +1054,25 @@ def mock_llm_call_error(model_id, content, sys_prompt, callback, tenant_id): @patch('backend.services.prompt_service.call_llm_for_system_prompt') @patch('backend.services.prompt_service.join_info_for_generate_system_prompt') - @patch('backend.services.prompt_service.get_prompt_generate_prompt_template') + @patch('backend.services.prompt_service.resolve_prompt_generate_template') def test_generate_system_prompt_error_during_streaming( self, - mock_get_prompt_template, + mock_resolve_prompt_template, mock_join_info, mock_call_llm, ): """Test generate_system_prompt handles error that occurs during streaming (line 330-331)""" # Setup mock_prompt_config = { - "USER_PROMPT": "Test user prompt template", - "DUTY_SYSTEM_PROMPT": "Generate duty prompt", - "CONSTRAINT_SYSTEM_PROMPT": "Generate constraint prompt", - "FEW_SHOTS_SYSTEM_PROMPT": "Generate few shots prompt", - "AGENT_VARIABLE_NAME_SYSTEM_PROMPT": "Generate agent var name", - "AGENT_DISPLAY_NAME_SYSTEM_PROMPT": "Generate agent display name", - "AGENT_DESCRIPTION_SYSTEM_PROMPT": "Generate agent description" + "user_prompt": "Test user prompt template", + "duty_system_prompt": "Generate duty prompt", + "constraint_system_prompt": "Generate constraint prompt", + "few_shots_system_prompt": "Generate few shots prompt", + "agent_variable_name_system_prompt": "Generate agent var name", + "agent_display_name_system_prompt": "Generate agent display name", + "agent_description_system_prompt": "Generate agent description" } - mock_get_prompt_template.return_value = mock_prompt_config + mock_resolve_prompt_template.return_value = mock_prompt_config mock_join_info.return_value = "Joined template content" # Track which call we're on @@ -1092,6 +1103,7 @@ def mock_llm_call_error_after_first( "Test task", [{"name": "tool1"}], "tenant123", + "test_user", self.test_model_id, "zh" ): @@ -1146,7 +1158,7 @@ def test_get_enabled_sub_agent_description_for_generate_prompt_empty( def test_join_info_for_generate_system_prompt_english(self, mock_template): """Test join_info_for_generate_system_prompt with English language""" # Setup - mock_prompt_for_generate = {"USER_PROMPT": "Test User Prompt"} + mock_prompt_for_generate = {"user_prompt": "Test User Prompt"} mock_sub_agents = [ {"name": "agent1", "description": "Agent 1 desc"} ] @@ -1176,7 +1188,7 @@ def test_join_info_for_generate_system_prompt_english(self, mock_template): def test_join_info_for_generate_system_prompt_empty_tools_and_agents(self, mock_template): """Test join_info_for_generate_system_prompt with empty tools and sub-agents""" # Setup - mock_prompt_for_generate = {"USER_PROMPT": "Test User Prompt"} + mock_prompt_for_generate = {"user_prompt": "Test User Prompt"} mock_sub_agents = [] mock_task_description = "Test task" mock_tools = [] @@ -1197,7 +1209,7 @@ def test_join_info_for_generate_system_prompt_empty_tools_and_agents(self, mock_ def test_join_info_for_generate_system_prompt_with_knowledge_base_names(self, mock_template): """Test join_info_for_generate_system_prompt with knowledge_base_display_names""" # Setup - mock_prompt_for_generate = {"USER_PROMPT": "Test User Prompt"} + mock_prompt_for_generate = {"user_prompt": "Test User Prompt"} mock_sub_agents = [] mock_task_description = "Test task" mock_tools = [ @@ -1226,7 +1238,7 @@ def test_join_info_for_generate_system_prompt_with_knowledge_base_names(self, mo def test_join_info_for_generate_system_prompt_without_knowledge_base_names(self, mock_template): """Test join_info_for_generate_system_prompt without knowledge_base_display_names""" # Setup - mock_prompt_for_generate = {"USER_PROMPT": "Test User Prompt"} + mock_prompt_for_generate = {"user_prompt": "Test User Prompt"} mock_sub_agents = [] mock_task_description = "Test task" mock_tools = [ diff --git a/test/backend/services/test_prompt_template_service.py b/test/backend/services/test_prompt_template_service.py new file mode 100644 index 000000000..34415b203 --- /dev/null +++ b/test/backend/services/test_prompt_template_service.py @@ -0,0 +1,501 @@ +import importlib +import os +import sys +import types + +import pytest + + +BACKEND_PATH = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../../backend") +) + + +@pytest.fixture(autouse=True) +def _reset_prompt_template_service_modules(): + yield + sys.modules.pop("services.prompt_template_service", None) + sys.modules.pop("database.prompt_template_db", None) + + +@pytest.fixture +def prompt_template_models(monkeypatch): + if BACKEND_PATH not in sys.path: + sys.path.insert(0, BACKEND_PATH) + + nexent_module = types.ModuleType("nexent") + nexent_core_module = types.ModuleType("nexent.core") + nexent_agents_module = types.ModuleType("nexent.core.agents") + agent_model_module = types.ModuleType("nexent.core.agents.agent_model") + agent_model_module.ToolConfig = type("ToolConfig", (), {}) + + monkeypatch.setitem(sys.modules, "nexent", nexent_module) + monkeypatch.setitem(sys.modules, "nexent.core", nexent_core_module) + monkeypatch.setitem(sys.modules, "nexent.core.agents", nexent_agents_module) + monkeypatch.setitem(sys.modules, "nexent.core.agents.agent_model", agent_model_module) + + consts_model = importlib.import_module("consts.model") + consts_exceptions = importlib.import_module("consts.exceptions") + return consts_model, consts_exceptions + + +@pytest.fixture +def prompt_template_service_module(monkeypatch): + if BACKEND_PATH not in sys.path: + sys.path.insert(0, BACKEND_PATH) + + db_module = types.ModuleType("database.prompt_template_db") + for name in [ + "create_prompt_template", + "delete_prompt_template", + "get_prompt_template_by_id", + "get_prompt_template_by_name", + "get_prompt_template_by_template_id", + "query_prompt_templates_by_user", + "upsert_prompt_template_by_id", + "update_prompt_template", + ]: + setattr(db_module, name, lambda *args, **kwargs: None) + monkeypatch.setitem(sys.modules, "database.prompt_template_db", db_module) + + sys.modules.pop("services.prompt_template_service", None) + module = importlib.import_module("services.prompt_template_service") + return importlib.reload(module) + + +@pytest.fixture +def template_content_factory(): + def _build(seed: str = "value", **overrides): + content = { + "duty_system_prompt": f"{seed}-duty", + "constraint_system_prompt": f"{seed}-constraint", + "few_shots_system_prompt": f"{seed}-few-shots", + "agent_variable_name_system_prompt": f"{seed}-agent-name", + "agent_display_name_system_prompt": f"{seed}-display-name", + "agent_description_system_prompt": f"{seed}-description", + "user_prompt": f"{seed}-user", + "agent_name_regenerate_system_prompt": f"{seed}-regen-name-system", + "agent_name_regenerate_user_prompt": f"{seed}-regen-name-user", + "agent_display_name_regenerate_system_prompt": f"{seed}-regen-display-system", + "agent_display_name_regenerate_user_prompt": f"{seed}-regen-display-user", + } + content.update(overrides) + return content + + return _build + + +@pytest.fixture +def prompt_template_request_factory(template_content_factory, prompt_template_models): + consts_model, _ = prompt_template_models + + def _build( + template_name: str = "template-a", + description: str | None = "template description", + template_type: str = "agent_generate", + template_content_zh: dict | None = None, + template_content_en: dict | None = None, + ): + return consts_model.PromptTemplateRequest( + template_name=template_name, + description=description, + template_type=template_type, + template_content_zh=consts_model.PromptTemplateContentRequest( + **(template_content_zh or template_content_factory("zh")) + ), + template_content_en=( + consts_model.PromptTemplateContentRequest( + **(template_content_en or template_content_factory("en")) + ) + if template_content_en is not None + else None + ), + ) + + return _build + + +def test_build_system_default_prompt_template_payload( + mocker, prompt_template_service_module, template_content_factory +): + mocker.patch.object( + prompt_template_service_module, + "get_prompt_generate_prompt_template", + side_effect=[ + template_content_factory("zh"), + template_content_factory("en"), + ], + ) + + payload = prompt_template_service_module.build_system_default_prompt_template_payload() + + assert payload["template_id"] == 0 + assert payload["template_name"] == "system_default" + assert payload["tenant_id"] == prompt_template_service_module.SYSTEM_PROMPT_TEMPLATE_TENANT_ID + assert payload["user_id"] == prompt_template_service_module.SYSTEM_PROMPT_TEMPLATE_USER_ID + assert payload["template_content_zh"]["duty_system_prompt"] == "zh-duty" + assert payload["template_content_en"]["duty_system_prompt"] == "en-duty" + + +def test_sync_system_default_prompt_template_marks_system_default( + mocker, prompt_template_service_module +): + payload = {"template_id": 0, "template_name": "system_default"} + mocker.patch.object( + prompt_template_service_module, + "build_system_default_prompt_template_payload", + return_value=payload, + ) + upsert_mock = mocker.patch.object( + prompt_template_service_module, + "upsert_prompt_template_by_id", + return_value={"template_id": 0, "template_name": "system_default"}, + ) + + result = prompt_template_service_module.sync_system_default_prompt_template() + + upsert_mock.assert_called_once_with( + template_id=0, + template_data=payload, + user_id=prompt_template_service_module.SYSTEM_PROMPT_TEMPLATE_USER_ID, + ) + assert result["is_system_default"] is True + + +def test_get_system_default_prompt_template_syncs_when_missing( + mocker, prompt_template_service_module +): + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_template_id", + return_value=None, + ) + sync_mock = mocker.patch.object( + prompt_template_service_module, + "sync_system_default_prompt_template", + return_value={"template_id": 0, "template_name": "system_default"}, + ) + + result = prompt_template_service_module.get_system_default_prompt_template() + + sync_mock.assert_called_once_with() + assert result["template_id"] == 0 + assert result["is_system_default"] is True + + +def test_normalize_template_request_trims_and_drops_empty_optional_fields( + prompt_template_service_module, prompt_template_request_factory, template_content_factory +): + request = prompt_template_request_factory( + template_name=" template-a ", + description=" ", + template_content_zh=template_content_factory( + "zh", + constraint_system_prompt="", + few_shots_system_prompt=" ", + ), + template_content_en=template_content_factory( + "en", + duty_system_prompt="", + constraint_system_prompt="", + few_shots_system_prompt="", + agent_variable_name_system_prompt="", + agent_display_name_system_prompt="", + agent_description_system_prompt="", + user_prompt="", + agent_name_regenerate_system_prompt="", + agent_name_regenerate_user_prompt="", + agent_display_name_regenerate_system_prompt="", + agent_display_name_regenerate_user_prompt="", + ), + ) + + result = prompt_template_service_module._normalize_template_request(request) + + assert result["template_name"] == "template-a" + assert result["description"] is None + assert "constraint_system_prompt" not in result["template_content_zh"] + assert result["template_content_en"] is None + + +def test_normalize_template_request_requires_non_empty_zh_content( + prompt_template_service_module, + prompt_template_request_factory, + template_content_factory, + prompt_template_models, +): + _, consts_exceptions = prompt_template_models + request = prompt_template_request_factory( + template_content_zh=template_content_factory( + "zh", + duty_system_prompt="", + constraint_system_prompt="", + few_shots_system_prompt="", + agent_variable_name_system_prompt="", + agent_display_name_system_prompt="", + agent_description_system_prompt="", + user_prompt="", + agent_name_regenerate_system_prompt="", + agent_name_regenerate_user_prompt="", + agent_display_name_regenerate_system_prompt="", + agent_display_name_regenerate_user_prompt="", + ) + ) + + with pytest.raises( + consts_exceptions.ValidationError, match="template_content_zh is required" + ): + prompt_template_service_module._normalize_template_request(request) + + +def test_list_prompt_templates_impl_prepends_system_default_and_filters_duplicate_id( + mocker, prompt_template_service_module +): + mocker.patch.object( + prompt_template_service_module, + "sync_system_default_prompt_template", + return_value={"template_id": 0, "template_name": "system_default", "is_system_default": True}, + ) + mocker.patch.object( + prompt_template_service_module, + "query_prompt_templates_by_user", + return_value=[ + {"template_id": 0, "template_name": "system_default"}, + {"template_id": 2, "template_name": "custom-template"}, + ], + ) + + result = prompt_template_service_module.list_prompt_templates_impl("tenant-1", "user-1") + + assert [item["template_id"] for item in result] == [0, 2] + assert result[0]["is_system_default"] is True + assert result[1]["is_system_default"] is False + + +def test_create_prompt_template_impl_rejects_duplicate_name( + mocker, + prompt_template_service_module, + prompt_template_request_factory, + prompt_template_models, +): + _, consts_exceptions = prompt_template_models + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_name", + return_value={"template_id": 1, "template_name": "template-a"}, + ) + + with pytest.raises( + consts_exceptions.DuplicateError, match="Prompt template name already exists" + ): + prompt_template_service_module.create_prompt_template_impl( + prompt_template_request_factory(), + tenant_id="tenant-1", + user_id="user-1", + ) + + +def test_create_prompt_template_impl_persists_user_template( + mocker, prompt_template_service_module, prompt_template_request_factory +): + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_name", + return_value=None, + ) + create_mock = mocker.patch.object( + prompt_template_service_module, + "create_prompt_template", + return_value={"template_id": 9, "template_name": "template-a"}, + ) + + result = prompt_template_service_module.create_prompt_template_impl( + prompt_template_request_factory(), + tenant_id="tenant-1", + user_id="user-1", + ) + + create_payload = create_mock.call_args.args[0] + assert create_payload["tenant_id"] == "tenant-1" + assert create_payload["user_id"] == "user-1" + assert create_payload["created_by"] == "user-1" + assert result["is_system_default"] is False + + +def test_update_prompt_template_impl_rejects_system_default( + prompt_template_service_module, + prompt_template_request_factory, + prompt_template_models, +): + _, consts_exceptions = prompt_template_models + with pytest.raises( + consts_exceptions.ValidationError, + match="System default prompt template cannot be updated", + ): + prompt_template_service_module.update_prompt_template_impl( + template_id=0, + request=prompt_template_request_factory(), + tenant_id="tenant-1", + user_id="user-1", + ) + + +def test_update_prompt_template_impl_updates_existing_template( + mocker, prompt_template_service_module, prompt_template_request_factory +): + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_id", + return_value={"template_id": 3, "template_name": "template-a"}, + ) + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_name", + return_value={"template_id": 3, "template_name": "template-a"}, + ) + update_mock = mocker.patch.object( + prompt_template_service_module, + "update_prompt_template", + return_value={"template_id": 3, "template_name": "template-a"}, + ) + + result = prompt_template_service_module.update_prompt_template_impl( + template_id=3, + request=prompt_template_request_factory(), + tenant_id="tenant-1", + user_id="user-1", + ) + + assert update_mock.call_args.kwargs["template_id"] == 3 + assert update_mock.call_args.kwargs["user_id"] == "user-1" + assert result["is_system_default"] is False + + +@pytest.mark.parametrize("deleted_count, expected_deleted", [(1, True), (0, False)]) +def test_delete_prompt_template_impl_returns_deleted_status( + mocker, prompt_template_service_module, deleted_count, expected_deleted +): + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_id", + return_value={"template_id": 5, "template_name": "template-a"}, + ) + mocker.patch.object( + prompt_template_service_module, + "delete_prompt_template", + return_value=deleted_count, + ) + + result = prompt_template_service_module.delete_prompt_template_impl( + template_id=5, + tenant_id="tenant-1", + user_id="user-1", + ) + + assert result == {"template_id": 5, "deleted": expected_deleted} + + +def test_resolve_prompt_generate_template_falls_back_to_system_default_when_custom_missing( + mocker, prompt_template_service_module +): + mocker.patch.object( + prompt_template_service_module, + "sync_system_default_prompt_template", + return_value={ + "template_content_en": {"duty_system_prompt": "system-en-duty"}, + "template_content_zh": {"constraint_system_prompt": "system-zh-constraint"}, + }, + ) + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_id", + return_value=None, + ) + + result = prompt_template_service_module.resolve_prompt_generate_template( + tenant_id="tenant-1", + user_id="user-1", + language=prompt_template_service_module.LANGUAGE["EN"], + prompt_template_id=8, + ) + + assert result == { + "duty_system_prompt": "system-en-duty", + "constraint_system_prompt": "system-zh-constraint", + } + + +def test_resolve_prompt_generate_template_merges_custom_and_system_fallbacks( + mocker, prompt_template_service_module +): + mocker.patch.object( + prompt_template_service_module, + "sync_system_default_prompt_template", + return_value={ + "template_content_en": {"few_shots_system_prompt": "system-en-few"}, + "template_content_zh": {"user_prompt": "system-zh-user"}, + }, + ) + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_id", + return_value={ + "template_id": 6, + "template_content_en": {"duty_system_prompt": "custom-en-duty"}, + "template_content_zh": {"constraint_system_prompt": "custom-zh-constraint"}, + }, + ) + + result = prompt_template_service_module.resolve_prompt_generate_template( + tenant_id="tenant-1", + user_id="user-1", + language=prompt_template_service_module.LANGUAGE["EN"], + prompt_template_id=6, + ) + + assert result == { + "duty_system_prompt": "custom-en-duty", + "constraint_system_prompt": "custom-zh-constraint", + "few_shots_system_prompt": "system-en-few", + "user_prompt": "system-zh-user", + } + + +@pytest.mark.parametrize( + ("template_id", "expected"), + [ + (None, (None, None)), + (0, (0, "system_default")), + ], +) +def test_get_prompt_template_summary_handles_none_and_system_default( + prompt_template_service_module, template_id, expected +): + assert ( + prompt_template_service_module.get_prompt_template_summary( + template_id=template_id, + tenant_id="tenant-1", + user_id="user-1", + ) + == expected + ) + + +def test_get_prompt_template_summary_raises_when_template_missing( + mocker, prompt_template_service_module, prompt_template_models +): + _, consts_exceptions = prompt_template_models + mocker.patch.object( + prompt_template_service_module, + "get_prompt_template_by_id", + return_value=None, + ) + + with pytest.raises( + consts_exceptions.NotFoundException, match="Prompt template not found" + ): + prompt_template_service_module.get_prompt_template_summary( + template_id=10, + tenant_id="tenant-1", + user_id="user-1", + ) diff --git a/test/backend/test_cluster_summarization.py b/test/backend/test_cluster_summarization.py index 82af6d5ba..dd24a9f20 100644 --- a/test/backend/test_cluster_summarization.py +++ b/test/backend/test_cluster_summarization.py @@ -35,10 +35,28 @@ consts_error_code_mock.ErrorCode = MagicMock() consts_exceptions_mock = MagicMock() consts_exceptions_mock.AppException = Exception +consts_prompt_template_mock = MagicMock() +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP = { + "duty_system_prompt": "DUTY_SYSTEM_PROMPT", + "constraint_system_prompt": "CONSTRAINT_SYSTEM_PROMPT", + "few_shots_system_prompt": "FEW_SHOTS_SYSTEM_PROMPT", + "agent_variable_name_system_prompt": "AGENT_VARIABLE_NAME_SYSTEM_PROMPT", + "agent_display_name_system_prompt": "AGENT_DISPLAY_NAME_SYSTEM_PROMPT", + "agent_description_system_prompt": "AGENT_DESCRIPTION_SYSTEM_PROMPT", + "user_prompt": "USER_PROMPT", + "agent_name_regenerate_system_prompt": "AGENT_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_name_regenerate_user_prompt": "AGENT_NAME_REGENERATE_USER_PROMPT", + "agent_display_name_regenerate_system_prompt": "AGENT_DISPLAY_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_display_name_regenerate_user_prompt": "AGENT_DISPLAY_NAME_REGENERATE_USER_PROMPT", +} +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELDS = tuple( + consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP.keys() +) sys.modules['consts'] = consts_mock sys.modules['consts.const'] = consts_const_mock sys.modules['consts.error_code'] = consts_error_code_mock sys.modules['consts.exceptions'] = consts_exceptions_mock +sys.modules['consts.prompt_template'] = consts_prompt_template_mock # Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/test/backend/test_document_vector_integration.py b/test/backend/test_document_vector_integration.py index 4fb094618..33d97c776 100644 --- a/test/backend/test_document_vector_integration.py +++ b/test/backend/test_document_vector_integration.py @@ -36,10 +36,28 @@ consts_error_code_mock.ErrorCode = MagicMock() consts_exceptions_mock = MagicMock() consts_exceptions_mock.AppException = Exception +consts_prompt_template_mock = MagicMock() +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP = { + "duty_system_prompt": "DUTY_SYSTEM_PROMPT", + "constraint_system_prompt": "CONSTRAINT_SYSTEM_PROMPT", + "few_shots_system_prompt": "FEW_SHOTS_SYSTEM_PROMPT", + "agent_variable_name_system_prompt": "AGENT_VARIABLE_NAME_SYSTEM_PROMPT", + "agent_display_name_system_prompt": "AGENT_DISPLAY_NAME_SYSTEM_PROMPT", + "agent_description_system_prompt": "AGENT_DESCRIPTION_SYSTEM_PROMPT", + "user_prompt": "USER_PROMPT", + "agent_name_regenerate_system_prompt": "AGENT_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_name_regenerate_user_prompt": "AGENT_NAME_REGENERATE_USER_PROMPT", + "agent_display_name_regenerate_system_prompt": "AGENT_DISPLAY_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_display_name_regenerate_user_prompt": "AGENT_DISPLAY_NAME_REGENERATE_USER_PROMPT", +} +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELDS = tuple( + consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP.keys() +) sys.modules['consts'] = consts_mock sys.modules['consts.const'] = consts_const_mock sys.modules['consts.error_code'] = consts_error_code_mock sys.modules['consts.exceptions'] = consts_exceptions_mock +sys.modules['consts.prompt_template'] = consts_prompt_template_mock # Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/test/backend/test_document_vector_utils.py b/test/backend/test_document_vector_utils.py index 9bce2af29..53c87a022 100644 --- a/test/backend/test_document_vector_utils.py +++ b/test/backend/test_document_vector_utils.py @@ -35,10 +35,28 @@ consts_error_code_mock.ErrorCode = MagicMock() consts_exceptions_mock = MagicMock() consts_exceptions_mock.AppException = Exception +consts_prompt_template_mock = MagicMock() +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP = { + "duty_system_prompt": "DUTY_SYSTEM_PROMPT", + "constraint_system_prompt": "CONSTRAINT_SYSTEM_PROMPT", + "few_shots_system_prompt": "FEW_SHOTS_SYSTEM_PROMPT", + "agent_variable_name_system_prompt": "AGENT_VARIABLE_NAME_SYSTEM_PROMPT", + "agent_display_name_system_prompt": "AGENT_DISPLAY_NAME_SYSTEM_PROMPT", + "agent_description_system_prompt": "AGENT_DESCRIPTION_SYSTEM_PROMPT", + "user_prompt": "USER_PROMPT", + "agent_name_regenerate_system_prompt": "AGENT_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_name_regenerate_user_prompt": "AGENT_NAME_REGENERATE_USER_PROMPT", + "agent_display_name_regenerate_system_prompt": "AGENT_DISPLAY_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_display_name_regenerate_user_prompt": "AGENT_DISPLAY_NAME_REGENERATE_USER_PROMPT", +} +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELDS = tuple( + consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP.keys() +) sys.modules['consts'] = consts_mock sys.modules['consts.const'] = consts_const_mock sys.modules['consts.error_code'] = consts_error_code_mock sys.modules['consts.exceptions'] = consts_exceptions_mock +sys.modules['consts.prompt_template'] = consts_prompt_template_mock # Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/test/backend/test_document_vector_utils_coverage.py b/test/backend/test_document_vector_utils_coverage.py index 23a6923c8..2b4278603 100644 --- a/test/backend/test_document_vector_utils_coverage.py +++ b/test/backend/test_document_vector_utils_coverage.py @@ -34,10 +34,28 @@ consts_error_code_mock.ErrorCode = MagicMock() consts_exceptions_mock = MagicMock() consts_exceptions_mock.AppException = Exception +consts_prompt_template_mock = MagicMock() +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP = { + "duty_system_prompt": "DUTY_SYSTEM_PROMPT", + "constraint_system_prompt": "CONSTRAINT_SYSTEM_PROMPT", + "few_shots_system_prompt": "FEW_SHOTS_SYSTEM_PROMPT", + "agent_variable_name_system_prompt": "AGENT_VARIABLE_NAME_SYSTEM_PROMPT", + "agent_display_name_system_prompt": "AGENT_DISPLAY_NAME_SYSTEM_PROMPT", + "agent_description_system_prompt": "AGENT_DESCRIPTION_SYSTEM_PROMPT", + "user_prompt": "USER_PROMPT", + "agent_name_regenerate_system_prompt": "AGENT_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_name_regenerate_user_prompt": "AGENT_NAME_REGENERATE_USER_PROMPT", + "agent_display_name_regenerate_system_prompt": "AGENT_DISPLAY_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_display_name_regenerate_user_prompt": "AGENT_DISPLAY_NAME_REGENERATE_USER_PROMPT", +} +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELDS = tuple( + consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP.keys() +) sys.modules['consts'] = consts_mock sys.modules['consts.const'] = consts_const_mock sys.modules['consts.error_code'] = consts_error_code_mock sys.modules['consts.exceptions'] = consts_exceptions_mock +sys.modules['consts.prompt_template'] = consts_prompt_template_mock # Add backend to path before patching backend modules current_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/test/backend/test_summary_formatting.py b/test/backend/test_summary_formatting.py index be9d6a20d..247e20399 100644 --- a/test/backend/test_summary_formatting.py +++ b/test/backend/test_summary_formatting.py @@ -32,10 +32,28 @@ consts_error_code_mock.ErrorCode = MagicMock() consts_exceptions_mock = MagicMock() consts_exceptions_mock.AppException = Exception +consts_prompt_template_mock = MagicMock() +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP = { + "duty_system_prompt": "DUTY_SYSTEM_PROMPT", + "constraint_system_prompt": "CONSTRAINT_SYSTEM_PROMPT", + "few_shots_system_prompt": "FEW_SHOTS_SYSTEM_PROMPT", + "agent_variable_name_system_prompt": "AGENT_VARIABLE_NAME_SYSTEM_PROMPT", + "agent_display_name_system_prompt": "AGENT_DISPLAY_NAME_SYSTEM_PROMPT", + "agent_description_system_prompt": "AGENT_DESCRIPTION_SYSTEM_PROMPT", + "user_prompt": "USER_PROMPT", + "agent_name_regenerate_system_prompt": "AGENT_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_name_regenerate_user_prompt": "AGENT_NAME_REGENERATE_USER_PROMPT", + "agent_display_name_regenerate_system_prompt": "AGENT_DISPLAY_NAME_REGENERATE_SYSTEM_PROMPT", + "agent_display_name_regenerate_user_prompt": "AGENT_DISPLAY_NAME_REGENERATE_USER_PROMPT", +} +consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELDS = tuple( + consts_prompt_template_mock.PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP.keys() +) sys.modules['consts'] = consts_mock sys.modules['consts.const'] = consts_const_mock sys.modules['consts.error_code'] = consts_error_code_mock sys.modules['consts.exceptions'] = consts_exceptions_mock +sys.modules['consts.prompt_template'] = consts_prompt_template_mock # Add backend to path before patching backend modules sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend')) diff --git a/test/backend/utils/test_llm_utils.py b/test/backend/utils/test_llm_utils.py index 2052bba54..ba8a3c6e0 100644 --- a/test/backend/utils/test_llm_utils.py +++ b/test/backend/utils/test_llm_utils.py @@ -496,6 +496,30 @@ def gen(): res = call_llm_for_system_prompt(2, "u2", "s2") assert res == "ABC" + def test_call_llm_for_system_prompt_skips_chunk_without_choices(self, mocker: MockFixture): + mock_get_model_by_id = mocker.patch('backend.utils.llm_utils.get_model_by_model_id') + mock_get_model_name = mocker.patch('backend.utils.llm_utils.get_model_name_from_config') + mock_openai = mocker.patch('backend.utils.llm_utils.OpenAIModel') + + mock_get_model_by_id.return_value = {"base_url": "http://y", "api_key": "k2"} + mock_get_model_name.return_value = "gpt-6" + + mock_instance = mock_openai.return_value + + empty_chunk = MagicMock() + empty_chunk.choices = [] + + valid_chunk = MagicMock() + valid_chunk.choices = [MagicMock()] + valid_chunk.choices[0].delta.content = "OK" + + mock_instance.client = MagicMock() + mock_instance.client.chat.completions.create.return_value = [empty_chunk, valid_chunk] + mock_instance._prepare_completion_kwargs.return_value = {} + + res = call_llm_for_system_prompt(2, "u2", "s2") + assert res == "OK" + def test_call_llm_for_system_prompt_with_callback(self, mocker: MockFixture): """Test call_llm_for_system_prompt with callback""" mock_get_model_by_id = mocker.patch('backend.utils.llm_utils.get_model_by_model_id') @@ -1204,4 +1228,4 @@ def test_error_empty_message(self, mocker: MockFixture): with pytest.raises(AppException) as exc_info: call_llm_for_system_prompt(1, "user prompt", "system prompt") - assert exc_info.value.error_code == ErrorCode.MODEL_PROMPT_GENERATION_FAILED \ No newline at end of file + assert exc_info.value.error_code == ErrorCode.MODEL_PROMPT_GENERATION_FAILED