diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index 5a11b550b..8a3fbc807 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -247,7 +247,9 @@ async def create_model_config_list(tenant_id): ), url=record["base_url"], ssl_verify=record.get("ssl_verify", True), - model_factory=record.get("model_factory"))) + model_factory=record.get("model_factory"), + timeout_seconds=record.get("timeout_seconds"), + concurrency_limit=record.get("concurrency_limit"))) # fit for old version, main_model and sub_model use default model main_model_config = tenant_config_manager.get_model_config( key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id) @@ -258,7 +260,9 @@ async def create_model_config_list(tenant_id): "model_name") else "", url=main_model_config.get("base_url", ""), ssl_verify=main_model_config.get("ssl_verify", True), - model_factory=main_model_config.get("model_factory"))) + model_factory=main_model_config.get("model_factory"), + timeout_seconds=main_model_config.get("timeout_seconds"), + concurrency_limit=main_model_config.get("concurrency_limit"))) model_list.append( ModelConfig(cite_name="sub_model", api_key=main_model_config.get("api_key", ""), @@ -266,7 +270,9 @@ async def create_model_config_list(tenant_id): "model_name") else "", url=main_model_config.get("base_url", ""), ssl_verify=main_model_config.get("ssl_verify", True), - model_factory=main_model_config.get("model_factory"))) + model_factory=main_model_config.get("model_factory"), + timeout_seconds=main_model_config.get("timeout_seconds"), + concurrency_limit=main_model_config.get("concurrency_limit"))) return model_list diff --git a/backend/apps/config_app.py b/backend/apps/config_app.py index 0cfc962ea..22710c1e2 100644 --- a/backend/apps/config_app.py +++ b/backend/apps/config_app.py @@ -29,6 +29,7 @@ from apps.a2a_client_app import router as a2a_client_router from apps.monitoring_app import router as monitoring_router from apps.a2a_server_app import router as a2a_server_router +from apps.haotian_app import router as haotian_router from consts.const import IS_SPEED_MODE # Create logger instance @@ -71,3 +72,4 @@ app.include_router(invitation_router) app.include_router(a2a_client_router) app.include_router(a2a_server_router) +app.include_router(haotian_router) diff --git a/backend/consts/agent_unavailable_reasons.py b/backend/consts/agent_unavailable_reasons.py new file mode 100644 index 000000000..4e710ee7d --- /dev/null +++ b/backend/consts/agent_unavailable_reasons.py @@ -0,0 +1,43 @@ +""" +Agent Unavailable Reason Constants + +Centralized definition of all possible reasons why an agent may be unavailable. +These values are returned to the frontend via the 'unavailable_reasons' field. +""" + + +class AgentUnavailableReason: + """Reason codes for agent unavailability.""" + + # Identity conflicts + DUPLICATE_NAME = "duplicate_name" + DUPLICATE_DISPLAY_NAME = "duplicate_display_name" + + # Model issues + MODEL_NOT_CONFIGURED = "model_not_configured" + MODEL_UNAVAILABLE = "model_unavailable" + + # Tool issues + TOOL_UNAVAILABLE = "tool_unavailable" + ALL_TOOLS_DISABLED = "all_tools_disabled" + + # Agent issues + AGENT_NOT_FOUND = "agent_not_found" + + @classmethod + def all_reasons(cls) -> list[str]: + """Return all defined unavailable reason codes.""" + return [ + cls.DUPLICATE_NAME, + cls.DUPLICATE_DISPLAY_NAME, + cls.MODEL_NOT_CONFIGURED, + cls.MODEL_UNAVAILABLE, + cls.TOOL_UNAVAILABLE, + cls.ALL_TOOLS_DISABLED, + cls.AGENT_NOT_FOUND, + ] + + @classmethod + def is_valid_reason(cls, reason: str) -> bool: + """Check if a reason string is a valid reason code.""" + return reason in cls.all_reasons() diff --git a/backend/consts/model.py b/backend/consts/model.py index bcaffcae7..ae086a625 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -121,6 +121,8 @@ class ModelRequest(BaseModel): # STT specific fields model_appid: Optional[str] = None access_token: Optional[str] = None + timeout_seconds: Optional[int] = None + concurrency_limit: Optional[int] = None class ProviderModelRequest(BaseModel): @@ -756,6 +758,8 @@ class ManageTenantModelCreateRequest(BaseModel): # STT specific fields model_appid: Optional[str] = Field(None, description="Application ID for STT models (e.g., Volcano Engine)") access_token: Optional[str] = Field(None, description="Access token for STT models (e.g., Volcano Engine)") + timeout_seconds: Optional[int] = Field(None, description="Request timeout in seconds") + concurrency_limit: Optional[int] = Field(None, description="Maximum concurrent requests for this model") class ManageTenantModelUpdateRequest(BaseModel): @@ -776,6 +780,8 @@ class ManageTenantModelUpdateRequest(BaseModel): # STT specific fields model_appid: Optional[str] = Field(None, description="Application ID for STT models") access_token: Optional[str] = Field(None, description="Access token for STT models") + timeout_seconds: Optional[int] = Field(None, description="Request timeout in seconds") + concurrency_limit: Optional[int] = Field(None, description="Maximum concurrent requests for this model") class ManageTenantModelDeleteRequest(BaseModel): diff --git a/backend/database/db_models.py b/backend/database/db_models.py index baa8e903e..43e1bd7bc 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -182,6 +182,10 @@ class ModelRecord(TableBase): String(100), doc="Application ID for model authentication (used by some STT/TTS providers like Volcano Engine)") access_token = Column( String(100), doc="Access token for model authentication (used by some STT/TTS providers like Volcano Engine)") + timeout_seconds = Column( + Integer, doc="Request timeout in seconds for this model. Default is 120 seconds.") + concurrency_limit = Column( + Integer, doc="Maximum concurrent requests for this model. Default is null (unlimited).") class ModelMonitoringRecord(SimpleTableBase): diff --git a/backend/database/model_management_db.py b/backend/database/model_management_db.py index cb1c6c69f..d501fd52f 100644 --- a/backend/database/model_management_db.py +++ b/backend/database/model_management_db.py @@ -1,3 +1,4 @@ +import logging from typing import Any, Dict, List, Optional from sqlalchemy import and_, desc, func, insert, select, update @@ -7,6 +8,8 @@ from .db_models import ModelRecord from .utils import add_creation_tracking, add_update_tracking +logger = logging.getLogger("database.model_management_db") + def create_model_record(model_data: Dict[str, Any], user_id: str, tenant_id: str) -> bool: """ @@ -84,6 +87,58 @@ def update_model_record( return result.rowcount > 0 +def update_model_record_by_model_name( + model_name: str, + update_data: Dict[str, Any], + user_id: Optional[str] = None, + tenant_id: Optional[str] = None, + model_repo: Optional[str] = None +) -> bool: + """ + Update a model record by model_name and tenant_id. + + Args: + model_name: Model name (display name, not the primary key) + update_data: Dictionary containing update data + user_id: Reserved parameter for filling updated_by field + tenant_id: Tenant ID for filtering + model_repo: Optional model repo for more precise matching + + Returns: + bool: Whether the operation was successful + """ + import logging + db_logger = logging.getLogger("database.client") + + with get_db_session() as session: + # Data cleaning + cleaned_data = db_client.clean_string_values(update_data) + + # Add update timestamp + cleaned_data["update_time"] = func.current_timestamp() + if user_id: + cleaned_data = add_update_tracking(cleaned_data, user_id) + + db_logger.debug(f"update_model_record_by_model_name: model_name={model_name}, model_repo={model_repo}, tenant_id={tenant_id}, cleaned_data={cleaned_data}") + + # Build conditions list + conditions = [ + ModelRecord.model_name == model_name, + ModelRecord.tenant_id == tenant_id + ] + if model_repo: + conditions.append(ModelRecord.model_repo == model_repo) + + # Build the update statement + stmt = update(ModelRecord).where(*conditions).values(cleaned_data) + + # Execute the update statement + result = session.execute(stmt) + db_logger.info(f"update_model_record_by_model_name: rowcount={result.rowcount}") + + return result.rowcount > 0 + + def delete_model_record(model_id: int, user_id: str, tenant_id: str) -> bool: """ Delete a model record (soft delete) and update the update timestamp diff --git a/backend/services/agent_service.py b/backend/services/agent_service.py index 02fa7d8c6..ae0274f34 100644 --- a/backend/services/agent_service.py +++ b/backend/services/agent_service.py @@ -19,6 +19,7 @@ from consts.const import MEMORY_SEARCH_START_MSG, MEMORY_SEARCH_DONE_MSG, MEMORY_SEARCH_FAIL_MSG, TOOL_TYPE_MAPPING, \ LANGUAGE, MESSAGE_ROLE, MODEL_CONFIG_MAPPING, CAN_EDIT_ALL_USER_ROLES, PERMISSION_EDIT, PERMISSION_READ, PERMISSION_PRIVATE from consts.exceptions import MemoryPreparationException +from consts.agent_unavailable_reasons import AgentUnavailableReason from consts.model import ( AgentInfoRequest, AgentRequest, @@ -1533,8 +1534,8 @@ def _mark_duplicates(groups: dict[str, list[dict]], reason_key: str) -> None: for duplicate_entry in sorted_entries[1:]: duplicate_entry["unavailable_reasons"].append(reason_key) - _mark_duplicates(name_groups, "duplicate_name") - _mark_duplicates(display_name_groups, "duplicate_display_name") + _mark_duplicates(name_groups, AgentUnavailableReason.DUPLICATE_NAME) + _mark_duplicates(display_name_groups, AgentUnavailableReason.DUPLICATE_DISPLAY_NAME) def _collect_model_availability_reasons(agent: dict, tenant_id: str, model_cache: Dict[int, Optional[dict]]) -> list[str]: @@ -1546,7 +1547,7 @@ def _collect_model_availability_reasons(agent: dict, tenant_id: str, model_cache model_id=agent.get("model_id"), tenant_id=tenant_id, model_cache=model_cache, - reason_key="model_unavailable" + reason_key=AgentUnavailableReason.MODEL_UNAVAILABLE )) return reasons @@ -1604,7 +1605,7 @@ def check_agent_availability( agent_info = search_agent_info_by_agent_id(agent_id, tenant_id) if not agent_info: - return False, ["agent_not_found"] + return False, [AgentUnavailableReason.AGENT_NOT_FOUND] # Check tool availability tool_info = search_tools_for_sub_agent(agent_id=agent_id, tenant_id=tenant_id) @@ -1612,7 +1613,7 @@ def check_agent_availability( if tool_id_list: tool_statuses = check_tool_is_available(tool_id_list) if not all(tool_statuses): - unavailable_reasons.append("tool_unavailable") + unavailable_reasons.append(AgentUnavailableReason.TOOL_UNAVAILABLE) # Check model availability model_reasons = _collect_model_availability_reasons( diff --git a/backend/services/agent_version_service.py b/backend/services/agent_version_service.py index 69163dbc6..397361059 100644 --- a/backend/services/agent_version_service.py +++ b/backend/services/agent_version_service.py @@ -33,6 +33,7 @@ ) from database.model_management_db import get_model_by_model_id from utils.str_utils import convert_string_to_list +from consts.agent_unavailable_reasons import AgentUnavailableReason logger = logging.getLogger("agent_version_service") @@ -337,21 +338,18 @@ def _check_version_snapshot_availability( # Check if agent info exists if not agent_info: - return False, ["agent_not_found"] + return False, [AgentUnavailableReason.AGENT_NOT_FOUND] # Check model availability model_id = agent_info.get('model_id') if model_id is None or model_id == 0: - unavailable_reasons.append("model_not_configured") + unavailable_reasons.append(AgentUnavailableReason.MODEL_NOT_CONFIGURED) - # Check tools availability - if not tool_instances: - unavailable_reasons.append("no_tools") - else: - # Check if at least one tool is enabled + # Check tools availability (only when tools are configured) + if tool_instances: has_enabled_tool = any(t.get('enabled', True) for t in tool_instances) if not has_enabled_tool: - unavailable_reasons.append("all_tools_disabled") + unavailable_reasons.append(AgentUnavailableReason.ALL_TOOLS_DISABLED) return len(unavailable_reasons) == 0, unavailable_reasons diff --git a/backend/services/conversation_management_service.py b/backend/services/conversation_management_service.py index d5d4a85a4..302ec63a8 100644 --- a/backend/services/conversation_management_service.py +++ b/backend/services/conversation_management_service.py @@ -248,6 +248,8 @@ def call_llm_for_title(question: str, tenant_id: str, language: str = LANGUAGE[" display_name = model_config.get("display_name", "") if model_config else "" set_monitoring_operation("title_generation", display_name=display_name or None) + timeout_seconds = model_config.get("timeout_seconds") if model_config else None + # Create OpenAIModel instance llm = OpenAIModel( model_id=get_model_name_from_config(model_config) if model_config.get("model_name") else "", @@ -256,7 +258,9 @@ def call_llm_for_title(question: str, tenant_id: str, language: str = LANGUAGE[" temperature=0.7, top_p=0.95, model_factory=model_config.get("model_factory", None), - ssl_verify=model_config.get("ssl_verify", True) + ssl_verify=model_config.get("ssl_verify", True), + timeout_seconds=timeout_seconds, + stream=False, ) # Build messages - use new template variable 'question' instead of 'content' diff --git a/backend/services/file_management_service.py b/backend/services/file_management_service.py index b5cd048bf..7dad75a0a 100644 --- a/backend/services/file_management_service.py +++ b/backend/services/file_management_service.py @@ -352,6 +352,7 @@ def get_llm_model(tenant_id: str): # Get the tenant config main_model_config = tenant_config_manager.get_model_config( key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id) + timeout_seconds = main_model_config.get("timeout_seconds") if main_model_config else None long_text_to_text_model = OpenAILongContextModel( observer=MessageObserver(), model_id=get_model_name_from_config(main_model_config), @@ -359,6 +360,7 @@ def get_llm_model(tenant_id: str): api_key=main_model_config.get("api_key"), max_context_tokens=main_model_config.get("max_tokens"), ssl_verify=main_model_config.get("ssl_verify", True), + timeout_seconds=timeout_seconds, ) return long_text_to_text_model diff --git a/backend/services/haotian_service.py b/backend/services/haotian_service.py index a49079ec7..e7f762244 100644 --- a/backend/services/haotian_service.py +++ b/backend/services/haotian_service.py @@ -11,6 +11,8 @@ logger = logging.getLogger("haotian_service") +_DEFAULT_KNOWLEDGE_BASE_ID = "abcdefg" + def _normalize_list_payload(raw: Dict[str, Any]) -> Dict[str, Any]: """ @@ -24,7 +26,7 @@ def _normalize_list_payload(raw: Dict[str, Any]) -> Dict[str, Any]: ] } - This function also filters out knowledge sets with name == "Public". + When dify_dataset_id is "null", it is replaced with the default ID. """ knowledge_sets = raw.get("knowledge_sets", []) if not isinstance(knowledge_sets, list): @@ -35,7 +37,7 @@ def _normalize_list_payload(raw: Dict[str, Any]) -> Dict[str, Any]: if not isinstance(ks, dict): continue set_name = str(ks.get("name", "") or "").strip() - if not set_name or set_name == "Public": + if not set_name: continue bases = ks.get("knowledge_bases", []) @@ -48,15 +50,18 @@ def _normalize_list_payload(raw: Dict[str, Any]) -> Dict[str, Any]: continue dataset_id = str(kb.get("dify_dataset_id", "") or "").strip() kb_name = str(kb.get("name", "") or "").strip() - if not dataset_id or not kb_name: + if not kb_name: continue + if dataset_id == "null" or not dataset_id: + dataset_id = _DEFAULT_KNOWLEDGE_BASE_ID normalized_bases.append( {"dify_dataset_id": dataset_id, "name": kb_name} ) - normalized_sets.append( - {"name": set_name, "knowledge_bases": normalized_bases} - ) + if normalized_bases: + normalized_sets.append( + {"name": set_name, "knowledge_bases": normalized_bases} + ) return {"knowledge_sets": normalized_sets} @@ -77,7 +82,7 @@ async def fetch_haotian_knowledge_sets_impl( ) headers = {"Authorization": external_authorization} - async with httpx.AsyncClient(timeout=timeout_s, follow_redirects=True) as client: + async with httpx.AsyncClient(timeout=timeout_s, follow_redirects=True, trust_env=False) as client: resp = await client.get(list_url, headers=headers) if resp.status_code >= 400: raise RuntimeError( diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py index a20b2a6ca..73adacc00 100644 --- a/backend/services/model_health_service.py +++ b/backend/services/model_health_service.py @@ -29,6 +29,7 @@ async def _embedding_dimension_check( model_base_url: str, model_api_key: str, ssl_verify: bool = True, + timeout_seconds: Optional[float] = None, ): # Test connectivity based on different model types if model_type == "embedding": @@ -38,6 +39,7 @@ async def _embedding_dimension_check( api_key=model_api_key, embedding_dim=0, ssl_verify=ssl_verify, + timeout_seconds=timeout_seconds, ).dimension_check() if len(embedding) > 0: return len(embedding[0]) @@ -51,6 +53,7 @@ async def _embedding_dimension_check( api_key=model_api_key, embedding_dim=0, ssl_verify=ssl_verify, + timeout_seconds=timeout_seconds, ).dimension_check() if len(embedding) > 0: return len(embedding[0]) @@ -71,6 +74,7 @@ async def _perform_connectivity_check( model_appid: Optional[str] = None, access_token: Optional[str] = None, display_name: Optional[str] = None, + timeout_seconds: Optional[float] = None, ) -> bool: """ Perform specific model connectivity check @@ -80,6 +84,8 @@ async def _perform_connectivity_check( model_base_url: Model base URL model_api_key: API key ssl_verify: Whether to verify SSL certificates (default: True) + display_name: Optional display name for monitoring + timeout_seconds: Optional request timeout in seconds Returns: bool: Connectivity check result """ @@ -91,21 +97,23 @@ async def _perform_connectivity_check( # Test connectivity based on different model types if model_type == "embedding": - connectivity = len(await OpenAICompatibleEmbedding( + embedding = OpenAICompatibleEmbedding( model_name=model_name, base_url=model_base_url, api_key=model_api_key, embedding_dim=0, - ssl_verify=ssl_verify - ).dimension_check()) > 0 + ssl_verify=ssl_verify, + ) + connectivity = len(await embedding.dimension_check(timeout=timeout_seconds if timeout_seconds else 5.0)) > 0 elif model_type == "multi_embedding": - connectivity = len(await JinaEmbedding( + embedding = JinaEmbedding( model_name=model_name, base_url=model_base_url, api_key=model_api_key, embedding_dim=0, - ssl_verify=ssl_verify - ).dimension_check()) > 0 + ssl_verify=ssl_verify, + ) + connectivity = len(await embedding.dimension_check(timeout=timeout_seconds if timeout_seconds else 5.0)) > 0 elif model_type == "llm": observer = MessageObserver() set_monitoring_operation("connectivity_check", @@ -115,7 +123,8 @@ async def _perform_connectivity_check( model_id=model_name, api_base=model_base_url, api_key=model_api_key, - ssl_verify=ssl_verify + ssl_verify=ssl_verify, + timeout_seconds=timeout_seconds, ).check_connectivity() elif model_type == "rerank": rerank_model = OpenAICompatibleRerank( @@ -192,14 +201,22 @@ async def check_model_connectivity(display_name: str, tenant_id: str) -> dict: model_factory = model.get("model_factory") model_appid = model.get("model_appid") access_token = model.get("access_token") + timeout_seconds = model.get("timeout_seconds") try: set_monitoring_context(tenant_id=tenant_id) + ssl_verify_fallback = False connectivity = await _perform_connectivity_check( model_name, model_type, model_base_url, model_api_key, ssl_verify, - model_factory, model_appid, access_token,display_name=display_name, + model_factory, model_appid, access_token, display_name, timeout_seconds, ) + if not connectivity and ssl_verify: + ssl_verify_fallback = True + connectivity = await _perform_connectivity_check( + model_name, model_type, model_base_url, model_api_key, False, + model_factory, model_appid, access_token, display_name, timeout_seconds, + ) except Exception as e: update_data = { "connect_status": ModelConnectStatusEnum.UNAVAILABLE.value} @@ -215,6 +232,8 @@ async def check_model_connectivity(display_name: str, tenant_id: str) -> dict: f"UNCONNECTED: {model_name}") connect_status = ModelConnectStatusEnum.AVAILABLE.value if connectivity else ModelConnectStatusEnum.UNAVAILABLE.value update_data = {"connect_status": connect_status} + if ssl_verify_fallback: + update_data["ssl_verify"] = False update_model_record(model["model_id"], update_data) return { "connectivity": connectivity, @@ -245,16 +264,18 @@ async def verify_model_config_connectivity(model_config: dict): model_factory = model_config.get("model_factory") model_appid = model_config.get("model_appid") access_token = model_config.get("access_token") + # Get timeout from model config if present + timeout_seconds = model_config.get("timeout_seconds") try: connectivity = await _perform_connectivity_check( model_name, model_type, model_base_url, model_api_key, ssl_verify, - model_factory, model_appid, access_token + model_factory, model_appid, access_token, None, timeout_seconds, ) if not connectivity and ssl_verify: connectivity = await _perform_connectivity_check( model_name, model_type, model_base_url, model_api_key, False, - model_factory, model_appid, access_token + model_factory, model_appid, access_token, None, timeout_seconds, ) if not connectivity: error_msg = f"Failed to connect to model '{model_name}' at {model_base_url}. Please verify the URL, API key, and network connection." @@ -296,9 +317,17 @@ async def embedding_dimension_check(model_config: dict): try: ssl_verify = model_config.get("ssl_verify", True) + timeout_seconds = model_config.get("timeout_seconds") dimension = await _embedding_dimension_check( - model_name, model_type, model_base_url, model_api_key, ssl_verify + model_name, model_type, model_base_url, model_api_key, ssl_verify, + timeout_seconds=timeout_seconds ) + # Fallback to ssl_verify=False if initial check fails + if dimension == 0 and ssl_verify: + dimension = await _embedding_dimension_check( + model_name, model_type, model_base_url, model_api_key, False, + timeout_seconds=timeout_seconds + ) return dimension except ValueError as e: logger.error(f"Error checking embedding dimension: {str(e)}") diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index d012803be..ab0e52259 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -13,10 +13,11 @@ get_model_records, get_models_by_tenant_factory_type, update_model_record, + update_model_record_by_model_name, ) from services.model_provider_service import ( prepare_model_dict, - merge_existing_model_tokens, + merge_existing_model_attributes, get_provider_models, ) from services.model_health_service import embedding_dimension_check @@ -45,9 +46,15 @@ async def create_model_for_tenant(user_id: str, tenant_id: str, model_data: Dict model_base_url.replace(LOCALHOST_NAME, DOCKER_INTERNAL_HOST) .replace(LOCALHOST_IP, DOCKER_INTERNAL_HOST) ) - model_data['ssl_verify'] = True - if "open/router" in model_base_url: - model_data['ssl_verify'] = False + # Auto-set ssl_verify based on api_key: + # - Empty api_key (local/LAN services) -> ssl_verify=False + # - "open/router" URL -> ssl_verify=False + # - Otherwise -> ssl_verify=True + model_api_key = model_data.get("api_key", "") + if not model_api_key or "open/router" in model_base_url: + model_data["ssl_verify"] = False + else: + model_data["ssl_verify"] = True # Split model_name into repo and name model_repo, model_name = split_repo_name( model_data["model_name"]) if model_data.get("model_name") else ("", "") @@ -114,8 +121,8 @@ async def create_provider_models_for_tenant(tenant_id: str, provider_request: Di # Get provider model list model_list = await get_provider_models(provider_request) - # Merge existing model's max_tokens attribute - model_list = merge_existing_model_tokens( + # Merge existing model's attributes (max_tokens, api_key, timeout_seconds, concurrency_limit) + model_list = merge_existing_model_attributes( model_list, tenant_id, provider_request["provider"], provider_request["model_type"]) # Sort model list by ID @@ -251,6 +258,15 @@ async def update_single_model_for_tenant( m.get("model_type") == "multi_embedding" for m in existing_models ) + # Auto-set ssl_verify based on api_key if provided: + # - Empty api_key -> ssl_verify=False + # - Otherwise -> ssl_verify=True + if "api_key" in model_data: + if not model_data["api_key"]: + model_data["ssl_verify"] = False + else: + model_data["ssl_verify"] = True + if has_multi_embedding: # Update both embedding and multi_embedding records for model in existing_models: @@ -276,12 +292,31 @@ async def update_single_model_for_tenant( async def batch_update_models_for_tenant(user_id: str, tenant_id: str, model_list: List[Dict[str, Any]]): - """Batch update models for a tenant.""" + """Batch update models for a tenant by model_id or model_name.""" try: for model in model_list: - update_model_record(model["model_id"], model, user_id, tenant_id) - - logging.debug("Batch update models successfully") + # Build update data excluding id fields + update_data = {k: v for k, v in model.items() if k not in ["model_id", "model_name"]} + + model_id_or_name = model.get("model_id") or model.get("model_name") + + # Check if model_id is a numeric string (primary key) + if model_id_or_name and model_id_or_name.isdigit(): + # Use model_id (primary key) for update + logging.info(f"[DEBUG] Updating model by id: model_id={model_id_or_name}, tenant_id={tenant_id}, update_data={update_data}") + update_model_record(int(model_id_or_name), update_data, user_id, tenant_id) + else: + # Parse "model_repo/model_name" format from frontend's model_id field + if "/" in model_id_or_name: + model_repo, model_name = model_id_or_name.split("/", 1) + else: + model_repo = None + model_name = model_id_or_name + + logging.info(f"[DEBUG] Updating model by name: model_name={model_name}, model_repo={model_repo}, tenant_id={tenant_id}, update_data={update_data}") + update_model_record_by_model_name(model_name, update_data, user_id, tenant_id, model_repo) + + logging.info("[DEBUG] Batch update models successfully") except Exception as e: logging.error(f"Failed to batch update models: {str(e)}") raise Exception(f"Failed to batch update models: {str(e)}") diff --git a/backend/services/model_provider_service.py b/backend/services/model_provider_service.py index dbff17082..9b9f26bd4 100644 --- a/backend/services/model_provider_service.py +++ b/backend/services/model_provider_service.py @@ -100,11 +100,13 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a # Build the canonical representation using the existing Pydantic schema for # consistency of validation and default handling. # For embedding/multi_embedding models, max_tokens will be set via connectivity check later, - # so use 0 as placeholder if not provided + # so use 0 as placeholder if not provided. + # Set default timeout_seconds to 120 for LLM models (embedding models don't need it). model_type = model["model_type"] is_embedding_type = model_type in ["embedding", "multi_embedding"] max_tokens_value = model.get( "max_tokens", 0) if not is_embedding_type else 0 + timeout_seconds_value = 120 if not is_embedding_type else None model_obj = ModelRequest( model_factory=provider, @@ -115,7 +117,8 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a display_name=model_display_name, expected_chunk_size=expected_chunk_size, maximum_chunk_size=maximum_chunk_size, - chunk_batch=chunk_batch + chunk_batch=chunk_batch, + timeout_seconds=timeout_seconds_value ) model_dict = model_obj.model_dump() @@ -155,19 +158,29 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a return model_dict -def merge_existing_model_tokens(model_list: List[dict], tenant_id: str, provider: str, model_type: str) -> List[dict]: +def merge_existing_model_attributes( + model_list: List[dict], + tenant_id: str, + provider: str, + model_type: str, + fields: List[str] = None +) -> List[dict]: """ - Merge existing model's max_tokens attribute into the model list. + Merge existing model's attributes into the model list. Args: model_list: List of models tenant_id: Tenant ID provider: Provider model_type: Model type + fields: List of fields to merge (defaults to max_tokens, api_key, timeout_seconds, concurrency_limit) Returns: List[dict]: Merged model list """ + if fields is None: + fields = ["max_tokens", "api_key", "timeout_seconds", "concurrency_limit"] + if model_type == "embedding" or model_type == "multi_embedding": return model_list @@ -184,15 +197,35 @@ def merge_existing_model_tokens(model_list: List[dict], tenant_id: str, provider "/" + existing_model["model_name"] existing_model_map[model_full_name] = existing_model - # Iterate through the model list, if the model exists in the existing model list, add max_tokens attribute + # Iterate through the model list, merge specified fields from existing models for model in model_list: if model.get("id") in existing_model_map: - model["max_tokens"] = existing_model_map[model.get( - "id")].get("max_tokens") + existing_model = existing_model_map[model.get("id")] + for field in fields: + if existing_model.get(field) is not None: + model[field] = existing_model.get(field) return model_list +def merge_existing_model_tokens(model_list: List[dict], tenant_id: str, provider: str, model_type: str) -> List[dict]: + """ + Merge existing model's max_tokens attribute into the model list. + + DEPRECATED: Use merge_existing_model_attributes instead. + + Args: + model_list: List of models + tenant_id: Tenant ID + provider: Provider + model_type: Model type + + Returns: + List[dict]: Merged model list + """ + return merge_existing_model_attributes(model_list, tenant_id, provider, model_type, ["max_tokens"]) + + # Re-export provider classes for backward compatibility __all__ = [ "AbstractModelProvider", @@ -200,6 +233,7 @@ def merge_existing_model_tokens(model_list: List[dict], tenant_id: str, provider "ModelEngineProvider", "prepare_model_dict", "merge_existing_model_tokens", + "merge_existing_model_attributes", "get_provider_models", "get_model_engine_raw_url", ] diff --git a/backend/services/prompt_service.py b/backend/services/prompt_service.py index aa4d420d5..e0f5f96a0 100644 --- a/backend/services/prompt_service.py +++ b/backend/services/prompt_service.py @@ -259,19 +259,51 @@ def generate_system_prompt(sub_agent_info_list, task_description, tool_info_list stop_flags = {"duty": False, "constraint": False, "few_shots": False, "agent_var_name": False, "agent_display_name": False, "agent_description": False} - # Start all generation threads + # Get model concurrency limit to control the number of concurrent LLM calls + # If None or >= 6, no limit (all 6 calls run concurrently) + # If < 6, use semaphore to limit concurrent calls + from database.model_management_db import get_model_by_model_id + model_config = get_model_by_model_id(model_id, tenant_id) + concurrency_limit = model_config.get("concurrency_limit") if model_config else None + + # Start all generation threads with concurrency control threads, error_holder = _start_generation_threads( - content, prompt_for_generate, produce_queue, latest, stop_flags, tenant_id, model_id) + content, prompt_for_generate, produce_queue, latest, stop_flags, tenant_id, model_id, + concurrency_limit=concurrency_limit + ) # Stream results yield from _stream_results(produce_queue, latest, stop_flags, threads, error_holder) -def _start_generation_threads(content, prompt_for_generate, produce_queue, latest, stop_flags, tenant_id, model_id): - """Start all prompt generation threads""" +def _start_generation_threads(content, prompt_for_generate, produce_queue, latest, stop_flags, tenant_id, model_id, + concurrency_limit: Optional[int] = None): + """Start all prompt generation threads with optional concurrency control. + + Args: + concurrency_limit: Maximum concurrent LLM calls. If None or >= 6, no limit. + If < 6, use semaphore to control concurrency. + """ # Shared error tracking across threads error_holder = {"error": None} + # Total number of generation tasks + total_tasks = 6 + + # Determine effective concurrency limit + # None means unlimited, 0 or negative means unlimited + if concurrency_limit is None or concurrency_limit <= 0 or concurrency_limit >= total_tasks: + effective_limit = None + else: + effective_limit = concurrency_limit + + # Use semaphore if concurrency is limited + semaphore = threading.Semaphore(effective_limit) if effective_limit else None + if semaphore: + logger.info(f"Using concurrency limit of {effective_limit} for prompt generation (total tasks: {total_tasks})") + else: + logger.info("Using unlimited concurrency for prompt generation") + def make_callback(tag): def callback_fn(current_text): latest[tag] = current_text @@ -280,8 +312,16 @@ def callback_fn(current_text): def run_and_flag(tag, sys_prompt): try: - call_llm_for_system_prompt( - model_id, content, sys_prompt, make_callback(tag), tenant_id) + # Acquire semaphore before starting (if limited) + if semaphore: + semaphore.acquire() + try: + call_llm_for_system_prompt( + model_id, content, sys_prompt, make_callback(tag), tenant_id) + finally: + # Always release semaphore after completion + if semaphore: + semaphore.release() except Exception as e: logger.error(f"Error in {tag} generation: {e}") error_holder["error"] = e diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index 5e5229ff6..0f779cb98 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -130,11 +130,15 @@ def get_local_tools() -> List[ToolInfo]: if hasattr(param.default, 'exclude') and param.default.exclude: continue + # Check if default is a Pydantic FieldInfo (has .default attribute) + is_pydantic_field = hasattr(param.default, 'default') + # Get description in both languages - param_description = param.default.description if hasattr(param.default, 'description') else "" + param_description = param.default.description if is_pydantic_field else "" # First try to get from param.default.description_zh (FieldInfo) - param_description_zh = param.default.description_zh if hasattr(param.default, 'description_zh') else None + # Note: Pydantic Field doesn't have description_zh attribute, so use getattr with default + param_description_zh = getattr(param.default, 'description_zh', None) if is_pydantic_field else None # Fallback to init_param_descriptions if not found if param_description_zh is None and param_name in init_param_descriptions: @@ -146,11 +150,21 @@ def get_local_tools() -> List[ToolInfo]: "description": param_description, "description_zh": param_description_zh } - if param.default.default is PydanticUndefined: - param_info["optional"] = False + + # Handle both Pydantic FieldInfo and simple defaults + if is_pydantic_field: + if param.default.default is PydanticUndefined: + param_info["optional"] = False + else: + param_info["default"] = param.default.default + param_info["optional"] = True else: - param_info["default"] = param.default.default - param_info["optional"] = True + # Simple default value (not a FieldInfo) + if param.default == inspect.Parameter.empty: + param_info["optional"] = False + else: + param_info["default"] = param.default + param_info["optional"] = True init_params_list.append(param_info) diff --git a/backend/utils/llm_utils.py b/backend/utils/llm_utils.py index e99b9f384..7d6b0dc17 100644 --- a/backend/utils/llm_utils.py +++ b/backend/utils/llm_utils.py @@ -73,6 +73,8 @@ def call_llm_for_system_prompt( set_monitoring_operation("system_prompt_generation", display_name=display_name or None) + timeout_seconds = llm_model_config.get("timeout_seconds") if llm_model_config else None + llm = OpenAIModel( model_id=get_model_name_from_config(llm_model_config) if llm_model_config else "", api_base=llm_model_config.get("base_url", "") if llm_model_config else "", @@ -82,6 +84,7 @@ def call_llm_for_system_prompt( model_factory=llm_model_config.get("model_factory") if llm_model_config else None, ssl_verify=llm_model_config.get("ssl_verify", True) if llm_model_config else True, display_name=display_name or None, + timeout_seconds=timeout_seconds, ) messages = [ {"role": MESSAGE_ROLE["SYSTEM"], "content": system_prompt}, @@ -100,6 +103,15 @@ def call_llm_for_system_prompt( reasoning_content_seen = False content_tokens_seen = 0 for chunk in current_request: + # Safety check: skip non-standard chunks that lack expected attributes + if not hasattr(chunk, 'choices'): + if hasattr(chunk, '__str__'): + logger.warning(f"Received non-standard chunk (no 'choices'): {str(chunk)[:200]}") + continue + + if not chunk.choices: + continue + delta = chunk.choices[0].delta reasoning_content = getattr(delta, "reasoning_content", None) new_token = delta.content diff --git a/backend/utils/tool_utils.py b/backend/utils/tool_utils.py index f06f36bc3..f1d9147e3 100644 --- a/backend/utils/tool_utils.py +++ b/backend/utils/tool_utils.py @@ -46,7 +46,8 @@ def get_local_tools_description_zh() -> Dict[str, Dict]: if hasattr(param.default, 'exclude') and param.default.exclude: continue - param_description_zh = param.default.description_zh if hasattr(param.default, 'description_zh') else None + # Note: Pydantic Field doesn't have description_zh attribute + param_description_zh = getattr(param.default, 'description_zh', None) if hasattr(param.default, 'description_zh') else None if param_description_zh is None and param_name in init_param_descriptions: param_description_zh = init_param_descriptions[param_name].get('description_zh') diff --git a/docker/docker-compose.prod.yml b/docker/docker-compose.prod.yml index 934fe8b2f..3cc7ac59a 100644 --- a/docker/docker-compose.prod.yml +++ b/docker/docker-compose.prod.yml @@ -78,6 +78,8 @@ services: - ${ROOT_DIR}/openssh-server/ssh-keys:/opt/ssh-keys:ro - ${ROOT_DIR}/scripts/sync_user_supabase2pg.py:/opt/sync_user_supabase2pg.py:ro - /var/run/docker.sock:/var/run/docker.sock:ro # Docker socket for MCP container management + # CA certificates for external service SSL verification (e.g., SMTP) + - /etc/ssl/certs:/etc/ssl/certs:ro environment: <<: [*minio-vars, *es-vars] skip_proxy: "true" diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 89088f2c3..4056683dc 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -89,6 +89,8 @@ services: - ${ROOT_DIR}/openssh-server/ssh-keys:/opt/ssh-keys:ro - ${ROOT_DIR}/scripts/sync_user_supabase2pg.py:/opt/sync_user_supabase2pg.py:ro - /var/run/docker.sock:/var/run/docker.sock:ro # Docker socket for MCP container management + # CA certificates for external service SSL verification (e.g., SMTP) + - /etc/ssl/certs:/etc/ssl/certs:ro environment: <<: [*minio-vars, *es-vars] skip_proxy: "true" diff --git a/docker/sql/v2.1.1_0507_add_concurrency_and_timeout_to_model_record_t.sql b/docker/sql/v2.1.1_0507_add_concurrency_and_timeout_to_model_record_t.sql new file mode 100644 index 000000000..59632f8ed --- /dev/null +++ b/docker/sql/v2.1.1_0507_add_concurrency_and_timeout_to_model_record_t.sql @@ -0,0 +1,13 @@ +-- Add concurrency_limit column to model_record_t table +ALTER TABLE nexent.model_record_t +ADD COLUMN IF NOT EXISTS concurrency_limit INTEGER DEFAULT NULL; + +-- Add comment to the column +COMMENT ON COLUMN nexent.model_record_t.concurrency_limit IS 'Maximum concurrent requests for this model. Default is NULL (unlimited).'; + +-- Add timeout_seconds column to model_record_t table +ALTER TABLE nexent.model_record_t +ADD COLUMN IF NOT EXISTS timeout_seconds INTEGER DEFAULT 120; + +-- Add comment to the column +COMMENT ON COLUMN nexent.model_record_t.timeout_seconds IS 'Request timeout in seconds for this model. Default is 120 seconds.'; diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx index 53c6d3f03..39c3bbce2 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx @@ -1474,10 +1474,21 @@ export default function ToolConfigModal({ case TOOL_PARAM_TYPES.ARRAY: case TOOL_PARAM_TYPES.OBJECT: default: - // Check if parameter name contains "password" for secure input - const isPasswordType = param.name.toLowerCase().includes("password"); + // Check if parameter name indicates a secure/sensitive field + const sensitivePatterns = [ + "password", + "authorization", + "api_key", + "apikey", + "api-key", + "secret", + "token", + ]; + const isSecureField = sensitivePatterns.some((pattern) => + param.name.toLowerCase().includes(pattern) + ); - if (isPasswordType) { + if (isSecureField) { return ( m.name === editedAgent.model || m.displayName === editedAgent.model + ); + let effectiveMainAgentModel = initialAgentInfo.mainAgentModel; + let effectiveMainAgentModelId = editedAgent.model_id || 0; + + if (!agentModelAvailable && defaultLlmModel) { + // Agent's original model is no longer available, switch to default model + effectiveMainAgentModel = defaultLlmModel.displayName || ""; + effectiveMainAgentModelId = defaultLlmModel.id || 0; + // Update the initialAgentInfo with the new model + initialAgentInfo.mainAgentModel = effectiveMainAgentModel; + } + const initialBusinessInfo = { businessDescription: editedAgent.business_description || "", businessLogicModelName: @@ -291,12 +306,18 @@ export default function AgentGenerateDetail({ setBusinessInfo(initialBusinessInfo); form.setFieldsValue(initialAgentInfo); - // Sync model to store if not already set (e.g., in create mode with default model) + // Sync model to store (use default model if original is unavailable) if (isCreatingMode && defaultLlmModel) { updateProfileInfo({ model: defaultLlmModel.displayName || "", model_id: defaultLlmModel.id || 0, }); + } else if (!agentModelAvailable && defaultLlmModel) { + // Update model in store when original model is no longer available + updateProfileInfo({ + model: effectiveMainAgentModel, + model_id: effectiveMainAgentModelId, + }); } // Sync max_step to store in create mode (default to 5) if (isCreatingMode && !editedAgent.max_step) { @@ -310,7 +331,7 @@ export default function AgentGenerateDetail({ }); } - }, [currentAgentId, defaultLlmModel?.id, isCreatingMode, forceRefreshKey]); + }, [currentAgentId, defaultLlmModel, isCreatingMode, forceRefreshKey, availableLlmModels]); // Default to selecting all groups when creating a new agent. // Only applies when groups are loaded and no group is selected yet. @@ -609,7 +630,7 @@ export default function AgentGenerateDetail({ { agent_id: effectiveAgentId, task_description: businessInfo.businessDescription, - model_id: businessInfo.businessLogicModelId.toString(), + model_id: businessInfo.businessLogicModelId, sub_agent_ids: editedAgent.sub_agent_id_list, tool_ids: Array.isArray(editedAgent.tools) ? editedAgent.tools.map((tool: any) => diff --git a/frontend/app/[locale]/agents/components/agentManage/AgentList.tsx b/frontend/app/[locale]/agents/components/agentManage/AgentList.tsx index edfeff559..4a4046c9b 100644 --- a/frontend/app/[locale]/agents/components/agentManage/AgentList.tsx +++ b/frontend/app/[locale]/agents/components/agentManage/AgentList.tsx @@ -24,6 +24,7 @@ import { clearAgentNewMark } from "@/services/agentConfigService"; import { a2aClientService } from "@/services/a2aService"; import A2AServerSettingsPanel from "../a2a/A2AServerSettingsPanel"; import log from "@/lib/logger"; +import { getUnavailableReasonLabels } from "@/lib/agentLabelMapper"; interface AgentListProps { agentList: Agent[]; @@ -429,18 +430,8 @@ export default function AgentList({ { const reasons = agent.unavailable_reasons || []; - if (reasons.includes('agent_not_found')) { - return t('subAgentPool.tooltip.unavailableAgent'); - } else if (reasons.includes('tool_unavailable')) { - return t('toolPool.tooltip.unavailableTool'); - } else if (reasons.includes('duplicate_name')) { - return t('agent.error.nameExists', { name }); - } else if (reasons.includes('duplicate_display_name')) { - return t('agent.error.displayNameExists', { displayName }); - } else if (reasons.includes('model_unavailable')) { - return t('agent.error.modelUnavailable'); - } - return t('subAgentPool.tooltip.unavailableAgent'); // fallback + const labels = getUnavailableReasonLabels(reasons, t); + return labels.join(", ") || t('subAgentPool.tooltip.unavailableAgent'); })()} > diff --git a/frontend/app/[locale]/chat/components/chatAgentSelector.tsx b/frontend/app/[locale]/chat/components/chatAgentSelector.tsx index b67aa491e..f7a540172 100644 --- a/frontend/app/[locale]/chat/components/chatAgentSelector.tsx +++ b/frontend/app/[locale]/chat/components/chatAgentSelector.tsx @@ -11,6 +11,7 @@ import { ChatAgentSelectorProps } from "@/types/chat"; import { Agent } from "@/types/agentConfig"; import { clearAgentNewMark } from "@/services/agentConfigService"; import { usePublishedAgentList } from "@/hooks/agent/usePublishedAgentList"; +import { getUnavailableReasonLabels } from "@/lib/agentLabelMapper"; export function ChatAgentSelector({ selectedAgentId, @@ -355,7 +356,11 @@ export function ChatAgentSelector({ if (isDuplicateDisabled) { unavailableReason = t("subAgentPool.tooltip.duplicateNameDisabled"); } else if (!isAvailableTool) { - unavailableReason = t("subAgentPool.tooltip.hasUnavailableTools"); + const reasons = agent.unavailable_reasons || []; + const labels = getUnavailableReasonLabels(reasons, t); + unavailableReason = labels.length > 0 + ? labels.join(", ") + : t("agentSelector.agentUnavailable"); } } diff --git a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx index 11391c133..94a869301 100644 --- a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx @@ -829,6 +829,7 @@ export const ModelAddDialog = ({ }; const isEmbeddingModel = form.type === MODEL_TYPES.EMBEDDING; + const isRerankModel = form.type === MODEL_TYPES.RERANK; const isSTTModel = form.type === MODEL_TYPES.STT; return ( @@ -1717,4 +1718,4 @@ export const ModelAddDialog = ({ ); -}; +}; \ No newline at end of file diff --git a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx index ad3cf0391..894f50907 100644 --- a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx @@ -52,11 +52,9 @@ export const ModelDeleteDialog = ({ const [isConfirmLoading, setIsConfirmLoading] = useState(false); const [maxTokens, setMaxTokens] = useState(0); - // Settings modal state - const [settingsModalVisible, setSettingsModalVisible] = useState(false); - const [selectedModelForSettings, setSelectedModelForSettings] = - useState(null); - const [modelMaxTokens, setModelMaxTokens] = useState("4096"); + // Single model settings modal state + const [isSingleModelSettingsOpen, setIsSingleModelSettingsOpen] = useState(false); + const [selectedSingleModel, setSelectedSingleModel] = useState(null); const [providerModelSearchTerm, setProviderModelSearchTerm] = useState(""); // Embedding model chunk config modal state @@ -589,9 +587,13 @@ export const ModelDeleteDialog = ({ const handleProviderConfigSave = async ({ apiKey, maxTokens, + timeoutSeconds, + concurrencyLimit, }: { apiKey: string; maxTokens: number; + timeoutSeconds?: number; + concurrencyLimit?: number; }) => { setMaxTokens(maxTokens); if ( @@ -624,6 +626,8 @@ export const ModelDeleteDialog = ({ model_id: String(m.id), apiKey: apiKey || m.apiKey, maxTokens: maxTokens || m.maxTokens, + ...(timeoutSeconds !== undefined ? { timeoutSeconds } : {}), + ...(concurrencyLimit !== undefined ? { concurrencyLimit } : {}), })); await modelService.updateBatchModel( @@ -639,6 +643,8 @@ export const ModelDeleteDialog = ({ prev.map((model) => ({ ...model, max_tokens: maxTokens || model.max_tokens || 4096, + timeout_seconds: timeoutSeconds || model.timeout_seconds, + concurrency_limit: concurrencyLimit !== undefined ? concurrencyLimit : model.concurrency_limit, })) ); } catch (e) { @@ -649,29 +655,6 @@ export const ModelDeleteDialog = ({ setIsProviderConfigOpen(false); }; - // Handle settings button click - const handleSettingsClick = (model: any) => { - setSelectedModelForSettings(model); - setModelMaxTokens(model.max_tokens?.toString() || "4096"); - setSettingsModalVisible(true); - }; - - // Handle settings save - const handleSettingsSave = () => { - if (selectedModelForSettings) { - // Update the model in the list with new max_tokens - setProviderModels((prev) => - prev.map((model) => - model.id === selectedModelForSettings.id - ? { ...model, max_tokens: parseInt(modelMaxTokens) || 4096 } - : model - ) - ); - } - setSettingsModalVisible(false); - setSelectedModelForSettings(null); - }; - // Handle embedding model click to open config modal const handleEmbeddingModelClick = (model: ModelOption | any) => { const isEmbeddingModel = @@ -729,6 +712,12 @@ export const ModelDeleteDialog = ({ } }; + // Handle single model settings button click + const handleSingleModelSettingsClick = (model: any) => { + setSelectedSingleModel(model); + setIsSingleModelSettingsOpen(true); + }; + // Handle embedding config save const handleEmbeddingConfigSave = async () => { if (!selectedEmbeddingModel) return; @@ -1330,7 +1319,7 @@ export const ModelDeleteDialog = ({ size="small" onClick={(e) => { e.stopPropagation(); // Prevent switch toggle - handleSettingsClick(providerModel); + handleSingleModelSettingsClick(providerModel); }} /> @@ -1516,34 +1505,78 @@ export const ModelDeleteDialog = ({ m.source === (selectedSource || MODEL_SOURCES.SILICON) )?.maxTokens || 4096 ).toString()} + initialTimeoutSeconds={( + models.find( + (m) => + m.type === deletingModelType && + m.source === (selectedSource || MODEL_SOURCES.SILICON) + )?.timeoutSeconds?.toString() || "120" + )} + initialConcurrencyLimit={( + models.find( + (m) => + m.type === deletingModelType && + m.source === (selectedSource || MODEL_SOURCES.SILICON) + )?.concurrencyLimit?.toString() || "" + )} modelType={deletingModelType || undefined} onSave={handleProviderConfigSave} /> - {/* Settings Modal */} - setSettingsModalVisible(false)} - onOk={handleSettingsSave} - cancelText={t("common.button.cancel")} - okText={t("common.button.save")} - destroyOnHidden - > -
-
- - setModelMaxTokens(e.target.value)} - placeholder={t("model.dialog.placeholder.maxTokens")} - /> -
-
-
+ {/* Single Model Settings Modal */} + { + setIsSingleModelSettingsOpen(false); + setSelectedSingleModel(null); + }} + initialMaxTokens={selectedSingleModel?.max_tokens?.toString() || "4096"} + initialTimeoutSeconds={selectedSingleModel?.timeout_seconds?.toString() || "120"} + initialConcurrencyLimit={selectedSingleModel?.concurrency_limit?.toString() || ""} + modelType={deletingModelType || undefined} + showApiKeyField={false} + onSave={async (config) => { + if (!selectedSingleModel) return; + try { + const modelName = selectedSingleModel.model_name || selectedSingleModel.id; + + const updatePayload: any = { + model_id: modelName, + maxTokens: config.maxTokens, + timeoutSeconds: config.timeoutSeconds, + concurrencyLimit: config.concurrencyLimit, + }; + + if (config.apiKey) { + updatePayload.apiKey = config.apiKey; + } + + await modelService.updateBatchModel( + [updatePayload], + selectedSingleModel.model_factory + ); + + // Update the model in the list + setProviderModels((prev) => + prev.map((model) => + model.id === selectedSingleModel.id + ? { + ...model, + max_tokens: config.maxTokens, + timeout_seconds: config.timeoutSeconds, + concurrency_limit: config.concurrencyLimit, + } + : model + ) + ); + + message.success(t("model.message.updateSuccess") || "Update successful"); + } catch (error) { + console.error("Failed to update model settings:", error); + message.error(t("model.message.updateFailed") || "Failed to update settings"); + } + }} + /> {/* Embedding Model Config Modal */} { setForm((prev) => ({ ...prev, [field]: value })); // If the key configuration item changes, clear the verification status - if (["url", "apiKey", "maxTokens", "vectorDimension"].includes(field)) { + if (["url", "apiKey", "maxTokens", "timeoutSeconds", "vectorDimension"].includes(field)) { setConnectivityStatus({ status: null, message: "" }); } }; @@ -176,6 +180,8 @@ export const ModelEditDialog = ({ expectedChunkSize: isEmbeddingModel ? form.chunkSizeRange[0] : undefined, maximumChunkSize: isEmbeddingModel ? form.chunkSizeRange[1] : undefined, chunkingBatchSize: isEmbeddingModel ? parseInt(form.chunkingBatchSize) || 10 : undefined, + timeoutSeconds: !isEmbeddingModel && !isRerankModel ? parseInt(form.timeoutSeconds) || 120 : undefined, + concurrencyLimit: !isEmbeddingModel && !isRerankModel ? (form.concurrencyLimit ? parseInt(form.concurrencyLimit) : undefined) : undefined, }); } else { await modelService.updateSingleModel({ @@ -196,6 +202,13 @@ export const ModelEditDialog = ({ chunkingBatchSize: parseInt(form.chunkingBatchSize) || 10, } : {}), + // Send timeout for non-embedding models + ...(!isEmbeddingModel && !isRerankModel + ? { + timeoutSeconds: parseInt(form.timeoutSeconds) || 120, + concurrencyLimit: form.concurrencyLimit ? parseInt(form.concurrencyLimit) : undefined, + } + : {}), }); } @@ -306,6 +319,40 @@ export const ModelEditDialog = ({ )} + {/* Timeout Seconds */} + {!isEmbeddingModel && !isRerankModel && ( +
+ + handleFormChange("timeoutSeconds", e.target.value)} + /> +
+ )} + + {/* Concurrency Limit */} + {!isEmbeddingModel && !isRerankModel && ( +
+ + handleFormChange("concurrencyLimit", e.target.value)} + placeholder={t("model.dialog.placeholder.concurrencyLimit")} + /> +
+ {t("model.dialog.hint.concurrencyLimit")} +
+
+ )} + {/* Chunk Size Range for embedding models */} {isEmbeddingModel && (
@@ -408,28 +455,38 @@ interface ProviderConfigEditDialogProps { isOpen: boolean initialApiKey?: string initialMaxTokens?: string + initialTimeoutSeconds?: string + initialConcurrencyLimit?: string modelType?: ModelType + showApiKeyField?: boolean // Whether to show API Key field (default: true) onClose: () => void - onSave: (config: { apiKey: string; maxTokens: number }) => Promise | void + onSave: (config: { apiKey: string; maxTokens: number; timeoutSeconds?: number; concurrencyLimit?: number }) => Promise | void } export const ProviderConfigEditDialog = ({ isOpen, initialApiKey = '', initialMaxTokens = '4096', + initialTimeoutSeconds = '120', + initialConcurrencyLimit = '', modelType, + showApiKeyField = true, onClose, onSave, }: ProviderConfigEditDialogProps) => { const { t } = useTranslation() const [apiKey, setApiKey] = useState(initialApiKey) const [maxTokens, setMaxTokens] = useState(initialMaxTokens) + const [timeoutSeconds, setTimeoutSeconds] = useState(initialTimeoutSeconds) + const [concurrencyLimit, setConcurrencyLimit] = useState(initialConcurrencyLimit) const [saving, setSaving] = useState(false) useEffect(() => { setApiKey(initialApiKey) setMaxTokens(initialMaxTokens) - }, [initialApiKey, initialMaxTokens]) + setTimeoutSeconds(initialTimeoutSeconds) + setConcurrencyLimit(initialConcurrencyLimit) + }, [initialApiKey, initialMaxTokens, initialTimeoutSeconds, initialConcurrencyLimit]) const valid = () => { const parsed = parseInt(maxTokens) @@ -440,7 +497,14 @@ export const ProviderConfigEditDialog = ({ if (!valid()) return try { setSaving(true) - await onSave({ apiKey: apiKey.trim() === '' ? 'sk-no-api-key' : apiKey, maxTokens: parseInt(maxTokens) }) + const isEmbeddingModel = modelType === MODEL_TYPES.EMBEDDING || modelType === MODEL_TYPES.MULTI_EMBEDDING + const isRerankModel = modelType === MODEL_TYPES.RERANK + await onSave({ + apiKey: showApiKeyField ? (apiKey.trim() === '' ? 'sk-no-api-key' : apiKey) : '', + maxTokens: parseInt(maxTokens), + ...(!isEmbeddingModel && !isRerankModel ? { timeoutSeconds: parseInt(timeoutSeconds) || 120 } : {}), + ...(!isEmbeddingModel && !isRerankModel ? { concurrencyLimit: concurrencyLimit ? parseInt(concurrencyLimit) : undefined } : {}), + }) onClose() } finally { setSaving(false) @@ -448,6 +512,7 @@ export const ProviderConfigEditDialog = ({ } const isEmbeddingModel = modelType === MODEL_TYPES.EMBEDDING || modelType === MODEL_TYPES.MULTI_EMBEDDING + const isRerankModel = modelType === MODEL_TYPES.RERANK return (
-
- - setApiKey(e.target.value)} visibilityToggle={false} /> -
+ {showApiKeyField && ( +
+ + setApiKey(e.target.value)} visibilityToggle={false} /> +
+ )} {!isEmbeddingModel && (
)} + {!isEmbeddingModel && !isRerankModel && ( +
+ + setTimeoutSeconds(e.target.value)} + /> +
+ )} + {!isEmbeddingModel && !isRerankModel && ( +
+ + setConcurrencyLimit(e.target.value)} + placeholder={t("model.dialog.placeholder.concurrencyLimit")} + /> +
+ {t("model.dialog.hint.concurrencyLimit")} +
+
+ )}