From a76c71dc5d44e4319369a9b9ae7b15757cf69719 Mon Sep 17 00:00:00 2001 From: wyxkerry <1012700194@qq.com> Date: Tue, 24 Mar 2026 23:59:08 +0800 Subject: [PATCH 01/27] =?UTF-8?q?=E2=9C=A8add=5Fimage=5Fretrieval?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/agents/create_agent_info.py | 7 +- backend/apps/file_management_app.py | 16 +- backend/apps/model_managment_app.py | 6 +- backend/apps/vectordatabase_app.py | 39 +- backend/consts/const.py | 4 + backend/consts/model.py | 1 + backend/data_process/ray_actors.py | 166 +++++-- backend/database/attachment_db.py | 64 ++- backend/database/client.py | 20 + backend/database/db_models.py | 1 + backend/database/knowledge_db.py | 9 + backend/database/model_management_db.py | 11 +- backend/services/config_sync_service.py | 2 +- backend/services/data_process_service.py | 19 +- backend/services/datamate_service.py | 3 +- backend/services/model_health_service.py | 4 +- .../services/tool_configuration_service.py | 7 +- backend/services/vectordatabase_service.py | 284 +++++++++--- backend/utils/file_management_utils.py | 12 +- docker/.env.bak | 174 ++++++++ docker/deploy.sh | 284 +++++++++++- docker/init.sql | 2 + ...dd_is_multimodal_to_knowledge_record_t.sql | 5 + .../components/agentConfig/ToolManagement.tsx | 13 +- .../agentConfig/tool/ToolConfigModal.tsx | 71 ++- .../agentConfig/tool/ToolTestPanel.tsx | 29 +- .../knowledges/KnowledgeBaseConfiguration.tsx | 141 +++++- .../components/document/DocumentChunk.tsx | 160 ++----- .../components/document/DocumentList.tsx | 28 +- .../knowledge/KnowledgeBaseList.tsx | 50 ++- .../knowledges/contexts/DocumentContext.tsx | 8 +- .../contexts/KnowledgeBaseContext.tsx | 75 +++- .../components/model/ModelAddDialog.tsx | 2 +- .../components/model/ModelEditDialog.tsx | 4 +- .../models/components/modelConfig.tsx | 3 +- .../components/resources/ModelList.tsx | 8 +- .../KnowledgeBaseSelectorModal.tsx | 96 +++- frontend/const/agentConfig.ts | 4 +- frontend/const/knowledgeBaseLayout.ts | 2 + frontend/hooks/useConfig.ts | 4 + frontend/public/locales/en/common.json | 4 +- frontend/public/locales/zh/common.json | 3 +- frontend/services/api.ts | 4 +- frontend/services/knowledgeBaseService.ts | 27 +- frontend/services/modelService.ts | 5 +- frontend/tsconfig.json | 2 +- frontend/types/knowledgeBase.ts | 3 + sdk/nexent/core/models/embedding_model.py | 8 +- .../core/tools/knowledge_base_search_tool.py | 6 - sdk/nexent/core/utils/favicon_extractor.py | 48 +- sdk/nexent/data_process/core.py | 43 +- sdk/nexent/data_process/extract_image.py | 413 ++++++++++++++++++ .../vector_database/elasticsearch_core.py | 200 +++++++-- test/backend/agents/test_create_agent_info.py | 57 ++- test/backend/app/test_model_managment_app.py | 8 +- test/backend/app/test_vectordatabase_app.py | 70 ++- test/backend/data_process/test_ray_actors.py | 80 ++++ test/backend/database/test_attachment_db.py | 46 +- test/backend/database/test_knowledge_db.py | 106 ++++- .../database/test_model_managment_db.py | 28 ++ .../services/test_config_sync_service.py | 32 +- .../services/test_data_process_service.py | 15 + .../backend/services/test_datamate_service.py | 1 + .../services/test_model_health_service.py | 85 +--- .../test_tool_configuration_service.py | 65 ++- .../services/test_vectordatabase_service.py | 179 +++++++- test/sdk/core/models/test_embedding_model.py | 16 + .../tools/test_knowledge_base_search_tool.py | 30 ++ test/sdk/core/utils/test_favicon_extractor.py | 38 ++ test/sdk/data_process/test_core.py | 106 ++++- test/sdk/data_process/test_extract_image.py | 118 +++++ .../test_elasticsearch_core.py | 107 ++++- 72 files changed, 3220 insertions(+), 571 deletions(-) create mode 100644 docker/.env.bak create mode 100644 docker/sql/v1.8.1_0306_add_is_multimodal_to_knowledge_record_t.sql create mode 100644 sdk/nexent/data_process/extract_image.py create mode 100644 test/sdk/core/utils/test_favicon_extractor.py create mode 100644 test/sdk/data_process/test_extract_image.py diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index 4246cac91..04d7150ad 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -1,4 +1,4 @@ -import threading +import threading import logging from typing import List, Optional from urllib.parse import urljoin @@ -450,6 +450,7 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int rerank = param_dict.get("rerank", False) rerank_model_name = param_dict.get("rerank_model_name", "") rerank_model = None + is_multimodal = bool(tool_config.params.pop("multimodal", False)) if rerank and rerank_model_name: rerank_model = get_rerank_model( tenant_id=tenant_id, model_name=rerank_model_name @@ -457,7 +458,9 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int tool_config.metadata = { "vdb_core": get_vector_db_core(), - "embedding_model": get_embedding_model(tenant_id=tenant_id), + "embedding_model": get_embedding_model( + tenant_id=tenant_id, is_multimodal=is_multimodal + ), "rerank_model": rerank_model, } elif tool_config.class_name in ["DifySearchTool", "DataMateSearchTool"]: diff --git a/backend/apps/file_management_app.py b/backend/apps/file_management_app.py index 50224c952..bb5cbb318 100644 --- a/backend/apps/file_management_app.py +++ b/backend/apps/file_management_app.py @@ -116,12 +116,13 @@ async def upload_files( @file_management_config_router.post("/process") async def process_files( - files: List[dict] = Body( - ..., description="List of file details to process, including path_or_url and filename"), - chunking_strategy: Optional[str] = Body("basic"), - index_name: str = Body(...), - destination: str = Body(...), - authorization: Optional[str] = Header(None) + files: Annotated[List[dict], Body( + ..., description="List of file details to process, including path_or_url and filename")], + index_name: Annotated[str, Body(...)], + destination: Annotated[str, Body(...)], + chunking_strategy: Annotated[Optional[str], Body(...)] = "basic", + model_id: Annotated[Optional[int], Body(...)] = None, + authorization: Annotated[Optional[str], Header()] = None ): """ Trigger data processing for a list of uploaded files. @@ -134,7 +135,8 @@ async def process_files( chunking_strategy=chunking_strategy, source_type=destination, index_name=index_name, - authorization=authorization + authorization=authorization, + model_id=model_id ) process_result = await trigger_data_process(files, process_params) diff --git a/backend/apps/model_managment_app.py b/backend/apps/model_managment_app.py index 0a5a04139..4cfcb7d69 100644 --- a/backend/apps/model_managment_app.py +++ b/backend/apps/model_managment_app.py @@ -298,6 +298,10 @@ async def get_llm_model_list(authorization: Optional[str] = Header(None)): @router.post("/healthcheck") async def check_model_health( display_name: str = Query(..., description="Display name to check"), + modelType: Optional[str] = Query( + None, + description="Optional model type filter (e.g., llm/embedding/multi_embedding)", + ), authorization: Optional[str] = Header(None) ): """Check and update model connectivity, returning the latest status. @@ -308,7 +312,7 @@ async def check_model_health( """ try: _, tenant_id = get_current_user_id(authorization) - result = await check_model_connectivity(display_name, tenant_id) + result = await check_model_connectivity(display_name, tenant_id, modelType) return JSONResponse(status_code=HTTPStatus.OK, content={ "message": "Successfully checked model connectivity", "data": result diff --git a/backend/apps/vectordatabase_app.py b/backend/apps/vectordatabase_app.py index 872b5387b..7f948e625 100644 --- a/backend/apps/vectordatabase_app.py +++ b/backend/apps/vectordatabase_app.py @@ -65,11 +65,13 @@ def create_new_index( # Extract optional fields from request body ingroup_permission = None group_ids = None - embedding_model_name = None + is_multimodal = False + embedding_model_name: Optional[str] = None if request: ingroup_permission = request.get("ingroup_permission") group_ids = request.get("group_ids") - embedding_model_name = request.get("embedding_model_name") + is_multimodal = request.get("is_multimodal", False) + embedding_model_name = request.get("embeddingModel") # Treat path parameter as user-facing knowledge base name for new creations return ElasticSearchService.create_knowledge_base( @@ -81,6 +83,7 @@ def create_new_index( ingroup_permission=ingroup_permission, group_ids=group_ids, embedding_model_name=embedding_model_name, + is_multimodal=is_multimodal, ) except Exception as e: raise HTTPException( @@ -124,6 +127,7 @@ async def update_index( knowledge_name = request.get("knowledge_name") ingroup_permission = request.get("ingroup_permission") group_ids = request.get("group_ids") + is_multimodal = request.get("is_multimodal") # Call service layer to update knowledge base result = ElasticSearchService.update_knowledge_base( @@ -131,6 +135,7 @@ async def update_index( knowledge_name=knowledge_name, ingroup_permission=ingroup_permission, group_ids=group_ids, + is_multimodal=is_multimodal, tenant_id=tenant_id, user_id=user_id, ) @@ -200,13 +205,23 @@ def create_index_documents( user_id, tenant_id = get_current_user_id(authorization) # Get the knowledge base record to retrieve the saved embedding model - knowledge_record = get_knowledge_record({'index_name': index_name}) + knowledge_record = get_knowledge_record( + {"index_name": index_name, "tenant_id": tenant_id} + ) saved_embedding_model_name = None if knowledge_record: saved_embedding_model_name = knowledge_record.get('embedding_model_name') - - # Use the saved model from knowledge base, fallback to tenant default if not set - embedding_model = get_embedding_model(tenant_id, saved_embedding_model_name) + is_multimodal = ( + True if knowledge_record and knowledge_record.get('is_multimodal') == 'Y' else False + ) + + # Use the saved model from knowledge base, fallback to tenant default if not set. + embedding_model = get_embedding_model( + tenant_id=tenant_id, + is_multimodal=is_multimodal, + model_name=saved_embedding_model_name, + strict_model_name=bool(saved_embedding_model_name), + ) return ElasticSearchService.index_documents( embedding_model=embedding_model, @@ -463,6 +478,7 @@ def update_chunk( chunk_request=payload, vdb_core=vdb_core, user_id=user_id, + tenant_id=tenant_id, ) return JSONResponse(status_code=HTTPStatus.OK, content=result) except ValueError as e: @@ -529,8 +545,17 @@ async def hybrid_search( """Run a hybrid (accurate + semantic) search across indices.""" try: _, tenant_id = get_current_user_id(authorization) + resolved_index_names: List[str] = [] + for requested_name in payload.index_names: + try: + resolved_name = get_index_name_by_knowledge_name( + requested_name, tenant_id + ) + except Exception: + resolved_name = requested_name + resolved_index_names.append(resolved_name) result = ElasticSearchService.search_hybrid( - index_names=payload.index_names, + index_names=resolved_index_names, query=payload.query, tenant_id=tenant_id, top_k=payload.top_k, diff --git a/backend/consts/const.py b/backend/consts/const.py index 223a1d00b..e10fabac9 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -28,6 +28,10 @@ class VectorDatabaseType(str, Enum): # Data Processing Service Configuration DATA_PROCESS_SERVICE = os.getenv("DATA_PROCESS_SERVICE") CLIP_MODEL_PATH = os.getenv("CLIP_MODEL_PATH") +TABLE_TRANSFORMER_MODEL_PATH = os.getenv("TABLE_TRANSFORMER_MODEL_PATH") +UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH = os.getenv( + "UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH" +) # Upload Configuration diff --git a/backend/consts/model.py b/backend/consts/model.py index 707802957..c8acfa3d1 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -234,6 +234,7 @@ class ProcessParams(BaseModel): source_type: str index_name: str authorization: Optional[str] = None + model_id: Optional[int] = None class OpinionRequest(BaseModel): diff --git a/backend/data_process/ray_actors.py b/backend/data_process/ray_actors.py index 2fa590bec..b54525f62 100644 --- a/backend/data_process/ray_actors.py +++ b/backend/data_process/ray_actors.py @@ -1,11 +1,19 @@ +from io import BytesIO import logging import json from typing import Any, Dict, List, Optional import ray -from consts.const import RAY_ACTOR_NUM_CPUS, REDIS_BACKEND_URL, DEFAULT_EXPECTED_CHUNK_SIZE, DEFAULT_MAXIMUM_CHUNK_SIZE -from database.attachment_db import get_file_stream +from consts.const import ( + RAY_ACTOR_NUM_CPUS, + REDIS_BACKEND_URL, + DEFAULT_EXPECTED_CHUNK_SIZE, + DEFAULT_MAXIMUM_CHUNK_SIZE, + TABLE_TRANSFORMER_MODEL_PATH, + UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH, +) +from database.attachment_db import build_s3_url, get_file_stream, upload_fileobj from database.model_management_db import get_model_by_model_id from nexent.data_process import DataProcessCore @@ -58,49 +66,137 @@ def process_file( if task_id: params['task_id'] = task_id - # Get chunk size parameters from embedding model if model_id is provided - if model_id and tenant_id: - try: - # Get embedding model details directly by model_id - model_record = get_model_by_model_id( - model_id=model_id, tenant_id=tenant_id) - if model_record: - expected_chunk_size = model_record.get( - 'expected_chunk_size', DEFAULT_EXPECTED_CHUNK_SIZE) - maximum_chunk_size = model_record.get( - 'maximum_chunk_size', DEFAULT_MAXIMUM_CHUNK_SIZE) - model_name = model_record.get('display_name') - - # Pass chunk sizes to processing parameters - params['max_characters'] = maximum_chunk_size - params['new_after_n_chars'] = expected_chunk_size - - logger.info( - f"[RayActor] Using chunk sizes from embedding model '{model_name}' (ID: {model_id}): " - f"max_characters={maximum_chunk_size}, new_after_n_chars={expected_chunk_size}") - else: - logger.warning( - f"[RayActor] Embedding model with ID {model_id} not found for tenant '{tenant_id}', using default chunk sizes") - except Exception as e: + self._apply_model_chunk_sizes( + model_id=model_id, tenant_id=tenant_id, params=params) + self._apply_model_paths(params) + file_data = self._read_file_bytes(source) + + result = self._processor.file_process( + file_data=file_data, + filename=source, + chunking_strategy=chunking_strategy, + **params + ) + chunks, images_info = self._normalize_processor_result(result) + if images_info: + self._append_image_chunks( + source=source, chunks=chunks, images_info=images_info) + + chunks = self._validate_chunks(chunks, source) + if not chunks: + return [] + + logger.info( + f"[RayActor] Processing done: produced {len(chunks)} chunks for source='{source}'") + return chunks + + def _apply_model_paths(self, params: Dict[str, Any]) -> None: + params["table_transformer_model_path"] = TABLE_TRANSFORMER_MODEL_PATH + params[ + "unstructured_default_model_initialize_params_json_path" + ] = UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH + + def _apply_model_chunk_sizes( + self, + model_id: Optional[int], + tenant_id: Optional[str], + params: Dict[str, Any], + ) -> None: + if not (model_id and tenant_id): + return + + try: + model_record = get_model_by_model_id( + model_id=model_id, tenant_id=tenant_id) + if not model_record: logger.warning( - f"[RayActor] Failed to retrieve chunk sizes from embedding model ID {model_id}: {e}. Using default chunk sizes") + f"[RayActor] Embedding model with ID {model_id} not found for tenant '{tenant_id}', using default chunk sizes") + return + + expected_chunk_size = model_record.get( + 'expected_chunk_size', DEFAULT_EXPECTED_CHUNK_SIZE) + maximum_chunk_size = model_record.get( + 'maximum_chunk_size', DEFAULT_MAXIMUM_CHUNK_SIZE) + model_name = model_record.get('display_name') + model_type = model_record.get('model_type') + params['max_characters'] = maximum_chunk_size + params['new_after_n_chars'] = expected_chunk_size + if model_type: + params['model_type'] = model_type + + logger.info( + f"[RayActor] Using chunk sizes from embedding model '{model_name}' (ID: {model_id}): " + f"max_characters={maximum_chunk_size}, new_after_n_chars={expected_chunk_size}") + except Exception as e: + logger.warning( + f"[RayActor] Failed to retrieve chunk sizes from embedding model ID {model_id}: {e}. Using default chunk sizes") + + def _read_file_bytes(self, source: str) -> bytes: try: file_stream = get_file_stream(source) if file_stream is None: raise FileNotFoundError( f"Unable to fetch file from URL: {source}") - file_data = file_stream.read() + return file_stream.read() except Exception as e: logger.error(f"Failed to fetch file from {source}: {e}") raise - chunks = self._processor.file_process( - file_data=file_data, - filename=source, - chunking_strategy=chunking_strategy, - **params - ) + def _normalize_processor_result( + self, result: Any + ) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + if isinstance(result, tuple) and len(result) == 2: + chunks, images_info = result + return chunks or [], images_info or [] + return result or [], [] + + def _append_image_chunks( + self, + source: str, + chunks: List[Dict[str, Any]], + images_info: List[Dict[str, Any]], + ) -> None: + folder = "images_in_attachments" + for index, image_data in enumerate(images_info): + if not isinstance(image_data, dict): + logger.warning( + f"[RayActor] Skipping image entry at index {index}: unexpected type {type(image_data)}" + ) + continue + if "image_bytes" not in image_data: + logger.warning( + f"[RayActor] Skipping image entry at index {index}: missing image_bytes" + ) + continue + + img_obj = BytesIO(image_data["image_bytes"]) + result = upload_fileobj( + file_obj=img_obj, + file_name=f"{index}.{image_data['image_format']}", + prefix=folder) + image_url = build_s3_url(result.get("object_name", "")) + + image_data["source_file"] = source + image_data["image_url"] = image_url + + chunks.append({ + "content": json.dumps({ + "source_file": source, + "position": image_data["position"], + "image_url": image_url, + }), + "filename": source, + "metadata": { + "chunk_index": len(chunks) + index, + "process_source": "UniversalImageExtractor", + "image_url": image_url, + } + }) + + def _validate_chunks( + self, chunks: Any, source: str + ) -> List[Dict[str, Any]]: if chunks is None: logger.warning( @@ -115,8 +211,6 @@ def process_file( f"[RayActor] file_process returned empty list for source='{source}'") return [] - logger.info( - f"[RayActor] Processing done: produced {len(chunks)} chunks for source='{source}'") return chunks def store_chunks_in_redis(self, redis_key: str, chunks: List[Dict[str, Any]]) -> bool: diff --git a/backend/database/attachment_db.py b/backend/database/attachment_db.py index 1faabac23..85dc2d2b2 100644 --- a/backend/database/attachment_db.py +++ b/backend/database/attachment_db.py @@ -2,11 +2,63 @@ import os import uuid from datetime import datetime -from typing import Any, BinaryIO, Dict, List, Optional +from typing import Any, BinaryIO, Dict, List, Optional, Tuple from .client import minio_client +def _normalize_object_and_bucket(object_name: str, bucket: Optional[str] = None) -> Tuple[str, Optional[str]]: + """ + Normalize object_name + bucket from supported URL styles. + + Supports: + - s3://bucket/key + - /bucket/key + - key (uses provided bucket or default bucket) + """ + if not object_name: + return object_name, bucket + + if object_name.startswith("s3://"): + s3_path = object_name[len("s3://") :] + parts = s3_path.split("/", 1) + parsed_bucket = parts[0] if parts[0] else None + parsed_key = parts[1] if len(parts) > 1 else "" + return parsed_key, parsed_bucket or bucket + + if object_name.startswith("/"): + path = object_name.lstrip("/") + parts = path.split("/", 1) + parsed_bucket = parts[0] if parts[0] else None + parsed_key = parts[1] if len(parts) > 1 else "" + return parsed_key, parsed_bucket or bucket + + return object_name, bucket + + +def build_s3_url(object_name: str, bucket: Optional[str] = None) -> str: + """ + Build an s3://bucket/key style URL from an object name (or passthrough if already s3://). + """ + if not object_name: + return "" + + if object_name.startswith("s3://"): + return object_name + + if object_name.startswith("/"): + path = object_name.lstrip("/") + parts = path.split("/", 1) + if len(parts) == 2: + return f"s3://{parts[0]}/{parts[1]}" + return f"s3://{parts[0]}/" + + resolved_bucket = bucket or minio_client.default_bucket + if resolved_bucket: + return f"s3://{resolved_bucket}/{object_name}" + return f"s3://{object_name}" + + def generate_object_name(file_name: str, prefix: str = "attachments") -> str: """ Generate a unique object name @@ -165,7 +217,8 @@ def get_file_size_from_minio(object_name: str, bucket: Optional[str] = None) -> """ Get file size by object name """ - bucket = bucket or minio_client.storage_config.default_bucket + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) + bucket = bucket or minio_client.default_bucket return minio_client.get_file_size(object_name, bucket) @@ -181,6 +234,7 @@ def file_exists(object_name: str, bucket: Optional[str] = None) -> bool: bool: True if file exists, False otherwise """ try: + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) return minio_client.file_exists(object_name, bucket) except Exception: return False @@ -198,6 +252,8 @@ def copy_file(source_object: str, dest_object: str, bucket: Optional[str] = None Returns: Dict[str, Any]: Result containing success flag and error message (if any) """ + source_object, bucket = _normalize_object_and_bucket(source_object, bucket) + dest_object, bucket = _normalize_object_and_bucket(dest_object, bucket) success, result = minio_client.copy_file(source_object, dest_object, bucket) if success: return {"success": True, "object_name": result} @@ -242,8 +298,9 @@ def delete_file(object_name: str, bucket: Optional[str] = None) -> Dict[str, Any Returns: Dict[str, Any]: Delete result, containing success flag and error message (if any) """ + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) if not bucket: - bucket = minio_client.storage_config.default_bucket + bucket = minio_client.default_bucket success, result = minio_client.delete_file(object_name, bucket) response = {"success": success, "object_name": object_name} @@ -265,6 +322,7 @@ def get_file_stream(object_name: str, bucket: Optional[str] = None) -> Optional[ Returns: Optional[BinaryIO]: Standard BinaryIO stream object, or None if failed """ + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) success, result = minio_client.get_file_stream(object_name, bucket) if not success: return None diff --git a/backend/database/client.py b/backend/database/client.py index 9b0b97a52..8885ea694 100644 --- a/backend/database/client.py +++ b/backend/database/client.py @@ -89,6 +89,9 @@ def __init__(self): if MinioClient._initialized: return MinioClient._initialized = True + # Explicitly initialize attributes so external callers never hit missing-attribute errors. + self._storage_client = None + self.storage_config = None def _ensure_initialized(self): """Lazily initialize the storage client on first use.""" @@ -108,6 +111,23 @@ def _ensure_initialized(self): return True return False + @property + def default_bucket(self) -> Optional[str]: + """ + Resolve default bucket safely for callers that need bucket info. + Falls back to configured constant when lazy init has not run yet. + """ + try: + self._ensure_initialized() + except Exception: + # Keep this accessor resilient; operational methods can still raise + # detailed storage errors when invoked. + pass + + if getattr(self, "storage_config", None) is not None: + return self.storage_config.default_bucket + return MINIO_DEFAULT_BUCKET + def upload_file( self, file_path: str, diff --git a/backend/database/db_models.py b/backend/database/db_models.py index bc95a5e68..688743343 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -280,6 +280,7 @@ class KnowledgeRecord(TableBase): group_ids = Column(String, doc="Knowledge base group IDs list") ingroup_permission = Column( String(30), doc="In-group permission: EDIT, READ_ONLY, PRIVATE") + is_multimodal = Column(String(1), default="N", doc="Whether it is multimodal. Optional values: Y/N") class TenantConfig(TableBase): diff --git a/backend/database/knowledge_db.py b/backend/database/knowledge_db.py index df42e1888..40f4ca718 100644 --- a/backend/database/knowledge_db.py +++ b/backend/database/knowledge_db.py @@ -52,6 +52,7 @@ def create_knowledge_record(query: Dict[str, Any]) -> Dict[str, Any]: "knowledge_name": knowledge_name, "group_ids": convert_list_to_string(group_ids) if isinstance(group_ids, list) else group_ids, "ingroup_permission": query.get("ingroup_permission"), + "is_multimodal": 'Y' if query.get("is_multimodal") else 'N' } # For backward compatibility: if caller explicitly provides index_name, @@ -178,6 +179,9 @@ def update_knowledge_record(query: Dict[str, Any]) -> bool: if query.get("group_ids") is not None: record.group_ids = query["group_ids"] + if query.get("is_multimodal"): + record.is_multimodal = 'Y' if query["is_multimodal"] else 'N' + # Update timestamp and user if query.get("user_id"): record.updated_by = query["user_id"] @@ -254,6 +258,11 @@ def get_knowledge_record(query: Optional[Dict[str, Any]] = None) -> Dict[str, An db_query = db_query.filter( KnowledgeRecord.tenant_id == query['tenant_id']) + if 'is_multimodal' in query: + db_query = db_query.filter( + KnowledgeRecord.is_multimodal == query['is_multimodal'] + ) + result = db_query.first() if result: diff --git a/backend/database/model_management_db.py b/backend/database/model_management_db.py index cb1c6c69f..61753f52f 100644 --- a/backend/database/model_management_db.py +++ b/backend/database/model_management_db.py @@ -170,7 +170,7 @@ def get_model_records(filters: Optional[Dict[str, Any]], tenant_id: str) -> List return result_list -def get_model_by_display_name(display_name: str, tenant_id: str) -> Optional[Dict[str, Any]]: +def get_model_by_display_name(display_name: str, tenant_id: str, model_type: str = None) -> Optional[Dict[str, Any]]: """ Get a model record by display name @@ -179,6 +179,11 @@ def get_model_by_display_name(display_name: str, tenant_id: str) -> Optional[Dic tenant_id: """ filters = {'display_name': display_name} + + if model_type in ["multiEmbedding", "multi_embedding"]: + filters['model_type'] = "multi_embedding" + elif model_type == "embedding": + filters['model_type'] = "embedding" records = get_model_records(filters, tenant_id) if not records: @@ -203,7 +208,7 @@ def get_models_by_display_name(display_name: str, tenant_id: str) -> List[Dict[s return get_model_records(filters, tenant_id) -def get_model_id_by_display_name(display_name: str, tenant_id: str) -> Optional[int]: +def get_model_id_by_display_name(display_name: str, tenant_id: str, model_type: str = None) -> Optional[int]: """ Get a model ID by display name @@ -214,7 +219,7 @@ def get_model_id_by_display_name(display_name: str, tenant_id: str) -> Optional[ Returns: Optional[int]: Model ID """ - model = get_model_by_display_name(display_name, tenant_id) + model = get_model_by_display_name(display_name, tenant_id, model_type) return model["model_id"] if model else None diff --git a/backend/services/config_sync_service.py b/backend/services/config_sync_service.py index 9fe50813a..c484ca23f 100644 --- a/backend/services/config_sync_service.py +++ b/backend/services/config_sync_service.py @@ -99,7 +99,7 @@ async def save_config_impl(config, tenant_id, user_id): config_key = get_env_key(model_type) + "_ID" model_id = get_model_id_by_display_name( - model_display_name, tenant_id) + model_display_name, tenant_id, model_type=model_type) handle_model_config(tenant_id, user_id, config_key, model_id, tenant_config_dict) diff --git a/backend/services/data_process_service.py b/backend/services/data_process_service.py index 2b222a584..17e64a697 100644 --- a/backend/services/data_process_service.py +++ b/backend/services/data_process_service.py @@ -255,6 +255,17 @@ async def load_image(self, image_url: str) -> Optional[Image.Image]: async def _load_image(self, session: aiohttp.ClientSession, path: str) -> Optional[Image.Image]: """Internal method to load an image from various sources""" try: + if path.startswith('s3://'): + # Fetch from MinIO using s3://bucket/key + file_stream = get_file_stream(object_name=path) + if file_stream is None: + raise FileNotFoundError( + f"Unable to fetch file from URL: {path}") + file_data = file_stream.read() + image_based64_str = base64.b64encode( + file_data).decode('utf-8') + path = f"data:image/jpeg;base64,{image_based64_str}" + # Check if input is base64 encoded if path.startswith('data:image'): # Extract the base64 data after the comma @@ -463,6 +474,8 @@ async def create_batch_tasks_impl(self, authorization: Optional[str], request: B chunking_strategy = source_config.get('chunking_strategy') index_name = source_config.get('index_name') original_filename = source_config.get('original_filename') + embedding_model_id = source_config.get('embedding_model_id') + tenant_id = source_config.get('tenant_id') # Validate required fields if not source: @@ -481,7 +494,9 @@ async def create_batch_tasks_impl(self, authorization: Optional[str], request: B source_type=source_type, chunking_strategy=chunking_strategy, index_name=index_name, - original_filename=original_filename + original_filename=original_filename, + embedding_model_id=embedding_model_id, + tenant_id=tenant_id ).set(queue='process_q'), forward.s( index_name=index_name, @@ -559,7 +574,7 @@ async def process_uploaded_text_file(self, file_content: bytes, filename: str, c } async def convert_office_to_pdf_impl(self, object_name: str, pdf_object_name: str) -> None: - """Full conversion pipeline: download → convert → upload → validate → cleanup. + """Full conversion pipeline: download -> convert -> upload -> validate -> cleanup. All five steps run inside data-process so that LibreOffice only needs to be installed in this container. diff --git a/backend/services/datamate_service.py b/backend/services/datamate_service.py index 776e0eb1d..26e777eba 100644 --- a/backend/services/datamate_service.py +++ b/backend/services/datamate_service.py @@ -51,7 +51,8 @@ async def _create_datamate_knowledge_records(knowledge_base_ids: List[str], "tenant_id": tenant_id, "user_id": user_id, # Use datamate as embedding model name - "embedding_model_name": embedding_model_names[i] + "embedding_model_name": embedding_model_names[i], + "is_multimodal": False, } # Run synchronous database operation in executor to avoid blocking diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py index 9214a1ffa..5b8e27f07 100644 --- a/backend/services/model_health_service.py +++ b/backend/services/model_health_service.py @@ -128,10 +128,10 @@ async def _perform_connectivity_check( return connectivity -async def check_model_connectivity(display_name: str, tenant_id: str) -> dict: +async def check_model_connectivity(display_name: str, tenant_id: str, model_type: str = None) -> dict: try: # Query the database using display_name and tenant context from app layer - model = get_model_by_display_name(display_name, tenant_id=tenant_id) + model = get_model_by_display_name(display_name, tenant_id=tenant_id, model_type=model_type) if not model: raise LookupError(f"Model configuration not found for {display_name}") diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index d7240db26..c49724f60 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -152,6 +152,10 @@ def get_local_tools() -> List[ToolInfo]: else: param_info["default"] = param.default.default param_info["optional"] = True + if getattr(param.default, "json_schema_extra", None): + optional_override = param.default.json_schema_extra.get("optional") + if optional_override is not None: + param_info["optional"] = optional_override init_params_list.append(param_info) @@ -704,7 +708,8 @@ def _validate_local_tool( instantiation_params[param_name] = param.default if tool_name == "knowledge_base_search": - embedding_model = get_embedding_model(tenant_id=tenant_id) + is_multimodal = instantiation_params.pop("multimodal", False) + embedding_model = get_embedding_model(tenant_id=tenant_id, is_multimodal=is_multimodal) vdb_core = get_vector_db_core() # Get rerank configuration diff --git a/backend/services/vectordatabase_service.py b/backend/services/vectordatabase_service.py index 5639103de..284f6fb7e 100644 --- a/backend/services/vectordatabase_service.py +++ b/backend/services/vectordatabase_service.py @@ -28,7 +28,7 @@ from consts.const import DATAMATE_URL, ES_API_KEY, ES_HOST, LANGUAGE, VectorDatabaseType, IS_SPEED_MODE, PERMISSION_EDIT, PERMISSION_READ from consts.model import ChunkCreateRequest, ChunkUpdateRequest -from database.attachment_db import delete_file +from database.attachment_db import delete_file, get_file_stream from database.knowledge_db import ( create_knowledge_record, delete_knowledge_record, @@ -176,7 +176,80 @@ def check_knowledge_base_exist_impl(knowledge_name: str, vdb_core: VectorDatabas return {"status": "available"} -def get_embedding_model(tenant_id: str, model_name: Optional[str] = None): +def _build_embedding_from_config(model_config: Dict[str, Any]) -> Optional[BaseEmbedding]: + model_type = model_config.get("model_type", "") + if model_type == "embedding": + return OpenAICompatibleEmbedding( + api_key=model_config.get("api_key", ""), + base_url=model_config.get("base_url", ""), + model_name=get_model_name_from_config(model_config) or "", + embedding_dim=model_config.get("max_tokens", 1024), + ssl_verify=model_config.get("ssl_verify", True), + ) + if model_type == "multi_embedding": + return JinaEmbedding( + api_key=model_config.get("api_key", ""), + base_url=model_config.get("base_url", ""), + model_name=get_model_name_from_config(model_config) or "", + embedding_dim=model_config.get("max_tokens", 1024), + ssl_verify=model_config.get("ssl_verify", True), + ) + return None + + +def _find_model_record( + tenant_id: str, + is_multimodal: bool, + model_name: str, +) -> Optional[Dict[str, Any]]: + model_type = "multi_embedding" if is_multimodal else "embedding" + models = get_model_records({"model_type": model_type}, tenant_id) + for model in models: + model_display_name = ( + f"{model.get('model_repo')}/{model['model_name']}" + if model.get("model_repo") + else model["model_name"] + ) + if model_display_name == model_name: + return model + return None + + +def _build_embedding_from_record( + model_record: Dict[str, Any], + is_multimodal: bool, +) -> BaseEmbedding: + model_config = { + "model_repo": model_record.get("model_repo", ""), + "model_name": model_record["model_name"], + "api_key": model_record.get("api_key", ""), + "base_url": model_record.get("base_url", ""), + "model_type": "embedding", + "max_tokens": model_record.get("max_tokens", 1024), + "ssl_verify": model_record.get("ssl_verify", True), + } + if not is_multimodal: + return OpenAICompatibleEmbedding( + api_key=model_config.get("api_key", ""), + base_url=model_config.get("base_url", ""), + model_name=get_model_name_from_config(model_config) or "", + embedding_dim=model_config.get("max_tokens", 1024), + ssl_verify=model_config.get("ssl_verify", True), + ) + return JinaEmbedding( + api_key=model_config.get("api_key", ""), + base_url=model_config.get("base_url", ""), + model_name=get_model_name_from_config(model_config) or "", + embedding_dim=model_config.get("max_tokens", 1024), + ssl_verify=model_config.get("ssl_verify", True), + ) + +def get_embedding_model( + tenant_id: str, + is_multimodal: bool = False, + model_name: Optional[str] = None, + strict_model_name: bool = False, +): """ Get the embedding model for the tenant, optionally using a specific model name. @@ -188,58 +261,50 @@ def get_embedding_model(tenant_id: str, model_name: Optional[str] = None): Returns: Embedding model instance or None """ - # If model_name is provided, try to find it in the tenant's models + # If model_name is provided, try to find it in the tenant's models. + if model_name is None and (isinstance(is_multimodal, str) or is_multimodal is None): + model_name = is_multimodal + is_multimodal = False if model_name: try: - models = get_model_records({"model_type": "embedding"}, tenant_id) - for model in models: - model_display_name = model.get("model_repo") + "/" + model["model_name"] if model.get("model_repo") else model["model_name"] - if model_display_name == model_name: - # Found the model, create embedding instance - model_config = { - "model_repo": model.get("model_repo", ""), - "model_name": model["model_name"], - "api_key": model.get("api_key", ""), - "base_url": model.get("base_url", ""), - "model_type": "embedding", - "max_tokens": model.get("max_tokens", 1024), - "ssl_verify": model.get("ssl_verify", True), - } - return OpenAICompatibleEmbedding( - api_key=model_config.get("api_key", ""), - base_url=model_config.get("base_url", ""), - model_name=get_model_name_from_config(model_config) or "", - embedding_dim=model_config.get("max_tokens", 1024), - ssl_verify=model_config.get("ssl_verify", True), - ) + model_record = _find_model_record( + tenant_id=tenant_id, + is_multimodal=is_multimodal, + model_name=model_name, + ) + if model_record: + return _build_embedding_from_record( + model_record=model_record, + is_multimodal=is_multimodal, + ) except Exception as e: logger.warning(f"Failed to get embedding model by name {model_name}: {e}") + if strict_model_name: + raise ValueError( + f"Embedding model '{model_name}' is not configured for current tenant" + ) # Fall back to default embedding model (current behavior) model_config = tenant_config_manager.get_model_config( - key="EMBEDDING_ID", tenant_id=tenant_id) - - model_type = model_config.get("model_type", "") - - if model_type == "embedding": - # Get the es core - return OpenAICompatibleEmbedding( - api_key=model_config.get("api_key", ""), - base_url=model_config.get("base_url", ""), - model_name=get_model_name_from_config(model_config) or "", - embedding_dim=model_config.get("max_tokens", 1024), - ssl_verify=model_config.get("ssl_verify", True), - ) - elif model_type == "multi_embedding": - return JinaEmbedding( - api_key=model_config.get("api_key", ""), - base_url=model_config.get("base_url", ""), - model_name=get_model_name_from_config(model_config) or "", - embedding_dim=model_config.get("max_tokens", 1024), - ssl_verify=model_config.get("ssl_verify", True), + key="MULTI_EMBEDDING_ID" if is_multimodal else "EMBEDDING_ID", + tenant_id=tenant_id, + ) + return _build_embedding_from_config(model_config) + + +def _resolve_embedding_model( + tenant_id: str, + is_multimodal: bool, + embedding_model_name: Optional[str], +) -> Optional[BaseEmbedding]: + if embedding_model_name: + return get_embedding_model( + tenant_id, + is_multimodal=is_multimodal, + model_name=embedding_model_name, + strict_model_name=True, ) - else: - return None + return get_embedding_model(tenant_id, is_multimodal=is_multimodal) def get_rerank_model(tenant_id: str, model_name: Optional[str] = None): @@ -406,6 +471,7 @@ async def full_delete_knowledge_base(index_name: str, vdb_core: VectorDatabaseCo @staticmethod def create_index( + embedding_model: BaseEmbedding, index_name: str = Path(..., description="Name of the index to create"), embedding_dim: Optional[int] = Query( @@ -419,15 +485,24 @@ def create_index( try: if vdb_core.check_index_exists(index_name): raise Exception(f"Index {index_name} already exists") - embedding_model = get_embedding_model(tenant_id) + if not embedding_model: + embedding_model = get_embedding_model(tenant_id) success = vdb_core.create_index(index_name, embedding_dim=embedding_dim or ( embedding_model.embedding_dim if embedding_model else 1024)) if not success: raise Exception(f"Failed to create index {index_name}") - knowledge_data = {"index_name": index_name, - "created_by": user_id, - "tenant_id": tenant_id, - "embedding_model_name": embedding_model.model} + is_multimodal = ( + True + if embedding_model and getattr(embedding_model, "model_type", None) == "multimodal" + else False + ) + knowledge_data = { + "index_name": index_name, + "created_by": user_id, + "tenant_id": tenant_id, + "embedding_model_name": embedding_model.model, + "is_multimodal": is_multimodal, + } create_knowledge_record(knowledge_data) return {"status": "success", "message": f"Index {index_name} created successfully"} except Exception as e: @@ -443,6 +518,7 @@ def create_knowledge_base( ingroup_permission: Optional[str] = None, group_ids: Optional[List[int]] = None, embedding_model_name: Optional[str] = None, + is_multimodal: bool = False, ): """ Create a new knowledge base with a user-facing name and an internal Elasticsearch index name. @@ -468,7 +544,18 @@ def create_knowledge_base( """ try: # Get embedding model - use user-selected model if provided, otherwise use tenant default - embedding_model = get_embedding_model(tenant_id, embedding_model_name) + embedding_model = get_embedding_model( + tenant_id=tenant_id, + is_multimodal=is_multimodal, + model_name=embedding_model_name, + ) + + # If caller did not provide an explicit flag, infer multimodal from model metadata. + resolved_is_multimodal = is_multimodal or ( + True + if embedding_model and getattr(embedding_model, "model_type", None) == "multimodal" + else False + ) # Determine the embedding model name to save: use user-provided name if available, # otherwise use the model's display name @@ -483,6 +570,7 @@ def create_knowledge_base( "user_id": user_id, "tenant_id": tenant_id, "embedding_model_name": saved_embedding_model_name, + "is_multimodal": resolved_is_multimodal, } # Add group permission and group IDs if provided @@ -519,6 +607,7 @@ def update_knowledge_base( knowledge_name: Optional[str] = None, ingroup_permission: Optional[str] = None, group_ids: Optional[List[int]] = None, + is_multimodal: bool = False, tenant_id: Optional[str] = None, user_id: Optional[str] = None, ) -> bool: @@ -549,6 +638,7 @@ def update_knowledge_base( update_data = { "index_name": index_name, "updated_by": user_id, + "is_multimodal": is_multimodal, } if knowledge_name is not None: @@ -784,6 +874,7 @@ def list_indices( # knowledge source and ingroup permission from DB record "knowledge_sources": record["knowledge_sources"], "ingroup_permission": record["ingroup_permission"], + "is_multimodal": record.get("is_multimodal"), "tenant_id": record.get("tenant_id"), # Update time for sorting and display "update_time": record.get("update_time"), @@ -882,12 +973,27 @@ def index_documents( "author": author, "date": date, "content": text, - "process_source": "Unstructured", + "process_source": metadata.get("process_source", "Unstructured"), "file_size": file_size, "create_time": create_time, "languages": metadata.get("languages", []), "embedding_model_name": embedding_model_name } + + image_url = metadata.get("image_url", "") + if len(image_url) > 0: + # Fetch image bytes from MinIO (supports s3://bucket/key or /bucket/key) + try: + file_stream = get_file_stream( + object_name=image_url) + if file_stream is None: + raise FileNotFoundError( + f"Unable to fetch file from URL: {image_url}") + document["image_bytes"] = file_stream.read() + except Exception as e: + logger.error( + f"Failed to fetch file from {image_url}: {e}") + raise documents.append(document) @@ -908,8 +1014,9 @@ def index_documents( 'tenant_id') if knowledge_record else None if tenant_id: + model_type = "EMBEDDING_ID" if embedding_model.model_type == "text" else "MULTI_EMBEDDING_ID" model_config = tenant_config_manager.get_model_config( - key="EMBEDDING_ID", tenant_id=tenant_id) + key=model_type, tenant_id=tenant_id) embedding_batch_size = model_config.get("chunk_batch", 10) if embedding_batch_size is None: embedding_batch_size = 10 @@ -1552,6 +1659,7 @@ def create_chunk( try: # Get knowledge base's embedding model name embedding_model_name = None + is_multimodal = False if tenant_id: try: knowledge_record = get_knowledge_record({ @@ -1559,6 +1667,11 @@ def create_chunk( "tenant_id": tenant_id }) embedding_model_name = knowledge_record.get("embedding_model_name") if knowledge_record else None + is_multimodal = ( + True + if knowledge_record and knowledge_record.get("is_multimodal") == "Y" + else False + ) except Exception as e: logger.warning(f"Failed to get embedding model name for index {index_name}: {e}") @@ -1566,7 +1679,16 @@ def create_chunk( embedding_vector = None if chunk_request.content: try: - embedding_model = get_embedding_model(tenant_id, embedding_model_name) if tenant_id else None + embedding_model = ( + get_embedding_model( + tenant_id=tenant_id, + is_multimodal=is_multimodal, + model_name=embedding_model_name, + strict_model_name=bool(embedding_model_name), + ) + if tenant_id + else None + ) if embedding_model: embeddings = embedding_model.get_embeddings(chunk_request.content) if embeddings and len(embeddings) > 0: @@ -1577,6 +1699,8 @@ def create_chunk( else: logger.warning(f"No embedding model available for index {index_name}") except Exception as e: + if embedding_model_name: + raise logger.warning(f"Failed to generate embedding for chunk: {e}") # Build chunk payload @@ -1617,6 +1741,7 @@ def update_chunk( chunk_request: ChunkUpdateRequest, vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), user_id: Optional[str] = None, + tenant_id: Optional[str] = None, ): """ Update a chunk document. @@ -1625,6 +1750,37 @@ def update_chunk( update_fields = chunk_request.dict( exclude_unset=True, exclude={"metadata"}) metadata = chunk_request.metadata or {} + + if "content" in update_fields and update_fields.get("content"): + embedding_model_name = None + is_multimodal = False + if tenant_id: + knowledge_record = get_knowledge_record( + {"index_name": index_name, "tenant_id": tenant_id} + ) + embedding_model_name = ( + knowledge_record.get("embedding_model_name") + if knowledge_record + else None + ) + is_multimodal = bool( + knowledge_record and knowledge_record.get("is_multimodal") == "Y" + ) + + embedding_model = get_embedding_model( + tenant_id=tenant_id, + is_multimodal=is_multimodal, + model_name=embedding_model_name, + strict_model_name=bool(embedding_model_name), + ) + embeddings = embedding_model.get_embeddings( + update_fields["content"] + ) + if embeddings and len(embeddings) > 0: + update_fields["embedding"] = embeddings[0] + if embedding_model_name: + update_fields["embedding_model_name"] = embedding_model_name + update_payload = ElasticSearchService._build_chunk_payload( base_fields={ **update_fields, @@ -1700,7 +1856,23 @@ def search_hybrid( if weight_accurate < 0 or weight_accurate > 1: raise ValueError("weight_accurate must be between 0 and 1") - embedding_model = get_embedding_model(tenant_id) + embedding_model_name = None + is_multimodal = False + for index_name in index_names: + knowledge_record = get_knowledge_record( + {"index_name": index_name, "tenant_id": tenant_id} + ) + if knowledge_record: + embedding_model_name = knowledge_record.get("embedding_model_name") + is_multimodal = knowledge_record.get("is_multimodal") == "Y" + break + + embedding_model = get_embedding_model( + tenant_id=tenant_id, + is_multimodal=is_multimodal, + model_name=embedding_model_name, + strict_model_name=bool(embedding_model_name), + ) if not embedding_model: raise ValueError( "No embedding model configured for the current tenant") diff --git a/backend/utils/file_management_utils.py b/backend/utils/file_management_utils.py index 7d31a74bb..37681d9c7 100644 --- a/backend/utils/file_management_utils.py +++ b/backend/utils/file_management_utils.py @@ -15,7 +15,6 @@ from consts.model import ProcessParams from database.attachment_db import get_file_size_from_minio from utils.auth_utils import get_current_user_id -from utils.config_utils import tenant_config_manager logger = logging.getLogger("file_management_utils") @@ -45,18 +44,13 @@ async def trigger_data_process(files: List[dict], process_params: ProcessParams) if not files: return None - # Get chunking size according to the embedding model - embedding_model_id = None + # Get tenant_id from authorization for downstream task processing + embedding_model_id = process_params.model_id tenant_id = None try: _, tenant_id = get_current_user_id(process_params.authorization) - # Get embedding model ID from tenant config - tenant_config = tenant_config_manager.load_config(tenant_id) - embedding_model_id_str = tenant_config.get("EMBEDDING_ID") if tenant_config else None - if embedding_model_id_str: - embedding_model_id = int(embedding_model_id_str) except Exception as e: - logger.warning(f"Failed to get embedding model ID for tenant: {e}") + logger.warning(f"Failed to get tenant_id from authorization: {e}") # Build headers with authorization headers = { diff --git a/docker/.env.bak b/docker/.env.bak new file mode 100644 index 000000000..77eb8bf79 --- /dev/null +++ b/docker/.env.bak @@ -0,0 +1,174 @@ +# ===== Necessary Configs (Neccessary till now, will be migrated to frontend page) ===== + +# Voice Service Config +APPID=app_id +TOKEN=token + +# ===== Non-essential Configs (Modify if you know what you are doing) ===== + +CLUSTER=volcano_tts +VOICE_TYPE=zh_male_jieshuonansheng_mars_bigtts +SPEED_RATIO=1.3 + +# ===== Proxy Configuration (Optional) ===== + +# HTTP_PROXY=http://proxy-server:port +# HTTPS_PROXY=http://proxy-server:port +# NO_PROXY=localhost,127.0.0.1 + +# ===== Backend Configuration (No need to modify at all) ===== + +# Model Path Config +CLIP_MODEL_PATH=/opt/models/clip-vit-base-patch32 +NLTK_DATA=/opt/models/nltk_data + +# Elasticsearch Service +ELASTICSEARCH_HOST=http://nexent-elasticsearch:9200 +ELASTIC_PASSWORD=nexent@2025 + +# Elasticsearch Memory Configuration +ES_JAVA_OPTS="-Xms2g -Xmx2g" + +# Elasticsearch Disk Watermark Configuration +ES_DISK_WATERMARK_LOW=85% +ES_DISK_WATERMARK_HIGH=90% +ES_DISK_WATERMARK_FLOOD_STAGE=95% + +# Main Services +# Config service (port 5010) - Main API service for config operations +CONFIG_SERVICE_URL=http://nexent-config:5010 +ELASTICSEARCH_SERVICE=http://nexent-config:5010/api + +# Runtime service (port 5014) - Runtime execution service for agent operations +RUNTIME_SERVICE_URL=http://nexent-runtime:5014 + +# MCP service (port 5011) - MCP protocol service +NEXENT_MCP_SERVER=http://nexent-mcp:5011 + +# Data process service (port 5012) - Data processing service +DATA_PROCESS_SERVICE=http://nexent-data-process:5012/api + +# Northbound service (port 5013) - Northbound API service +NORTHBOUND_API_SERVER=http://nexent-northbound:5013/api + +# Postgres Config +POSTGRES_HOST=nexent-postgresql +POSTGRES_USER=root +NEXENT_POSTGRES_PASSWORD=nexent@4321 +POSTGRES_DB=nexent +POSTGRES_PORT=5432 + +# Minio Config +MINIO_ENDPOINT=http://nexent-minio:9000 +MINIO_ROOT_USER=nexent +MINIO_ROOT_PASSWORD=nexent@4321 +MINIO_REGION=cn-north-1 +MINIO_DEFAULT_BUCKET=nexent + +# Redis Config +REDIS_URL=redis://redis:6379/0 +REDIS_BACKEND_URL=redis://redis:6379/1 + +# Model Engine Config +MODEL_ENGINE_ENABLED=false + +# Supabase Config +DASHBOARD_USERNAME=supabase +DASHBOARD_PASSWORD=Huawei123 + +# Supabase db Config +SUPABASE_POSTGRES_PASSWORD=Huawei123 +SUPABASE_POSTGRES_HOST=db +SUPABASE_POSTGRES_DB=supabase +SUPABASE_POSTGRES_PORT=5436 + +# Supabase Auth Config +SITE_URL=http://localhost:3011 +SUPABASE_URL=http://supabase-kong-mini:8000 +API_EXTERNAL_URL=http://supabase-kong-mini:8000 +DISABLE_SIGNUP=false +JWT_EXPIRY=3600 +DEBUG_JWT_EXPIRE_SECONDS=0 + +# Supabase Configuration +ENABLE_EMAIL_SIGNUP=true +ENABLE_EMAIL_AUTOCONFIRM=true +ENABLE_ANONYMOUS_USERS=false + +# Supabase Phone Config +ENABLE_PHONE_SIGNUP=false +ENABLE_PHONE_AUTOCONFIRM=false + +MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify" +MAILER_URLPATHS_INVITE="/auth/v1/verify" +MAILER_URLPATHS_RECOVERY="/auth/v1/verify" +MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify" + +INVITE_CODE=nexent2025 + +# Terminal Tool SSH Key Path +SSH_PRIVATE_KEY_PATH=/path/to/openssh-server/ssh-keys/openssh_server_key + +# ===== Data Processing Service Configuration ===== + +# Redis Port +REDIS_PORT=6379 + +# Flower Monitoring +FLOWER_PORT=5555 + +# Ray Configuration +RAY_ACTOR_NUM_CPUS=2 +RAY_DASHBOARD_PORT=8265 +RAY_DASHBOARD_HOST=0.0.0.0 +RAY_NUM_CPUS=4 +RAY_OBJECT_STORE_MEMORY_GB=0.25 +RAY_TEMP_DIR=/tmp/ray +RAY_LOG_LEVEL=INFO + +# Service Control Flags +DISABLE_RAY_DASHBOARD=true +DISABLE_CELERY_FLOWER=true +DOCKER_ENVIRONMENT=false +ENABLE_UPLOAD_IMAGE=false + +# Celery Configuration +CELERY_WORKER_PREFETCH_MULTIPLIER=1 +CELERY_TASK_TIME_LIMIT=3600 +ELASTICSEARCH_REQUEST_TIMEOUT=30 + +# Worker Configuration +QUEUES=process_q,forward_q +WORKER_NAME= +WORKER_CONCURRENCY=4 + + +# Telemetry and Monitoring Configuration +ENABLE_TELEMETRY=false +SERVICE_NAME=nexent-backend +JAEGER_ENDPOINT=http://localhost:14268/api/traces +PROMETHEUS_PORT=8000 +TELEMETRY_SAMPLE_RATE=1.0 +LLM_SLOW_REQUEST_THRESHOLD_SECONDS=5.0 +LLM_SLOW_TOKEN_RATE_THRESHOLD=10.0 + +# Market Backend Address +MARKET_BACKEND=https://market.nexent.tech +DEPLOYMENT_VERSION="full" +# Root dir +ROOT_DIR="/e/nexent-data" +# ROOT_DIR="/e/aaa/nexent-data" +NEXENT_MCP_DOCKER_IMAGE="nexent/nexent-mcp:v2.0.1" +MINIO_ACCESS_KEY="d5e4c27903857e33c7b22ace" +MINIO_SECRET_KEY="lVvTKIaN1iR5RwimAb4q0r1Zt15XgHPhpkEVnskGwXM=" +JWT_SECRET="wc2Bcv8+4FC4qzMBVAnsWjpvhnCjSnA1fSfA1DEndBg=" +SECRET_KEY_BASE="Iw/qi8PIR4I+82ezPXT9+YJcQ83FPzMK4Q05TxfqDDQQ7qaVv2zWlfuuklrU1VI/BS2pT/f8VAHcsw3UNVX93g==" +VAULT_ENC_KEY="sPcMDblQo0CaA04K5s4sB4w6UuqN/OOBp1Q/EjVAeBM=" +SUPABASE_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJyb2xlIjoiYW5vbiIsImlzcyI6InN1cGFiYXNlIiwiaWF0IjoxNzc2NTI2NjE4LCJleHAiOjE5MzQyMDY2MTh9.IKxXdpmk1nekjln8Pdq9hZgENsJJAUve3nfbnz4RoUU" +SERVICE_ROLE_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJyb2xlIjoic2VydmljZV9yb2xlIiwiaXNzIjoic3VwYWJhc2UiLCJpYXQiOjE3NzUzNzYxNDcsImV4cCI6MTkzMzA1NjE0N30.XqDKEp-UmCwMEaPL77BarvirJqXcURH5UwqSaxvaiN4" +ELASTICSEARCH_API_KEY="S2hLcVhKMEJfWkZpV0wycjdDekQ6ejB1UHJvei1RTS1vb0RnbTVFNnhmdw==" + +# TABLE_TRANSFORMER_MODEL_PATH="E:\\nexent-data\\model\\table-transformer-structure-recognition" +# UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH="E:\\nexent-data\\model\\config.json" +TABLE_TRANSFORMER_MODEL_PATH="E:\\nexent-data\\model\\table-transformer-structure-recognition" +UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH="E:\\nexent-data\\model\\config.json" diff --git a/docker/deploy.sh b/docker/deploy.sh index e30e6e75a..f65dd240b 100755 --- a/docker/deploy.sh +++ b/docker/deploy.sh @@ -17,6 +17,7 @@ DEPLOY_OPTIONS_FILE="$SCRIPT_DIR/deploy.options" MODE_CHOICE_SAVED="" VERSION_CHOICE_SAVED="" IS_MAINLAND_SAVED="" +DOWNLOAD_MODELS="N" ENABLE_TERMINAL_SAVED="N" TERMINAL_MOUNT_DIR_SAVED="${TERMINAL_MOUNT_DIR:-}" APP_VERSION="" @@ -79,6 +80,56 @@ is_windows_env() { return 1 } +detect_os_type() { + # Return: windows | mac | linux | unknown + local os_name + os_name=$(uname -s 2>/dev/null | tr '[:upper:]' '[:lower:]') + case "$os_name" in + mingw*|msys*|cygwin*) + echo "windows" + ;; + darwin*) + echo "mac" + ;; + linux*) + echo "linux" + ;; + *) + echo "unknown" + ;; + esac +} + +format_path_for_env() { + # Convert path to OS-specific format for .env values + local input_path="$1" + local os_type + os_type=$(detect_os_type) + + if [ "$os_type" = "windows" ]; then + if command -v cygpath >/dev/null 2>&1; then + cygpath -w "$input_path" + return 0 + fi + + if [[ "$input_path" =~ ^/([a-zA-Z])/(.*)$ ]]; then + local drive="${BASH_REMATCH[1]}" + local rest="${BASH_REMATCH[2]}" + rest="${rest//\//\\}" + printf "%s:\\%s" "$(echo "$drive" | tr '[:lower:]' '[:upper:]')" "$rest" + return 0 + fi + fi + + printf "%s" "$input_path" +} + +escape_backslashes() { + # Escape backslashes for safe writing into .env or JSON + local input_path="$1" + printf "%s" "$input_path" | sed 's/\\/\\\\/g' +} + is_port_in_use() { # Check if a TCP port is already in use (Linux/macOS/Windows Git Bash) local port="$1" @@ -266,6 +317,7 @@ persist_deploy_options() { echo "MODE_CHOICE=\"${MODE_CHOICE_SAVED}\"" echo "VERSION_CHOICE=\"${VERSION_CHOICE_SAVED}\"" echo "IS_MAINLAND=\"${IS_MAINLAND_SAVED}\"" + echo "DOWNLOAD_MODELS=\"${DOWNLOAD_MODELS}\"" echo "ENABLE_TERMINAL=\"${ENABLE_TERMINAL_SAVED}\"" echo "TERMINAL_MOUNT_DIR=\"${TERMINAL_MOUNT_DIR_SAVED}\"" } > "$DEPLOY_OPTIONS_FILE" @@ -528,6 +580,227 @@ select_deployment_mode() { echo "" } + +# Model download selection +select_model_download() { + echo "" + + local input_choice="" + read -r -p "Do you want to download AI model files (table-transformer and yolox)? [Y/N] (default: N): " input_choice + echo "" + + if [[ $input_choice =~ ^[Yy]$ ]]; then + DOWNLOAD_MODELS="Y" + echo "INFO: Model download will be performed." + else + DOWNLOAD_MODELS="N" + echo "INFO: Skipping model download." + fi + echo "----------------------------------------" + echo "" +} + +# kerry + +download_and_config_models() { + if [ "$DOWNLOAD_MODELS" != "Y" ]; then + echo "INFO: Model download skipped by user choice." + return 0 + fi + + echo "INFO: Downloading AI model files (this may take a while)..." + + local ENV_FILE_DIR="$SCRIPT_DIR" + local ENV_FILE_PATH="$ENV_FILE_DIR/.env" + local ORIGINAL_DIR="$(pwd)" + + MODEL_ROOT="$ROOT_DIR/model" + mkdir -p "$MODEL_ROOT" + echo "INFO: Model directory: $MODEL_ROOT" + + export HF_ENDPOINT="https://hf-mirror.com" + + command -v git >/dev/null || { echo "ERROR: git is required but not found."; return 1; } + + # ========================================== + # 1. Table Transformer (table-structure recognition) + echo "INFO: Downloading table-transformer-structure-recognition..." + + TT_MODEL_DIR_NAME="table-transformer-structure-recognition" + TT_MODEL_DIR_PATH="$MODEL_ROOT/$TT_MODEL_DIR_NAME" + TT_MODEL_FILE_CHECK="$TT_MODEL_DIR_PATH/model.safetensors" + + cd "$MODEL_ROOT" || return 1 + + if [ -d "$TT_MODEL_DIR_PATH" ] && [ -f "$TT_MODEL_FILE_CHECK" ]; then + FILE_SIZE=$(stat -c%s "$TT_MODEL_FILE_CHECK" 2>/dev/null || stat -f%z "$TT_MODEL_FILE_CHECK" 2>/dev/null) + if [ "$FILE_SIZE" -gt 1000000 ]; then + echo "INFO: Table Transformer already exists." + else + echo "WARN: Existing model file looks incomplete, re-downloading..." + rm -rf "$TT_MODEL_DIR_NAME" + fi + fi + + if [ ! -f "$TT_MODEL_FILE_CHECK" ]; then + if [ -d "$TT_MODEL_DIR_NAME" ]; then + echo "WARN: Removing existing directory before re-download..." + rm -rf "$TT_MODEL_DIR_NAME" + fi + + echo "INFO: Step 1/2: Clone repo (skip LFS files)..." + if ! GIT_LFS_SKIP_SMUDGE=1 git clone "$HF_ENDPOINT/microsoft/$TT_MODEL_DIR_NAME" "$TT_MODEL_DIR_NAME"; then + echo "ERROR: Failed to clone repository." + cd "$ORIGINAL_DIR" + return 1 + fi + + cd "$TT_MODEL_DIR_NAME" || return 1 + + echo "INFO: Step 2/2: Download model.safetensors..." + LARGE_FILE_URL="$HF_ENDPOINT/microsoft/$TT_MODEL_DIR_NAME/resolve/main/model.safetensors" + + if command -v curl &> /dev/null; then + curl -L -o "model.safetensors" "$LARGE_FILE_URL" --progress-bar + elif command -v wget &> /dev/null; then + wget "$LARGE_FILE_URL" -O "model.safetensors" + else + echo "ERROR: curl or wget is required to download model files." + cd "$MODEL_ROOT"; rm -rf "$TT_MODEL_DIR_NAME"; cd "$ORIGINAL_DIR"; return 1 + fi + + if [ ! -f "model.safetensors" ]; then + echo "ERROR: model.safetensors download failed." + cd "$MODEL_ROOT"; rm -rf "$TT_MODEL_DIR_NAME"; cd "$ORIGINAL_DIR"; return 1 + fi + + FILE_SIZE=$(stat -c%s "model.safetensors" 2>/dev/null || stat -f%z "model.safetensors" 2>/dev/null) + if [ "$FILE_SIZE" -lt 1000000 ]; then + echo "ERROR: model.safetensors seems too small (size: $FILE_SIZE bytes)." + cd "$MODEL_ROOT"; rm -rf "$TT_MODEL_DIR_NAME"; cd "$ORIGINAL_DIR"; return 1 + fi + + echo "INFO: model.safetensors downloaded (size: $(du -h model.safetensors | cut -f1))" + cd "$MODEL_ROOT" + fi + + echo "INFO: Table Transformer OK" + + # ========================================== + # 2. YOLOX (layout detection model) + echo "INFO: Downloading yolox_l0.05.onnx" + + YOLOX_MODEL_FILE="$MODEL_ROOT/yolox_l0.05.onnx" + MIN_YOLOX_SIZE=50000000 + + NEED_DOWNLOAD=false + + if [ -f "$YOLOX_MODEL_FILE" ]; then + CURRENT_SIZE=$(stat -c%s "$YOLOX_MODEL_FILE" 2>/dev/null || stat -f%z "$YOLOX_MODEL_FILE" 2>/dev/null) + if [ "$CURRENT_SIZE" -lt "$MIN_YOLOX_SIZE" ]; then + echo "WARN: Existing YOLOX file looks incomplete (size: $(numfmt --to=iec-i --suffix=B $CURRENT_SIZE 2>/dev/null || echo $CURRENT_SIZE)). Re-downloading..." + NEED_DOWNLOAD=true + else + echo "INFO: YOLOX already exists." + fi + else + NEED_DOWNLOAD=true + fi + + if [ "$NEED_DOWNLOAD" = true ]; then + ONNX_URL="$HF_ENDPOINT/unstructuredio/yolo_x_layout/resolve/main/yolox_l0.05.onnx" + + if command -v curl &> /dev/null; then + echo "INFO: Downloading with curl (supports resume -C -)..." + if curl -L -C - -o "$YOLOX_MODEL_FILE" "$ONNX_URL" --progress-bar; then + echo "INFO: curl download completed" + else + echo "ERROR: curl download failed." + cd "$ORIGINAL_DIR" + return 1 + fi + elif command -v wget &> /dev/null; then + echo "INFO: Downloading with wget (supports resume -c)..." + wget -c "$ONNX_URL" -O "$YOLOX_MODEL_FILE" + else + echo "ERROR: curl or wget is required to download model files." + cd "$ORIGINAL_DIR" + return 1 + fi + + if [ -f "$YOLOX_MODEL_FILE" ]; then + FINAL_SIZE=$(stat -c%s "$YOLOX_MODEL_FILE" 2>/dev/null || stat -f%z "$YOLOX_MODEL_FILE" 2>/dev/null) + if [ "$FINAL_SIZE" -lt "$MIN_YOLOX_SIZE" ]; then + echo "ERROR: YOLOX file seems too small (size: $FINAL_SIZE bytes)." + cd "$ORIGINAL_DIR" + return 1 + else + echo "INFO: YOLOX downloaded (size: $(numfmt --to=iec-i --suffix=B $FINAL_SIZE 2>/dev/null || echo $FINAL_SIZE))" + fi + else + echo "ERROR: YOLOX download failed: file not found." + cd "$ORIGINAL_DIR" + return 1 + fi + fi + + echo "INFO: YOLOX OK" + + # ========================================== + # 3. config.json + CONFIG_FILE="$MODEL_ROOT/config.json" + YOLOX_ABS_PATH=$(cd "$(dirname "$YOLOX_MODEL_FILE")" && pwd)/$(basename "$YOLOX_MODEL_FILE") + YOLOX_OS_PATH=$(format_path_for_env "$YOLOX_ABS_PATH") + YOLOX_CONFIG_PATH=$(escape_backslashes "$YOLOX_OS_PATH") + + cat > "$CONFIG_FILE" < /dev/null; then + update_env_var "TABLE_TRANSFORMER_MODEL_PATH" "$TT_MODEL_DIR_ENV_PATH" + update_env_var "UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH" "$CONFIG_FILE_ENV_PATH" + else + sed -i.bak "/^TABLE_TRANSFORMER_MODEL_PATH=/d" "$ENV_FILE_PATH" 2>/dev/null || true + echo "TABLE_TRANSFORMER_MODEL_PATH="$TT_MODEL_DIR_ENV_PATH"" >> "$ENV_FILE_PATH" + + sed -i.bak "/^UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH=/d" "$ENV_FILE_PATH" 2>/dev/null || true + echo "UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH="$CONFIG_FILE_ENV_PATH"" >> "$ENV_FILE_PATH" + rm -f "$ENV_FILE_PATH.bak" 2>/dev/null + fi + + echo "INFO: Environment file updated" + cd "$ORIGINAL_DIR" +} + clean() { export MINIO_ACCESS_KEY= export MINIO_SECRET_KEY= @@ -600,6 +873,13 @@ prepare_directory_and_data() { create_dir_with_permission "$ROOT_DIR/minio" 775 create_dir_with_permission "$ROOT_DIR/redis" 775 + echo "📦 Check the status of model configuration..." + download_and_config_models || { + echo "⚠️ A warning occurred during the model configuration step, but subsequent deployment will proceed..." + # Do not exit here; the user may choose N or prefer to continue after a download failure. + } + echo "" + cp -rn volumes $ROOT_DIR chmod -R 775 $ROOT_DIR/volumes echo " 📁 Directory $ROOT_DIR/volumes has been created and permissions set to 775." @@ -1057,6 +1337,8 @@ main_deploy() { select_terminal_tool || { echo "❌ Terminal tool container configuration failed"; exit 1; } choose_image_env || { echo "❌ Image environment setup failed"; exit 1; } + select_model_download || { echo "❌ Model download failed"; exit 1;} + # Set NEXENT_MCP_DOCKER_IMAGE in .env file if [ -n "${NEXENT_MCP_DOCKER_IMAGE:-}" ]; then update_env_var "NEXENT_MCP_DOCKER_IMAGE" "${NEXENT_MCP_DOCKER_IMAGE}" @@ -1142,7 +1424,7 @@ docker_compose_command="" case $version_type in "v1") echo "Detected Docker Compose V1, version: $version_number" - # The version ​​v1.28.0​​ is the minimum requirement in Docker Compose v1 that explicitly supports interpolation syntax with default values like ${VAR:-default} + # The version 1.28.0 is the minimum requirement in Docker Compose v1 for default interpolation syntax. if [[ $version_number < "1.28.0" ]]; then echo "Warning: V1 version is too old, consider upgrading to V2" exit 1 diff --git a/docker/init.sql b/docker/init.sql index 26c345b69..6d274c8e1 100644 --- a/docker/init.sql +++ b/docker/init.sql @@ -213,6 +213,7 @@ CREATE TABLE IF NOT EXISTS "knowledge_record_t" ( "embedding_model_name" varchar(200) COLLATE "pg_catalog"."default", "group_ids" varchar, "ingroup_permission" varchar(30), + "is_multimodal" varchar(1) COLLATE "pg_catalog"."default" DEFAULT 'N'::character varying, "create_time" timestamp(0) DEFAULT CURRENT_TIMESTAMP, "update_time" timestamp(0) DEFAULT CURRENT_TIMESTAMP, "delete_flag" varchar(1) COLLATE "pg_catalog"."default" DEFAULT 'N'::character varying, @@ -230,6 +231,7 @@ COMMENT ON COLUMN "knowledge_record_t"."knowledge_sources" IS 'Knowledge base so COMMENT ON COLUMN "knowledge_record_t"."embedding_model_name" IS 'Embedding model name, used to record the embedding model used by the knowledge base'; COMMENT ON COLUMN "knowledge_record_t"."group_ids" IS 'Knowledge base group IDs list'; COMMENT ON COLUMN "knowledge_record_t"."ingroup_permission" IS 'In-group permission: EDIT, READ_ONLY, PRIVATE'; +COMMENT ON COLUMN "knowledge_record_t"."is_multimodal" IS 'whether it is multimodal'; COMMENT ON COLUMN "knowledge_record_t"."create_time" IS 'Creation time, audit field'; COMMENT ON COLUMN "knowledge_record_t"."update_time" IS 'Update time, audit field'; COMMENT ON COLUMN "knowledge_record_t"."delete_flag" IS 'When deleted by user frontend, delete flag will be set to true, achieving soft delete effect. Optional values Y/N'; diff --git a/docker/sql/v1.8.1_0306_add_is_multimodal_to_knowledge_record_t.sql b/docker/sql/v1.8.1_0306_add_is_multimodal_to_knowledge_record_t.sql new file mode 100644 index 000000000..d5b14bfbb --- /dev/null +++ b/docker/sql/v1.8.1_0306_add_is_multimodal_to_knowledge_record_t.sql @@ -0,0 +1,5 @@ +-- Add is_multimodal column to knowledge_record_t table +ALTER TABLE nexent.knowledge_record_t +ADD COLUMN IF NOT EXISTS is_multimodal varchar(1) DEFAULT 'N'; + +COMMENT ON COLUMN nexent.knowledge_record_t.is_multimodal IS 'whether it is multimodal'; diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx index c97536b92..481d71920 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx @@ -100,7 +100,8 @@ export default function ToolManagement({ // Use tool list hook for data management const { availableTools } = useToolList(); - const { isVlmAvailable, isEmbeddingAvailable } = useConfig(); + const { isVlmAvailable, isEmbeddingAvailable, isMultiEmbeddingAvailable } = useConfig(); + const isEmbeddingOrMultiAvailable = isEmbeddingAvailable || isMultiEmbeddingAvailable; // Prefetch knowledge bases for KB tools const { prefetchKnowledgeBases } = usePrefetchKnowledgeBases(); @@ -363,7 +364,10 @@ export default function ToolManagement({ tool.id ); const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable); - const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); + const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding( + tool.name, + isEmbeddingOrMultiAvailable + ); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding const tooltipTitle = isReadOnly @@ -468,7 +472,10 @@ export default function ToolManagement({ {group.tools.map((tool) => { const isSelected = originalSelectedToolIdsSet.has(tool.id); const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable); - const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); + const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding( + tool.name, + isEmbeddingOrMultiAvailable + ); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding const tooltipTitle = isReadOnly diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx index d09a06039..606d1da6a 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx @@ -1,4 +1,4 @@ -"use client"; +"use client"; import { useState, useEffect, useCallback, useMemo, useRef } from "react"; import { useTranslation } from "react-i18next"; @@ -459,6 +459,35 @@ export default function ToolConfigModal({ } }, [configData]); + const currentMultiEmbeddingModel = useMemo(() => { + try { + const modelConfig = configData?.models; + return ( + modelConfig?.multiEmbedding?.modelName || + modelConfig?.multiEmbedding?.displayName || + null + ); + } catch { + return null; + } + }, [configData]); + + const toolMultimodal = useMemo(() => { + const multimodalParam = currentParams.find( + (param) => param.name === "multimodal" + ); + const value = multimodalParam?.value; + if (typeof value === "boolean") { + return value; + } + if (typeof value === "string") { + const normalized = value.trim().toLowerCase(); + if (["true", "1", "yes", "y"].includes(normalized)) return true; + if (["false", "0", "no", "n"].includes(normalized)) return false; + } + return null; + }, [currentParams]); + // Check if a knowledge base can be selected const canSelectKnowledgeBase = useCallback( (kb: KnowledgeBase): boolean => { @@ -469,9 +498,43 @@ export default function ToolConfigModal({ return false; } + if (kb.source === "nexent") { + const hasMultimodalConstraintMismatch = + toolMultimodal !== null && + ((toolMultimodal && !kb.is_multimodal) || + (!toolMultimodal && kb.is_multimodal)); + if (hasMultimodalConstraintMismatch) { + return false; + } + + if (kb.is_multimodal) { + if (!currentMultiEmbeddingModel) { + return false; + } + if ( + kb.embeddingModel && + kb.embeddingModel !== "unknown" && + kb.embeddingModel !== currentMultiEmbeddingModel + ) { + return false; + } + } else { + if (!currentEmbeddingModel) { + return false; + } + if ( + kb.embeddingModel && + kb.embeddingModel !== "unknown" && + kb.embeddingModel !== currentEmbeddingModel + ) { + return false; + } + } + } + return true; }, - [currentEmbeddingModel] + [currentEmbeddingModel, currentMultiEmbeddingModel, toolMultimodal] ); // Track whether this is the first time opening the modal (reset when modal closes) @@ -1290,7 +1353,7 @@ export default function ToolConfigModal({ })} options={options.map((option) => ({ value: option, - label: option, + label: String(option), }))} /> ); @@ -1684,6 +1747,8 @@ export default function ToolConfigModal({ syncLoading={kbLoading} isSelectable={canSelectKnowledgeBase} currentEmbeddingModel={currentEmbeddingModel} + currentMultiEmbeddingModel={currentMultiEmbeddingModel} + toolMultimodal={toolMultimodal} difyConfig={ toolKbType === "dify_search" ? difyConfig diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolTestPanel.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolTestPanel.tsx index f2bcc7f9e..99e21e8f8 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolTestPanel.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolTestPanel.tsx @@ -133,12 +133,12 @@ export default function ToolTestPanel({ // Check if this is the index_names parameter and KB selection is enabled const isIndexNamesParam = paramName === "index_names" && toolRequiresKbSelection; + if (isIndexNamesParam) { + // index_names is provided by KB selector config, no need to duplicate in input params. + return; + } - if (isIndexNamesParam && selectedKbIds.length > 0) { - // Use the selected KB IDs from configParams as default - parameterValues[paramName] = selectedKbIds; - formValues[`param_${paramName}`] = selectedKbIds; - } else if ( + if ( paramInfo && typeof paramInfo === "object" && paramInfo.default != null @@ -211,25 +211,6 @@ export default function ToolTestPanel({ if (!idsMatch) { form.setFieldValue(fieldName, selectedKbIds); - - // Also update the parameter values - if (selectedKbIds.length > 0) { - setParameterValues((prev) => ({ - ...prev, - index_names: selectedKbIds, - })); - // Update manual JSON input while preserving other values - setManualJsonInput((prev) => { - try { - const parsed = JSON.parse(prev); - parsed.index_names = selectedKbIds; - return JSON.stringify(parsed, null, 2); - } catch { - // If JSON is invalid, keep the current value - return prev; - } - }); - } } }, [selectedKbIds, toolRequiresKbSelection, form]); diff --git a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx index a5e7d52d1..1dc71ffa1 100644 --- a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx +++ b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx @@ -7,6 +7,7 @@ import { useRef, useLayoutEffect, useCallback, + useMemo, } from "react"; import { useTranslation } from "react-i18next"; @@ -45,6 +46,37 @@ import { } from "./contexts/DocumentContext"; import { useUIContext, UIProvider } from "./contexts/UIStateContext"; +const EMBEDDING_MODEL_OPTION_DELIMITER = "::"; +const normalizeEmbeddingModelType = (type: string) => + (type || "").trim().toLowerCase(); + +const toEmbeddingModelOptionValue = (displayName: string, type: string) => + `${displayName}${EMBEDDING_MODEL_OPTION_DELIMITER}${type}`; + +const parseEmbeddingModelOptionValue = (value: string) => { + const normalizedValue = (value || "").trim(); + const delimiterIndex = normalizedValue.lastIndexOf( + EMBEDDING_MODEL_OPTION_DELIMITER + ); + if (delimiterIndex >= 0) { + const displayName = normalizedValue.slice(0, delimiterIndex); + const type = normalizedValue.slice( + delimiterIndex + EMBEDDING_MODEL_OPTION_DELIMITER.length + ); + return { + displayName: displayName || "", + type: (type || "").trim(), + isMultimodal: + normalizeEmbeddingModelType(type || "") === "multi_embedding", + }; + } + return { + displayName: normalizedValue || "", + type: "", + isMultimodal: false, + }; +}; + // EmptyState component defined directly in this file interface EmptyStateProps { icon?: React.ReactNode | string; @@ -55,7 +87,7 @@ interface EmptyStateProps { } const EmptyState: React.FC = ({ - icon = "📋", + icon = "馃搵", title, description, action, @@ -129,7 +161,7 @@ function DataConfig({ isActive }: DataConfigProps) { const { token } = theme.useToken(); // Get available embedding models for knowledge base creation - const { availableEmbeddingModels } = useModelList({ enabled: true }); + const { models } = useModelList({ enabled: true }); // Clear cache when component initializes useEffect(() => { @@ -197,11 +229,41 @@ function DataConfig({ isActive }: DataConfigProps) { const [modelFilter, setModelFilter] = useState([]); const contentRef = useRef(null); - // Open warning modal when single Embedding model is not configured (ignore multi-embedding) + const availableEmbeddingModels = useMemo(() => { + return models.filter( + (model) => + (model.type === "embedding" || model.type === "multi_embedding") && + model.connect_status === "available" + ); + }, [models]); + + const resolveEmbeddingModelId = useCallback( + ({ + displayName, + isMultimodal, + }: { + displayName?: string; + isMultimodal?: boolean; + }) => { + const normalizedDisplayName = (displayName || "").trim(); + if (!normalizedDisplayName) return undefined; + + const modelType = isMultimodal ? "multi_embedding" : "embedding"; + return availableEmbeddingModels.find( + (model) => + model.displayName === normalizedDisplayName && model.type === modelType + )?.id; + }, + [availableEmbeddingModels] + ); + + // Open warning modal only when neither embedding nor multi-embedding is configured. useEffect(() => { - const singleEmbeddingModelName = modelConfig?.embedding?.modelName; - setShowEmbeddingWarning(!singleEmbeddingModelName); - }, [modelConfig?.embedding?.modelName]); + const singleEmbeddingModelName = modelConfig?.embedding?.modelName?.trim(); + const multiEmbeddingModelName = + modelConfig?.multiEmbedding?.modelName?.trim(); + setShowEmbeddingWarning(!singleEmbeddingModelName && !multiEmbeddingModelName); + }, [modelConfig?.embedding?.modelName, modelConfig?.multiEmbedding?.modelName]); // Add event listener for selecting new knowledge base useEffect(() => { @@ -369,11 +431,11 @@ function DataConfig({ isActive }: DataConfigProps) { // Directly call fetchKnowledgeBases to update knowledge base list data await fetchKnowledgeBases(false, true); } catch (error) { - log.error("获取知识库最新数据失败:", error); + log.error("鑾峰彇鐭ヨ瘑搴撴渶鏂版暟鎹け璐?", error); } }, 100); } catch (error) { - log.error("获取文档列表失败:", error); + log.error("鑾峰彇鏂囨。鍒楄〃澶辫触:", error); message.error(t("knowledgeBase.message.getDocumentsFailed")); docDispatch({ type: "ERROR", @@ -618,11 +680,30 @@ function DataConfig({ isActive }: DataConfigProps) { setNewKbName(defaultName); setNewKbIngroupPermission("READ_ONLY"); setNewKbGroupIds([]); - // Set default embedding model - prioritize config's default model, fall back to first available model - const configModel = modelConfig?.embedding?.modelName; - const defaultModel = configModel || (availableEmbeddingModels.length > 0 - ? availableEmbeddingModels[0].displayName - : ""); + // Set default embedding model: + // 1) configured embedding model, 2) configured multimodal model, 3) first available option. + const configEmbeddingModel = modelConfig?.embedding?.modelName?.trim() || ""; + const configMultiEmbeddingModel = + modelConfig?.multiEmbedding?.modelName?.trim() || ""; + const preferredModel = [ + { modelName: configEmbeddingModel, type: "embedding" }, + { modelName: configMultiEmbeddingModel, type: "multi_embedding" }, + ].find( + ({ modelName, type }) => + !!modelName && + availableEmbeddingModels.some( + (model) => model.displayName === modelName && model.type === type + ) + ); + const defaultModel = + (preferredModel && + toEmbeddingModelOptionValue(preferredModel.modelName, preferredModel.type)) || + (availableEmbeddingModels[0] + ? toEmbeddingModelOptionValue( + availableEmbeddingModels[0].displayName, + availableEmbeddingModels[0].type + ) + : ""); setNewKbEmbeddingModel(defaultModel); setIsCreatingMode(true); setHasClickedUpload(false); // Reset upload button click state @@ -681,13 +762,22 @@ function DataConfig({ isActive }: DataConfigProps) { return; } + const parsedSelectedModel = + parseEmbeddingModelOptionValue(newKbEmbeddingModel); + const isMultimodal = parsedSelectedModel.isMultimodal; + const selectedModelId = resolveEmbeddingModelId({ + displayName: parsedSelectedModel.displayName, + isMultimodal: parsedSelectedModel.isMultimodal, + }); + const newKB = await createKnowledgeBase( newKbName.trim(), t("knowledgeBase.description.default"), "elasticsearch", newKbIngroupPermission, newKbGroupIds, - newKbEmbeddingModel + parsedSelectedModel.displayName, + isMultimodal ); if (!newKB) { @@ -702,7 +792,7 @@ function DataConfig({ isActive }: DataConfigProps) { setHasClickedUpload(false); setNewlyCreatedKbId(newKB.id); // Mark this KB as newly created - await uploadDocuments(newKB.id, filesToUpload); + await uploadDocuments(newKB.id, filesToUpload, selectedModelId); setUploadFiles([]); knowledgeBasePollingService @@ -738,7 +828,12 @@ function DataConfig({ isActive }: DataConfigProps) { } try { - await uploadDocuments(kbId, filesToUpload); + const activeKbModelId = resolveEmbeddingModelId({ + displayName: kbState.activeKnowledgeBase?.embeddingModel, + isMultimodal: kbState.activeKnowledgeBase?.is_multimodal, + }); + + await uploadDocuments(kbId, filesToUpload, activeKbModelId); setUploadFiles([]); knowledgeBasePollingService.triggerKnowledgeBaseListUpdate(true); @@ -887,7 +982,7 @@ function DataConfig({ isActive }: DataConfigProps) { = ({ knowledgeBaseId, documents, getFileIcon, - currentEmbeddingModel = null, - knowledgeBaseEmbeddingModel = "", onChunkCountChange, permission, }) => { @@ -128,55 +126,10 @@ const DocumentChunk: React.FC = ({ setTooltipResetKey((prev) => prev + 1); }, []); - // Determine if embedding models mismatch (specific condition for tooltip) - const isEmbeddingModelMismatch = React.useMemo(() => { - if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) { - return false; - } - if (knowledgeBaseEmbeddingModel === "unknown") { - return false; - } - return currentEmbeddingModel !== knowledgeBaseEmbeddingModel; - }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel]); - - // Determine if in read-only mode (embedding model mismatch OR user has READ_ONLY permission) - // Note: isReadOnlyMode is broader, includes model mismatch and other conditions + // Determine if in read-only mode (user permission only) const isReadOnlyMode = React.useMemo(() => { - // Check if user has READ_ONLY permission - if (permission === "READ_ONLY") { - return true; - } - if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) { - return false; - } - if (knowledgeBaseEmbeddingModel === "unknown") { - return false; - } - return currentEmbeddingModel !== knowledgeBaseEmbeddingModel; - }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel, permission]); - - // Determine if search should be disabled (only when embedding model mismatch, NOT for READ_ONLY permission) - // This allows READ_ONLY users to still perform search - const isSearchDisabled = React.useMemo(() => { - if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) { - return false; - } - if (knowledgeBaseEmbeddingModel === "unknown") { - return false; - } - return currentEmbeddingModel !== knowledgeBaseEmbeddingModel; - }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel]); - - // Disabled tooltip message when embedding model mismatch - const disabledTooltipMessage = React.useMemo(() => { - if (isEmbeddingModelMismatch && currentEmbeddingModel && knowledgeBaseEmbeddingModel && knowledgeBaseEmbeddingModel !== "unknown") { - return t("document.chunk.tooltip.disabledDueToModelMismatch", { - currentModel: currentEmbeddingModel, - knowledgeBaseModel: knowledgeBaseEmbeddingModel - }); - } - return ""; - }, [isEmbeddingModelMismatch, currentEmbeddingModel, knowledgeBaseEmbeddingModel, t]); + return permission === "READ_ONLY"; + }, [permission]); // Set active document when documents change useEffect(() => { @@ -321,15 +274,6 @@ const DocumentChunk: React.FC = ({ return; } - // Check embedding model consistency before searching - if (isEmbeddingModelMismatch && currentEmbeddingModel && knowledgeBaseEmbeddingModel && knowledgeBaseEmbeddingModel !== "unknown") { - message.error(t("document.chunk.error.searchFailed", { - currentModel: currentEmbeddingModel, - knowledgeBaseModel: knowledgeBaseEmbeddingModel - })); - return; - } - if (!knowledgeBaseName) { message.error(t("document.chunk.error.searchFailed")); return; @@ -380,9 +324,6 @@ const DocumentChunk: React.FC = ({ resetChunkSearch, searchValue, t, - isEmbeddingModelMismatch, - currentEmbeddingModel, - knowledgeBaseEmbeddingModel, ]); const refreshChunks = React.useCallback(async () => { @@ -463,20 +404,6 @@ const DocumentChunk: React.FC = ({ return; } - // Check embedding model consistency before creating chunk - if (chunkModalMode === "create") { - if (knowledgeBaseEmbeddingModel && - knowledgeBaseEmbeddingModel !== "unknown" && - currentEmbeddingModel && - currentEmbeddingModel !== knowledgeBaseEmbeddingModel) { - message.error(t("document.chunk.error.createFailed", { - currentModel: currentEmbeddingModel, - knowledgeBaseModel: knowledgeBaseEmbeddingModel - })); - return; - } - } - try { const values = await chunkForm.validateFields(); setChunkSubmitting(true); @@ -761,11 +688,11 @@ const DocumentChunk: React.FC = ({
{chunk.source_type === "datamate" - ? t("document.chunk.source.datamate", "来源: Datamate") + ? t("document.chunk.source.datamate", "\u6765\u6e90: Datamate") : chunk.source_type === "file" || chunk.source_type === "minio" || chunk.source_type === "local" - ? t("document.chunk.source.nexent", "来源: Nexent") + ? t("document.chunk.source.nexent", "\u6765\u6e90: Nexent") : ""}
@@ -805,57 +732,37 @@ const DocumentChunk: React.FC = ({ {/* Search and Add Button Bar */}
- {/* Wrap search input with tooltip when model mismatch */} - {isEmbeddingModelMismatch ? ( - - - setSearchValue(e.target.value)} - onPressEnter={() => { + setSearchValue(e.target.value)} + onPressEnter={() => { + void handleSearch(); + }} + style={{ width: 320 }} + suffix={ +
+ {searchValue && ( +
- } - /> - )} +
+ } + />
{/* Create Chunk button - hide when user has READ_ONLY permission */} {!isReadOnlyMode && ( @@ -864,7 +771,6 @@ const DocumentChunk: React.FC = ({ type="text" icon={} onClick={openCreateChunkModal} - disabled={isEmbeddingModelMismatch} > )} diff --git a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx index 06940d9f0..dd51c4c3a 100644 --- a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx +++ b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx @@ -503,11 +503,29 @@ const DocumentListContainer = forwardRef( onChange={onEmbeddingModelChange} style={{ minWidth: 200, justifyContent: "center", alignItems: "flex-end" }} placeholder={t("knowledgeBase.create.embeddingModelPlaceholder") || "Select embedding model"} - options={(availableEmbeddingModels || []).map((model) => ({ - value: model.displayName, - label: model.displayName, - disabled: model.connect_status === "unavailable", - }))} + allowClear={false} + options={[ + { + label: t("modelConfig.option.embeddingModel"), + options: (availableEmbeddingModels || []) + .filter((model) => model.type === "embedding") + .map((model) => ({ + value: `${model.displayName}::${model.type}`, + label: model.displayName, + disabled: model.connect_status === "unavailable", + })), + }, + { + label: t("modelConfig.option.multiEmbeddingModel"), + options: (availableEmbeddingModels || []) + .filter((model) => model.type === "multi_embedding") + .map((model) => ({ + value: `${model.displayName}::${model.type}`, + label: model.displayName, + disabled: model.connect_status === "unavailable", + })), + }, + ].filter((group) => group.options.length > 0)} /> )} {/* User groups multi-select */} diff --git a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx index cbff0297b..670a413d6 100644 --- a/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx +++ b/frontend/app/[locale]/knowledges/components/knowledge/KnowledgeBaseList.tsx @@ -31,7 +31,10 @@ import { KB_LAYOUT, KB_TAG_VARIANTS } from "@/const/knowledgeBaseLayout"; interface KnowledgeBaseListProps { knowledgeBases: KnowledgeBase[]; activeKnowledgeBase: KnowledgeBase | null; - currentEmbeddingModel: string | null; + configuredEmbeddingModels?: Array<{ + displayName: string; + type: string; + }>; isLoading?: boolean; syncLoading?: boolean; onClick: (kb: KnowledgeBase) => void; @@ -56,7 +59,7 @@ interface KnowledgeBaseListProps { const KnowledgeBaseList: React.FC = ({ knowledgeBases, activeKnowledgeBase, - currentEmbeddingModel, + configuredEmbeddingModels = [], isLoading = false, syncLoading = false, onClick, @@ -127,6 +130,34 @@ const KnowledgeBaseList: React.FC = ({ return `knowledgeBase.ingroup.permission.${permission || "DEFAULT"}`; }; + const configuredModelTypesByName = useMemo(() => { + const map = new Map>(); + configuredEmbeddingModels.forEach((model) => { + const modelName = (model.displayName || "").trim(); + const modelType = (model.type || "").trim().toLowerCase(); + if (!modelName) return; + if (modelType !== "embedding" && modelType !== "multi_embedding") return; + if (!map.has(modelName)) { + map.set(modelName, new Set()); + } + map.get(modelName)!.add(modelType); + }); + return map; + }, [configuredEmbeddingModels]); + + const isModelMismatch = (kb: KnowledgeBase) => { + if (kb.embeddingModel === "unknown") return false; + if (kb.source === "datamate") return false; + const modelTypes = configuredModelTypesByName.get( + (kb.embeddingModel || "").trim() + ); + return !modelTypes; + }; + + const hasIndexedDocumentsAndChunks = (kb: KnowledgeBase) => { + return (kb.documentCount || 0) > 0 && (kb.chunkCount || 0) > 0; + }; + // Search and filter states const [searchKeyword, setSearchKeyword] = useState(""); const [selectedSources, setSelectedSources] = useState([]); @@ -579,6 +610,21 @@ const KnowledgeBaseList: React.FC = ({ })} )} + {kb.is_multimodal && + hasIndexedDocumentsAndChunks(kb) && ( + + multimodal + + )} + {isModelMismatch(kb) && ( + + {t("knowledgeBase.tag.modelMismatch")} + + )} {/* User group tags - only show when not PRIVATE */} diff --git a/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx b/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx index b956dd919..63d9ad1c2 100644 --- a/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx +++ b/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx @@ -112,7 +112,7 @@ export const DocumentContext = createContext<{ state: DocumentState; dispatch: React.Dispatch; fetchDocuments: (kbId: string, forceRefresh?: boolean, kbSource?: string) => Promise; - uploadDocuments: (kbId: string, files: File[]) => Promise; + uploadDocuments: (kbId: string, files: File[], modelId?: number) => Promise; deleteDocument: (kbId: string, docId: string) => Promise; }>({ state: { @@ -202,11 +202,11 @@ export const DocumentProvider: React.FC = ({ children }) }, [state.loadingKbIds, state.documentsMap, t]); // Upload documents to a knowledge base - const uploadDocuments = useCallback(async (kbId: string, files: File[]) => { + const uploadDocuments = useCallback(async (kbId: string, files: File[], modelId?: number) => { dispatch({ type: DOCUMENT_ACTION_TYPES.SET_UPLOADING, payload: true }); try { - await knowledgeBaseService.uploadDocuments(kbId, files); + await knowledgeBaseService.uploadDocuments(kbId, files, undefined, modelId); // Set loading state before fetching latest documents dispatch({ type: DOCUMENT_ACTION_TYPES.SET_LOADING_DOCUMENTS, payload: true }); @@ -265,4 +265,4 @@ export const DocumentProvider: React.FC = ({ children }) {children} ); -}; \ No newline at end of file +}; diff --git a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx index 5985c4b08..947cac8aa 100644 --- a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx +++ b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx @@ -110,7 +110,8 @@ export const KnowledgeBaseContext = createContext<{ source?: string, ingroup_permission?: string, group_ids?: number[], - embeddingModel?: string + embeddingModel?: string, + is_multimodal?: boolean ) => Promise; deleteKnowledgeBase: (id: string) => Promise; selectKnowledgeBase: (id: string) => void; @@ -125,6 +126,7 @@ export const KnowledgeBaseContext = createContext<{ selectedIds: [], activeKnowledgeBase: null, currentEmbeddingModel: null, + currentMultiEmbeddingModel: null, isLoading: false, syncLoading: false, error: null, @@ -159,6 +161,7 @@ export const KnowledgeBaseProvider: React.FC = ({ selectedIds: [], activeKnowledgeBase: null, currentEmbeddingModel: null, + currentMultiEmbeddingModel: null, isLoading: false, syncLoading: false, error: null, @@ -168,11 +171,6 @@ export const KnowledgeBaseProvider: React.FC = ({ // Check if knowledge base is selectable - memoized with useCallback const isKnowledgeBaseSelectable = useCallback( (kb: KnowledgeBase): boolean => { - // If no current embedding model is set, not selectable - if (!state.currentEmbeddingModel) { - return false; - } - // Check if knowledge base has content (documents or chunks) const hasContent = (kb.documentCount || 0) > 0 || (kb.chunkCount || 0) > 0; @@ -187,22 +185,46 @@ export const KnowledgeBaseProvider: React.FC = ({ return true; } - // For local knowledge bases, only selectable when model exactly matches current model - return ( - kb.embeddingModel === "unknown" || - kb.embeddingModel === state.currentEmbeddingModel - ); + if (kb.embeddingModel === "unknown") { + return true; + } + + const currentEmbeddingModel = state.currentEmbeddingModel?.trim() || ""; + const currentMultiEmbeddingModel = + modelConfig?.multiEmbedding?.modelName?.trim() || ""; + + if (kb.is_multimodal) { + // Multimodal KB is selectable as long as current multimodal model is configured. + return !!currentMultiEmbeddingModel; + } + + // Text KB is selectable as long as current embedding model is configured. + return !!currentEmbeddingModel; }, - [state.currentEmbeddingModel] + [modelConfig?.multiEmbedding?.modelName, state.currentEmbeddingModel] ); // Check if knowledge base has model mismatch (for display purposes) - // Note: Always return false to remove model mismatch restrictions const hasKnowledgeBaseModelMismatch = useCallback( (kb: KnowledgeBase): boolean => { - return false; + if (kb.embeddingModel === "unknown") { + return false; + } + if (kb.source === "datamate") { + return false; + } + + if (kb.is_multimodal) { + const multiEmbeddingModel = + modelConfig?.multiEmbedding?.modelName?.trim() || ""; + // Only show warning when the required current model is not configured. + return !multiEmbeddingModel; + } + + // Only show warning when the required current model is not configured. + return !state.currentEmbeddingModel; }, - [] + [modelConfig?.multiEmbedding?.modelName, state.currentEmbeddingModel] ); // Load knowledge base data (supports force fetch from server and load selected status) - optimized with useCallback @@ -311,17 +333,31 @@ export const KnowledgeBaseProvider: React.FC = ({ source: string = "elasticsearch", ingroup_permission?: string, group_ids?: number[], - embeddingModel?: string + embeddingModel?: string, + is_multimodal?: boolean ) => { try { + const selectedEmbeddingModel = embeddingModel?.trim() || ""; + const defaultMultiEmbeddingModel = + modelConfig?.multiEmbedding?.modelName?.trim() || ""; + const resolvedIsMultimodal = + typeof is_multimodal === "boolean" + ? is_multimodal + : !!defaultMultiEmbeddingModel && + selectedEmbeddingModel === defaultMultiEmbeddingModel; + const fallbackEmbeddingModel = resolvedIsMultimodal + ? defaultMultiEmbeddingModel + : state.currentEmbeddingModel || ""; + const resolvedEmbeddingModel = + selectedEmbeddingModel || fallbackEmbeddingModel; const newKB = await knowledgeBaseService.createKnowledgeBase({ name, description, source, - // Use provided embeddingModel if available, otherwise fall back to current model or default - embeddingModel: embeddingModel || state.currentEmbeddingModel || "", + embeddingModel: resolvedEmbeddingModel, ingroup_permission, group_ids, + is_multimodal: resolvedIsMultimodal, }); return newKB; } catch (error) { @@ -333,7 +369,7 @@ export const KnowledgeBaseProvider: React.FC = ({ return null; } }, - [state.currentEmbeddingModel, t] + [modelConfig?.multiEmbedding?.modelName, state.currentEmbeddingModel, t] ); // Delete knowledge base - memoized with useCallback @@ -609,6 +645,7 @@ export const KnowledgeBaseProvider: React.FC = ({ selectKnowledgeBase, setActiveKnowledgeBase, isKnowledgeBaseSelectable, + hasKnowledgeBaseModelMismatch, refreshKnowledgeBaseData, refreshKnowledgeBaseDataWithDataMate, ] diff --git a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx index 7cbf5192e..b0d86d2b8 100644 --- a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx @@ -1,4 +1,4 @@ -import { useMemo, useState, useCallback, useEffect } from "react"; +import { useMemo, useState, useCallback, useEffect } from "react"; import { useTranslation } from "react-i18next"; import { Modal, Select, Input, Button, Switch, Tooltip, App } from "antd"; diff --git a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx index 5e498e8de..7b8479385 100644 --- a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx @@ -1,4 +1,4 @@ -import { useState, useEffect } from 'react' +import { useState, useEffect } from 'react' import { useTranslation } from 'react-i18next' import { Modal, Input, Button, App } from "antd"; @@ -480,4 +480,4 @@ export const ProviderConfigEditDialog = ({ ) -} \ No newline at end of file +} diff --git a/frontend/app/[locale]/models/components/modelConfig.tsx b/frontend/app/[locale]/models/components/modelConfig.tsx index e20e74876..5e91f71f1 100644 --- a/frontend/app/[locale]/models/components/modelConfig.tsx +++ b/frontend/app/[locale]/models/components/modelConfig.tsx @@ -527,6 +527,7 @@ export const ModelConfigSection = forwardRef< try { const isConnected = await modelService.verifyCustomModel( modelName, + modelType, signal ); @@ -603,7 +604,7 @@ export const ModelConfigSection = forwardRef< throttleTimerRef.current = setTimeout(async () => { try { // Use modelService to verify model - const isConnected = await modelService.verifyCustomModel(displayName); + const isConnected = await modelService.verifyCustomModel(displayName, modelType); // Update model status updateModelStatus( diff --git a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx index 42ca403e2..260e83b3b 100644 --- a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx +++ b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx @@ -1,4 +1,4 @@ -"use client"; +"use client"; import React, { useState } from "react"; import { useTranslation } from "react-i18next"; @@ -92,7 +92,7 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { }; // Handle checking model connectivity - const handleCheckConnectivity = async (displayName: string) => { + const handleCheckConnectivity = async (displayName: string, modelType: string) => { if (!tenantId) { message.error(t("tenantResources.tenants.tenantIdRequired")); return; @@ -100,7 +100,7 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { setCheckingConnectivity((prev) => new Set(prev).add(displayName)); try { - const isConnected = await modelService.verifyCustomModel(displayName); + const isConnected = await modelService.verifyCustomModel(displayName, modelType); if (isConnected) { message.success(t("tenantResources.models.connectivitySuccess")); } else { @@ -194,7 +194,7 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) {