diff --git a/.github/workflows/docker-deploy.yml b/.github/workflows/docker-deploy.yml index 9d04c8913..a77c2491f 100644 --- a/.github/workflows/docker-deploy.yml +++ b/.github/workflows/docker-deploy.yml @@ -38,7 +38,10 @@ jobs: - name: Check if model is cached locally id: check-model run: | - if [ -f ~/model-assets/clip-vit-base-patch32/config.json ] && [ -d ~/model-assets/nltk_data ]; then + if [ -f ~/model-assets/clip-vit-base-patch32/config.json ] && \ + [ -d ~/model-assets/nltk_data ] && \ + [ -d ~/model-assets/table-transformer-structure-recognition ] && \ + [ -d ~/model-assets/yolox ]; then echo "cache-hit=true" >> "$GITHUB_OUTPUT" cp -r ~/model-assets ./ else @@ -105,4 +108,4 @@ jobs: ./deploy.sh --mode 3 --is-mainland N --enable-terminal N --version 2 --root-dir "$HOME/nexent-production-data" else ./deploy.sh --mode 1 --is-mainland N --enable-terminal N --version 2 --root-dir "$HOME/nexent-development-data" - fi \ No newline at end of file + fi diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index 5a11b550b..88d71216a 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -1,4 +1,4 @@ -import threading +import threading import logging from typing import List, Optional from urllib.parse import urljoin @@ -469,6 +469,7 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int rerank = param_dict.get("rerank", False) rerank_model_name = param_dict.get("rerank_model_name", "") rerank_model = None + is_multimodal = bool(tool_config.params.pop("multimodal", False)) if rerank and rerank_model_name: rerank_model = get_rerank_model( tenant_id=tenant_id, model_name=rerank_model_name diff --git a/backend/apps/file_management_app.py b/backend/apps/file_management_app.py index 578277b6d..677961442 100644 --- a/backend/apps/file_management_app.py +++ b/backend/apps/file_management_app.py @@ -126,12 +126,13 @@ async def upload_files( @file_management_config_router.post("/process") async def process_files( - files: List[dict] = Body( - ..., description="List of file details to process, including path_or_url and filename"), - chunking_strategy: Optional[str] = Body("basic"), - index_name: str = Body(...), - destination: str = Body(...), - authorization: Optional[str] = Header(None) + files: Annotated[List[dict], Body( + ..., description="List of file details to process, including path_or_url and filename")], + index_name: Annotated[str, Body(...)], + destination: Annotated[str, Body(...)], + chunking_strategy: Annotated[Optional[str], Body(...)] = "basic", + model_id: Annotated[Optional[int], Body(...)] = None, + authorization: Annotated[Optional[str], Header()] = None ): """ Trigger data processing for a list of uploaded files. @@ -144,7 +145,8 @@ async def process_files( chunking_strategy=chunking_strategy, source_type=destination, index_name=index_name, - authorization=authorization + authorization=authorization, + model_id=model_id ) process_result = await trigger_data_process(files, process_params) diff --git a/backend/apps/model_managment_app.py b/backend/apps/model_managment_app.py index 278b729e8..7029477e6 100644 --- a/backend/apps/model_managment_app.py +++ b/backend/apps/model_managment_app.py @@ -33,7 +33,7 @@ from fastapi.responses import JSONResponse from fastapi.encoders import jsonable_encoder from http import HTTPStatus -from typing import List, Optional +from typing import Annotated, List, Optional from services.model_health_service import ( check_model_connectivity, verify_model_config_connectivity, @@ -297,7 +297,8 @@ async def get_llm_model_list(authorization: Optional[str] = Header(None)): @router.post("/healthcheck") async def check_model_health( - display_name: str = Query(..., description="Display name to check"), + display_name: Annotated[str, Query(..., description="Display name to check")], + model_type: Annotated[str, Query(..., description="...")], authorization: Optional[str] = Header(None) ): """Check and update model connectivity, returning the latest status. @@ -308,7 +309,7 @@ async def check_model_health( """ try: _, tenant_id = get_current_user_id(authorization) - result = await check_model_connectivity(display_name, tenant_id) + result = await check_model_connectivity(display_name, tenant_id, model_type) return JSONResponse(status_code=HTTPStatus.OK, content={ "message": "Successfully checked model connectivity", "data": result diff --git a/backend/apps/vectordatabase_app.py b/backend/apps/vectordatabase_app.py index 6f4232afd..d28f59822 100644 --- a/backend/apps/vectordatabase_app.py +++ b/backend/apps/vectordatabase_app.py @@ -82,11 +82,13 @@ def create_new_index( # Extract optional fields from request body ingroup_permission = None group_ids = None - embedding_model_name = None + embedding_model_name: Optional[str] = None + is_multimodal: Optional[bool] = None if request: ingroup_permission = request.get("ingroup_permission") group_ids = request.get("group_ids") - embedding_model_name = request.get("embedding_model_name") + embedding_model_name = request.get("embeddingModel") + is_multimodal = request.get("is_multimodal") # Treat path parameter as user-facing knowledge base name for new creations return ElasticSearchService.create_knowledge_base( @@ -98,6 +100,7 @@ def create_new_index( ingroup_permission=ingroup_permission, group_ids=group_ids, embedding_model_name=embedding_model_name, + is_multimodal=is_multimodal, ) except Exception as e: raise HTTPException( @@ -664,6 +667,7 @@ def update_chunk( chunk_request=payload, vdb_core=vdb_core, user_id=user_id, + tenant_id=tenant_id, ) return JSONResponse(status_code=HTTPStatus.OK, content=result) except ValueError as e: @@ -730,8 +734,17 @@ async def hybrid_search( """Run a hybrid (accurate + semantic) search across indices.""" try: _, tenant_id = get_current_user_id(authorization) + resolved_index_names: List[str] = [] + for requested_name in payload.index_names: + try: + resolved_name = get_index_name_by_knowledge_name( + requested_name, tenant_id + ) + except Exception: + resolved_name = requested_name + resolved_index_names.append(resolved_name) result = ElasticSearchService.search_hybrid( - index_names=payload.index_names, + index_names=resolved_index_names, query=payload.query, tenant_id=tenant_id, top_k=payload.top_k, diff --git a/backend/consts/const.py b/backend/consts/const.py index 680bc78db..9f84208da 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -31,6 +31,10 @@ class VectorDatabaseType(str, Enum): # Data Processing Service Configuration DATA_PROCESS_SERVICE = os.getenv("DATA_PROCESS_SERVICE") CLIP_MODEL_PATH = os.getenv("CLIP_MODEL_PATH") +TABLE_TRANSFORMER_MODEL_PATH = os.getenv("TABLE_TRANSFORMER_MODEL_PATH") +UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH = os.getenv( + "UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH" +) # Upload Configuration @@ -129,6 +133,7 @@ class VectorDatabaseType(str, Enum): MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY") MINIO_REGION = os.getenv("MINIO_REGION") MINIO_DEFAULT_BUCKET = os.getenv("MINIO_DEFAULT_BUCKET") +S3_URL_PREFIX = "s3://" # Postgres Configuration diff --git a/backend/consts/model.py b/backend/consts/model.py index 2f1d7aae3..2c5117ee3 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -300,6 +300,7 @@ class ProcessParams(BaseModel): source_type: str index_name: str authorization: Optional[str] = None + model_id: Optional[int] = None class OpinionRequest(BaseModel): diff --git a/backend/data_process/ray_actors.py b/backend/data_process/ray_actors.py index 0dea828ce..c3879c007 100644 --- a/backend/data_process/ray_actors.py +++ b/backend/data_process/ray_actors.py @@ -1,3 +1,4 @@ +from io import BytesIO import logging import json import time @@ -5,8 +6,15 @@ import ray -from consts.const import RAY_ACTOR_NUM_CPUS, REDIS_BACKEND_URL, DEFAULT_EXPECTED_CHUNK_SIZE, DEFAULT_MAXIMUM_CHUNK_SIZE -from database.attachment_db import get_file_stream +from consts.const import ( + RAY_ACTOR_NUM_CPUS, + REDIS_BACKEND_URL, + DEFAULT_EXPECTED_CHUNK_SIZE, + DEFAULT_MAXIMUM_CHUNK_SIZE, + TABLE_TRANSFORMER_MODEL_PATH, + UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH, +) +from database.attachment_db import build_s3_url, get_file_stream, upload_fileobj from database.model_management_db import get_model_by_model_id from nexent.data_process import DataProcessCore @@ -43,35 +51,16 @@ def _prepare_process_params( Normalize task/model-related processing params. """ process_params = dict(params) + self._apply_model_paths(process_params) if task_id: process_params["task_id"] = task_id - if not (model_id and tenant_id): - return process_params - - try: - model_record = get_model_by_model_id( - model_id=model_id, tenant_id=tenant_id) - if not model_record: - logger.warning( - f"[RayActor] Embedding model with ID {model_id} not found for tenant '{tenant_id}', using default chunk sizes") - return process_params - - expected_chunk_size = model_record.get( - "expected_chunk_size", DEFAULT_EXPECTED_CHUNK_SIZE) - maximum_chunk_size = model_record.get( - "maximum_chunk_size", DEFAULT_MAXIMUM_CHUNK_SIZE) - model_name = model_record.get("display_name") - - process_params["max_characters"] = maximum_chunk_size - process_params["new_after_n_chars"] = expected_chunk_size - - logger.info( - f"[RayActor] Using chunk sizes from embedding model '{model_name}' (ID: {model_id}): " - f"max_characters={maximum_chunk_size}, new_after_n_chars={expected_chunk_size}") - except Exception as e: - logger.warning( - f"[RayActor] Failed to retrieve chunk sizes from embedding model ID {model_id}: {e}. Using default chunk sizes") + # Reuse shared model param logic so we also keep extra fields + self._apply_model_chunk_sizes( + model_id=model_id, + tenant_id=tenant_id, + params=process_params, + ) return process_params def _run_file_process( @@ -82,24 +71,19 @@ def _run_file_process( process_params: Dict[str, Any], log_subject: str, ) -> List[Dict[str, Any]]: - chunks = self._processor.file_process( + result = self._processor.file_process( file_data=file_data, filename=filename, chunking_strategy=chunking_strategy, **process_params ) - - if chunks is None: - logger.warning( - f"[RayActor] file_process returned None for {log_subject}='{filename}'") - return [] - if not isinstance(chunks, list): - logger.error( - f"[RayActor] file_process returned non-list type {type(chunks)} for {log_subject}='{filename}'") - return [] - if len(chunks) == 0: - logger.warning( - f"[RayActor] file_process returned empty list for {log_subject}='{filename}'") + + chunks, images_info = self._normalize_processor_result(result) + if images_info: + self._append_image_chunks( + source=filename, chunks=chunks, images_info=images_info) + chunks = self._validate_chunks(chunks, filename) + if not chunks: return [] logger.info( @@ -161,8 +145,129 @@ def process_file( chunking_strategy=chunking_strategy, process_params=process_params, log_subject="source", - ) + ) + + def _apply_model_paths(self, params: Dict[str, Any]) -> None: + params["table_transformer_model_path"] = TABLE_TRANSFORMER_MODEL_PATH + params[ + "unstructured_default_model_initialize_params_json_path" + ] = UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH + + def _apply_model_chunk_sizes( + self, + model_id: Optional[int], + tenant_id: Optional[str], + params: Dict[str, Any], + ) -> None: + if not (model_id and tenant_id): + return + + try: + model_record = get_model_by_model_id( + model_id=model_id, tenant_id=tenant_id) + if not model_record: + logger.warning( + f"[RayActor] Embedding model with ID {model_id} not found for tenant '{tenant_id}', using default chunk sizes") + return + + expected_chunk_size = model_record.get( + 'expected_chunk_size', DEFAULT_EXPECTED_CHUNK_SIZE) + maximum_chunk_size = model_record.get( + 'maximum_chunk_size', DEFAULT_MAXIMUM_CHUNK_SIZE) + model_name = model_record.get('display_name') + model_type = model_record.get('model_type') + + params['max_characters'] = maximum_chunk_size + params['new_after_n_chars'] = expected_chunk_size + if model_type: + params['model_type'] = model_type + + logger.info( + f"[RayActor] Using chunk sizes from embedding model '{model_name}' (ID: {model_id}): " + f"max_characters={maximum_chunk_size}, new_after_n_chars={expected_chunk_size}") + except Exception as e: + logger.warning( + f"[RayActor] Failed to retrieve chunk sizes from embedding model ID {model_id}: {e}. Using default chunk sizes") + + def _read_file_bytes(self, source: str) -> bytes: + try: + file_stream = get_file_stream(source) + if file_stream is None: + raise FileNotFoundError( + f"Unable to fetch file from URL: {source}") + return file_stream.read() + except Exception as e: + logger.error(f"Failed to fetch file from {source}: {e}") + raise + def _normalize_processor_result( + self, result: Any + ) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + if isinstance(result, tuple) and len(result) == 2: + chunks, images_info = result + return chunks or [], images_info or [] + return result or [], [] + + def _append_image_chunks( + self, + source: str, + chunks: List[Dict[str, Any]], + images_info: List[Dict[str, Any]], + ) -> None: + folder = "images_in_attachments" + for index, image_data in enumerate(images_info): + if not isinstance(image_data, dict): + logger.warning( + f"[RayActor] Skipping image entry at index {index}: unexpected type {type(image_data)}" + ) + continue + if "image_bytes" not in image_data: + logger.warning( + f"[RayActor] Skipping image entry at index {index}: missing image_bytes" + ) + continue + + img_obj = BytesIO(image_data["image_bytes"]) + result = upload_fileobj( + file_obj=img_obj, + file_name=f"{index}.{image_data['image_format']}", + prefix=folder) + image_url = build_s3_url(result.get("object_name", "")) + + image_data["source_file"] = source + image_data["image_url"] = image_url + + chunks.append({ + "content": json.dumps({ + "source_file": source, + "position": image_data["position"], + "image_url": image_url, + }), + "filename": source, + "metadata": { + "chunk_index": len(chunks) + index, + "process_source": "UniversalImageExtractor", + "image_url": image_url, + } + }) + + def _validate_chunks( + self, chunks: Any, source: str + ) -> List[Dict[str, Any]]: + if chunks is None: + logger.warning( + f"[RayActor] file_process returned None for source='{source}'") + return [] + if not isinstance(chunks, list): + logger.error( + f"[RayActor] file_process returned non-list type {type(chunks)} for source='{source}'") + return [] + if len(chunks) == 0: + logger.warning( + f"[RayActor] file_process returned empty list for source='{source}'") + return [] + return chunks + def process_bytes( self, file_bytes: bytes, diff --git a/backend/data_process/tasks.py b/backend/data_process/tasks.py index 71d83b090..f2a30f9b7 100644 --- a/backend/data_process/tasks.py +++ b/backend/data_process/tasks.py @@ -379,11 +379,11 @@ def _extract_error_code_from_es_response( def _send_chunks_to_es( chunks: List[Dict[str, Any]], index_name: str, - authorization: Optional[str], - task_id: Optional[str], - source: Optional[str], - original_filename: Optional[str], - large_mode: Optional[bool] = None, + authorization: str | None, + task_id: Optional[str] = None, + source: str = "", + original_filename: str = "", + large_mode: bool = False, ) -> Dict[str, Any]: async def _post(): elasticsearch_url = ELASTICSEARCH_SERVICE @@ -405,12 +405,10 @@ async def _post(): connector = aiohttp.TCPConnector(verify_ssl=False) timeout = aiohttp.ClientTimeout(total=600) - large_mode_value = "true" if large_mode else "false" if large_mode is not None else None - request_params = ( - { - "large_mode": large_mode_value - } - ) + request_params: Dict[str, str] = {} + + if large_mode: + request_params["large_mode"] = "true" async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: async with session.post( @@ -761,7 +759,7 @@ def forward_part( original_filename: Optional[str] = None, batch_index: Optional[int] = None, total_batches: Optional[int] = None, - large_mode: Optional[bool] = None, + large_mode: Optional[bool] = False, ) -> Dict[str, Any]: """ Forward sub-task that indexes a chunk batch. @@ -1645,7 +1643,7 @@ def forward( task_id=task_id, source=original_source, original_filename=original_filename, - large_mode=None, + large_mode=False, ) else: batches = _build_balanced_batches( diff --git a/backend/database/attachment_db.py b/backend/database/attachment_db.py index 187381cd2..06b84e5ac 100644 --- a/backend/database/attachment_db.py +++ b/backend/database/attachment_db.py @@ -2,13 +2,66 @@ import os import uuid from datetime import datetime -from typing import Any, BinaryIO, Dict, List, Optional +from typing import Any, BinaryIO, Dict, List, Optional, Tuple from .client import minio_client +from consts.const import S3_URL_PREFIX from consts.const import NORTHBOUND_EXTERNAL_URL from urllib.parse import quote +def _normalize_object_and_bucket(object_name: str, bucket: Optional[str] = None) -> Tuple[str, Optional[str]]: + """ + Normalize object_name + bucket from supported URL styles. + + Supports: + - s3://bucket/key + - /bucket/key + - key (uses provided bucket or default bucket) + """ + if not object_name: + return object_name, bucket + + if object_name.startswith(S3_URL_PREFIX): + s3_path = object_name[len(S3_URL_PREFIX) :] + parts = s3_path.split("/", 1) + parsed_bucket = parts[0] if parts[0] else None + parsed_key = parts[1] if len(parts) > 1 else "" + return parsed_key, parsed_bucket or bucket + + if object_name.startswith("/"): + path = object_name.lstrip("/") + parts = path.split("/", 1) + parsed_bucket = parts[0] if parts[0] else None + parsed_key = parts[1] if len(parts) > 1 else "" + return parsed_key, parsed_bucket or bucket + + return object_name, bucket + + +def build_s3_url(object_name: str, bucket: Optional[str] = None) -> str: + """ + Build an s3://bucket/key style URL from an object name (or passthrough if already s3://). + """ + if not object_name: + return "" + + if object_name.startswith(S3_URL_PREFIX): + return object_name + + if object_name.startswith("/"): + path = object_name.lstrip("/") + parts = path.split("/", 1) + if len(parts) == 2: + return f"{S3_URL_PREFIX}{parts[0]}/{parts[1]}" + return f"{S3_URL_PREFIX}{parts[0]}/" + + resolved_bucket = bucket or minio_client.default_bucket + if resolved_bucket: + return f"{S3_URL_PREFIX}{resolved_bucket}/{object_name}" + return f"{S3_URL_PREFIX}{object_name}" + + def _build_mcp_presigned_url(presigned_url: str) -> str: """ Build northbound API proxy URL for MCP tools. @@ -217,6 +270,7 @@ def get_file_size_from_minio(object_name: str, bucket: Optional[str] = None) -> """ Get file size by object name """ + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) # Ensure minio_client is initialized before accessing storage_config minio_client._ensure_initialized() bucket = bucket or minio_client.storage_config.default_bucket @@ -235,6 +289,7 @@ def file_exists(object_name: str, bucket: Optional[str] = None) -> bool: bool: True if file exists, False otherwise """ try: + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) return minio_client.file_exists(object_name, bucket) except Exception: return False @@ -252,6 +307,8 @@ def copy_file(source_object: str, dest_object: str, bucket: Optional[str] = None Returns: Dict[str, Any]: Result containing success flag and error message (if any) """ + source_object, bucket = _normalize_object_and_bucket(source_object, bucket) + dest_object, bucket = _normalize_object_and_bucket(dest_object, bucket) success, result = minio_client.copy_file(source_object, dest_object, bucket) if success: return {"success": True, "object_name": result} @@ -296,6 +353,7 @@ def delete_file(object_name: str, bucket: Optional[str] = None) -> Dict[str, Any Returns: Dict[str, Any]: Delete result, containing success flag and error message (if any) """ + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) if not bucket: minio_client._ensure_initialized() bucket = minio_client.storage_config.default_bucket @@ -320,6 +378,7 @@ def get_file_stream(object_name: str, bucket: Optional[str] = None) -> Optional[ Returns: Optional[BinaryIO]: Standard BinaryIO stream object, or None if failed """ + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) success, result = minio_client.get_file_stream(object_name, bucket) if not success: return None diff --git a/backend/database/client.py b/backend/database/client.py index 05f8940b9..e095c5636 100644 --- a/backend/database/client.py +++ b/backend/database/client.py @@ -89,6 +89,9 @@ def __init__(self): if MinioClient._initialized: return MinioClient._initialized = True + # Explicitly initialize attributes so external callers never hit missing-attribute errors. + self._storage_client = None + self.storage_config = None def _ensure_initialized(self): """Lazily initialize the storage client on first use.""" @@ -108,6 +111,23 @@ def _ensure_initialized(self): return True return False + @property + def default_bucket(self) -> Optional[str]: + """ + Resolve default bucket safely for callers that need bucket info. + Falls back to configured constant when lazy init has not run yet. + """ + try: + self._ensure_initialized() + except Exception: + # Keep this accessor resilient; operational methods can still raise + # detailed storage errors when invoked. + pass + + if getattr(self, "storage_config", None) is not None: + return self.storage_config.default_bucket + return MINIO_DEFAULT_BUCKET + def upload_file( self, file_path: str, diff --git a/backend/database/knowledge_db.py b/backend/database/knowledge_db.py index 8674bb4fb..9a8b1c8c1 100644 --- a/backend/database/knowledge_db.py +++ b/backend/database/knowledge_db.py @@ -183,7 +183,7 @@ def update_knowledge_record(query: Dict[str, Any]) -> bool: # Update group IDs if query.get("group_ids") is not None: record.group_ids = query["group_ids"] - + # Update timestamp and user if query.get("user_id"): record.updated_by = query["user_id"] @@ -259,7 +259,7 @@ def get_knowledge_record(query: Optional[Dict[str, Any]] = None) -> Dict[str, An if 'tenant_id' in query and query['tenant_id'] is not None: db_query = db_query.filter( KnowledgeRecord.tenant_id == query['tenant_id']) - + result = db_query.first() if result: @@ -404,14 +404,25 @@ def get_index_name_by_knowledge_name(knowledge_name: str, tenant_id: str) -> str """ try: with get_db_session() as session: + # First try resolving by user-facing knowledge_name. result = session.query(KnowledgeRecord).filter( KnowledgeRecord.knowledge_name == knowledge_name, KnowledgeRecord.tenant_id == tenant_id, KnowledgeRecord.delete_flag != 'Y' ).first() - if result: return result.index_name + + # Backward/forward compatibility: if caller already passes internal index_name, + # accept it directly by resolving on index_name as well. + index_result = session.query(KnowledgeRecord).filter( + KnowledgeRecord.index_name == knowledge_name, + KnowledgeRecord.tenant_id == tenant_id, + KnowledgeRecord.delete_flag != 'Y' + ).first() + if index_result: + return index_result.index_name + raise ValueError( f"Knowledge base '{knowledge_name}' not found for the current tenant" ) diff --git a/backend/database/model_management_db.py b/backend/database/model_management_db.py index cb1c6c69f..61753f52f 100644 --- a/backend/database/model_management_db.py +++ b/backend/database/model_management_db.py @@ -170,7 +170,7 @@ def get_model_records(filters: Optional[Dict[str, Any]], tenant_id: str) -> List return result_list -def get_model_by_display_name(display_name: str, tenant_id: str) -> Optional[Dict[str, Any]]: +def get_model_by_display_name(display_name: str, tenant_id: str, model_type: str = None) -> Optional[Dict[str, Any]]: """ Get a model record by display name @@ -179,6 +179,11 @@ def get_model_by_display_name(display_name: str, tenant_id: str) -> Optional[Dic tenant_id: """ filters = {'display_name': display_name} + + if model_type in ["multiEmbedding", "multi_embedding"]: + filters['model_type'] = "multi_embedding" + elif model_type == "embedding": + filters['model_type'] = "embedding" records = get_model_records(filters, tenant_id) if not records: @@ -203,7 +208,7 @@ def get_models_by_display_name(display_name: str, tenant_id: str) -> List[Dict[s return get_model_records(filters, tenant_id) -def get_model_id_by_display_name(display_name: str, tenant_id: str) -> Optional[int]: +def get_model_id_by_display_name(display_name: str, tenant_id: str, model_type: str = None) -> Optional[int]: """ Get a model ID by display name @@ -214,7 +219,7 @@ def get_model_id_by_display_name(display_name: str, tenant_id: str) -> Optional[ Returns: Optional[int]: Model ID """ - model = get_model_by_display_name(display_name, tenant_id) + model = get_model_by_display_name(display_name, tenant_id, model_type) return model["model_id"] if model else None diff --git a/backend/services/config_sync_service.py b/backend/services/config_sync_service.py index 0ed29bfc5..81bc9078b 100644 --- a/backend/services/config_sync_service.py +++ b/backend/services/config_sync_service.py @@ -99,7 +99,7 @@ async def save_config_impl(config, tenant_id, user_id): config_key = get_env_key(model_type) + "_ID" model_id = get_model_id_by_display_name( - model_display_name, tenant_id) + model_display_name, tenant_id, model_type=model_type) handle_model_config(tenant_id, user_id, config_key, model_id, tenant_config_dict) diff --git a/backend/services/data_process_service.py b/backend/services/data_process_service.py index a024089a3..ce0e3f993 100644 --- a/backend/services/data_process_service.py +++ b/backend/services/data_process_service.py @@ -296,6 +296,17 @@ async def load_image(self, image_url: str) -> Optional[Image.Image]: async def _load_image(self, session: aiohttp.ClientSession, path: str) -> Optional[Image.Image]: """Internal method to load an image from various sources""" try: + if path.startswith('s3://'): + # Fetch from MinIO using s3://bucket/key + file_stream = get_file_stream(object_name=path) + if file_stream is None: + raise FileNotFoundError( + f"Unable to fetch file from URL: {path}") + file_data = file_stream.read() + image_based64_str = base64.b64encode( + file_data).decode('utf-8') + path = f"data:image/jpeg;base64,{image_based64_str}" + # Check if input is base64 encoded if path.startswith('data:image'): # Extract the base64 data after the comma @@ -504,6 +515,8 @@ async def create_batch_tasks_impl(self, authorization: Optional[str], request: B chunking_strategy = source_config.get('chunking_strategy') index_name = source_config.get('index_name') original_filename = source_config.get('original_filename') + embedding_model_id = source_config.get('embedding_model_id') + tenant_id = source_config.get('tenant_id') # Validate required fields if not source: @@ -522,7 +535,9 @@ async def create_batch_tasks_impl(self, authorization: Optional[str], request: B source_type=source_type, chunking_strategy=chunking_strategy, index_name=index_name, - original_filename=original_filename + original_filename=original_filename, + embedding_model_id=embedding_model_id, + tenant_id=tenant_id ).set(queue='process_q'), forward.s( index_name=index_name, @@ -600,7 +615,7 @@ async def process_uploaded_text_file(self, file_content: bytes, filename: str, c } async def convert_office_to_pdf_impl(self, object_name: str, pdf_object_name: str) -> None: - """Full conversion pipeline: download → convert → upload → validate → cleanup. + """Full conversion pipeline: download -> convert -> upload -> validate -> cleanup. All five steps run inside data-process so that LibreOffice only needs to be installed in this container. diff --git a/backend/services/datamate_service.py b/backend/services/datamate_service.py index 776e0eb1d..41858440b 100644 --- a/backend/services/datamate_service.py +++ b/backend/services/datamate_service.py @@ -51,7 +51,7 @@ async def _create_datamate_knowledge_records(knowledge_base_ids: List[str], "tenant_id": tenant_id, "user_id": user_id, # Use datamate as embedding model name - "embedding_model_name": embedding_model_names[i] + "embedding_model_name": embedding_model_names[i], } # Run synchronous database operation in executor to avoid blocking diff --git a/backend/services/file_management_service.py b/backend/services/file_management_service.py index b5cd048bf..d86c310f4 100644 --- a/backend/services/file_management_service.py +++ b/backend/services/file_management_service.py @@ -58,6 +58,7 @@ def check_file_access(object_name: str, user_id: Optional[str]) -> bool: Access rules: - knowledge_base/*: All authenticated users can access - attachments/{user_id}/*: Only the owner (user_id) can access + - images_in_attachments/*: All authenticated users can access - preview/*: Accessible if the original file is accessible Args: @@ -74,6 +75,11 @@ def check_file_access(object_name: str, user_id: Optional[str]) -> bool: # Knowledge base files: all authenticated users can access return True + if object_name.startswith("images_in_attachments/"): + # Extracted image files used by knowledge-base image chunks. + # Keep them readable for authenticated users to avoid broken image citations. + return True + # Check if file is in user's attachments folder # Pattern: attachments/{user_id}/* if object_name.startswith(f"attachments/{user_id}/"): diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py index a20b2a6ca..5e960d3da 100644 --- a/backend/services/model_health_service.py +++ b/backend/services/model_health_service.py @@ -170,10 +170,10 @@ async def _perform_connectivity_check( return connectivity -async def check_model_connectivity(display_name: str, tenant_id: str) -> dict: +async def check_model_connectivity(display_name: str, tenant_id: str, model_type: str = None) -> dict: try: # Query the database using display_name and tenant context from app layer - model = get_model_by_display_name(display_name, tenant_id=tenant_id) + model = get_model_by_display_name(display_name, tenant_id=tenant_id, model_type=model_type) if not model: raise LookupError( f"Model configuration not found for {display_name}") diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index 5e5229ff6..9c008c07f 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -151,6 +151,10 @@ def get_local_tools() -> List[ToolInfo]: else: param_info["default"] = param.default.default param_info["optional"] = True + if getattr(param.default, "json_schema_extra", None): + optional_override = param.default.json_schema_extra.get("optional") + if optional_override is not None: + param_info["optional"] = optional_override init_params_list.append(param_info) @@ -681,6 +685,8 @@ def _validate_local_tool( if not tool_class: raise NotFoundException(f"Tool class not found for {tool_name}") + runtime_inputs = dict(inputs or {}) + # Parse instantiation parameters first instantiation_params = params or {} # Get signature and extract default values for all parameters @@ -704,6 +710,7 @@ def _validate_local_tool( if tool_name == "knowledge_base_search": index_names = instantiation_params.get("index_names", []) + is_multimodal = instantiation_params.pop("multimodal", False) # Must have embedding model for knowledge base search if not index_names or not tenant_id: @@ -799,7 +806,18 @@ def _validate_local_tool( else: tool_instance = tool_class(**instantiation_params) - result = tool_instance.forward(**(inputs or {})) + # Only pass declared runtime inputs to forward() to avoid unexpected kwargs. + declared_inputs = getattr(tool_class, "inputs", {}) or {} + allowed_input_names = ( + set(declared_inputs.keys()) if isinstance(declared_inputs, dict) else set() + ) + filtered_runtime_inputs = ( + {k: v for k, v in runtime_inputs.items() if k in allowed_input_names} + if allowed_input_names + else runtime_inputs + ) + + result = tool_instance.forward(**filtered_runtime_inputs) return result except Exception as e: logger.error(f"Local tool validation failed for {tool_name}: {e}") diff --git a/backend/services/vectordatabase_service.py b/backend/services/vectordatabase_service.py index 8ad9b54e2..9fc8c9112 100644 --- a/backend/services/vectordatabase_service.py +++ b/backend/services/vectordatabase_service.py @@ -28,7 +28,7 @@ from consts.const import DATAMATE_URL, ES_API_KEY, ES_HOST, LANGUAGE, VectorDatabaseType, IS_SPEED_MODE, PERMISSION_EDIT, PERMISSION_READ from consts.model import ChunkCreateRequest, ChunkUpdateRequest -from database.attachment_db import delete_file +from database.attachment_db import delete_file, get_file_stream from database.knowledge_db import ( create_knowledge_record, delete_knowledge_record, @@ -101,6 +101,28 @@ def _get_embedding_model_display_name(model_id: Optional[int], tenant_id: str) - return "" +def _is_multimodal_by_model_id(model_id: Optional[int], tenant_id: str) -> bool: + """ + Determine whether an embedding model is multimodal based on model_id. + + Args: + model_id: The embedding model ID. + tenant_id: Tenant ID for model lookup. + + Returns: + True when the model type is `multi_embedding`, otherwise False. + """ + if model_id is None: + return False + try: + model = get_model_by_model_id(model_id, tenant_id) + if model: + return model.get("model_type") == "multi_embedding" + except Exception as e: + logger.warning(f"Failed to determine multimodal flag for model_id {model_id}: {e}") + return False + + class KnowledgeBaseNeedsModelConfigError(Exception): """Exception raised when a knowledge base needs an embedding model to be configured.""" def __init__(self, index_name: str, message: str = None): @@ -283,8 +305,42 @@ def check_knowledge_base_exist_impl(knowledge_name: str, vdb_core: VectorDatabas # Case B: Name is available in this tenant return {"status": "available"} - -def get_embedding_model(tenant_id: str, model_name: Optional[str] = None) -> tuple[Optional[Any], Optional[int]]: +def _normalize_model_type(raw_model_type: Optional[str]) -> Optional[str]: + if raw_model_type in ["multiEmbedding", "multi_embedding"]: + return "multi_embedding" + if raw_model_type == "embedding": + return "embedding" + return None + +def _build_model_config(model: dict) -> dict: + return { + "model_repo": model.get("model_repo", ""), + "model_name": model["model_name"], + "api_key": model.get("api_key", ""), + "base_url": model.get("base_url", ""), + "model_type": model.get("model_type", "embedding"), + "max_tokens": model.get("max_tokens", 1024), + "ssl_verify": model.get("ssl_verify", True), + } + +def _create_embedding_model(model: dict) -> Any: + model_config = _build_model_config(model) + common_kwargs = { + "api_key": model_config.get("api_key", ""), + "base_url": model_config.get("base_url", ""), + "model_name": get_model_name_from_config(model_config) or "", + "embedding_dim": model_config.get("max_tokens", 1024), + "ssl_verify": model_config.get("ssl_verify", True), + } + if model.get("model_type", "embedding") == "multi_embedding": + return JinaEmbedding(**common_kwargs) + return OpenAICompatibleEmbedding(**common_kwargs) + +def get_embedding_model( + tenant_id: str, + model_name: Optional[str] = None, + model_type: Optional[str] = None +) -> tuple[Optional[Any], Optional[int]]: """ Get the embedding model for the tenant, optionally using a specific model name. @@ -296,40 +352,19 @@ def get_embedding_model(tenant_id: str, model_name: Optional[str] = None) -> tup Returns: Tuple of (embedding model instance or None, model_id or None) """ - # If model_name is provided, find the model by display_name if model_name: try: - model = get_model_by_display_name(model_name, tenant_id) - if model and model.get("model_type") in ["embedding", "multi_embedding"]: - model_config = { - "model_repo": model.get("model_repo", ""), - "model_name": model["model_name"], - "api_key": model.get("api_key", ""), - "base_url": model.get("base_url", ""), - "model_type": model.get("model_type", "embedding"), - "max_tokens": model.get("max_tokens", 1024), - "ssl_verify": model.get("ssl_verify", True), - } - model_type = model.get("model_type", "embedding") - if model_type == "multi_embedding": - embedding_model = JinaEmbedding( - api_key=model_config.get("api_key", ""), - base_url=model_config.get("base_url", ""), - model_name=get_model_name_from_config(model_config) or "", - embedding_dim=model_config.get("max_tokens", 1024), - ssl_verify=model_config.get("ssl_verify", True), - ) - else: - embedding_model = OpenAICompatibleEmbedding( - api_key=model_config.get("api_key", ""), - base_url=model_config.get("base_url", ""), - model_name=get_model_name_from_config(model_config) or "", - embedding_dim=model_config.get("max_tokens", 1024), - ssl_verify=model_config.get("ssl_verify", True), - ) - return embedding_model, model.get("model_id") + normalized_model_type = _normalize_model_type(model_type) + if normalized_model_type: + model = get_model_by_display_name(model_name, tenant_id, normalized_model_type) else: + model = get_model_by_display_name(model_name, tenant_id) + + if not model or model.get("model_type") not in ["embedding", "multi_embedding"]: logger.warning(f"Model '{model_name}' not found or is not an embedding model") + return None, None + + return _create_embedding_model(model), model.get("model_id") except Exception as e: logger.warning(f"Failed to get embedding model by name {model_name}: {e}") @@ -595,6 +630,7 @@ def create_knowledge_base( ingroup_permission: Optional[str] = None, group_ids: Optional[List[int]] = None, embedding_model_name: Optional[str] = None, + is_multimodal: Optional[bool] = None, ): """ Create a new knowledge base with a user-facing name and an internal Elasticsearch index name. @@ -620,7 +656,17 @@ def create_knowledge_base( """ try: # Get embedding model - use user-selected model if provided, otherwise use tenant default - embedding_model, model_id = get_embedding_model(tenant_id, embedding_model_name) + selected_model_type = None + if is_multimodal is True: + selected_model_type = "multi_embedding" + elif is_multimodal is False and embedding_model_name: + selected_model_type = "embedding" + + embedding_model, model_id = get_embedding_model( + tenant_id, + embedding_model_name, + selected_model_type + ) # Determine the embedding model name to save: use user-provided name if available, # otherwise use the model's display name @@ -1002,6 +1048,7 @@ def list_indices( model_id = record.get("embedding_model_id") tenant_id = record.get("tenant_id") or target_tenant_id embedding_model_display_name = _get_embedding_model_display_name(model_id, tenant_id) + is_multimodal = _is_multimodal_by_model_id(model_id, tenant_id) stats_info.append({ # Internal index name (used as ID) @@ -1013,6 +1060,7 @@ def list_indices( # knowledge source and ingroup permission from DB record "knowledge_sources": record["knowledge_sources"], "ingroup_permission": record["ingroup_permission"], + "is_multimodal": is_multimodal, "tenant_id": record.get("tenant_id"), # Embedding model info: display_name from model_id "embedding_model_name": embedding_model_display_name or record.get("embedding_model_name", ""), @@ -1122,12 +1170,27 @@ def index_documents( "author": author, "date": date, "content": text, - "process_source": "Unstructured", + "process_source": metadata.get("process_source", "Unstructured"), "file_size": file_size, "create_time": create_time, "languages": metadata.get("languages", []), "embedding_model_name": embedding_model_name } + + image_url = metadata.get("image_url", "") + if len(image_url) > 0: + # Fetch image bytes from MinIO (supports s3://bucket/key or /bucket/key) + try: + file_stream = get_file_stream( + object_name=image_url) + if file_stream is None: + raise FileNotFoundError( + f"Unable to fetch file from URL: {image_url}") + document["image_bytes"] = file_stream.read() + except Exception as e: + logger.error( + f"Failed to fetch file from {image_url}: {e}") + raise documents.append(document) @@ -1148,8 +1211,9 @@ def index_documents( 'tenant_id') if knowledge_record else None if tenant_id: + model_type = "EMBEDDING_ID" if embedding_model.model_type == "text" else "MULTI_EMBEDDING_ID" model_config = tenant_config_manager.get_model_config( - key="EMBEDDING_ID", tenant_id=tenant_id) + key=model_type, tenant_id=tenant_id) embedding_batch_size = model_config.get("chunk_batch", 10) if embedding_batch_size is None: embedding_batch_size = 10 @@ -1867,6 +1931,7 @@ def update_chunk( chunk_request: ChunkUpdateRequest, vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), user_id: Optional[str] = None, + tenant_id: Optional[str] = None, ): """ Update a chunk document. diff --git a/backend/utils/file_management_utils.py b/backend/utils/file_management_utils.py index 7d31a74bb..37681d9c7 100644 --- a/backend/utils/file_management_utils.py +++ b/backend/utils/file_management_utils.py @@ -15,7 +15,6 @@ from consts.model import ProcessParams from database.attachment_db import get_file_size_from_minio from utils.auth_utils import get_current_user_id -from utils.config_utils import tenant_config_manager logger = logging.getLogger("file_management_utils") @@ -45,18 +44,13 @@ async def trigger_data_process(files: List[dict], process_params: ProcessParams) if not files: return None - # Get chunking size according to the embedding model - embedding_model_id = None + # Get tenant_id from authorization for downstream task processing + embedding_model_id = process_params.model_id tenant_id = None try: _, tenant_id = get_current_user_id(process_params.authorization) - # Get embedding model ID from tenant config - tenant_config = tenant_config_manager.load_config(tenant_id) - embedding_model_id_str = tenant_config.get("EMBEDDING_ID") if tenant_config else None - if embedding_model_id_str: - embedding_model_id = int(embedding_model_id_str) except Exception as e: - logger.warning(f"Failed to get embedding model ID for tenant: {e}") + logger.warning(f"Failed to get tenant_id from authorization: {e}") # Build headers with authorization headers = { diff --git a/docker/.env.bak b/docker/.env.bak deleted file mode 100644 index 24b53751b..000000000 --- a/docker/.env.bak +++ /dev/null @@ -1,168 +0,0 @@ -# ===== Necessary Configs (Necessary till now, will be migrated to frontend page) ===== - -# Voice Service Config -APPID=app_id -TOKEN=token - -# ===== Non-essential Configs (Modify if you know what you are doing) ===== - -CLUSTER=volcano_tts -VOICE_TYPE=zh_male_jieshuonansheng_mars_bigtts -SPEED_RATIO=1.3 - -# ===== Proxy Configuration (Optional) ===== - -# HTTP_PROXY=http://proxy-server:port -# HTTPS_PROXY=http://proxy-server:port -# NO_PROXY=localhost,127.0.0.1 - -# ===== Backend Configuration (No need to modify at all) ===== - -# Model Path Config -CLIP_MODEL_PATH=/opt/models/clip-vit-base-patch32 -NLTK_DATA=/opt/models/nltk_data - -# Elasticsearch Service -ELASTICSEARCH_HOST=http://nexent-elasticsearch:9200 -ELASTIC_PASSWORD=nexent@2025 - -# Elasticsearch Memory Configuration -ES_JAVA_OPTS="-Xms2g -Xmx2g" - -# Elasticsearch Disk Watermark Configuration -ES_DISK_WATERMARK_LOW=85% -ES_DISK_WATERMARK_HIGH=90% -ES_DISK_WATERMARK_FLOOD_STAGE=95% - -# Main Services -# Config service (port 5010) - Main API service for config operations -CONFIG_SERVICE_URL=http://nexent-config:5010 -ELASTICSEARCH_SERVICE=http://nexent-config:5010/api - -# Runtime service (port 5014) - Runtime execution service for agent operations -RUNTIME_SERVICE_URL=http://nexent-runtime:5014 - -# MCP service (port 5011) - MCP protocol service -NEXENT_MCP_SERVER=http://nexent-mcp:5011 -MCP_MANAGEMENT_API=http://nexent-mcp:5015 - -# Data process service (port 5012) - Data processing service -DATA_PROCESS_SERVICE=http://nexent-data-process:5012/api - -# Northbound service (port 5013) - Northbound API service -NORTHBOUND_API_SERVER=http://nexent-northbound:5013/api - -# Postgres Config -POSTGRES_HOST=nexent-postgresql -POSTGRES_USER=root -NEXENT_POSTGRES_PASSWORD=nexent@4321 -POSTGRES_DB=nexent -POSTGRES_PORT=5432 - -# Minio Config -MINIO_ENDPOINT=http://nexent-minio:9000 -MINIO_ROOT_USER=nexent -MINIO_ROOT_PASSWORD=nexent@4321 -MINIO_REGION=cn-north-1 -MINIO_DEFAULT_BUCKET=nexent - -# Redis Config -REDIS_URL=redis://redis:6379/0 -REDIS_BACKEND_URL=redis://redis:6379/1 - -# Model Engine Config -MODEL_ENGINE_ENABLED=false - -# Supabase Config -DASHBOARD_USERNAME=supabase -DASHBOARD_PASSWORD=Huawei123 - -# Supabase db Config -SUPABASE_POSTGRES_PASSWORD=Huawei123 -SUPABASE_POSTGRES_HOST=db -SUPABASE_POSTGRES_DB=supabase -SUPABASE_POSTGRES_PORT=5436 - -# Supabase Auth Config -SITE_URL=http://localhost:3011 -SUPABASE_URL=http://supabase-kong-mini:8000 -API_EXTERNAL_URL=http://supabase-kong-mini:8000 -DISABLE_SIGNUP=false -JWT_EXPIRY=3600 -DEBUG_JWT_EXPIRE_SECONDS=0 - -# Supabase Configuration -ENABLE_EMAIL_SIGNUP=true -ENABLE_EMAIL_AUTOCONFIRM=true -ENABLE_ANONYMOUS_USERS=false - -# Supabase Phone Config -ENABLE_PHONE_SIGNUP=false -ENABLE_PHONE_AUTOCONFIRM=false - -MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify" -MAILER_URLPATHS_INVITE="/auth/v1/verify" -MAILER_URLPATHS_RECOVERY="/auth/v1/verify" -MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify" - -INVITE_CODE=nexent2025 - -# Terminal Tool SSH Key Path -SSH_PRIVATE_KEY_PATH=/path/to/openssh-server/ssh-keys/openssh_server_key - -# ===== Data Processing Service Configuration ===== - -# Redis Port -REDIS_PORT=6379 - -# Flower Monitoring -FLOWER_PORT=5555 - -# Ray Configuration -RAY_ACTOR_NUM_CPUS=2 -RAY_DASHBOARD_PORT=8265 -RAY_DASHBOARD_HOST=0.0.0.0 -RAY_NUM_CPUS=4 -RAY_OBJECT_STORE_MEMORY_GB=0.25 -RAY_TEMP_DIR=/tmp/ray -RAY_LOG_LEVEL=INFO - -# Service Control Flags -DISABLE_RAY_DASHBOARD=true -DISABLE_CELERY_FLOWER=true -DOCKER_ENVIRONMENT=false -ENABLE_UPLOAD_IMAGE=false - -# Celery Configuration -CELERY_WORKER_PREFETCH_MULTIPLIER=1 -CELERY_TASK_TIME_LIMIT=3600 -ELASTICSEARCH_REQUEST_TIMEOUT=30 - -# Worker Configuration -QUEUES=process_q,forward_q -WORKER_NAME= -WORKER_CONCURRENCY=4 - -# Skills Configuration -SKILLS_PATH=/mnt/nexent/skills - -# Telemetry and Monitoring Configuration -ENABLE_TELEMETRY=false -SERVICE_NAME=nexent-backend -JAEGER_ENDPOINT=http://localhost:14268/api/traces -PROMETHEUS_PORT=8000 -TELEMETRY_SAMPLE_RATE=1.0 -LLM_SLOW_REQUEST_THRESHOLD_SECONDS=5.0 -LLM_SLOW_TOKEN_RATE_THRESHOLD=10.0 - -# Market Backend Address -MARKET_BACKEND=http://60.204.251.153:8010 -DEPLOYMENT_VERSION="speed" -# Root dir -ROOT_DIR="/c/Users/18270/nexent-data" -TERMINAL_MOUNT_DIR="/opt/terminal" -SSH_USERNAME="root" -SSH_PASSWORD="731215" -NEXENT_MCP_DOCKER_IMAGE="ccr.ccs.tencentyun.com/nexent-hub/nexent-mcp:v2.0.1" -MINIO_ACCESS_KEY="72c31cb5b521511cea652723" -MINIO_SECRET_KEY="m5gcSuKzZnp84CqmG7z5VKnd2C+H5U3PSr7eoJeygmI=" diff --git a/docker/.env.example b/docker/.env.example index e55bba45a..25fe228de 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -21,6 +21,8 @@ SPEED_RATIO=1.3 # Model Path Config CLIP_MODEL_PATH=/opt/models/clip-vit-base-patch32 NLTK_DATA=/opt/models/nltk_data +TABLE_TRANSFORMER_MODEL_PATH=/opt/models/table-transformer-structure-recognition +UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH=/opt/models/yolox/config.json # Elasticsearch Service ELASTICSEARCH_HOST=http://nexent-elasticsearch:9200 diff --git a/docker/deploy.sh b/docker/deploy.sh index 7fb78aa90..8b012f5c6 100755 --- a/docker/deploy.sh +++ b/docker/deploy.sh @@ -17,6 +17,7 @@ DEPLOY_OPTIONS_FILE="$SCRIPT_DIR/deploy.options" MODE_CHOICE_SAVED="" VERSION_CHOICE_SAVED="" IS_MAINLAND_SAVED="" +ENABLE_EXTRACTED_IMAGE_MODELS_SAVED="N" ENABLE_SKILLS_SAVED="Y" ENABLE_TERMINAL_SAVED="N" TERMINAL_MOUNT_DIR_SAVED="${TERMINAL_MOUNT_DIR:-}" @@ -85,6 +86,58 @@ is_windows_env() { return 1 } +detect_os_type() { + # Return: windows | mac | linux | unknown + local os_name + os_name=$(uname -s 2>/dev/null | tr '[:upper:]' '[:lower:]') + case "$os_name" in + mingw*|msys*|cygwin*) + echo "windows" + ;; + darwin*) + echo "mac" + ;; + linux*) + echo "linux" + ;; + *) + echo "unknown" + ;; + esac + return 0 +} + +format_path_for_env() { + # Convert path to OS-specific format for .env values + local input_path="$1" + local os_type + os_type=$(detect_os_type) + + if [[ "$os_type" = "windows" ]]; then + if command -v cygpath >/dev/null 2>&1; then + cygpath -w "$input_path" + return 0 + fi + + if [[ "$input_path" =~ ^/([a-zA-Z])/(.*)$ ]]; then + local drive="${BASH_REMATCH[1]}" + local rest="${BASH_REMATCH[2]}" + rest="${rest//\//\\}" + printf "%s:\\%s" "$(echo "$drive" | tr '[:lower:]' '[:upper:]')" "$rest" + return 0 + fi + fi + + printf "%s" "$input_path" +} + +escape_backslashes() { + # Escape backslashes for safe writing into .env or JSON + local input_path="$1" + printf "%s" "$input_path" | sed 's/\\/\\\\/g' + return 0 +} + is_port_in_use() { # Check if a TCP port is already in use (Linux/macOS/Windows Git Bash) local port="$1" @@ -272,6 +325,7 @@ persist_deploy_options() { echo "MODE_CHOICE=\"${MODE_CHOICE_SAVED}\"" echo "VERSION_CHOICE=\"${VERSION_CHOICE_SAVED}\"" echo "IS_MAINLAND=\"${IS_MAINLAND_SAVED}\"" + echo "ENABLE_EXTRACTED_IMAGE_MODELS_SAVED=\"${ENABLE_EXTRACTED_IMAGE_MODELS_SAVED}\"" echo "ENABLE_SKILLS=\"${ENABLE_SKILLS_SAVED}\"" echo "ENABLE_TERMINAL=\"${ENABLE_TERMINAL_SAVED}\"" echo "TERMINAL_MOUNT_DIR=\"${TERMINAL_MOUNT_DIR_SAVED}\"" @@ -414,7 +468,7 @@ get_compose_version() { # Function to get the version of docker compose if command -v docker &> /dev/null; then version_output=$(docker compose version 2>/dev/null) - # 修改点:放宽正则匹配,允许版本号后面跟随其他字符(如 -desktop.1) + if [[ $version_output =~ v([0-9]+\.[0-9]+\.[0-9]+) ]]; then echo "v2 ${BASH_REMATCH[1]}" return 0 @@ -423,7 +477,7 @@ get_compose_version() { if command -v docker-compose &> /dev/null; then version_output=$(docker-compose --version 2>/dev/null) - # 同样放宽这里的匹配规则,以防万一 + if [[ $version_output =~ ([0-9]+\.[0-9]+\.[0-9]+) ]]; then echo "v1 ${BASH_REMATCH[1]}" return 0 @@ -537,6 +591,43 @@ select_deployment_mode() { echo "" } + +# Extracted image models selection +select_extracted_image_models_mode() { + echo "" + + local input_choice="" + read -r -p "Do you want to enable pre-extracted image models mode? [Y/N] (default: N): " input_choice + echo "" + + if [[ $input_choice =~ ^[Yy]$ ]]; then + ENABLE_EXTRACTED_IMAGE_MODELS_SAVED="Y" + echo "INFO: ENABLE_EXTRACTED_IMAGE_MODELS_SAVED=Y, deployment will not change model path variables." + else + ENABLE_EXTRACTED_IMAGE_MODELS_SAVED="N" + echo "INFO: ENABLE_EXTRACTED_IMAGE_MODELS_SAVED=N, deployment will clear model path variables." + fi + echo "----------------------------------------" + echo "" + return 0 +} + +configure_extracted_image_models_env() { + # New behavior: + # - ENABLE_EXTRACTED_IMAGE_MODELS_SAVED=N: clear the two model-path values in .env + # - ENABLE_EXTRACTED_IMAGE_MODELS_SAVED=Y: do nothing (no .env update) + if [[ "$ENABLE_EXTRACTED_IMAGE_MODELS_SAVED" =~ ^[Yy]$ ]]; then + echo "INFO: ENABLE_EXTRACTED_IMAGE_MODELS_SAVED=Y, skip model handling (no .env update)." + return 0 + fi + + echo "INFO: ENABLE_EXTRACTED_IMAGE_MODELS_SAVED=N, clearing model path variables in .env..." + update_env_var "TABLE_TRANSFORMER_MODEL_PATH" "" + update_env_var "UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH" "" + echo "INFO: Cleared TABLE_TRANSFORMER_MODEL_PATH and UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH." + return 0 +} + clean() { export MINIO_ACCESS_KEY= export MINIO_SECRET_KEY= @@ -609,6 +700,13 @@ prepare_directory_and_data() { create_dir_with_permission "$ROOT_DIR/minio" 775 create_dir_with_permission "$ROOT_DIR/redis" 775 + echo "📦 Check the status of model configuration..." + configure_extracted_image_models_env || { + echo "⚠️ A warning occurred during the model configuration step, but subsequent deployment will proceed..." + # Do not exit here; the user may choose N or prefer to continue after a model-path handling failure. + } + echo "" + cp -rn volumes $ROOT_DIR chmod -R 775 $ROOT_DIR/volumes echo " 📁 Directory $ROOT_DIR/volumes has been created and permissions set to 775." @@ -1301,6 +1399,8 @@ main_deploy() { choose_image_env || { echo "❌ Image environment setup failed"; exit 1; } select_skills_installation || { echo "❌ Skills installation selection failed"; exit 1; } + select_extracted_image_models_mode || { echo "❌ Extracted image models configuration failed"; exit 1;} + # Set NEXENT_MCP_DOCKER_IMAGE in .env file if [ -n "${NEXENT_MCP_DOCKER_IMAGE:-}" ]; then update_env_var "NEXENT_MCP_DOCKER_IMAGE" "${NEXENT_MCP_DOCKER_IMAGE}" @@ -1459,7 +1559,7 @@ docker_compose_command="" case $version_type in "v1") echo "Detected Docker Compose V1, version: $version_number" - # The version ​​v1.28.0​​ is the minimum requirement in Docker Compose v1 that explicitly supports interpolation syntax with default values like ${VAR:-default} + # The version 1.28.0 is the minimum requirement in Docker Compose v1 for default interpolation syntax. if [[ $version_number < "1.28.0" ]]; then echo "Warning: V1 version is too old, consider upgrading to V2" exit 1 diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx index 909592345..c109f5722 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx @@ -102,7 +102,8 @@ export default function ToolManagement({ // Use tool list hook for data management const { availableTools } = useToolList(); - const { isVlmAvailable, isEmbeddingAvailable } = useConfig(); + const { isVlmAvailable, isEmbeddingAvailable, isMultiEmbeddingAvailable } = useConfig(); + const isEmbeddingOrMultiAvailable = isEmbeddingAvailable || isMultiEmbeddingAvailable; // Prefetch knowledge bases for KB tools const { prefetchKnowledgeBases } = usePrefetchKnowledgeBases(); @@ -363,7 +364,10 @@ export default function ToolManagement({ tool.id ); const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable); - const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); + const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding( + tool.name, + isEmbeddingOrMultiAvailable + ); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding const tooltipTitle = isReadOnly @@ -468,7 +472,10 @@ export default function ToolManagement({ {group.tools.map((tool) => { const isSelected = originalSelectedToolIdsSet.has(tool.id); const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable); - const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); + const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding( + tool.name, + isEmbeddingOrMultiAvailable + ); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding const tooltipTitle = isReadOnly diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx index 53c6d3f03..fe1ac6e2b 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx @@ -1,4 +1,4 @@ -"use client"; +"use client"; import { useState, useEffect, useCallback, useMemo, useRef } from "react"; import { useTranslation } from "react-i18next"; @@ -36,6 +36,10 @@ import { API_ENDPOINTS } from "@/services/api"; import knowledgeBaseService from "@/services/knowledgeBaseService"; import log from "@/lib/logger"; import { isZhLocale, getLocalizedDescription } from "@/lib/utils"; +import { + isEmbeddingModelCompatible as isEmbeddingModelCompatibleBase, + isMultimodalConstraintMismatch as isMultimodalConstraintMismatchBase, +} from "@/lib/knowledgeBaseCompatibility"; export interface ToolConfigModalProps { isOpen: boolean; @@ -524,6 +528,86 @@ export default function ToolConfigModal({ } }, [configData]); + const currentMultiEmbeddingModel = useMemo(() => { + try { + const modelConfig = configData?.models; + return ( + modelConfig?.multiEmbedding?.modelName || + modelConfig?.multiEmbedding?.displayName || + null + ); + } catch { + return null; + } + }, [configData]); + + const hasEmbeddingModel = Boolean(currentEmbeddingModel); + const hasMultiEmbeddingModel = Boolean(currentMultiEmbeddingModel); + const canToggleMultimodalParam = hasEmbeddingModel && hasMultiEmbeddingModel; + const forcedMultimodalValue = useMemo(() => { + if (!hasEmbeddingModel && hasMultiEmbeddingModel) { + return true; + } + if (hasEmbeddingModel && !hasMultiEmbeddingModel) { + return false; + } + return null; + }, [hasEmbeddingModel, hasMultiEmbeddingModel]); + + const toolMultimodal = useMemo(() => { + const multimodalParam = currentParams.find( + (param) => param.name === "multimodal" + ); + const value = multimodalParam?.value; + if (typeof value === "boolean") { + return value; + } + if (typeof value === "string") { + const normalized = value.trim().toLowerCase(); + if (["true", "1", "yes", "y"].includes(normalized)) return true; + if (["false", "0", "no", "n"].includes(normalized)) return false; + } + return null; + }, [currentParams]); + + useEffect(() => { + if (tool?.name !== "knowledge_base_search") return; + if (forcedMultimodalValue === null) return; + + const index = currentParams.findIndex( + (param) => param.name === "multimodal" + ); + if (index < 0) return; + + const param = currentParams[index]; + if (param.value === forcedMultimodalValue) return; + + const updatedParams = [...currentParams]; + updatedParams[index] = { ...param, value: forcedMultimodalValue }; + setCurrentParams(updatedParams); + + const fieldName = `param_${index}`; + form.setFieldValue(fieldName, forcedMultimodalValue); + }, [tool?.name, forcedMultimodalValue, currentParams, form]); + + const isMultimodalConstraintMismatch = useCallback( + (kb: KnowledgeBase) => { + return isMultimodalConstraintMismatchBase(kb, toolMultimodal); + }, + [toolMultimodal] + ); + + const isEmbeddingModelCompatible = useCallback( + (kb: KnowledgeBase) => { + return isEmbeddingModelCompatibleBase( + kb, + currentEmbeddingModel, + currentMultiEmbeddingModel + ); + }, + [currentEmbeddingModel, currentMultiEmbeddingModel] + ); + // Check if a knowledge base can be selected const canSelectKnowledgeBase = useCallback( (kb: KnowledgeBase): boolean => { @@ -534,9 +618,16 @@ export default function ToolConfigModal({ return false; } + if (kb.source === "nexent") { + if (isMultimodalConstraintMismatch(kb)) { + return false; + } + return isEmbeddingModelCompatible(kb); + } + return true; }, - [currentEmbeddingModel] + [isEmbeddingModelCompatible, isMultimodalConstraintMismatch] ); // Track whether this is the first time opening the modal (reset when modal closes) @@ -1451,7 +1542,7 @@ export default function ToolConfigModal({ })} options={options.map((option) => ({ value: option, - label: option, + label: String(option), }))} /> ); @@ -1866,6 +1957,8 @@ export default function ToolConfigModal({ syncLoading={kbLoading} isSelectable={canSelectKnowledgeBase} currentEmbeddingModel={currentEmbeddingModel} + currentMultiEmbeddingModel={currentMultiEmbeddingModel} + toolMultimodal={toolMultimodal} difyConfig={ toolKbType === "dify_search" ? difyConfig diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolTestPanel.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolTestPanel.tsx index b6af20594..767fae3a7 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolTestPanel.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolTestPanel.tsx @@ -2,9 +2,8 @@ import { useState, useEffect, useRef } from "react"; import { useTranslation } from "react-i18next"; -import { Input, Button, Card, Typography, Tooltip, Modal, Form, Tag, Skeleton } from "antd"; +import { Input, Button, Card, Typography, Tooltip, Modal, Form } from "antd"; import { Settings, PenLine, X } from "lucide-react"; -import { CloseOutlined } from "@ant-design/icons"; import { Tool, ToolParam } from "@/types/agentConfig"; import { KnowledgeBase } from "@/types/knowledgeBase"; @@ -59,15 +58,8 @@ export default function ToolTestPanel({ configParams, onClose, toolRequiresKbSelection = false, - knowledgeBases = [], - kbLoading = false, - onOpenKbSelector, selectedKbIds = [], - selectedKbDisplayNames = [], - onKbSelectionChange, - onRemoveKb, toolKbType = null, - haotianKnowledgeSets = [], }: ToolTestPanelProps) { const { t } = useTranslation("common"); const [form] = Form.useForm(); @@ -605,28 +597,10 @@ export default function ToolTestPanel({ // Haotian uses dataset_ids, others use index_names const isKbSelectorParam = (paramName === "index_names" || paramName === "dataset_ids") && toolRequiresKbSelection; - // Get display names based on selected KB IDs and knowledge bases - let displayNames: string[] = []; - if (isKbSelectorParam && selectedKbIds.length > 0) { - if (toolKbType === "haotian_search" && haotianKnowledgeSets.length > 0) { - // Haotian: resolve names from haotianKnowledgeSets - displayNames = selectedKbIds.map((id) => { - const cleanId = id.trim(); - for (const ks of haotianKnowledgeSets) { - const kb = (ks.knowledge_bases || []).find( - (b) => String(b.dify_dataset_id) === cleanId - ); - if (kb) return kb.name; - } - return cleanId; - }); - } else if (knowledgeBases.length > 0) { - displayNames = selectedKbIds.map((id) => { - const cleanId = id.trim(); - const kb = knowledgeBases.find((k) => k.id === cleanId); - return kb?.display_name || kb?.name || cleanId; - }); - } + // KB selection is configured in the upper config area. + // Do not render duplicated KB params in the test input area. + if (isKbSelectorParam) { + return null; } // Add type-specific validation rules @@ -681,83 +655,6 @@ export default function ToolTestPanel({ break; } - // Render knowledge base selector for index_names parameter - if (isKbSelectorParam) { - return ( - - {paramName} - - } - name={fieldName} - rules={rules} - tooltip={{ - title: getLocalizedDescription(description, description_zh), - placement: "topLeft", - styles: { root: { maxWidth: 400 } }, - }} - > -
-
onOpenKbSelector?.(-1)} // -1 indicates this is from test panel - style={{ - width: "100%", - minHeight: "32px", - display: "flex", - flexWrap: "wrap", - alignItems: "center", - gap: "4px", - }} - title={displayNames.join(", ")} - > - {kbLoading && knowledgeBases.length === 0 ? ( -
- -
- ) : displayNames.length > 0 ? ( - displayNames.map((name, i) => ( - - - - } - onClose={(e) => { - e.stopPropagation(); - onRemoveKb?.(i, -1); // -1 indicates this is from test panel - }} - style={{ - marginRight: 0, - display: "inline-flex", - alignItems: "center", - lineHeight: "20px", - padding: "0 8px", - fontSize: "13px", - }} - > - {name} - - )) - ) : ( - - {t("toolConfig.input.knowledgeBaseSelector.placeholder", { - name: getLocalizedDescription(description, description_zh) || paramName, - })} - - )} -
-
-
- ); - } - return ( + (type || "").trim().toLowerCase(); + +const toEmbeddingModelOptionValue = (displayName: string, type: string) => + `${displayName}${EMBEDDING_MODEL_OPTION_DELIMITER}${type}`; + +const parseEmbeddingModelOptionValue = (value: string) => { + const normalizedValue = (value || "").trim(); + const delimiterIndex = normalizedValue.lastIndexOf( + EMBEDDING_MODEL_OPTION_DELIMITER + ); + if (delimiterIndex >= 0) { + const displayName = normalizedValue.slice(0, delimiterIndex); + const type = normalizedValue.slice( + delimiterIndex + EMBEDDING_MODEL_OPTION_DELIMITER.length + ); + return { + displayName: displayName || "", + type: (type || "").trim(), + isMultimodal: + normalizeEmbeddingModelType(type || "") === "multi_embedding", + }; + } + return { + displayName: normalizedValue || "", + type: "", + isMultimodal: false, + }; +}; + // EmptyState component defined directly in this file interface EmptyStateProps { icon?: React.ReactNode | string; @@ -55,7 +87,7 @@ interface EmptyStateProps { } const EmptyState: React.FC = ({ - icon = "📋", + icon = "馃搵", title, description, action, @@ -129,8 +161,7 @@ function DataConfig({ isActive }: DataConfigProps) { const { token } = theme.useToken(); // Get available embedding models for knowledge base creation - const { availableEmbeddingModels } = useModelList({ enabled: true }); - + const { models } = useModelList({ enabled: true }); // Clear cache when component initializes useEffect(() => { localStorage.removeItem("preloaded_kb_data"); @@ -198,11 +229,59 @@ function DataConfig({ isActive }: DataConfigProps) { const [modelFilter, setModelFilter] = useState([]); const contentRef = useRef(null); - // Open warning modal when single Embedding model is not configured (ignore multi-embedding) + const availableEmbeddingModels = useMemo(() => { + const embeddingRelatedModels = models.filter( + (model) => model.type === "embedding" || model.type === "multi_embedding" + ); + const availableKeys = new Set( + embeddingRelatedModels + .filter((model) => model.connect_status === "available") + .map((model) => `${model.displayName}::${model.type}`) + ); + + return embeddingRelatedModels.filter((model) => { + if (model.connect_status === "available") { + return true; + } + + // For paired records created from a multi-embedding model, mirror availability by display name. + if (model.type === "embedding") { + return availableKeys.has(`${model.displayName}::multi_embedding`); + } + if (model.type === "multi_embedding") { + return availableKeys.has(`${model.displayName}::embedding`); + } + return false; + }); + }, [models]); + + const resolveEmbeddingModelId = useCallback( + ({ + displayName, + isMultimodal, + }: { + displayName?: string; + isMultimodal?: boolean; + }) => { + const normalizedDisplayName = (displayName || "").trim(); + if (!normalizedDisplayName) return undefined; + + const modelType = isMultimodal ? "multi_embedding" : "embedding"; + return availableEmbeddingModels.find( + (model) => + model.displayName === normalizedDisplayName && model.type === modelType + )?.id; + }, + [availableEmbeddingModels] + ); + + // Open warning modal only when neither embedding nor multi-embedding is configured. useEffect(() => { - const singleEmbeddingModelName = modelConfig?.embedding?.modelName; - setShowEmbeddingWarning(!singleEmbeddingModelName); - }, [modelConfig?.embedding?.modelName]); + const singleEmbeddingModelName = modelConfig?.embedding?.modelName?.trim(); + const multiEmbeddingModelName = + modelConfig?.multiEmbedding?.modelName?.trim(); + setShowEmbeddingWarning(!singleEmbeddingModelName && !multiEmbeddingModelName); + }, [modelConfig?.embedding?.modelName, modelConfig?.multiEmbedding?.modelName]); // Add event listener for selecting new knowledge base useEffect(() => { @@ -370,11 +449,11 @@ function DataConfig({ isActive }: DataConfigProps) { // Directly call fetchKnowledgeBases to update knowledge base list data await fetchKnowledgeBases(false, true); } catch (error) { - log.error("获取知识库最新数据失败:", error); + log.error("鑾峰彇鐭ヨ瘑搴撴渶鏂版暟鎹け璐?", error); } }, 100); } catch (error) { - log.error("获取文档列表失败:", error); + log.error("鑾峰彇鏂囨。鍒楄〃澶辫触:", error); message.error(t("knowledgeBase.message.getDocumentsFailed")); docDispatch({ type: "ERROR", @@ -619,11 +698,30 @@ function DataConfig({ isActive }: DataConfigProps) { setNewKbName(defaultName); setNewKbIngroupPermission("READ_ONLY"); setNewKbGroupIds([]); - // Set default embedding model - prioritize config's default model, fall back to first available model - const configModel = modelConfig?.embedding?.modelName; - const defaultModel = configModel || (availableEmbeddingModels.length > 0 - ? availableEmbeddingModels[0].displayName - : ""); + // Set default embedding model: + // 1) configured embedding model, 2) configured multimodal model, 3) first available option. + const configEmbeddingModel = modelConfig?.embedding?.modelName?.trim() || ""; + const configMultiEmbeddingModel = + modelConfig?.multiEmbedding?.modelName?.trim() || ""; + const preferredModel = [ + { modelName: configEmbeddingModel, type: "embedding" }, + { modelName: configMultiEmbeddingModel, type: "multi_embedding" }, + ].find( + ({ modelName, type }) => + !!modelName && + availableEmbeddingModels.some( + (model) => model.displayName === modelName && model.type === type + ) + ); + const defaultModel = + (preferredModel && + toEmbeddingModelOptionValue(preferredModel.modelName, preferredModel.type)) || + (availableEmbeddingModels[0] + ? toEmbeddingModelOptionValue( + availableEmbeddingModels[0].displayName, + availableEmbeddingModels[0].type + ) + : ""); setNewKbEmbeddingModel(defaultModel); setIsCreatingMode(true); setHasClickedUpload(false); // Reset upload button click state @@ -682,13 +780,22 @@ function DataConfig({ isActive }: DataConfigProps) { return; } + const parsedSelectedModel = + parseEmbeddingModelOptionValue(newKbEmbeddingModel); + const isMultimodal = parsedSelectedModel.isMultimodal; + const selectedModelId = resolveEmbeddingModelId({ + displayName: parsedSelectedModel.displayName, + isMultimodal: parsedSelectedModel.isMultimodal, + }); + const newKB = await createKnowledgeBase( newKbName.trim(), t("knowledgeBase.description.default"), "elasticsearch", newKbIngroupPermission, newKbGroupIds, - newKbEmbeddingModel + parsedSelectedModel.displayName, + isMultimodal ); if (!newKB) { @@ -703,7 +810,7 @@ function DataConfig({ isActive }: DataConfigProps) { setHasClickedUpload(false); setNewlyCreatedKbId(newKB.id); // Mark this KB as newly created - await uploadDocuments(newKB.id, filesToUpload); + await uploadDocuments(newKB.id, filesToUpload, selectedModelId); setUploadFiles([]); knowledgeBasePollingService @@ -739,7 +846,12 @@ function DataConfig({ isActive }: DataConfigProps) { } try { - await uploadDocuments(kbId, filesToUpload); + const activeKbModelId = resolveEmbeddingModelId({ + displayName: kbState.activeKnowledgeBase?.embeddingModel, + isMultimodal: kbState.activeKnowledgeBase?.is_multimodal, + }); + + await uploadDocuments(kbId, filesToUpload, activeKbModelId); setUploadFiles([]); knowledgeBasePollingService.triggerKnowledgeBaseListUpdate(true); @@ -888,7 +1000,7 @@ function DataConfig({ isActive }: DataConfigProps) { = ({ knowledgeBaseId, documents, getFileIcon, - currentEmbeddingModel = null, - knowledgeBaseEmbeddingModel = "", + currentEmbeddingModel, + knowledgeBaseEmbeddingModel, onChunkCountChange, permission, }) => { @@ -128,55 +128,31 @@ const DocumentChunk: React.FC = ({ setTooltipResetKey((prev) => prev + 1); }, []); + const effectiveIndexName = knowledgeBaseId || knowledgeBaseName; + + const hasKnowledgeBaseModel = + Boolean(knowledgeBaseEmbeddingModel) && + knowledgeBaseEmbeddingModel !== "unknown"; + const hasCurrentModel = Boolean(currentEmbeddingModel); + // Determine if embedding models mismatch (specific condition for tooltip) const isEmbeddingModelMismatch = React.useMemo(() => { - if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) { + if (!hasKnowledgeBaseModel) { return false; } - if (knowledgeBaseEmbeddingModel === "unknown") { - return false; - } - return currentEmbeddingModel !== knowledgeBaseEmbeddingModel; - }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel]); + return !hasCurrentModel || currentEmbeddingModel !== knowledgeBaseEmbeddingModel; + }, [ + currentEmbeddingModel, + hasCurrentModel, + hasKnowledgeBaseModel, + knowledgeBaseEmbeddingModel, + ]); // Determine if in read-only mode (embedding model mismatch OR user has READ_ONLY permission) // Note: isReadOnlyMode is broader, includes model mismatch and other conditions const isReadOnlyMode = React.useMemo(() => { - // Check if user has READ_ONLY permission - if (permission === "READ_ONLY") { - return true; - } - if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) { - return false; - } - if (knowledgeBaseEmbeddingModel === "unknown") { - return false; - } - return currentEmbeddingModel !== knowledgeBaseEmbeddingModel; - }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel, permission]); - - // Determine if search should be disabled (only when embedding model mismatch, NOT for READ_ONLY permission) - // This allows READ_ONLY users to still perform search - const isSearchDisabled = React.useMemo(() => { - if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) { - return false; - } - if (knowledgeBaseEmbeddingModel === "unknown") { - return false; - } - return currentEmbeddingModel !== knowledgeBaseEmbeddingModel; - }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel]); - - // Disabled tooltip message when embedding model mismatch - const disabledTooltipMessage = React.useMemo(() => { - if (isEmbeddingModelMismatch && currentEmbeddingModel && knowledgeBaseEmbeddingModel && knowledgeBaseEmbeddingModel !== "unknown") { - return t("document.chunk.tooltip.disabledDueToModelMismatch", { - currentModel: currentEmbeddingModel, - knowledgeBaseModel: knowledgeBaseEmbeddingModel - }); - } - return ""; - }, [isEmbeddingModelMismatch, currentEmbeddingModel, knowledgeBaseEmbeddingModel, t]); + return permission === "READ_ONLY" || isEmbeddingModelMismatch; + }, [permission, isEmbeddingModelMismatch]); // Set active document when documents change useEffect(() => { @@ -201,14 +177,14 @@ const DocumentChunk: React.FC = ({ // Load chunks for active document with server-side pagination const loadChunks = React.useCallback(async () => { - if (!knowledgeBaseName || !activeDocumentKey) { + if (!effectiveIndexName || !activeDocumentKey) { return; } setLoading(true); try { const result = await knowledgeBaseService.previewChunksPaginated( - knowledgeBaseName, + effectiveIndexName, pagination.page, pagination.pageSize, activeDocumentKey @@ -240,7 +216,7 @@ const DocumentChunk: React.FC = ({ setLoading(false); } }, [ - knowledgeBaseName, + effectiveIndexName, activeDocumentKey, pagination.page, pagination.pageSize, @@ -321,16 +297,7 @@ const DocumentChunk: React.FC = ({ return; } - // Check embedding model consistency before searching - if (isEmbeddingModelMismatch && currentEmbeddingModel && knowledgeBaseEmbeddingModel && knowledgeBaseEmbeddingModel !== "unknown") { - message.error(t("document.chunk.error.searchFailed", { - currentModel: currentEmbeddingModel, - knowledgeBaseModel: knowledgeBaseEmbeddingModel - })); - return; - } - - if (!knowledgeBaseName) { + if (!effectiveIndexName) { message.error(t("document.chunk.error.searchFailed")); return; } @@ -340,7 +307,7 @@ const DocumentChunk: React.FC = ({ try { const response = await knowledgeBaseService.hybridSearch( - knowledgeBaseId, + effectiveIndexName, trimmedValue, { topK: pagination.pageSize, @@ -352,11 +319,14 @@ const DocumentChunk: React.FC = ({ return { id: item.id || "", content: item.content || "", - path_or_url: item.path_or_url, + path_or_url: item.path_or_url || item.url || item.pathOrUrl, filename: item.filename, create_time: item.create_time, score: item.score, // Preserve search score for display - source_type: item.source_type, // Preserve source type for display + source_type: + item.source_type === "local" || item.source_type === "minio" + ? "file" + : item.source_type, // Preserve source type for display }; }); @@ -373,16 +343,12 @@ const DocumentChunk: React.FC = ({ setChunkSearchLoading(false); } }, [ - knowledgeBaseName, - knowledgeBaseId, + effectiveIndexName, message, pagination.pageSize, resetChunkSearch, searchValue, t, - isEmbeddingModelMismatch, - currentEmbeddingModel, - knowledgeBaseEmbeddingModel, ]); const refreshChunks = React.useCallback(async () => { @@ -454,7 +420,7 @@ const DocumentChunk: React.FC = ({ }; const handleChunkSubmit = async () => { - if (!knowledgeBaseName) { + if (!effectiveIndexName) { message.error(t("document.chunk.error.loadFailed")); return; } @@ -463,26 +429,12 @@ const DocumentChunk: React.FC = ({ return; } - // Check embedding model consistency before creating chunk - if (chunkModalMode === "create") { - if (knowledgeBaseEmbeddingModel && - knowledgeBaseEmbeddingModel !== "unknown" && - currentEmbeddingModel && - currentEmbeddingModel !== knowledgeBaseEmbeddingModel) { - message.error(t("document.chunk.error.createFailed", { - currentModel: currentEmbeddingModel, - knowledgeBaseModel: knowledgeBaseEmbeddingModel - })); - return; - } - } - try { const values = await chunkForm.validateFields(); setChunkSubmitting(true); if (chunkModalMode === "create") { const filenamePayload = values.filename?.trim() || undefined; - await knowledgeBaseService.createChunk(knowledgeBaseName, { + await knowledgeBaseService.createChunk(effectiveIndexName, { content: values.content, filename: filenamePayload, path_or_url: activeDocumentKey, @@ -503,7 +455,7 @@ const DocumentChunk: React.FC = ({ return; } await knowledgeBaseService.updateChunk( - knowledgeBaseName, + effectiveIndexName, editingChunk.id, { content: values.content, @@ -541,7 +493,7 @@ const DocumentChunk: React.FC = ({ message.error(t("document.chunk.error.missingChunkId")); return; } - if (!knowledgeBaseName) { + if (!effectiveIndexName) { message.error(t("document.chunk.error.deleteFailed")); return; } @@ -556,7 +508,7 @@ const DocumentChunk: React.FC = ({ danger: true, onOk: async () => { try { - await knowledgeBaseService.deleteChunk(knowledgeBaseName, chunk.id); + await knowledgeBaseService.deleteChunk(effectiveIndexName, chunk.id); message.success(t("document.chunk.success.delete")); forceCloseTooltips(); // Update chunk count immediately for better UX @@ -761,11 +713,11 @@ const DocumentChunk: React.FC = ({
{chunk.source_type === "datamate" - ? t("document.chunk.source.datamate", "来源: Datamate") + ? t("document.chunk.source.datamate", "\u6765\u6e90: Datamate") : chunk.source_type === "file" || chunk.source_type === "minio" || chunk.source_type === "local" - ? t("document.chunk.source.nexent", "来源: Nexent") + ? t("document.chunk.source.nexent", "\u6765\u6e90: Nexent") : ""}
@@ -805,57 +757,37 @@ const DocumentChunk: React.FC = ({ {/* Search and Add Button Bar */}
- {/* Wrap search input with tooltip when model mismatch */} - {isEmbeddingModelMismatch ? ( - - - setSearchValue(e.target.value)} - onPressEnter={() => { + setSearchValue(e.target.value)} + onPressEnter={() => { + void handleSearch(); + }} + style={{ width: 320 }} + suffix={ +
+ {searchValue && ( +
- } - /> - )} +
+ } + />
{/* Create Chunk button - hide when user has READ_ONLY permission */} {!isReadOnlyMode && ( @@ -864,7 +796,6 @@ const DocumentChunk: React.FC = ({ type="text" icon={} onClick={openCreateChunkModal} - disabled={isEmbeddingModelMismatch} > )} diff --git a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx index 023f2205a..3590db86b 100644 --- a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx +++ b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx @@ -80,6 +80,8 @@ interface DocumentListProps { availableEmbeddingModels?: ModelOption[]; selectedEmbeddingModel?: string; onEmbeddingModelChange?: (value: string) => void; + isMultimodal?: boolean; + onMultimodalChange?: (value: boolean) => void; permission?: string; // User's permission for this knowledge base (READ_ONLY, EDIT, etc.) // Auto-summary frequency @@ -127,6 +129,8 @@ const DocumentListContainer = forwardRef( availableEmbeddingModels, selectedEmbeddingModel, onEmbeddingModelChange, + isMultimodal = false, + onMultimodalChange, permission, // Auto-summary frequency @@ -248,6 +252,8 @@ const [isSummarizing, setIsSummarizing] = useState(false); // Determine if user has read-only permission const isReadOnlyMode = permission === "READ_ONLY"; + const canToggleMultimodal = + isCreatingMode && typeof onMultimodalChange === "function"; // Permission options with icons shown inside dropdown const permissionOptions = [ @@ -313,6 +319,26 @@ const [isSummarizing, setIsSummarizing] = useState(false); // Check if group select should be disabled (when permission is PRIVATE) const isGroupSelectDisabled = ingroupPermission === "PRIVATE"; + const embeddingModelsForOptions = availableEmbeddingModels || []; + const availableEmbeddingModelKeys = new Set( + embeddingModelsForOptions + .filter((model) => model.connect_status === "available") + .map((model) => `${model.displayName}::${model.type}`) + ); + const isEmbeddingModelSelectable = (model: ModelOption): boolean => { + if (model.connect_status === "available") return true; + if (model.type === "embedding") { + return availableEmbeddingModelKeys.has( + `${model.displayName}::multi_embedding` + ); + } + if (model.type === "multi_embedding") { + return availableEmbeddingModelKeys.has( + `${model.displayName}::embedding` + ); + } + return false; + }; // Load frequency options from backend API useEffect(() => { @@ -533,11 +559,29 @@ const [isSummarizing, setIsSummarizing] = useState(false); onChange={onEmbeddingModelChange} style={{ minWidth: 200, justifyContent: "center", alignItems: "flex-end" }} placeholder={t("knowledgeBase.create.embeddingModelPlaceholder") || "Select embedding model"} - options={(availableEmbeddingModels || []).map((model) => ({ - value: model.displayName, - label: model.displayName, - disabled: model.connect_status === "unavailable", - }))} + allowClear={false} + options={[ + { + label: t("modelConfig.option.embeddingModel"), + options: embeddingModelsForOptions + .filter((model) => model.type === "embedding") + .map((model) => ({ + value: `${model.displayName}::${model.type}`, + label: model.displayName, + disabled: !isEmbeddingModelSelectable(model), + })), + }, + { + label: t("modelConfig.option.multiEmbeddingModel"), + options: embeddingModelsForOptions + .filter((model) => model.type === "multi_embedding") + .map((model) => ({ + value: `${model.displayName}::${model.type}`, + label: model.displayName, + disabled: !isEmbeddingModelSelectable(model), + })), + }, + ].filter((group) => group.options.length > 0)} /> )} {/* User groups multi-select */} @@ -645,7 +689,7 @@ const [isSummarizing, setIsSummarizing] = useState(false);
; isLoading?: boolean; syncLoading?: boolean; onClick: (kb: KnowledgeBase) => void; @@ -57,7 +60,7 @@ interface KnowledgeBaseListProps { const KnowledgeBaseList: React.FC = ({ knowledgeBases, activeKnowledgeBase, - currentEmbeddingModel, + configuredEmbeddingModels = [], isLoading = false, syncLoading = false, onClick, @@ -128,6 +131,34 @@ const KnowledgeBaseList: React.FC = ({ return `knowledgeBase.ingroup.permission.${permission || "DEFAULT"}`; }; + const configuredModelTypesByName = useMemo(() => { + const map = new Map>(); + configuredEmbeddingModels.forEach((model) => { + const modelName = (model.displayName || "").trim(); + const modelType = (model.type || "").trim().toLowerCase(); + if (!modelName) return; + if (modelType !== "embedding" && modelType !== "multi_embedding") return; + if (!map.has(modelName)) { + map.set(modelName, new Set()); + } + map.get(modelName)!.add(modelType); + }); + return map; + }, [configuredEmbeddingModels]); + + const isModelMismatch = (kb: KnowledgeBase) => { + if (kb.embeddingModel === "unknown") return false; + if (kb.source === "datamate") return false; + const modelTypes = configuredModelTypesByName.get( + (kb.embeddingModel || "").trim() + ); + return !modelTypes; + }; + + const hasIndexedDocumentsAndChunks = (kb: KnowledgeBase) => { + return (kb.documentCount || 0) > 0 && (kb.chunkCount || 0) > 0; + }; + // Search and filter states const [searchKeyword, setSearchKeyword] = useState(""); const [selectedSources, setSelectedSources] = useState([]); @@ -580,6 +611,21 @@ const KnowledgeBaseList: React.FC = ({ })} )} + {kb.is_multimodal && + hasIndexedDocumentsAndChunks(kb) && ( + + multimodal + + )} + {isModelMismatch(kb) && ( + + {t("knowledgeBase.tag.modelMismatch")} + + )} {/* User group tags - only show when not PRIVATE */} diff --git a/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx b/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx index b956dd919..63d9ad1c2 100644 --- a/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx +++ b/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx @@ -112,7 +112,7 @@ export const DocumentContext = createContext<{ state: DocumentState; dispatch: React.Dispatch; fetchDocuments: (kbId: string, forceRefresh?: boolean, kbSource?: string) => Promise; - uploadDocuments: (kbId: string, files: File[]) => Promise; + uploadDocuments: (kbId: string, files: File[], modelId?: number) => Promise; deleteDocument: (kbId: string, docId: string) => Promise; }>({ state: { @@ -202,11 +202,11 @@ export const DocumentProvider: React.FC = ({ children }) }, [state.loadingKbIds, state.documentsMap, t]); // Upload documents to a knowledge base - const uploadDocuments = useCallback(async (kbId: string, files: File[]) => { + const uploadDocuments = useCallback(async (kbId: string, files: File[], modelId?: number) => { dispatch({ type: DOCUMENT_ACTION_TYPES.SET_UPLOADING, payload: true }); try { - await knowledgeBaseService.uploadDocuments(kbId, files); + await knowledgeBaseService.uploadDocuments(kbId, files, undefined, modelId); // Set loading state before fetching latest documents dispatch({ type: DOCUMENT_ACTION_TYPES.SET_LOADING_DOCUMENTS, payload: true }); @@ -265,4 +265,4 @@ export const DocumentProvider: React.FC = ({ children }) {children} ); -}; \ No newline at end of file +}; diff --git a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx index 3c5946bd4..9733d44c4 100644 --- a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx +++ b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx @@ -117,7 +117,8 @@ export const KnowledgeBaseContext = createContext<{ source?: string, ingroup_permission?: string, group_ids?: number[], - embeddingModel?: string + embeddingModel?: string, + is_multimodal?: boolean ) => Promise; deleteKnowledgeBase: (id: string) => Promise; selectKnowledgeBase: (id: string) => void; @@ -133,6 +134,7 @@ export const KnowledgeBaseContext = createContext<{ selectedIds: [], activeKnowledgeBase: null, currentEmbeddingModel: null, + currentMultiEmbeddingModel: null, isLoading: false, syncLoading: false, error: null, @@ -168,6 +170,7 @@ export const KnowledgeBaseProvider: React.FC = ({ selectedIds: [], activeKnowledgeBase: null, currentEmbeddingModel: null, + currentMultiEmbeddingModel: null, isLoading: false, syncLoading: false, error: null, @@ -177,11 +180,6 @@ export const KnowledgeBaseProvider: React.FC = ({ // Check if knowledge base is selectable - memoized with useCallback const isKnowledgeBaseSelectable = useCallback( (kb: KnowledgeBase): boolean => { - // If no current embedding model is set, not selectable - if (!state.currentEmbeddingModel) { - return false; - } - // Check if knowledge base has content (documents or chunks) const hasContent = (kb.documentCount || 0) > 0 || (kb.chunkCount || 0) > 0; @@ -196,22 +194,46 @@ export const KnowledgeBaseProvider: React.FC = ({ return true; } - // For local knowledge bases, only selectable when model exactly matches current model - return ( - kb.embeddingModel === "unknown" || - kb.embeddingModel === state.currentEmbeddingModel - ); + if (kb.embeddingModel === "unknown") { + return true; + } + + const currentEmbeddingModel = state.currentEmbeddingModel?.trim() || ""; + const currentMultiEmbeddingModel = + modelConfig?.multiEmbedding?.modelName?.trim() || ""; + + if (kb.is_multimodal) { + // Multimodal KB is selectable as long as current multimodal model is configured. + return !!currentMultiEmbeddingModel; + } + + // Text KB is selectable as long as current embedding model is configured. + return !!currentEmbeddingModel; }, - [state.currentEmbeddingModel] + [modelConfig?.multiEmbedding?.modelName, state.currentEmbeddingModel] ); // Check if knowledge base has model mismatch (for display purposes) - // Note: Always return false to remove model mismatch restrictions const hasKnowledgeBaseModelMismatch = useCallback( (kb: KnowledgeBase): boolean => { - return false; + if (kb.embeddingModel === "unknown") { + return false; + } + if (kb.source === "datamate") { + return false; + } + + if (kb.is_multimodal) { + const multiEmbeddingModel = + modelConfig?.multiEmbedding?.modelName?.trim() || ""; + // Only show warning when the required current model is not configured. + return !multiEmbeddingModel; + } + + // Only show warning when the required current model is not configured. + return !state.currentEmbeddingModel; }, - [] + [modelConfig?.multiEmbedding?.modelName, state.currentEmbeddingModel] ); // Load knowledge base data (supports force fetch from server and load selected status) - optimized with useCallback @@ -325,17 +347,31 @@ export const KnowledgeBaseProvider: React.FC = ({ source: string = "elasticsearch", ingroup_permission?: string, group_ids?: number[], - embeddingModel?: string + embeddingModel?: string, + is_multimodal?: boolean ) => { try { + const selectedEmbeddingModel = embeddingModel?.trim() || ""; + const defaultMultiEmbeddingModel = + modelConfig?.multiEmbedding?.modelName?.trim() || ""; + const resolvedIsMultimodal = + typeof is_multimodal === "boolean" + ? is_multimodal + : !!defaultMultiEmbeddingModel && + selectedEmbeddingModel === defaultMultiEmbeddingModel; + const fallbackEmbeddingModel = resolvedIsMultimodal + ? defaultMultiEmbeddingModel + : state.currentEmbeddingModel || ""; + const resolvedEmbeddingModel = + selectedEmbeddingModel || fallbackEmbeddingModel; const newKB = await knowledgeBaseService.createKnowledgeBase({ name, description, source, - // Use provided embeddingModel if available, otherwise fall back to current model or default - embeddingModel: embeddingModel || state.currentEmbeddingModel || "", + embeddingModel: resolvedEmbeddingModel, ingroup_permission, group_ids, + is_multimodal: resolvedIsMultimodal, }); return newKB; } catch (error) { @@ -347,7 +383,7 @@ export const KnowledgeBaseProvider: React.FC = ({ return null; } }, - [state.currentEmbeddingModel, t] + [modelConfig?.multiEmbedding?.modelName, state.currentEmbeddingModel, t] ); // Delete knowledge base - memoized with useCallback diff --git a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx index 11391c133..01a11210e 100644 --- a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx @@ -1,4 +1,4 @@ -import { useMemo, useState, useCallback, useEffect } from "react"; +import { useMemo, useState, useCallback, useEffect } from "react"; import { useTranslation } from "react-i18next"; import { Modal, Select, Input, Button, Switch, Tooltip, App } from "antd"; @@ -483,7 +483,8 @@ export const ModelAddDialog = ({ if (tenantId) { connectivity = await modelService.checkManageTenantModelConnectivity( tenantId, - form.displayName || form.name + form.displayName || form.name, + modelType ); } else { // For STT models, build the appropriate config based on provider diff --git a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx index 3114c5535..ab66100aa 100644 --- a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx @@ -1,4 +1,4 @@ -import { useState, useEffect } from 'react' +import { useState, useEffect } from 'react' import { useTranslation } from 'react-i18next' import { Modal, Input, Button, App } from "antd"; @@ -481,4 +481,4 @@ export const ProviderConfigEditDialog = ({
) -} \ No newline at end of file +} diff --git a/frontend/app/[locale]/models/components/modelConfig.tsx b/frontend/app/[locale]/models/components/modelConfig.tsx index 07eee5c06..b298a4e30 100644 --- a/frontend/app/[locale]/models/components/modelConfig.tsx +++ b/frontend/app/[locale]/models/components/modelConfig.tsx @@ -527,6 +527,7 @@ export const ModelConfigSection = forwardRef< try { const isConnected = await modelService.verifyCustomModel( modelName, + modelType, signal ); @@ -603,7 +604,7 @@ export const ModelConfigSection = forwardRef< throttleTimerRef.current = setTimeout(async () => { try { // Use modelService to verify model - const isConnected = await modelService.verifyCustomModel(displayName); + const isConnected = await modelService.verifyCustomModel(displayName, modelType); // Update model status updateModelStatus( diff --git a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx index 786048894..c4ffaf8ca 100644 --- a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx +++ b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx @@ -1,4 +1,4 @@ -"use client"; +"use client"; import React, { useState, useMemo } from "react"; import { useTranslation } from "react-i18next"; @@ -130,7 +130,8 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { } }; - const handleCheckConnectivity = async (displayName: string) => { + // Handle checking model connectivity + const handleCheckConnectivity = async (displayName: string, modelType: string) => { if (!tenantId) { message.error(t("tenantResources.tenants.tenantIdRequired")); return; @@ -138,7 +139,7 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { setCheckingConnectivity((prev) => new Set(prev).add(displayName)); try { - const isConnected = await modelService.verifyCustomModel(displayName); + const isConnected = await modelService.verifyCustomModel(displayName, modelType); if (isConnected) { message.success(t("tenantResources.models.connectivitySuccess")); } else { @@ -299,7 +300,7 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) {