diff --git a/.github/workflows/docker-deploy.yml b/.github/workflows/docker-deploy.yml
index 9d04c8913..a77c2491f 100644
--- a/.github/workflows/docker-deploy.yml
+++ b/.github/workflows/docker-deploy.yml
@@ -38,7 +38,10 @@ jobs:
- name: Check if model is cached locally
id: check-model
run: |
- if [ -f ~/model-assets/clip-vit-base-patch32/config.json ] && [ -d ~/model-assets/nltk_data ]; then
+ if [ -f ~/model-assets/clip-vit-base-patch32/config.json ] && \
+ [ -d ~/model-assets/nltk_data ] && \
+ [ -d ~/model-assets/table-transformer-structure-recognition ] && \
+ [ -d ~/model-assets/yolox ]; then
echo "cache-hit=true" >> "$GITHUB_OUTPUT"
cp -r ~/model-assets ./
else
@@ -105,4 +108,4 @@ jobs:
./deploy.sh --mode 3 --is-mainland N --enable-terminal N --version 2 --root-dir "$HOME/nexent-production-data"
else
./deploy.sh --mode 1 --is-mainland N --enable-terminal N --version 2 --root-dir "$HOME/nexent-development-data"
- fi
\ No newline at end of file
+ fi
diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py
index 5a11b550b..88d71216a 100644
--- a/backend/agents/create_agent_info.py
+++ b/backend/agents/create_agent_info.py
@@ -1,4 +1,4 @@
-import threading
+import threading
import logging
from typing import List, Optional
from urllib.parse import urljoin
@@ -469,6 +469,7 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int
rerank = param_dict.get("rerank", False)
rerank_model_name = param_dict.get("rerank_model_name", "")
rerank_model = None
+ is_multimodal = bool(tool_config.params.pop("multimodal", False))
if rerank and rerank_model_name:
rerank_model = get_rerank_model(
tenant_id=tenant_id, model_name=rerank_model_name
diff --git a/backend/apps/file_management_app.py b/backend/apps/file_management_app.py
index 578277b6d..677961442 100644
--- a/backend/apps/file_management_app.py
+++ b/backend/apps/file_management_app.py
@@ -126,12 +126,13 @@ async def upload_files(
@file_management_config_router.post("/process")
async def process_files(
- files: List[dict] = Body(
- ..., description="List of file details to process, including path_or_url and filename"),
- chunking_strategy: Optional[str] = Body("basic"),
- index_name: str = Body(...),
- destination: str = Body(...),
- authorization: Optional[str] = Header(None)
+ files: Annotated[List[dict], Body(
+ ..., description="List of file details to process, including path_or_url and filename")],
+ index_name: Annotated[str, Body(...)],
+ destination: Annotated[str, Body(...)],
+ chunking_strategy: Annotated[Optional[str], Body(...)] = "basic",
+ model_id: Annotated[Optional[int], Body(...)] = None,
+ authorization: Annotated[Optional[str], Header()] = None
):
"""
Trigger data processing for a list of uploaded files.
@@ -144,7 +145,8 @@ async def process_files(
chunking_strategy=chunking_strategy,
source_type=destination,
index_name=index_name,
- authorization=authorization
+ authorization=authorization,
+ model_id=model_id
)
process_result = await trigger_data_process(files, process_params)
diff --git a/backend/apps/model_managment_app.py b/backend/apps/model_managment_app.py
index 278b729e8..7029477e6 100644
--- a/backend/apps/model_managment_app.py
+++ b/backend/apps/model_managment_app.py
@@ -33,7 +33,7 @@
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from http import HTTPStatus
-from typing import List, Optional
+from typing import Annotated, List, Optional
from services.model_health_service import (
check_model_connectivity,
verify_model_config_connectivity,
@@ -297,7 +297,8 @@ async def get_llm_model_list(authorization: Optional[str] = Header(None)):
@router.post("/healthcheck")
async def check_model_health(
- display_name: str = Query(..., description="Display name to check"),
+ display_name: Annotated[str, Query(..., description="Display name to check")],
+ model_type: Annotated[str, Query(..., description="...")],
authorization: Optional[str] = Header(None)
):
"""Check and update model connectivity, returning the latest status.
@@ -308,7 +309,7 @@ async def check_model_health(
"""
try:
_, tenant_id = get_current_user_id(authorization)
- result = await check_model_connectivity(display_name, tenant_id)
+ result = await check_model_connectivity(display_name, tenant_id, model_type)
return JSONResponse(status_code=HTTPStatus.OK, content={
"message": "Successfully checked model connectivity",
"data": result
diff --git a/backend/apps/vectordatabase_app.py b/backend/apps/vectordatabase_app.py
index 6f4232afd..d28f59822 100644
--- a/backend/apps/vectordatabase_app.py
+++ b/backend/apps/vectordatabase_app.py
@@ -82,11 +82,13 @@ def create_new_index(
# Extract optional fields from request body
ingroup_permission = None
group_ids = None
- embedding_model_name = None
+ embedding_model_name: Optional[str] = None
+ is_multimodal: Optional[bool] = None
if request:
ingroup_permission = request.get("ingroup_permission")
group_ids = request.get("group_ids")
- embedding_model_name = request.get("embedding_model_name")
+ embedding_model_name = request.get("embeddingModel")
+ is_multimodal = request.get("is_multimodal")
# Treat path parameter as user-facing knowledge base name for new creations
return ElasticSearchService.create_knowledge_base(
@@ -98,6 +100,7 @@ def create_new_index(
ingroup_permission=ingroup_permission,
group_ids=group_ids,
embedding_model_name=embedding_model_name,
+ is_multimodal=is_multimodal,
)
except Exception as e:
raise HTTPException(
@@ -664,6 +667,7 @@ def update_chunk(
chunk_request=payload,
vdb_core=vdb_core,
user_id=user_id,
+ tenant_id=tenant_id,
)
return JSONResponse(status_code=HTTPStatus.OK, content=result)
except ValueError as e:
@@ -730,8 +734,17 @@ async def hybrid_search(
"""Run a hybrid (accurate + semantic) search across indices."""
try:
_, tenant_id = get_current_user_id(authorization)
+ resolved_index_names: List[str] = []
+ for requested_name in payload.index_names:
+ try:
+ resolved_name = get_index_name_by_knowledge_name(
+ requested_name, tenant_id
+ )
+ except Exception:
+ resolved_name = requested_name
+ resolved_index_names.append(resolved_name)
result = ElasticSearchService.search_hybrid(
- index_names=payload.index_names,
+ index_names=resolved_index_names,
query=payload.query,
tenant_id=tenant_id,
top_k=payload.top_k,
diff --git a/backend/consts/const.py b/backend/consts/const.py
index 680bc78db..9f84208da 100644
--- a/backend/consts/const.py
+++ b/backend/consts/const.py
@@ -31,6 +31,10 @@ class VectorDatabaseType(str, Enum):
# Data Processing Service Configuration
DATA_PROCESS_SERVICE = os.getenv("DATA_PROCESS_SERVICE")
CLIP_MODEL_PATH = os.getenv("CLIP_MODEL_PATH")
+TABLE_TRANSFORMER_MODEL_PATH = os.getenv("TABLE_TRANSFORMER_MODEL_PATH")
+UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH = os.getenv(
+ "UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH"
+)
# Upload Configuration
@@ -129,6 +133,7 @@ class VectorDatabaseType(str, Enum):
MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY")
MINIO_REGION = os.getenv("MINIO_REGION")
MINIO_DEFAULT_BUCKET = os.getenv("MINIO_DEFAULT_BUCKET")
+S3_URL_PREFIX = "s3://"
# Postgres Configuration
diff --git a/backend/consts/model.py b/backend/consts/model.py
index 2f1d7aae3..2c5117ee3 100644
--- a/backend/consts/model.py
+++ b/backend/consts/model.py
@@ -300,6 +300,7 @@ class ProcessParams(BaseModel):
source_type: str
index_name: str
authorization: Optional[str] = None
+ model_id: Optional[int] = None
class OpinionRequest(BaseModel):
diff --git a/backend/data_process/ray_actors.py b/backend/data_process/ray_actors.py
index 0dea828ce..c3879c007 100644
--- a/backend/data_process/ray_actors.py
+++ b/backend/data_process/ray_actors.py
@@ -1,3 +1,4 @@
+from io import BytesIO
import logging
import json
import time
@@ -5,8 +6,15 @@
import ray
-from consts.const import RAY_ACTOR_NUM_CPUS, REDIS_BACKEND_URL, DEFAULT_EXPECTED_CHUNK_SIZE, DEFAULT_MAXIMUM_CHUNK_SIZE
-from database.attachment_db import get_file_stream
+from consts.const import (
+ RAY_ACTOR_NUM_CPUS,
+ REDIS_BACKEND_URL,
+ DEFAULT_EXPECTED_CHUNK_SIZE,
+ DEFAULT_MAXIMUM_CHUNK_SIZE,
+ TABLE_TRANSFORMER_MODEL_PATH,
+ UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH,
+)
+from database.attachment_db import build_s3_url, get_file_stream, upload_fileobj
from database.model_management_db import get_model_by_model_id
from nexent.data_process import DataProcessCore
@@ -43,35 +51,16 @@ def _prepare_process_params(
Normalize task/model-related processing params.
"""
process_params = dict(params)
+ self._apply_model_paths(process_params)
if task_id:
process_params["task_id"] = task_id
- if not (model_id and tenant_id):
- return process_params
-
- try:
- model_record = get_model_by_model_id(
- model_id=model_id, tenant_id=tenant_id)
- if not model_record:
- logger.warning(
- f"[RayActor] Embedding model with ID {model_id} not found for tenant '{tenant_id}', using default chunk sizes")
- return process_params
-
- expected_chunk_size = model_record.get(
- "expected_chunk_size", DEFAULT_EXPECTED_CHUNK_SIZE)
- maximum_chunk_size = model_record.get(
- "maximum_chunk_size", DEFAULT_MAXIMUM_CHUNK_SIZE)
- model_name = model_record.get("display_name")
-
- process_params["max_characters"] = maximum_chunk_size
- process_params["new_after_n_chars"] = expected_chunk_size
-
- logger.info(
- f"[RayActor] Using chunk sizes from embedding model '{model_name}' (ID: {model_id}): "
- f"max_characters={maximum_chunk_size}, new_after_n_chars={expected_chunk_size}")
- except Exception as e:
- logger.warning(
- f"[RayActor] Failed to retrieve chunk sizes from embedding model ID {model_id}: {e}. Using default chunk sizes")
+ # Reuse shared model param logic so we also keep extra fields
+ self._apply_model_chunk_sizes(
+ model_id=model_id,
+ tenant_id=tenant_id,
+ params=process_params,
+ )
return process_params
def _run_file_process(
@@ -82,24 +71,19 @@ def _run_file_process(
process_params: Dict[str, Any],
log_subject: str,
) -> List[Dict[str, Any]]:
- chunks = self._processor.file_process(
+ result = self._processor.file_process(
file_data=file_data,
filename=filename,
chunking_strategy=chunking_strategy,
**process_params
)
-
- if chunks is None:
- logger.warning(
- f"[RayActor] file_process returned None for {log_subject}='{filename}'")
- return []
- if not isinstance(chunks, list):
- logger.error(
- f"[RayActor] file_process returned non-list type {type(chunks)} for {log_subject}='{filename}'")
- return []
- if len(chunks) == 0:
- logger.warning(
- f"[RayActor] file_process returned empty list for {log_subject}='{filename}'")
+
+ chunks, images_info = self._normalize_processor_result(result)
+ if images_info:
+ self._append_image_chunks(
+ source=filename, chunks=chunks, images_info=images_info)
+ chunks = self._validate_chunks(chunks, filename)
+ if not chunks:
return []
logger.info(
@@ -161,8 +145,129 @@ def process_file(
chunking_strategy=chunking_strategy,
process_params=process_params,
log_subject="source",
- )
+ )
+
+ def _apply_model_paths(self, params: Dict[str, Any]) -> None:
+ params["table_transformer_model_path"] = TABLE_TRANSFORMER_MODEL_PATH
+ params[
+ "unstructured_default_model_initialize_params_json_path"
+ ] = UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH
+
+ def _apply_model_chunk_sizes(
+ self,
+ model_id: Optional[int],
+ tenant_id: Optional[str],
+ params: Dict[str, Any],
+ ) -> None:
+ if not (model_id and tenant_id):
+ return
+
+ try:
+ model_record = get_model_by_model_id(
+ model_id=model_id, tenant_id=tenant_id)
+ if not model_record:
+ logger.warning(
+ f"[RayActor] Embedding model with ID {model_id} not found for tenant '{tenant_id}', using default chunk sizes")
+ return
+
+ expected_chunk_size = model_record.get(
+ 'expected_chunk_size', DEFAULT_EXPECTED_CHUNK_SIZE)
+ maximum_chunk_size = model_record.get(
+ 'maximum_chunk_size', DEFAULT_MAXIMUM_CHUNK_SIZE)
+ model_name = model_record.get('display_name')
+ model_type = model_record.get('model_type')
+
+ params['max_characters'] = maximum_chunk_size
+ params['new_after_n_chars'] = expected_chunk_size
+ if model_type:
+ params['model_type'] = model_type
+
+ logger.info(
+ f"[RayActor] Using chunk sizes from embedding model '{model_name}' (ID: {model_id}): "
+ f"max_characters={maximum_chunk_size}, new_after_n_chars={expected_chunk_size}")
+ except Exception as e:
+ logger.warning(
+ f"[RayActor] Failed to retrieve chunk sizes from embedding model ID {model_id}: {e}. Using default chunk sizes")
+
+ def _read_file_bytes(self, source: str) -> bytes:
+ try:
+ file_stream = get_file_stream(source)
+ if file_stream is None:
+ raise FileNotFoundError(
+ f"Unable to fetch file from URL: {source}")
+ return file_stream.read()
+ except Exception as e:
+ logger.error(f"Failed to fetch file from {source}: {e}")
+ raise
+ def _normalize_processor_result(
+ self, result: Any
+ ) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ if isinstance(result, tuple) and len(result) == 2:
+ chunks, images_info = result
+ return chunks or [], images_info or []
+ return result or [], []
+
+ def _append_image_chunks(
+ self,
+ source: str,
+ chunks: List[Dict[str, Any]],
+ images_info: List[Dict[str, Any]],
+ ) -> None:
+ folder = "images_in_attachments"
+ for index, image_data in enumerate(images_info):
+ if not isinstance(image_data, dict):
+ logger.warning(
+ f"[RayActor] Skipping image entry at index {index}: unexpected type {type(image_data)}"
+ )
+ continue
+ if "image_bytes" not in image_data:
+ logger.warning(
+ f"[RayActor] Skipping image entry at index {index}: missing image_bytes"
+ )
+ continue
+
+ img_obj = BytesIO(image_data["image_bytes"])
+ result = upload_fileobj(
+ file_obj=img_obj,
+ file_name=f"{index}.{image_data['image_format']}",
+ prefix=folder)
+ image_url = build_s3_url(result.get("object_name", ""))
+
+ image_data["source_file"] = source
+ image_data["image_url"] = image_url
+
+ chunks.append({
+ "content": json.dumps({
+ "source_file": source,
+ "position": image_data["position"],
+ "image_url": image_url,
+ }),
+ "filename": source,
+ "metadata": {
+ "chunk_index": len(chunks) + index,
+ "process_source": "UniversalImageExtractor",
+ "image_url": image_url,
+ }
+ })
+
+ def _validate_chunks(
+ self, chunks: Any, source: str
+ ) -> List[Dict[str, Any]]:
+ if chunks is None:
+ logger.warning(
+ f"[RayActor] file_process returned None for source='{source}'")
+ return []
+ if not isinstance(chunks, list):
+ logger.error(
+ f"[RayActor] file_process returned non-list type {type(chunks)} for source='{source}'")
+ return []
+ if len(chunks) == 0:
+ logger.warning(
+ f"[RayActor] file_process returned empty list for source='{source}'")
+ return []
+ return chunks
+
def process_bytes(
self,
file_bytes: bytes,
diff --git a/backend/data_process/tasks.py b/backend/data_process/tasks.py
index 71d83b090..f2a30f9b7 100644
--- a/backend/data_process/tasks.py
+++ b/backend/data_process/tasks.py
@@ -379,11 +379,11 @@ def _extract_error_code_from_es_response(
def _send_chunks_to_es(
chunks: List[Dict[str, Any]],
index_name: str,
- authorization: Optional[str],
- task_id: Optional[str],
- source: Optional[str],
- original_filename: Optional[str],
- large_mode: Optional[bool] = None,
+ authorization: str | None,
+ task_id: Optional[str] = None,
+ source: str = "",
+ original_filename: str = "",
+ large_mode: bool = False,
) -> Dict[str, Any]:
async def _post():
elasticsearch_url = ELASTICSEARCH_SERVICE
@@ -405,12 +405,10 @@ async def _post():
connector = aiohttp.TCPConnector(verify_ssl=False)
timeout = aiohttp.ClientTimeout(total=600)
- large_mode_value = "true" if large_mode else "false" if large_mode is not None else None
- request_params = (
- {
- "large_mode": large_mode_value
- }
- )
+ request_params: Dict[str, str] = {}
+
+ if large_mode:
+ request_params["large_mode"] = "true"
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
async with session.post(
@@ -761,7 +759,7 @@ def forward_part(
original_filename: Optional[str] = None,
batch_index: Optional[int] = None,
total_batches: Optional[int] = None,
- large_mode: Optional[bool] = None,
+ large_mode: Optional[bool] = False,
) -> Dict[str, Any]:
"""
Forward sub-task that indexes a chunk batch.
@@ -1645,7 +1643,7 @@ def forward(
task_id=task_id,
source=original_source,
original_filename=original_filename,
- large_mode=None,
+ large_mode=False,
)
else:
batches = _build_balanced_batches(
diff --git a/backend/database/attachment_db.py b/backend/database/attachment_db.py
index 187381cd2..06b84e5ac 100644
--- a/backend/database/attachment_db.py
+++ b/backend/database/attachment_db.py
@@ -2,13 +2,66 @@
import os
import uuid
from datetime import datetime
-from typing import Any, BinaryIO, Dict, List, Optional
+from typing import Any, BinaryIO, Dict, List, Optional, Tuple
from .client import minio_client
+from consts.const import S3_URL_PREFIX
from consts.const import NORTHBOUND_EXTERNAL_URL
from urllib.parse import quote
+def _normalize_object_and_bucket(object_name: str, bucket: Optional[str] = None) -> Tuple[str, Optional[str]]:
+ """
+ Normalize object_name + bucket from supported URL styles.
+
+ Supports:
+ - s3://bucket/key
+ - /bucket/key
+ - key (uses provided bucket or default bucket)
+ """
+ if not object_name:
+ return object_name, bucket
+
+ if object_name.startswith(S3_URL_PREFIX):
+ s3_path = object_name[len(S3_URL_PREFIX) :]
+ parts = s3_path.split("/", 1)
+ parsed_bucket = parts[0] if parts[0] else None
+ parsed_key = parts[1] if len(parts) > 1 else ""
+ return parsed_key, parsed_bucket or bucket
+
+ if object_name.startswith("/"):
+ path = object_name.lstrip("/")
+ parts = path.split("/", 1)
+ parsed_bucket = parts[0] if parts[0] else None
+ parsed_key = parts[1] if len(parts) > 1 else ""
+ return parsed_key, parsed_bucket or bucket
+
+ return object_name, bucket
+
+
+def build_s3_url(object_name: str, bucket: Optional[str] = None) -> str:
+ """
+ Build an s3://bucket/key style URL from an object name (or passthrough if already s3://).
+ """
+ if not object_name:
+ return ""
+
+ if object_name.startswith(S3_URL_PREFIX):
+ return object_name
+
+ if object_name.startswith("/"):
+ path = object_name.lstrip("/")
+ parts = path.split("/", 1)
+ if len(parts) == 2:
+ return f"{S3_URL_PREFIX}{parts[0]}/{parts[1]}"
+ return f"{S3_URL_PREFIX}{parts[0]}/"
+
+ resolved_bucket = bucket or minio_client.default_bucket
+ if resolved_bucket:
+ return f"{S3_URL_PREFIX}{resolved_bucket}/{object_name}"
+ return f"{S3_URL_PREFIX}{object_name}"
+
+
def _build_mcp_presigned_url(presigned_url: str) -> str:
"""
Build northbound API proxy URL for MCP tools.
@@ -217,6 +270,7 @@ def get_file_size_from_minio(object_name: str, bucket: Optional[str] = None) ->
"""
Get file size by object name
"""
+ object_name, bucket = _normalize_object_and_bucket(object_name, bucket)
# Ensure minio_client is initialized before accessing storage_config
minio_client._ensure_initialized()
bucket = bucket or minio_client.storage_config.default_bucket
@@ -235,6 +289,7 @@ def file_exists(object_name: str, bucket: Optional[str] = None) -> bool:
bool: True if file exists, False otherwise
"""
try:
+ object_name, bucket = _normalize_object_and_bucket(object_name, bucket)
return minio_client.file_exists(object_name, bucket)
except Exception:
return False
@@ -252,6 +307,8 @@ def copy_file(source_object: str, dest_object: str, bucket: Optional[str] = None
Returns:
Dict[str, Any]: Result containing success flag and error message (if any)
"""
+ source_object, bucket = _normalize_object_and_bucket(source_object, bucket)
+ dest_object, bucket = _normalize_object_and_bucket(dest_object, bucket)
success, result = minio_client.copy_file(source_object, dest_object, bucket)
if success:
return {"success": True, "object_name": result}
@@ -296,6 +353,7 @@ def delete_file(object_name: str, bucket: Optional[str] = None) -> Dict[str, Any
Returns:
Dict[str, Any]: Delete result, containing success flag and error message (if any)
"""
+ object_name, bucket = _normalize_object_and_bucket(object_name, bucket)
if not bucket:
minio_client._ensure_initialized()
bucket = minio_client.storage_config.default_bucket
@@ -320,6 +378,7 @@ def get_file_stream(object_name: str, bucket: Optional[str] = None) -> Optional[
Returns:
Optional[BinaryIO]: Standard BinaryIO stream object, or None if failed
"""
+ object_name, bucket = _normalize_object_and_bucket(object_name, bucket)
success, result = minio_client.get_file_stream(object_name, bucket)
if not success:
return None
diff --git a/backend/database/client.py b/backend/database/client.py
index 05f8940b9..e095c5636 100644
--- a/backend/database/client.py
+++ b/backend/database/client.py
@@ -89,6 +89,9 @@ def __init__(self):
if MinioClient._initialized:
return
MinioClient._initialized = True
+ # Explicitly initialize attributes so external callers never hit missing-attribute errors.
+ self._storage_client = None
+ self.storage_config = None
def _ensure_initialized(self):
"""Lazily initialize the storage client on first use."""
@@ -108,6 +111,23 @@ def _ensure_initialized(self):
return True
return False
+ @property
+ def default_bucket(self) -> Optional[str]:
+ """
+ Resolve default bucket safely for callers that need bucket info.
+ Falls back to configured constant when lazy init has not run yet.
+ """
+ try:
+ self._ensure_initialized()
+ except Exception:
+ # Keep this accessor resilient; operational methods can still raise
+ # detailed storage errors when invoked.
+ pass
+
+ if getattr(self, "storage_config", None) is not None:
+ return self.storage_config.default_bucket
+ return MINIO_DEFAULT_BUCKET
+
def upload_file(
self,
file_path: str,
diff --git a/backend/database/knowledge_db.py b/backend/database/knowledge_db.py
index 8674bb4fb..9a8b1c8c1 100644
--- a/backend/database/knowledge_db.py
+++ b/backend/database/knowledge_db.py
@@ -183,7 +183,7 @@ def update_knowledge_record(query: Dict[str, Any]) -> bool:
# Update group IDs
if query.get("group_ids") is not None:
record.group_ids = query["group_ids"]
-
+
# Update timestamp and user
if query.get("user_id"):
record.updated_by = query["user_id"]
@@ -259,7 +259,7 @@ def get_knowledge_record(query: Optional[Dict[str, Any]] = None) -> Dict[str, An
if 'tenant_id' in query and query['tenant_id'] is not None:
db_query = db_query.filter(
KnowledgeRecord.tenant_id == query['tenant_id'])
-
+
result = db_query.first()
if result:
@@ -404,14 +404,25 @@ def get_index_name_by_knowledge_name(knowledge_name: str, tenant_id: str) -> str
"""
try:
with get_db_session() as session:
+ # First try resolving by user-facing knowledge_name.
result = session.query(KnowledgeRecord).filter(
KnowledgeRecord.knowledge_name == knowledge_name,
KnowledgeRecord.tenant_id == tenant_id,
KnowledgeRecord.delete_flag != 'Y'
).first()
-
if result:
return result.index_name
+
+ # Backward/forward compatibility: if caller already passes internal index_name,
+ # accept it directly by resolving on index_name as well.
+ index_result = session.query(KnowledgeRecord).filter(
+ KnowledgeRecord.index_name == knowledge_name,
+ KnowledgeRecord.tenant_id == tenant_id,
+ KnowledgeRecord.delete_flag != 'Y'
+ ).first()
+ if index_result:
+ return index_result.index_name
+
raise ValueError(
f"Knowledge base '{knowledge_name}' not found for the current tenant"
)
diff --git a/backend/database/model_management_db.py b/backend/database/model_management_db.py
index cb1c6c69f..61753f52f 100644
--- a/backend/database/model_management_db.py
+++ b/backend/database/model_management_db.py
@@ -170,7 +170,7 @@ def get_model_records(filters: Optional[Dict[str, Any]], tenant_id: str) -> List
return result_list
-def get_model_by_display_name(display_name: str, tenant_id: str) -> Optional[Dict[str, Any]]:
+def get_model_by_display_name(display_name: str, tenant_id: str, model_type: str = None) -> Optional[Dict[str, Any]]:
"""
Get a model record by display name
@@ -179,6 +179,11 @@ def get_model_by_display_name(display_name: str, tenant_id: str) -> Optional[Dic
tenant_id:
"""
filters = {'display_name': display_name}
+
+ if model_type in ["multiEmbedding", "multi_embedding"]:
+ filters['model_type'] = "multi_embedding"
+ elif model_type == "embedding":
+ filters['model_type'] = "embedding"
records = get_model_records(filters, tenant_id)
if not records:
@@ -203,7 +208,7 @@ def get_models_by_display_name(display_name: str, tenant_id: str) -> List[Dict[s
return get_model_records(filters, tenant_id)
-def get_model_id_by_display_name(display_name: str, tenant_id: str) -> Optional[int]:
+def get_model_id_by_display_name(display_name: str, tenant_id: str, model_type: str = None) -> Optional[int]:
"""
Get a model ID by display name
@@ -214,7 +219,7 @@ def get_model_id_by_display_name(display_name: str, tenant_id: str) -> Optional[
Returns:
Optional[int]: Model ID
"""
- model = get_model_by_display_name(display_name, tenant_id)
+ model = get_model_by_display_name(display_name, tenant_id, model_type)
return model["model_id"] if model else None
diff --git a/backend/services/config_sync_service.py b/backend/services/config_sync_service.py
index 0ed29bfc5..81bc9078b 100644
--- a/backend/services/config_sync_service.py
+++ b/backend/services/config_sync_service.py
@@ -99,7 +99,7 @@ async def save_config_impl(config, tenant_id, user_id):
config_key = get_env_key(model_type) + "_ID"
model_id = get_model_id_by_display_name(
- model_display_name, tenant_id)
+ model_display_name, tenant_id, model_type=model_type)
handle_model_config(tenant_id, user_id, config_key,
model_id, tenant_config_dict)
diff --git a/backend/services/data_process_service.py b/backend/services/data_process_service.py
index a024089a3..ce0e3f993 100644
--- a/backend/services/data_process_service.py
+++ b/backend/services/data_process_service.py
@@ -296,6 +296,17 @@ async def load_image(self, image_url: str) -> Optional[Image.Image]:
async def _load_image(self, session: aiohttp.ClientSession, path: str) -> Optional[Image.Image]:
"""Internal method to load an image from various sources"""
try:
+ if path.startswith('s3://'):
+ # Fetch from MinIO using s3://bucket/key
+ file_stream = get_file_stream(object_name=path)
+ if file_stream is None:
+ raise FileNotFoundError(
+ f"Unable to fetch file from URL: {path}")
+ file_data = file_stream.read()
+ image_based64_str = base64.b64encode(
+ file_data).decode('utf-8')
+ path = f"data:image/jpeg;base64,{image_based64_str}"
+
# Check if input is base64 encoded
if path.startswith('data:image'):
# Extract the base64 data after the comma
@@ -504,6 +515,8 @@ async def create_batch_tasks_impl(self, authorization: Optional[str], request: B
chunking_strategy = source_config.get('chunking_strategy')
index_name = source_config.get('index_name')
original_filename = source_config.get('original_filename')
+ embedding_model_id = source_config.get('embedding_model_id')
+ tenant_id = source_config.get('tenant_id')
# Validate required fields
if not source:
@@ -522,7 +535,9 @@ async def create_batch_tasks_impl(self, authorization: Optional[str], request: B
source_type=source_type,
chunking_strategy=chunking_strategy,
index_name=index_name,
- original_filename=original_filename
+ original_filename=original_filename,
+ embedding_model_id=embedding_model_id,
+ tenant_id=tenant_id
).set(queue='process_q'),
forward.s(
index_name=index_name,
@@ -600,7 +615,7 @@ async def process_uploaded_text_file(self, file_content: bytes, filename: str, c
}
async def convert_office_to_pdf_impl(self, object_name: str, pdf_object_name: str) -> None:
- """Full conversion pipeline: download → convert → upload → validate → cleanup.
+ """Full conversion pipeline: download -> convert -> upload -> validate -> cleanup.
All five steps run inside data-process so that LibreOffice only needs to be
installed in this container.
diff --git a/backend/services/datamate_service.py b/backend/services/datamate_service.py
index 776e0eb1d..41858440b 100644
--- a/backend/services/datamate_service.py
+++ b/backend/services/datamate_service.py
@@ -51,7 +51,7 @@ async def _create_datamate_knowledge_records(knowledge_base_ids: List[str],
"tenant_id": tenant_id,
"user_id": user_id,
# Use datamate as embedding model name
- "embedding_model_name": embedding_model_names[i]
+ "embedding_model_name": embedding_model_names[i],
}
# Run synchronous database operation in executor to avoid blocking
diff --git a/backend/services/file_management_service.py b/backend/services/file_management_service.py
index b5cd048bf..d86c310f4 100644
--- a/backend/services/file_management_service.py
+++ b/backend/services/file_management_service.py
@@ -58,6 +58,7 @@ def check_file_access(object_name: str, user_id: Optional[str]) -> bool:
Access rules:
- knowledge_base/*: All authenticated users can access
- attachments/{user_id}/*: Only the owner (user_id) can access
+ - images_in_attachments/*: All authenticated users can access
- preview/*: Accessible if the original file is accessible
Args:
@@ -74,6 +75,11 @@ def check_file_access(object_name: str, user_id: Optional[str]) -> bool:
# Knowledge base files: all authenticated users can access
return True
+ if object_name.startswith("images_in_attachments/"):
+ # Extracted image files used by knowledge-base image chunks.
+ # Keep them readable for authenticated users to avoid broken image citations.
+ return True
+
# Check if file is in user's attachments folder
# Pattern: attachments/{user_id}/*
if object_name.startswith(f"attachments/{user_id}/"):
diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py
index a20b2a6ca..5e960d3da 100644
--- a/backend/services/model_health_service.py
+++ b/backend/services/model_health_service.py
@@ -170,10 +170,10 @@ async def _perform_connectivity_check(
return connectivity
-async def check_model_connectivity(display_name: str, tenant_id: str) -> dict:
+async def check_model_connectivity(display_name: str, tenant_id: str, model_type: str = None) -> dict:
try:
# Query the database using display_name and tenant context from app layer
- model = get_model_by_display_name(display_name, tenant_id=tenant_id)
+ model = get_model_by_display_name(display_name, tenant_id=tenant_id, model_type=model_type)
if not model:
raise LookupError(
f"Model configuration not found for {display_name}")
diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py
index 5e5229ff6..9c008c07f 100644
--- a/backend/services/tool_configuration_service.py
+++ b/backend/services/tool_configuration_service.py
@@ -151,6 +151,10 @@ def get_local_tools() -> List[ToolInfo]:
else:
param_info["default"] = param.default.default
param_info["optional"] = True
+ if getattr(param.default, "json_schema_extra", None):
+ optional_override = param.default.json_schema_extra.get("optional")
+ if optional_override is not None:
+ param_info["optional"] = optional_override
init_params_list.append(param_info)
@@ -681,6 +685,8 @@ def _validate_local_tool(
if not tool_class:
raise NotFoundException(f"Tool class not found for {tool_name}")
+ runtime_inputs = dict(inputs or {})
+
# Parse instantiation parameters first
instantiation_params = params or {}
# Get signature and extract default values for all parameters
@@ -704,6 +710,7 @@ def _validate_local_tool(
if tool_name == "knowledge_base_search":
index_names = instantiation_params.get("index_names", [])
+ is_multimodal = instantiation_params.pop("multimodal", False)
# Must have embedding model for knowledge base search
if not index_names or not tenant_id:
@@ -799,7 +806,18 @@ def _validate_local_tool(
else:
tool_instance = tool_class(**instantiation_params)
- result = tool_instance.forward(**(inputs or {}))
+ # Only pass declared runtime inputs to forward() to avoid unexpected kwargs.
+ declared_inputs = getattr(tool_class, "inputs", {}) or {}
+ allowed_input_names = (
+ set(declared_inputs.keys()) if isinstance(declared_inputs, dict) else set()
+ )
+ filtered_runtime_inputs = (
+ {k: v for k, v in runtime_inputs.items() if k in allowed_input_names}
+ if allowed_input_names
+ else runtime_inputs
+ )
+
+ result = tool_instance.forward(**filtered_runtime_inputs)
return result
except Exception as e:
logger.error(f"Local tool validation failed for {tool_name}: {e}")
diff --git a/backend/services/vectordatabase_service.py b/backend/services/vectordatabase_service.py
index 8ad9b54e2..9fc8c9112 100644
--- a/backend/services/vectordatabase_service.py
+++ b/backend/services/vectordatabase_service.py
@@ -28,7 +28,7 @@
from consts.const import DATAMATE_URL, ES_API_KEY, ES_HOST, LANGUAGE, VectorDatabaseType, IS_SPEED_MODE, PERMISSION_EDIT, PERMISSION_READ
from consts.model import ChunkCreateRequest, ChunkUpdateRequest
-from database.attachment_db import delete_file
+from database.attachment_db import delete_file, get_file_stream
from database.knowledge_db import (
create_knowledge_record,
delete_knowledge_record,
@@ -101,6 +101,28 @@ def _get_embedding_model_display_name(model_id: Optional[int], tenant_id: str) -
return ""
+def _is_multimodal_by_model_id(model_id: Optional[int], tenant_id: str) -> bool:
+ """
+ Determine whether an embedding model is multimodal based on model_id.
+
+ Args:
+ model_id: The embedding model ID.
+ tenant_id: Tenant ID for model lookup.
+
+ Returns:
+ True when the model type is `multi_embedding`, otherwise False.
+ """
+ if model_id is None:
+ return False
+ try:
+ model = get_model_by_model_id(model_id, tenant_id)
+ if model:
+ return model.get("model_type") == "multi_embedding"
+ except Exception as e:
+ logger.warning(f"Failed to determine multimodal flag for model_id {model_id}: {e}")
+ return False
+
+
class KnowledgeBaseNeedsModelConfigError(Exception):
"""Exception raised when a knowledge base needs an embedding model to be configured."""
def __init__(self, index_name: str, message: str = None):
@@ -283,8 +305,42 @@ def check_knowledge_base_exist_impl(knowledge_name: str, vdb_core: VectorDatabas
# Case B: Name is available in this tenant
return {"status": "available"}
-
-def get_embedding_model(tenant_id: str, model_name: Optional[str] = None) -> tuple[Optional[Any], Optional[int]]:
+def _normalize_model_type(raw_model_type: Optional[str]) -> Optional[str]:
+ if raw_model_type in ["multiEmbedding", "multi_embedding"]:
+ return "multi_embedding"
+ if raw_model_type == "embedding":
+ return "embedding"
+ return None
+
+def _build_model_config(model: dict) -> dict:
+ return {
+ "model_repo": model.get("model_repo", ""),
+ "model_name": model["model_name"],
+ "api_key": model.get("api_key", ""),
+ "base_url": model.get("base_url", ""),
+ "model_type": model.get("model_type", "embedding"),
+ "max_tokens": model.get("max_tokens", 1024),
+ "ssl_verify": model.get("ssl_verify", True),
+ }
+
+def _create_embedding_model(model: dict) -> Any:
+ model_config = _build_model_config(model)
+ common_kwargs = {
+ "api_key": model_config.get("api_key", ""),
+ "base_url": model_config.get("base_url", ""),
+ "model_name": get_model_name_from_config(model_config) or "",
+ "embedding_dim": model_config.get("max_tokens", 1024),
+ "ssl_verify": model_config.get("ssl_verify", True),
+ }
+ if model.get("model_type", "embedding") == "multi_embedding":
+ return JinaEmbedding(**common_kwargs)
+ return OpenAICompatibleEmbedding(**common_kwargs)
+
+def get_embedding_model(
+ tenant_id: str,
+ model_name: Optional[str] = None,
+ model_type: Optional[str] = None
+) -> tuple[Optional[Any], Optional[int]]:
"""
Get the embedding model for the tenant, optionally using a specific model name.
@@ -296,40 +352,19 @@ def get_embedding_model(tenant_id: str, model_name: Optional[str] = None) -> tup
Returns:
Tuple of (embedding model instance or None, model_id or None)
"""
- # If model_name is provided, find the model by display_name
if model_name:
try:
- model = get_model_by_display_name(model_name, tenant_id)
- if model and model.get("model_type") in ["embedding", "multi_embedding"]:
- model_config = {
- "model_repo": model.get("model_repo", ""),
- "model_name": model["model_name"],
- "api_key": model.get("api_key", ""),
- "base_url": model.get("base_url", ""),
- "model_type": model.get("model_type", "embedding"),
- "max_tokens": model.get("max_tokens", 1024),
- "ssl_verify": model.get("ssl_verify", True),
- }
- model_type = model.get("model_type", "embedding")
- if model_type == "multi_embedding":
- embedding_model = JinaEmbedding(
- api_key=model_config.get("api_key", ""),
- base_url=model_config.get("base_url", ""),
- model_name=get_model_name_from_config(model_config) or "",
- embedding_dim=model_config.get("max_tokens", 1024),
- ssl_verify=model_config.get("ssl_verify", True),
- )
- else:
- embedding_model = OpenAICompatibleEmbedding(
- api_key=model_config.get("api_key", ""),
- base_url=model_config.get("base_url", ""),
- model_name=get_model_name_from_config(model_config) or "",
- embedding_dim=model_config.get("max_tokens", 1024),
- ssl_verify=model_config.get("ssl_verify", True),
- )
- return embedding_model, model.get("model_id")
+ normalized_model_type = _normalize_model_type(model_type)
+ if normalized_model_type:
+ model = get_model_by_display_name(model_name, tenant_id, normalized_model_type)
else:
+ model = get_model_by_display_name(model_name, tenant_id)
+
+ if not model or model.get("model_type") not in ["embedding", "multi_embedding"]:
logger.warning(f"Model '{model_name}' not found or is not an embedding model")
+ return None, None
+
+ return _create_embedding_model(model), model.get("model_id")
except Exception as e:
logger.warning(f"Failed to get embedding model by name {model_name}: {e}")
@@ -595,6 +630,7 @@ def create_knowledge_base(
ingroup_permission: Optional[str] = None,
group_ids: Optional[List[int]] = None,
embedding_model_name: Optional[str] = None,
+ is_multimodal: Optional[bool] = None,
):
"""
Create a new knowledge base with a user-facing name and an internal Elasticsearch index name.
@@ -620,7 +656,17 @@ def create_knowledge_base(
"""
try:
# Get embedding model - use user-selected model if provided, otherwise use tenant default
- embedding_model, model_id = get_embedding_model(tenant_id, embedding_model_name)
+ selected_model_type = None
+ if is_multimodal is True:
+ selected_model_type = "multi_embedding"
+ elif is_multimodal is False and embedding_model_name:
+ selected_model_type = "embedding"
+
+ embedding_model, model_id = get_embedding_model(
+ tenant_id,
+ embedding_model_name,
+ selected_model_type
+ )
# Determine the embedding model name to save: use user-provided name if available,
# otherwise use the model's display name
@@ -1002,6 +1048,7 @@ def list_indices(
model_id = record.get("embedding_model_id")
tenant_id = record.get("tenant_id") or target_tenant_id
embedding_model_display_name = _get_embedding_model_display_name(model_id, tenant_id)
+ is_multimodal = _is_multimodal_by_model_id(model_id, tenant_id)
stats_info.append({
# Internal index name (used as ID)
@@ -1013,6 +1060,7 @@ def list_indices(
# knowledge source and ingroup permission from DB record
"knowledge_sources": record["knowledge_sources"],
"ingroup_permission": record["ingroup_permission"],
+ "is_multimodal": is_multimodal,
"tenant_id": record.get("tenant_id"),
# Embedding model info: display_name from model_id
"embedding_model_name": embedding_model_display_name or record.get("embedding_model_name", ""),
@@ -1122,12 +1170,27 @@ def index_documents(
"author": author,
"date": date,
"content": text,
- "process_source": "Unstructured",
+ "process_source": metadata.get("process_source", "Unstructured"),
"file_size": file_size,
"create_time": create_time,
"languages": metadata.get("languages", []),
"embedding_model_name": embedding_model_name
}
+
+ image_url = metadata.get("image_url", "")
+ if len(image_url) > 0:
+ # Fetch image bytes from MinIO (supports s3://bucket/key or /bucket/key)
+ try:
+ file_stream = get_file_stream(
+ object_name=image_url)
+ if file_stream is None:
+ raise FileNotFoundError(
+ f"Unable to fetch file from URL: {image_url}")
+ document["image_bytes"] = file_stream.read()
+ except Exception as e:
+ logger.error(
+ f"Failed to fetch file from {image_url}: {e}")
+ raise
documents.append(document)
@@ -1148,8 +1211,9 @@ def index_documents(
'tenant_id') if knowledge_record else None
if tenant_id:
+ model_type = "EMBEDDING_ID" if embedding_model.model_type == "text" else "MULTI_EMBEDDING_ID"
model_config = tenant_config_manager.get_model_config(
- key="EMBEDDING_ID", tenant_id=tenant_id)
+ key=model_type, tenant_id=tenant_id)
embedding_batch_size = model_config.get("chunk_batch", 10)
if embedding_batch_size is None:
embedding_batch_size = 10
@@ -1867,6 +1931,7 @@ def update_chunk(
chunk_request: ChunkUpdateRequest,
vdb_core: VectorDatabaseCore = Depends(get_vector_db_core),
user_id: Optional[str] = None,
+ tenant_id: Optional[str] = None,
):
"""
Update a chunk document.
diff --git a/backend/utils/file_management_utils.py b/backend/utils/file_management_utils.py
index 7d31a74bb..37681d9c7 100644
--- a/backend/utils/file_management_utils.py
+++ b/backend/utils/file_management_utils.py
@@ -15,7 +15,6 @@
from consts.model import ProcessParams
from database.attachment_db import get_file_size_from_minio
from utils.auth_utils import get_current_user_id
-from utils.config_utils import tenant_config_manager
logger = logging.getLogger("file_management_utils")
@@ -45,18 +44,13 @@ async def trigger_data_process(files: List[dict], process_params: ProcessParams)
if not files:
return None
- # Get chunking size according to the embedding model
- embedding_model_id = None
+ # Get tenant_id from authorization for downstream task processing
+ embedding_model_id = process_params.model_id
tenant_id = None
try:
_, tenant_id = get_current_user_id(process_params.authorization)
- # Get embedding model ID from tenant config
- tenant_config = tenant_config_manager.load_config(tenant_id)
- embedding_model_id_str = tenant_config.get("EMBEDDING_ID") if tenant_config else None
- if embedding_model_id_str:
- embedding_model_id = int(embedding_model_id_str)
except Exception as e:
- logger.warning(f"Failed to get embedding model ID for tenant: {e}")
+ logger.warning(f"Failed to get tenant_id from authorization: {e}")
# Build headers with authorization
headers = {
diff --git a/docker/.env.bak b/docker/.env.bak
deleted file mode 100644
index 24b53751b..000000000
--- a/docker/.env.bak
+++ /dev/null
@@ -1,168 +0,0 @@
-# ===== Necessary Configs (Necessary till now, will be migrated to frontend page) =====
-
-# Voice Service Config
-APPID=app_id
-TOKEN=token
-
-# ===== Non-essential Configs (Modify if you know what you are doing) =====
-
-CLUSTER=volcano_tts
-VOICE_TYPE=zh_male_jieshuonansheng_mars_bigtts
-SPEED_RATIO=1.3
-
-# ===== Proxy Configuration (Optional) =====
-
-# HTTP_PROXY=http://proxy-server:port
-# HTTPS_PROXY=http://proxy-server:port
-# NO_PROXY=localhost,127.0.0.1
-
-# ===== Backend Configuration (No need to modify at all) =====
-
-# Model Path Config
-CLIP_MODEL_PATH=/opt/models/clip-vit-base-patch32
-NLTK_DATA=/opt/models/nltk_data
-
-# Elasticsearch Service
-ELASTICSEARCH_HOST=http://nexent-elasticsearch:9200
-ELASTIC_PASSWORD=nexent@2025
-
-# Elasticsearch Memory Configuration
-ES_JAVA_OPTS="-Xms2g -Xmx2g"
-
-# Elasticsearch Disk Watermark Configuration
-ES_DISK_WATERMARK_LOW=85%
-ES_DISK_WATERMARK_HIGH=90%
-ES_DISK_WATERMARK_FLOOD_STAGE=95%
-
-# Main Services
-# Config service (port 5010) - Main API service for config operations
-CONFIG_SERVICE_URL=http://nexent-config:5010
-ELASTICSEARCH_SERVICE=http://nexent-config:5010/api
-
-# Runtime service (port 5014) - Runtime execution service for agent operations
-RUNTIME_SERVICE_URL=http://nexent-runtime:5014
-
-# MCP service (port 5011) - MCP protocol service
-NEXENT_MCP_SERVER=http://nexent-mcp:5011
-MCP_MANAGEMENT_API=http://nexent-mcp:5015
-
-# Data process service (port 5012) - Data processing service
-DATA_PROCESS_SERVICE=http://nexent-data-process:5012/api
-
-# Northbound service (port 5013) - Northbound API service
-NORTHBOUND_API_SERVER=http://nexent-northbound:5013/api
-
-# Postgres Config
-POSTGRES_HOST=nexent-postgresql
-POSTGRES_USER=root
-NEXENT_POSTGRES_PASSWORD=nexent@4321
-POSTGRES_DB=nexent
-POSTGRES_PORT=5432
-
-# Minio Config
-MINIO_ENDPOINT=http://nexent-minio:9000
-MINIO_ROOT_USER=nexent
-MINIO_ROOT_PASSWORD=nexent@4321
-MINIO_REGION=cn-north-1
-MINIO_DEFAULT_BUCKET=nexent
-
-# Redis Config
-REDIS_URL=redis://redis:6379/0
-REDIS_BACKEND_URL=redis://redis:6379/1
-
-# Model Engine Config
-MODEL_ENGINE_ENABLED=false
-
-# Supabase Config
-DASHBOARD_USERNAME=supabase
-DASHBOARD_PASSWORD=Huawei123
-
-# Supabase db Config
-SUPABASE_POSTGRES_PASSWORD=Huawei123
-SUPABASE_POSTGRES_HOST=db
-SUPABASE_POSTGRES_DB=supabase
-SUPABASE_POSTGRES_PORT=5436
-
-# Supabase Auth Config
-SITE_URL=http://localhost:3011
-SUPABASE_URL=http://supabase-kong-mini:8000
-API_EXTERNAL_URL=http://supabase-kong-mini:8000
-DISABLE_SIGNUP=false
-JWT_EXPIRY=3600
-DEBUG_JWT_EXPIRE_SECONDS=0
-
-# Supabase Configuration
-ENABLE_EMAIL_SIGNUP=true
-ENABLE_EMAIL_AUTOCONFIRM=true
-ENABLE_ANONYMOUS_USERS=false
-
-# Supabase Phone Config
-ENABLE_PHONE_SIGNUP=false
-ENABLE_PHONE_AUTOCONFIRM=false
-
-MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify"
-MAILER_URLPATHS_INVITE="/auth/v1/verify"
-MAILER_URLPATHS_RECOVERY="/auth/v1/verify"
-MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify"
-
-INVITE_CODE=nexent2025
-
-# Terminal Tool SSH Key Path
-SSH_PRIVATE_KEY_PATH=/path/to/openssh-server/ssh-keys/openssh_server_key
-
-# ===== Data Processing Service Configuration =====
-
-# Redis Port
-REDIS_PORT=6379
-
-# Flower Monitoring
-FLOWER_PORT=5555
-
-# Ray Configuration
-RAY_ACTOR_NUM_CPUS=2
-RAY_DASHBOARD_PORT=8265
-RAY_DASHBOARD_HOST=0.0.0.0
-RAY_NUM_CPUS=4
-RAY_OBJECT_STORE_MEMORY_GB=0.25
-RAY_TEMP_DIR=/tmp/ray
-RAY_LOG_LEVEL=INFO
-
-# Service Control Flags
-DISABLE_RAY_DASHBOARD=true
-DISABLE_CELERY_FLOWER=true
-DOCKER_ENVIRONMENT=false
-ENABLE_UPLOAD_IMAGE=false
-
-# Celery Configuration
-CELERY_WORKER_PREFETCH_MULTIPLIER=1
-CELERY_TASK_TIME_LIMIT=3600
-ELASTICSEARCH_REQUEST_TIMEOUT=30
-
-# Worker Configuration
-QUEUES=process_q,forward_q
-WORKER_NAME=
-WORKER_CONCURRENCY=4
-
-# Skills Configuration
-SKILLS_PATH=/mnt/nexent/skills
-
-# Telemetry and Monitoring Configuration
-ENABLE_TELEMETRY=false
-SERVICE_NAME=nexent-backend
-JAEGER_ENDPOINT=http://localhost:14268/api/traces
-PROMETHEUS_PORT=8000
-TELEMETRY_SAMPLE_RATE=1.0
-LLM_SLOW_REQUEST_THRESHOLD_SECONDS=5.0
-LLM_SLOW_TOKEN_RATE_THRESHOLD=10.0
-
-# Market Backend Address
-MARKET_BACKEND=http://60.204.251.153:8010
-DEPLOYMENT_VERSION="speed"
-# Root dir
-ROOT_DIR="/c/Users/18270/nexent-data"
-TERMINAL_MOUNT_DIR="/opt/terminal"
-SSH_USERNAME="root"
-SSH_PASSWORD="731215"
-NEXENT_MCP_DOCKER_IMAGE="ccr.ccs.tencentyun.com/nexent-hub/nexent-mcp:v2.0.1"
-MINIO_ACCESS_KEY="72c31cb5b521511cea652723"
-MINIO_SECRET_KEY="m5gcSuKzZnp84CqmG7z5VKnd2C+H5U3PSr7eoJeygmI="
diff --git a/docker/.env.example b/docker/.env.example
index e55bba45a..25fe228de 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -21,6 +21,8 @@ SPEED_RATIO=1.3
# Model Path Config
CLIP_MODEL_PATH=/opt/models/clip-vit-base-patch32
NLTK_DATA=/opt/models/nltk_data
+TABLE_TRANSFORMER_MODEL_PATH=/opt/models/table-transformer-structure-recognition
+UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH=/opt/models/yolox/config.json
# Elasticsearch Service
ELASTICSEARCH_HOST=http://nexent-elasticsearch:9200
diff --git a/docker/deploy.sh b/docker/deploy.sh
index 7fb78aa90..8b012f5c6 100755
--- a/docker/deploy.sh
+++ b/docker/deploy.sh
@@ -17,6 +17,7 @@ DEPLOY_OPTIONS_FILE="$SCRIPT_DIR/deploy.options"
MODE_CHOICE_SAVED=""
VERSION_CHOICE_SAVED=""
IS_MAINLAND_SAVED=""
+ENABLE_EXTRACTED_IMAGE_MODELS_SAVED="N"
ENABLE_SKILLS_SAVED="Y"
ENABLE_TERMINAL_SAVED="N"
TERMINAL_MOUNT_DIR_SAVED="${TERMINAL_MOUNT_DIR:-}"
@@ -85,6 +86,58 @@ is_windows_env() {
return 1
}
+detect_os_type() {
+ # Return: windows | mac | linux | unknown
+ local os_name
+ os_name=$(uname -s 2>/dev/null | tr '[:upper:]' '[:lower:]')
+ case "$os_name" in
+ mingw*|msys*|cygwin*)
+ echo "windows"
+ ;;
+ darwin*)
+ echo "mac"
+ ;;
+ linux*)
+ echo "linux"
+ ;;
+ *)
+ echo "unknown"
+ ;;
+ esac
+ return 0
+}
+
+format_path_for_env() {
+ # Convert path to OS-specific format for .env values
+ local input_path="$1"
+ local os_type
+ os_type=$(detect_os_type)
+
+ if [[ "$os_type" = "windows" ]]; then
+ if command -v cygpath >/dev/null 2>&1; then
+ cygpath -w "$input_path"
+ return 0
+ fi
+
+ if [[ "$input_path" =~ ^/([a-zA-Z])/(.*)$ ]]; then
+ local drive="${BASH_REMATCH[1]}"
+ local rest="${BASH_REMATCH[2]}"
+ rest="${rest//\//\\}"
+ printf "%s:\\%s" "$(echo "$drive" | tr '[:lower:]' '[:upper:]')" "$rest"
+ return 0
+ fi
+ fi
+
+ printf "%s" "$input_path"
+}
+
+escape_backslashes() {
+ # Escape backslashes for safe writing into .env or JSON
+ local input_path="$1"
+ printf "%s" "$input_path" | sed 's/\\/\\\\/g'
+ return 0
+}
+
is_port_in_use() {
# Check if a TCP port is already in use (Linux/macOS/Windows Git Bash)
local port="$1"
@@ -272,6 +325,7 @@ persist_deploy_options() {
echo "MODE_CHOICE=\"${MODE_CHOICE_SAVED}\""
echo "VERSION_CHOICE=\"${VERSION_CHOICE_SAVED}\""
echo "IS_MAINLAND=\"${IS_MAINLAND_SAVED}\""
+ echo "ENABLE_EXTRACTED_IMAGE_MODELS_SAVED=\"${ENABLE_EXTRACTED_IMAGE_MODELS_SAVED}\""
echo "ENABLE_SKILLS=\"${ENABLE_SKILLS_SAVED}\""
echo "ENABLE_TERMINAL=\"${ENABLE_TERMINAL_SAVED}\""
echo "TERMINAL_MOUNT_DIR=\"${TERMINAL_MOUNT_DIR_SAVED}\""
@@ -414,7 +468,7 @@ get_compose_version() {
# Function to get the version of docker compose
if command -v docker &> /dev/null; then
version_output=$(docker compose version 2>/dev/null)
- # 修改点:放宽正则匹配,允许版本号后面跟随其他字符(如 -desktop.1)
+
if [[ $version_output =~ v([0-9]+\.[0-9]+\.[0-9]+) ]]; then
echo "v2 ${BASH_REMATCH[1]}"
return 0
@@ -423,7 +477,7 @@ get_compose_version() {
if command -v docker-compose &> /dev/null; then
version_output=$(docker-compose --version 2>/dev/null)
- # 同样放宽这里的匹配规则,以防万一
+
if [[ $version_output =~ ([0-9]+\.[0-9]+\.[0-9]+) ]]; then
echo "v1 ${BASH_REMATCH[1]}"
return 0
@@ -537,6 +591,43 @@ select_deployment_mode() {
echo ""
}
+
+# Extracted image models selection
+select_extracted_image_models_mode() {
+ echo ""
+
+ local input_choice=""
+ read -r -p "Do you want to enable pre-extracted image models mode? [Y/N] (default: N): " input_choice
+ echo ""
+
+ if [[ $input_choice =~ ^[Yy]$ ]]; then
+ ENABLE_EXTRACTED_IMAGE_MODELS_SAVED="Y"
+ echo "INFO: ENABLE_EXTRACTED_IMAGE_MODELS_SAVED=Y, deployment will not change model path variables."
+ else
+ ENABLE_EXTRACTED_IMAGE_MODELS_SAVED="N"
+ echo "INFO: ENABLE_EXTRACTED_IMAGE_MODELS_SAVED=N, deployment will clear model path variables."
+ fi
+ echo "----------------------------------------"
+ echo ""
+ return 0
+}
+
+configure_extracted_image_models_env() {
+ # New behavior:
+ # - ENABLE_EXTRACTED_IMAGE_MODELS_SAVED=N: clear the two model-path values in .env
+ # - ENABLE_EXTRACTED_IMAGE_MODELS_SAVED=Y: do nothing (no .env update)
+ if [[ "$ENABLE_EXTRACTED_IMAGE_MODELS_SAVED" =~ ^[Yy]$ ]]; then
+ echo "INFO: ENABLE_EXTRACTED_IMAGE_MODELS_SAVED=Y, skip model handling (no .env update)."
+ return 0
+ fi
+
+ echo "INFO: ENABLE_EXTRACTED_IMAGE_MODELS_SAVED=N, clearing model path variables in .env..."
+ update_env_var "TABLE_TRANSFORMER_MODEL_PATH" ""
+ update_env_var "UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH" ""
+ echo "INFO: Cleared TABLE_TRANSFORMER_MODEL_PATH and UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH."
+ return 0
+}
+
clean() {
export MINIO_ACCESS_KEY=
export MINIO_SECRET_KEY=
@@ -609,6 +700,13 @@ prepare_directory_and_data() {
create_dir_with_permission "$ROOT_DIR/minio" 775
create_dir_with_permission "$ROOT_DIR/redis" 775
+ echo "📦 Check the status of model configuration..."
+ configure_extracted_image_models_env || {
+ echo "⚠️ A warning occurred during the model configuration step, but subsequent deployment will proceed..."
+ # Do not exit here; the user may choose N or prefer to continue after a model-path handling failure.
+ }
+ echo ""
+
cp -rn volumes $ROOT_DIR
chmod -R 775 $ROOT_DIR/volumes
echo " 📁 Directory $ROOT_DIR/volumes has been created and permissions set to 775."
@@ -1301,6 +1399,8 @@ main_deploy() {
choose_image_env || { echo "❌ Image environment setup failed"; exit 1; }
select_skills_installation || { echo "❌ Skills installation selection failed"; exit 1; }
+ select_extracted_image_models_mode || { echo "❌ Extracted image models configuration failed"; exit 1;}
+
# Set NEXENT_MCP_DOCKER_IMAGE in .env file
if [ -n "${NEXENT_MCP_DOCKER_IMAGE:-}" ]; then
update_env_var "NEXENT_MCP_DOCKER_IMAGE" "${NEXENT_MCP_DOCKER_IMAGE}"
@@ -1459,7 +1559,7 @@ docker_compose_command=""
case $version_type in
"v1")
echo "Detected Docker Compose V1, version: $version_number"
- # The version v1.28.0 is the minimum requirement in Docker Compose v1 that explicitly supports interpolation syntax with default values like ${VAR:-default}
+ # The version 1.28.0 is the minimum requirement in Docker Compose v1 for default interpolation syntax.
if [[ $version_number < "1.28.0" ]]; then
echo "Warning: V1 version is too old, consider upgrading to V2"
exit 1
diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx
index 909592345..c109f5722 100644
--- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx
+++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx
@@ -102,7 +102,8 @@ export default function ToolManagement({
// Use tool list hook for data management
const { availableTools } = useToolList();
- const { isVlmAvailable, isEmbeddingAvailable } = useConfig();
+ const { isVlmAvailable, isEmbeddingAvailable, isMultiEmbeddingAvailable } = useConfig();
+ const isEmbeddingOrMultiAvailable = isEmbeddingAvailable || isMultiEmbeddingAvailable;
// Prefetch knowledge bases for KB tools
const { prefetchKnowledgeBases } = usePrefetchKnowledgeBases();
@@ -363,7 +364,10 @@ export default function ToolManagement({
tool.id
);
const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable);
- const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable);
+ const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(
+ tool.name,
+ isEmbeddingOrMultiAvailable
+ );
const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly;
// Tooltip priority: permission > VLM > Embedding
const tooltipTitle = isReadOnly
@@ -468,7 +472,10 @@ export default function ToolManagement({
{group.tools.map((tool) => {
const isSelected = originalSelectedToolIdsSet.has(tool.id);
const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable);
- const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable);
+ const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(
+ tool.name,
+ isEmbeddingOrMultiAvailable
+ );
const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly;
// Tooltip priority: permission > VLM > Embedding
const tooltipTitle = isReadOnly
diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx
index 53c6d3f03..fe1ac6e2b 100644
--- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx
+++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx
@@ -1,4 +1,4 @@
-"use client";
+"use client";
import { useState, useEffect, useCallback, useMemo, useRef } from "react";
import { useTranslation } from "react-i18next";
@@ -36,6 +36,10 @@ import { API_ENDPOINTS } from "@/services/api";
import knowledgeBaseService from "@/services/knowledgeBaseService";
import log from "@/lib/logger";
import { isZhLocale, getLocalizedDescription } from "@/lib/utils";
+import {
+ isEmbeddingModelCompatible as isEmbeddingModelCompatibleBase,
+ isMultimodalConstraintMismatch as isMultimodalConstraintMismatchBase,
+} from "@/lib/knowledgeBaseCompatibility";
export interface ToolConfigModalProps {
isOpen: boolean;
@@ -524,6 +528,86 @@ export default function ToolConfigModal({
}
}, [configData]);
+ const currentMultiEmbeddingModel = useMemo(() => {
+ try {
+ const modelConfig = configData?.models;
+ return (
+ modelConfig?.multiEmbedding?.modelName ||
+ modelConfig?.multiEmbedding?.displayName ||
+ null
+ );
+ } catch {
+ return null;
+ }
+ }, [configData]);
+
+ const hasEmbeddingModel = Boolean(currentEmbeddingModel);
+ const hasMultiEmbeddingModel = Boolean(currentMultiEmbeddingModel);
+ const canToggleMultimodalParam = hasEmbeddingModel && hasMultiEmbeddingModel;
+ const forcedMultimodalValue = useMemo(() => {
+ if (!hasEmbeddingModel && hasMultiEmbeddingModel) {
+ return true;
+ }
+ if (hasEmbeddingModel && !hasMultiEmbeddingModel) {
+ return false;
+ }
+ return null;
+ }, [hasEmbeddingModel, hasMultiEmbeddingModel]);
+
+ const toolMultimodal = useMemo(() => {
+ const multimodalParam = currentParams.find(
+ (param) => param.name === "multimodal"
+ );
+ const value = multimodalParam?.value;
+ if (typeof value === "boolean") {
+ return value;
+ }
+ if (typeof value === "string") {
+ const normalized = value.trim().toLowerCase();
+ if (["true", "1", "yes", "y"].includes(normalized)) return true;
+ if (["false", "0", "no", "n"].includes(normalized)) return false;
+ }
+ return null;
+ }, [currentParams]);
+
+ useEffect(() => {
+ if (tool?.name !== "knowledge_base_search") return;
+ if (forcedMultimodalValue === null) return;
+
+ const index = currentParams.findIndex(
+ (param) => param.name === "multimodal"
+ );
+ if (index < 0) return;
+
+ const param = currentParams[index];
+ if (param.value === forcedMultimodalValue) return;
+
+ const updatedParams = [...currentParams];
+ updatedParams[index] = { ...param, value: forcedMultimodalValue };
+ setCurrentParams(updatedParams);
+
+ const fieldName = `param_${index}`;
+ form.setFieldValue(fieldName, forcedMultimodalValue);
+ }, [tool?.name, forcedMultimodalValue, currentParams, form]);
+
+ const isMultimodalConstraintMismatch = useCallback(
+ (kb: KnowledgeBase) => {
+ return isMultimodalConstraintMismatchBase(kb, toolMultimodal);
+ },
+ [toolMultimodal]
+ );
+
+ const isEmbeddingModelCompatible = useCallback(
+ (kb: KnowledgeBase) => {
+ return isEmbeddingModelCompatibleBase(
+ kb,
+ currentEmbeddingModel,
+ currentMultiEmbeddingModel
+ );
+ },
+ [currentEmbeddingModel, currentMultiEmbeddingModel]
+ );
+
// Check if a knowledge base can be selected
const canSelectKnowledgeBase = useCallback(
(kb: KnowledgeBase): boolean => {
@@ -534,9 +618,16 @@ export default function ToolConfigModal({
return false;
}
+ if (kb.source === "nexent") {
+ if (isMultimodalConstraintMismatch(kb)) {
+ return false;
+ }
+ return isEmbeddingModelCompatible(kb);
+ }
+
return true;
},
- [currentEmbeddingModel]
+ [isEmbeddingModelCompatible, isMultimodalConstraintMismatch]
);
// Track whether this is the first time opening the modal (reset when modal closes)
@@ -1451,7 +1542,7 @@ export default function ToolConfigModal({
})}
options={options.map((option) => ({
value: option,
- label: option,
+ label: String(option),
}))}
/>
);
@@ -1866,6 +1957,8 @@ export default function ToolConfigModal({
syncLoading={kbLoading}
isSelectable={canSelectKnowledgeBase}
currentEmbeddingModel={currentEmbeddingModel}
+ currentMultiEmbeddingModel={currentMultiEmbeddingModel}
+ toolMultimodal={toolMultimodal}
difyConfig={
toolKbType === "dify_search"
? difyConfig
diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolTestPanel.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolTestPanel.tsx
index b6af20594..767fae3a7 100644
--- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolTestPanel.tsx
+++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolTestPanel.tsx
@@ -2,9 +2,8 @@
import { useState, useEffect, useRef } from "react";
import { useTranslation } from "react-i18next";
-import { Input, Button, Card, Typography, Tooltip, Modal, Form, Tag, Skeleton } from "antd";
+import { Input, Button, Card, Typography, Tooltip, Modal, Form } from "antd";
import { Settings, PenLine, X } from "lucide-react";
-import { CloseOutlined } from "@ant-design/icons";
import { Tool, ToolParam } from "@/types/agentConfig";
import { KnowledgeBase } from "@/types/knowledgeBase";
@@ -59,15 +58,8 @@ export default function ToolTestPanel({
configParams,
onClose,
toolRequiresKbSelection = false,
- knowledgeBases = [],
- kbLoading = false,
- onOpenKbSelector,
selectedKbIds = [],
- selectedKbDisplayNames = [],
- onKbSelectionChange,
- onRemoveKb,
toolKbType = null,
- haotianKnowledgeSets = [],
}: ToolTestPanelProps) {
const { t } = useTranslation("common");
const [form] = Form.useForm();
@@ -605,28 +597,10 @@ export default function ToolTestPanel({
// Haotian uses dataset_ids, others use index_names
const isKbSelectorParam = (paramName === "index_names" || paramName === "dataset_ids") && toolRequiresKbSelection;
- // Get display names based on selected KB IDs and knowledge bases
- let displayNames: string[] = [];
- if (isKbSelectorParam && selectedKbIds.length > 0) {
- if (toolKbType === "haotian_search" && haotianKnowledgeSets.length > 0) {
- // Haotian: resolve names from haotianKnowledgeSets
- displayNames = selectedKbIds.map((id) => {
- const cleanId = id.trim();
- for (const ks of haotianKnowledgeSets) {
- const kb = (ks.knowledge_bases || []).find(
- (b) => String(b.dify_dataset_id) === cleanId
- );
- if (kb) return kb.name;
- }
- return cleanId;
- });
- } else if (knowledgeBases.length > 0) {
- displayNames = selectedKbIds.map((id) => {
- const cleanId = id.trim();
- const kb = knowledgeBases.find((k) => k.id === cleanId);
- return kb?.display_name || kb?.name || cleanId;
- });
- }
+ // KB selection is configured in the upper config area.
+ // Do not render duplicated KB params in the test input area.
+ if (isKbSelectorParam) {
+ return null;
}
// Add type-specific validation rules
@@ -681,83 +655,6 @@ export default function ToolTestPanel({
break;
}
- // Render knowledge base selector for index_names parameter
- if (isKbSelectorParam) {
- return (
-
- {paramName}
-
- }
- name={fieldName}
- rules={rules}
- tooltip={{
- title: getLocalizedDescription(description, description_zh),
- placement: "topLeft",
- styles: { root: { maxWidth: 400 } },
- }}
- >
-
-
onOpenKbSelector?.(-1)} // -1 indicates this is from test panel
- style={{
- width: "100%",
- minHeight: "32px",
- display: "flex",
- flexWrap: "wrap",
- alignItems: "center",
- gap: "4px",
- }}
- title={displayNames.join(", ")}
- >
- {kbLoading && knowledgeBases.length === 0 ? (
-
-
-
- ) : displayNames.length > 0 ? (
- displayNames.map((name, i) => (
-
-
-
- }
- onClose={(e) => {
- e.stopPropagation();
- onRemoveKb?.(i, -1); // -1 indicates this is from test panel
- }}
- style={{
- marginRight: 0,
- display: "inline-flex",
- alignItems: "center",
- lineHeight: "20px",
- padding: "0 8px",
- fontSize: "13px",
- }}
- >
- {name}
-
- ))
- ) : (
-
- {t("toolConfig.input.knowledgeBaseSelector.placeholder", {
- name: getLocalizedDescription(description, description_zh) || paramName,
- })}
-
- )}
-
-
-
- );
- }
-
return (
+ (type || "").trim().toLowerCase();
+
+const toEmbeddingModelOptionValue = (displayName: string, type: string) =>
+ `${displayName}${EMBEDDING_MODEL_OPTION_DELIMITER}${type}`;
+
+const parseEmbeddingModelOptionValue = (value: string) => {
+ const normalizedValue = (value || "").trim();
+ const delimiterIndex = normalizedValue.lastIndexOf(
+ EMBEDDING_MODEL_OPTION_DELIMITER
+ );
+ if (delimiterIndex >= 0) {
+ const displayName = normalizedValue.slice(0, delimiterIndex);
+ const type = normalizedValue.slice(
+ delimiterIndex + EMBEDDING_MODEL_OPTION_DELIMITER.length
+ );
+ return {
+ displayName: displayName || "",
+ type: (type || "").trim(),
+ isMultimodal:
+ normalizeEmbeddingModelType(type || "") === "multi_embedding",
+ };
+ }
+ return {
+ displayName: normalizedValue || "",
+ type: "",
+ isMultimodal: false,
+ };
+};
+
// EmptyState component defined directly in this file
interface EmptyStateProps {
icon?: React.ReactNode | string;
@@ -55,7 +87,7 @@ interface EmptyStateProps {
}
const EmptyState: React.FC = ({
- icon = "📋",
+ icon = "馃搵",
title,
description,
action,
@@ -129,8 +161,7 @@ function DataConfig({ isActive }: DataConfigProps) {
const { token } = theme.useToken();
// Get available embedding models for knowledge base creation
- const { availableEmbeddingModels } = useModelList({ enabled: true });
-
+ const { models } = useModelList({ enabled: true });
// Clear cache when component initializes
useEffect(() => {
localStorage.removeItem("preloaded_kb_data");
@@ -198,11 +229,59 @@ function DataConfig({ isActive }: DataConfigProps) {
const [modelFilter, setModelFilter] = useState([]);
const contentRef = useRef(null);
- // Open warning modal when single Embedding model is not configured (ignore multi-embedding)
+ const availableEmbeddingModels = useMemo(() => {
+ const embeddingRelatedModels = models.filter(
+ (model) => model.type === "embedding" || model.type === "multi_embedding"
+ );
+ const availableKeys = new Set(
+ embeddingRelatedModels
+ .filter((model) => model.connect_status === "available")
+ .map((model) => `${model.displayName}::${model.type}`)
+ );
+
+ return embeddingRelatedModels.filter((model) => {
+ if (model.connect_status === "available") {
+ return true;
+ }
+
+ // For paired records created from a multi-embedding model, mirror availability by display name.
+ if (model.type === "embedding") {
+ return availableKeys.has(`${model.displayName}::multi_embedding`);
+ }
+ if (model.type === "multi_embedding") {
+ return availableKeys.has(`${model.displayName}::embedding`);
+ }
+ return false;
+ });
+ }, [models]);
+
+ const resolveEmbeddingModelId = useCallback(
+ ({
+ displayName,
+ isMultimodal,
+ }: {
+ displayName?: string;
+ isMultimodal?: boolean;
+ }) => {
+ const normalizedDisplayName = (displayName || "").trim();
+ if (!normalizedDisplayName) return undefined;
+
+ const modelType = isMultimodal ? "multi_embedding" : "embedding";
+ return availableEmbeddingModels.find(
+ (model) =>
+ model.displayName === normalizedDisplayName && model.type === modelType
+ )?.id;
+ },
+ [availableEmbeddingModels]
+ );
+
+ // Open warning modal only when neither embedding nor multi-embedding is configured.
useEffect(() => {
- const singleEmbeddingModelName = modelConfig?.embedding?.modelName;
- setShowEmbeddingWarning(!singleEmbeddingModelName);
- }, [modelConfig?.embedding?.modelName]);
+ const singleEmbeddingModelName = modelConfig?.embedding?.modelName?.trim();
+ const multiEmbeddingModelName =
+ modelConfig?.multiEmbedding?.modelName?.trim();
+ setShowEmbeddingWarning(!singleEmbeddingModelName && !multiEmbeddingModelName);
+ }, [modelConfig?.embedding?.modelName, modelConfig?.multiEmbedding?.modelName]);
// Add event listener for selecting new knowledge base
useEffect(() => {
@@ -370,11 +449,11 @@ function DataConfig({ isActive }: DataConfigProps) {
// Directly call fetchKnowledgeBases to update knowledge base list data
await fetchKnowledgeBases(false, true);
} catch (error) {
- log.error("获取知识库最新数据失败:", error);
+ log.error("鑾峰彇鐭ヨ瘑搴撴渶鏂版暟鎹け璐?", error);
}
}, 100);
} catch (error) {
- log.error("获取文档列表失败:", error);
+ log.error("鑾峰彇鏂囨。鍒楄〃澶辫触:", error);
message.error(t("knowledgeBase.message.getDocumentsFailed"));
docDispatch({
type: "ERROR",
@@ -619,11 +698,30 @@ function DataConfig({ isActive }: DataConfigProps) {
setNewKbName(defaultName);
setNewKbIngroupPermission("READ_ONLY");
setNewKbGroupIds([]);
- // Set default embedding model - prioritize config's default model, fall back to first available model
- const configModel = modelConfig?.embedding?.modelName;
- const defaultModel = configModel || (availableEmbeddingModels.length > 0
- ? availableEmbeddingModels[0].displayName
- : "");
+ // Set default embedding model:
+ // 1) configured embedding model, 2) configured multimodal model, 3) first available option.
+ const configEmbeddingModel = modelConfig?.embedding?.modelName?.trim() || "";
+ const configMultiEmbeddingModel =
+ modelConfig?.multiEmbedding?.modelName?.trim() || "";
+ const preferredModel = [
+ { modelName: configEmbeddingModel, type: "embedding" },
+ { modelName: configMultiEmbeddingModel, type: "multi_embedding" },
+ ].find(
+ ({ modelName, type }) =>
+ !!modelName &&
+ availableEmbeddingModels.some(
+ (model) => model.displayName === modelName && model.type === type
+ )
+ );
+ const defaultModel =
+ (preferredModel &&
+ toEmbeddingModelOptionValue(preferredModel.modelName, preferredModel.type)) ||
+ (availableEmbeddingModels[0]
+ ? toEmbeddingModelOptionValue(
+ availableEmbeddingModels[0].displayName,
+ availableEmbeddingModels[0].type
+ )
+ : "");
setNewKbEmbeddingModel(defaultModel);
setIsCreatingMode(true);
setHasClickedUpload(false); // Reset upload button click state
@@ -682,13 +780,22 @@ function DataConfig({ isActive }: DataConfigProps) {
return;
}
+ const parsedSelectedModel =
+ parseEmbeddingModelOptionValue(newKbEmbeddingModel);
+ const isMultimodal = parsedSelectedModel.isMultimodal;
+ const selectedModelId = resolveEmbeddingModelId({
+ displayName: parsedSelectedModel.displayName,
+ isMultimodal: parsedSelectedModel.isMultimodal,
+ });
+
const newKB = await createKnowledgeBase(
newKbName.trim(),
t("knowledgeBase.description.default"),
"elasticsearch",
newKbIngroupPermission,
newKbGroupIds,
- newKbEmbeddingModel
+ parsedSelectedModel.displayName,
+ isMultimodal
);
if (!newKB) {
@@ -703,7 +810,7 @@ function DataConfig({ isActive }: DataConfigProps) {
setHasClickedUpload(false);
setNewlyCreatedKbId(newKB.id); // Mark this KB as newly created
- await uploadDocuments(newKB.id, filesToUpload);
+ await uploadDocuments(newKB.id, filesToUpload, selectedModelId);
setUploadFiles([]);
knowledgeBasePollingService
@@ -739,7 +846,12 @@ function DataConfig({ isActive }: DataConfigProps) {
}
try {
- await uploadDocuments(kbId, filesToUpload);
+ const activeKbModelId = resolveEmbeddingModelId({
+ displayName: kbState.activeKnowledgeBase?.embeddingModel,
+ isMultimodal: kbState.activeKnowledgeBase?.is_multimodal,
+ });
+
+ await uploadDocuments(kbId, filesToUpload, activeKbModelId);
setUploadFiles([]);
knowledgeBasePollingService.triggerKnowledgeBaseListUpdate(true);
@@ -888,7 +1000,7 @@ function DataConfig({ isActive }: DataConfigProps) {
= ({
knowledgeBaseId,
documents,
getFileIcon,
- currentEmbeddingModel = null,
- knowledgeBaseEmbeddingModel = "",
+ currentEmbeddingModel,
+ knowledgeBaseEmbeddingModel,
onChunkCountChange,
permission,
}) => {
@@ -128,55 +128,31 @@ const DocumentChunk: React.FC = ({
setTooltipResetKey((prev) => prev + 1);
}, []);
+ const effectiveIndexName = knowledgeBaseId || knowledgeBaseName;
+
+ const hasKnowledgeBaseModel =
+ Boolean(knowledgeBaseEmbeddingModel) &&
+ knowledgeBaseEmbeddingModel !== "unknown";
+ const hasCurrentModel = Boolean(currentEmbeddingModel);
+
// Determine if embedding models mismatch (specific condition for tooltip)
const isEmbeddingModelMismatch = React.useMemo(() => {
- if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) {
+ if (!hasKnowledgeBaseModel) {
return false;
}
- if (knowledgeBaseEmbeddingModel === "unknown") {
- return false;
- }
- return currentEmbeddingModel !== knowledgeBaseEmbeddingModel;
- }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel]);
+ return !hasCurrentModel || currentEmbeddingModel !== knowledgeBaseEmbeddingModel;
+ }, [
+ currentEmbeddingModel,
+ hasCurrentModel,
+ hasKnowledgeBaseModel,
+ knowledgeBaseEmbeddingModel,
+ ]);
// Determine if in read-only mode (embedding model mismatch OR user has READ_ONLY permission)
// Note: isReadOnlyMode is broader, includes model mismatch and other conditions
const isReadOnlyMode = React.useMemo(() => {
- // Check if user has READ_ONLY permission
- if (permission === "READ_ONLY") {
- return true;
- }
- if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) {
- return false;
- }
- if (knowledgeBaseEmbeddingModel === "unknown") {
- return false;
- }
- return currentEmbeddingModel !== knowledgeBaseEmbeddingModel;
- }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel, permission]);
-
- // Determine if search should be disabled (only when embedding model mismatch, NOT for READ_ONLY permission)
- // This allows READ_ONLY users to still perform search
- const isSearchDisabled = React.useMemo(() => {
- if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) {
- return false;
- }
- if (knowledgeBaseEmbeddingModel === "unknown") {
- return false;
- }
- return currentEmbeddingModel !== knowledgeBaseEmbeddingModel;
- }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel]);
-
- // Disabled tooltip message when embedding model mismatch
- const disabledTooltipMessage = React.useMemo(() => {
- if (isEmbeddingModelMismatch && currentEmbeddingModel && knowledgeBaseEmbeddingModel && knowledgeBaseEmbeddingModel !== "unknown") {
- return t("document.chunk.tooltip.disabledDueToModelMismatch", {
- currentModel: currentEmbeddingModel,
- knowledgeBaseModel: knowledgeBaseEmbeddingModel
- });
- }
- return "";
- }, [isEmbeddingModelMismatch, currentEmbeddingModel, knowledgeBaseEmbeddingModel, t]);
+ return permission === "READ_ONLY" || isEmbeddingModelMismatch;
+ }, [permission, isEmbeddingModelMismatch]);
// Set active document when documents change
useEffect(() => {
@@ -201,14 +177,14 @@ const DocumentChunk: React.FC = ({
// Load chunks for active document with server-side pagination
const loadChunks = React.useCallback(async () => {
- if (!knowledgeBaseName || !activeDocumentKey) {
+ if (!effectiveIndexName || !activeDocumentKey) {
return;
}
setLoading(true);
try {
const result = await knowledgeBaseService.previewChunksPaginated(
- knowledgeBaseName,
+ effectiveIndexName,
pagination.page,
pagination.pageSize,
activeDocumentKey
@@ -240,7 +216,7 @@ const DocumentChunk: React.FC = ({
setLoading(false);
}
}, [
- knowledgeBaseName,
+ effectiveIndexName,
activeDocumentKey,
pagination.page,
pagination.pageSize,
@@ -321,16 +297,7 @@ const DocumentChunk: React.FC = ({
return;
}
- // Check embedding model consistency before searching
- if (isEmbeddingModelMismatch && currentEmbeddingModel && knowledgeBaseEmbeddingModel && knowledgeBaseEmbeddingModel !== "unknown") {
- message.error(t("document.chunk.error.searchFailed", {
- currentModel: currentEmbeddingModel,
- knowledgeBaseModel: knowledgeBaseEmbeddingModel
- }));
- return;
- }
-
- if (!knowledgeBaseName) {
+ if (!effectiveIndexName) {
message.error(t("document.chunk.error.searchFailed"));
return;
}
@@ -340,7 +307,7 @@ const DocumentChunk: React.FC = ({
try {
const response = await knowledgeBaseService.hybridSearch(
- knowledgeBaseId,
+ effectiveIndexName,
trimmedValue,
{
topK: pagination.pageSize,
@@ -352,11 +319,14 @@ const DocumentChunk: React.FC = ({
return {
id: item.id || "",
content: item.content || "",
- path_or_url: item.path_or_url,
+ path_or_url: item.path_or_url || item.url || item.pathOrUrl,
filename: item.filename,
create_time: item.create_time,
score: item.score, // Preserve search score for display
- source_type: item.source_type, // Preserve source type for display
+ source_type:
+ item.source_type === "local" || item.source_type === "minio"
+ ? "file"
+ : item.source_type, // Preserve source type for display
};
});
@@ -373,16 +343,12 @@ const DocumentChunk: React.FC = ({
setChunkSearchLoading(false);
}
}, [
- knowledgeBaseName,
- knowledgeBaseId,
+ effectiveIndexName,
message,
pagination.pageSize,
resetChunkSearch,
searchValue,
t,
- isEmbeddingModelMismatch,
- currentEmbeddingModel,
- knowledgeBaseEmbeddingModel,
]);
const refreshChunks = React.useCallback(async () => {
@@ -454,7 +420,7 @@ const DocumentChunk: React.FC = ({
};
const handleChunkSubmit = async () => {
- if (!knowledgeBaseName) {
+ if (!effectiveIndexName) {
message.error(t("document.chunk.error.loadFailed"));
return;
}
@@ -463,26 +429,12 @@ const DocumentChunk: React.FC = ({
return;
}
- // Check embedding model consistency before creating chunk
- if (chunkModalMode === "create") {
- if (knowledgeBaseEmbeddingModel &&
- knowledgeBaseEmbeddingModel !== "unknown" &&
- currentEmbeddingModel &&
- currentEmbeddingModel !== knowledgeBaseEmbeddingModel) {
- message.error(t("document.chunk.error.createFailed", {
- currentModel: currentEmbeddingModel,
- knowledgeBaseModel: knowledgeBaseEmbeddingModel
- }));
- return;
- }
- }
-
try {
const values = await chunkForm.validateFields();
setChunkSubmitting(true);
if (chunkModalMode === "create") {
const filenamePayload = values.filename?.trim() || undefined;
- await knowledgeBaseService.createChunk(knowledgeBaseName, {
+ await knowledgeBaseService.createChunk(effectiveIndexName, {
content: values.content,
filename: filenamePayload,
path_or_url: activeDocumentKey,
@@ -503,7 +455,7 @@ const DocumentChunk: React.FC = ({
return;
}
await knowledgeBaseService.updateChunk(
- knowledgeBaseName,
+ effectiveIndexName,
editingChunk.id,
{
content: values.content,
@@ -541,7 +493,7 @@ const DocumentChunk: React.FC = ({
message.error(t("document.chunk.error.missingChunkId"));
return;
}
- if (!knowledgeBaseName) {
+ if (!effectiveIndexName) {
message.error(t("document.chunk.error.deleteFailed"));
return;
}
@@ -556,7 +508,7 @@ const DocumentChunk: React.FC = ({
danger: true,
onOk: async () => {
try {
- await knowledgeBaseService.deleteChunk(knowledgeBaseName, chunk.id);
+ await knowledgeBaseService.deleteChunk(effectiveIndexName, chunk.id);
message.success(t("document.chunk.success.delete"));
forceCloseTooltips();
// Update chunk count immediately for better UX
@@ -761,11 +713,11 @@ const DocumentChunk: React.FC = ({
{chunk.source_type === "datamate"
- ? t("document.chunk.source.datamate", "来源: Datamate")
+ ? t("document.chunk.source.datamate", "\u6765\u6e90: Datamate")
: chunk.source_type === "file" ||
chunk.source_type === "minio" ||
chunk.source_type === "local"
- ? t("document.chunk.source.nexent", "来源: Nexent")
+ ? t("document.chunk.source.nexent", "\u6765\u6e90: Nexent")
: ""}
@@ -805,57 +757,37 @@ const DocumentChunk: React.FC = ({
{/* Search and Add Button Bar */}
{/* Create Chunk button - hide when user has READ_ONLY permission */}
{!isReadOnlyMode && (
@@ -864,7 +796,6 @@ const DocumentChunk: React.FC
= ({
type="text"
icon={}
onClick={openCreateChunkModal}
- disabled={isEmbeddingModelMismatch}
>
)}
diff --git a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx
index 023f2205a..3590db86b 100644
--- a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx
+++ b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx
@@ -80,6 +80,8 @@ interface DocumentListProps {
availableEmbeddingModels?: ModelOption[];
selectedEmbeddingModel?: string;
onEmbeddingModelChange?: (value: string) => void;
+ isMultimodal?: boolean;
+ onMultimodalChange?: (value: boolean) => void;
permission?: string; // User's permission for this knowledge base (READ_ONLY, EDIT, etc.)
// Auto-summary frequency
@@ -127,6 +129,8 @@ const DocumentListContainer = forwardRef(
availableEmbeddingModels,
selectedEmbeddingModel,
onEmbeddingModelChange,
+ isMultimodal = false,
+ onMultimodalChange,
permission,
// Auto-summary frequency
@@ -248,6 +252,8 @@ const [isSummarizing, setIsSummarizing] = useState(false);
// Determine if user has read-only permission
const isReadOnlyMode = permission === "READ_ONLY";
+ const canToggleMultimodal =
+ isCreatingMode && typeof onMultimodalChange === "function";
// Permission options with icons shown inside dropdown
const permissionOptions = [
@@ -313,6 +319,26 @@ const [isSummarizing, setIsSummarizing] = useState(false);
// Check if group select should be disabled (when permission is PRIVATE)
const isGroupSelectDisabled = ingroupPermission === "PRIVATE";
+ const embeddingModelsForOptions = availableEmbeddingModels || [];
+ const availableEmbeddingModelKeys = new Set(
+ embeddingModelsForOptions
+ .filter((model) => model.connect_status === "available")
+ .map((model) => `${model.displayName}::${model.type}`)
+ );
+ const isEmbeddingModelSelectable = (model: ModelOption): boolean => {
+ if (model.connect_status === "available") return true;
+ if (model.type === "embedding") {
+ return availableEmbeddingModelKeys.has(
+ `${model.displayName}::multi_embedding`
+ );
+ }
+ if (model.type === "multi_embedding") {
+ return availableEmbeddingModelKeys.has(
+ `${model.displayName}::embedding`
+ );
+ }
+ return false;
+ };
// Load frequency options from backend API
useEffect(() => {
@@ -533,11 +559,29 @@ const [isSummarizing, setIsSummarizing] = useState(false);
onChange={onEmbeddingModelChange}
style={{ minWidth: 200, justifyContent: "center", alignItems: "flex-end" }}
placeholder={t("knowledgeBase.create.embeddingModelPlaceholder") || "Select embedding model"}
- options={(availableEmbeddingModels || []).map((model) => ({
- value: model.displayName,
- label: model.displayName,
- disabled: model.connect_status === "unavailable",
- }))}
+ allowClear={false}
+ options={[
+ {
+ label: t("modelConfig.option.embeddingModel"),
+ options: embeddingModelsForOptions
+ .filter((model) => model.type === "embedding")
+ .map((model) => ({
+ value: `${model.displayName}::${model.type}`,
+ label: model.displayName,
+ disabled: !isEmbeddingModelSelectable(model),
+ })),
+ },
+ {
+ label: t("modelConfig.option.multiEmbeddingModel"),
+ options: embeddingModelsForOptions
+ .filter((model) => model.type === "multi_embedding")
+ .map((model) => ({
+ value: `${model.displayName}::${model.type}`,
+ label: model.displayName,
+ disabled: !isEmbeddingModelSelectable(model),
+ })),
+ },
+ ].filter((group) => group.options.length > 0)}
/>
)}
{/* User groups multi-select */}
@@ -645,7 +689,7 @@ const [isSummarizing, setIsSummarizing] = useState(false);
;
isLoading?: boolean;
syncLoading?: boolean;
onClick: (kb: KnowledgeBase) => void;
@@ -57,7 +60,7 @@ interface KnowledgeBaseListProps {
const KnowledgeBaseList: React.FC = ({
knowledgeBases,
activeKnowledgeBase,
- currentEmbeddingModel,
+ configuredEmbeddingModels = [],
isLoading = false,
syncLoading = false,
onClick,
@@ -128,6 +131,34 @@ const KnowledgeBaseList: React.FC = ({
return `knowledgeBase.ingroup.permission.${permission || "DEFAULT"}`;
};
+ const configuredModelTypesByName = useMemo(() => {
+ const map = new Map>();
+ configuredEmbeddingModels.forEach((model) => {
+ const modelName = (model.displayName || "").trim();
+ const modelType = (model.type || "").trim().toLowerCase();
+ if (!modelName) return;
+ if (modelType !== "embedding" && modelType !== "multi_embedding") return;
+ if (!map.has(modelName)) {
+ map.set(modelName, new Set());
+ }
+ map.get(modelName)!.add(modelType);
+ });
+ return map;
+ }, [configuredEmbeddingModels]);
+
+ const isModelMismatch = (kb: KnowledgeBase) => {
+ if (kb.embeddingModel === "unknown") return false;
+ if (kb.source === "datamate") return false;
+ const modelTypes = configuredModelTypesByName.get(
+ (kb.embeddingModel || "").trim()
+ );
+ return !modelTypes;
+ };
+
+ const hasIndexedDocumentsAndChunks = (kb: KnowledgeBase) => {
+ return (kb.documentCount || 0) > 0 && (kb.chunkCount || 0) > 0;
+ };
+
// Search and filter states
const [searchKeyword, setSearchKeyword] = useState("");
const [selectedSources, setSelectedSources] = useState([]);
@@ -580,6 +611,21 @@ const KnowledgeBaseList: React.FC = ({
})}
)}
+ {kb.is_multimodal &&
+ hasIndexedDocumentsAndChunks(kb) && (
+
+ multimodal
+
+ )}
+ {isModelMismatch(kb) && (
+
+ {t("knowledgeBase.tag.modelMismatch")}
+
+ )}
{/* User group tags - only show when not PRIVATE */}
diff --git a/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx b/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx
index b956dd919..63d9ad1c2 100644
--- a/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx
+++ b/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx
@@ -112,7 +112,7 @@ export const DocumentContext = createContext<{
state: DocumentState;
dispatch: React.Dispatch;
fetchDocuments: (kbId: string, forceRefresh?: boolean, kbSource?: string) => Promise;
- uploadDocuments: (kbId: string, files: File[]) => Promise;
+ uploadDocuments: (kbId: string, files: File[], modelId?: number) => Promise;
deleteDocument: (kbId: string, docId: string) => Promise;
}>({
state: {
@@ -202,11 +202,11 @@ export const DocumentProvider: React.FC = ({ children })
}, [state.loadingKbIds, state.documentsMap, t]);
// Upload documents to a knowledge base
- const uploadDocuments = useCallback(async (kbId: string, files: File[]) => {
+ const uploadDocuments = useCallback(async (kbId: string, files: File[], modelId?: number) => {
dispatch({ type: DOCUMENT_ACTION_TYPES.SET_UPLOADING, payload: true });
try {
- await knowledgeBaseService.uploadDocuments(kbId, files);
+ await knowledgeBaseService.uploadDocuments(kbId, files, undefined, modelId);
// Set loading state before fetching latest documents
dispatch({ type: DOCUMENT_ACTION_TYPES.SET_LOADING_DOCUMENTS, payload: true });
@@ -265,4 +265,4 @@ export const DocumentProvider: React.FC = ({ children })
{children}
);
-};
\ No newline at end of file
+};
diff --git a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx
index 3c5946bd4..9733d44c4 100644
--- a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx
+++ b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx
@@ -117,7 +117,8 @@ export const KnowledgeBaseContext = createContext<{
source?: string,
ingroup_permission?: string,
group_ids?: number[],
- embeddingModel?: string
+ embeddingModel?: string,
+ is_multimodal?: boolean
) => Promise;
deleteKnowledgeBase: (id: string) => Promise;
selectKnowledgeBase: (id: string) => void;
@@ -133,6 +134,7 @@ export const KnowledgeBaseContext = createContext<{
selectedIds: [],
activeKnowledgeBase: null,
currentEmbeddingModel: null,
+ currentMultiEmbeddingModel: null,
isLoading: false,
syncLoading: false,
error: null,
@@ -168,6 +170,7 @@ export const KnowledgeBaseProvider: React.FC = ({
selectedIds: [],
activeKnowledgeBase: null,
currentEmbeddingModel: null,
+ currentMultiEmbeddingModel: null,
isLoading: false,
syncLoading: false,
error: null,
@@ -177,11 +180,6 @@ export const KnowledgeBaseProvider: React.FC = ({
// Check if knowledge base is selectable - memoized with useCallback
const isKnowledgeBaseSelectable = useCallback(
(kb: KnowledgeBase): boolean => {
- // If no current embedding model is set, not selectable
- if (!state.currentEmbeddingModel) {
- return false;
- }
-
// Check if knowledge base has content (documents or chunks)
const hasContent =
(kb.documentCount || 0) > 0 || (kb.chunkCount || 0) > 0;
@@ -196,22 +194,46 @@ export const KnowledgeBaseProvider: React.FC = ({
return true;
}
- // For local knowledge bases, only selectable when model exactly matches current model
- return (
- kb.embeddingModel === "unknown" ||
- kb.embeddingModel === state.currentEmbeddingModel
- );
+ if (kb.embeddingModel === "unknown") {
+ return true;
+ }
+
+ const currentEmbeddingModel = state.currentEmbeddingModel?.trim() || "";
+ const currentMultiEmbeddingModel =
+ modelConfig?.multiEmbedding?.modelName?.trim() || "";
+
+ if (kb.is_multimodal) {
+ // Multimodal KB is selectable as long as current multimodal model is configured.
+ return !!currentMultiEmbeddingModel;
+ }
+
+ // Text KB is selectable as long as current embedding model is configured.
+ return !!currentEmbeddingModel;
},
- [state.currentEmbeddingModel]
+ [modelConfig?.multiEmbedding?.modelName, state.currentEmbeddingModel]
);
// Check if knowledge base has model mismatch (for display purposes)
- // Note: Always return false to remove model mismatch restrictions
const hasKnowledgeBaseModelMismatch = useCallback(
(kb: KnowledgeBase): boolean => {
- return false;
+ if (kb.embeddingModel === "unknown") {
+ return false;
+ }
+ if (kb.source === "datamate") {
+ return false;
+ }
+
+ if (kb.is_multimodal) {
+ const multiEmbeddingModel =
+ modelConfig?.multiEmbedding?.modelName?.trim() || "";
+ // Only show warning when the required current model is not configured.
+ return !multiEmbeddingModel;
+ }
+
+ // Only show warning when the required current model is not configured.
+ return !state.currentEmbeddingModel;
},
- []
+ [modelConfig?.multiEmbedding?.modelName, state.currentEmbeddingModel]
);
// Load knowledge base data (supports force fetch from server and load selected status) - optimized with useCallback
@@ -325,17 +347,31 @@ export const KnowledgeBaseProvider: React.FC = ({
source: string = "elasticsearch",
ingroup_permission?: string,
group_ids?: number[],
- embeddingModel?: string
+ embeddingModel?: string,
+ is_multimodal?: boolean
) => {
try {
+ const selectedEmbeddingModel = embeddingModel?.trim() || "";
+ const defaultMultiEmbeddingModel =
+ modelConfig?.multiEmbedding?.modelName?.trim() || "";
+ const resolvedIsMultimodal =
+ typeof is_multimodal === "boolean"
+ ? is_multimodal
+ : !!defaultMultiEmbeddingModel &&
+ selectedEmbeddingModel === defaultMultiEmbeddingModel;
+ const fallbackEmbeddingModel = resolvedIsMultimodal
+ ? defaultMultiEmbeddingModel
+ : state.currentEmbeddingModel || "";
+ const resolvedEmbeddingModel =
+ selectedEmbeddingModel || fallbackEmbeddingModel;
const newKB = await knowledgeBaseService.createKnowledgeBase({
name,
description,
source,
- // Use provided embeddingModel if available, otherwise fall back to current model or default
- embeddingModel: embeddingModel || state.currentEmbeddingModel || "",
+ embeddingModel: resolvedEmbeddingModel,
ingroup_permission,
group_ids,
+ is_multimodal: resolvedIsMultimodal,
});
return newKB;
} catch (error) {
@@ -347,7 +383,7 @@ export const KnowledgeBaseProvider: React.FC = ({
return null;
}
},
- [state.currentEmbeddingModel, t]
+ [modelConfig?.multiEmbedding?.modelName, state.currentEmbeddingModel, t]
);
// Delete knowledge base - memoized with useCallback
diff --git a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx
index 11391c133..01a11210e 100644
--- a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx
+++ b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx
@@ -1,4 +1,4 @@
-import { useMemo, useState, useCallback, useEffect } from "react";
+import { useMemo, useState, useCallback, useEffect } from "react";
import { useTranslation } from "react-i18next";
import { Modal, Select, Input, Button, Switch, Tooltip, App } from "antd";
@@ -483,7 +483,8 @@ export const ModelAddDialog = ({
if (tenantId) {
connectivity = await modelService.checkManageTenantModelConnectivity(
tenantId,
- form.displayName || form.name
+ form.displayName || form.name,
+ modelType
);
} else {
// For STT models, build the appropriate config based on provider
diff --git a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx
index 3114c5535..ab66100aa 100644
--- a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx
+++ b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx
@@ -1,4 +1,4 @@
-import { useState, useEffect } from 'react'
+import { useState, useEffect } from 'react'
import { useTranslation } from 'react-i18next'
import { Modal, Input, Button, App } from "antd";
@@ -481,4 +481,4 @@ export const ProviderConfigEditDialog = ({
)
-}
\ No newline at end of file
+}
diff --git a/frontend/app/[locale]/models/components/modelConfig.tsx b/frontend/app/[locale]/models/components/modelConfig.tsx
index 07eee5c06..b298a4e30 100644
--- a/frontend/app/[locale]/models/components/modelConfig.tsx
+++ b/frontend/app/[locale]/models/components/modelConfig.tsx
@@ -527,6 +527,7 @@ export const ModelConfigSection = forwardRef<
try {
const isConnected = await modelService.verifyCustomModel(
modelName,
+ modelType,
signal
);
@@ -603,7 +604,7 @@ export const ModelConfigSection = forwardRef<
throttleTimerRef.current = setTimeout(async () => {
try {
// Use modelService to verify model
- const isConnected = await modelService.verifyCustomModel(displayName);
+ const isConnected = await modelService.verifyCustomModel(displayName, modelType);
// Update model status
updateModelStatus(
diff --git a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx
index 786048894..c4ffaf8ca 100644
--- a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx
+++ b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx
@@ -1,4 +1,4 @@
-"use client";
+"use client";
import React, { useState, useMemo } from "react";
import { useTranslation } from "react-i18next";
@@ -130,7 +130,8 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) {
}
};
- const handleCheckConnectivity = async (displayName: string) => {
+ // Handle checking model connectivity
+ const handleCheckConnectivity = async (displayName: string, modelType: string) => {
if (!tenantId) {
message.error(t("tenantResources.tenants.tenantIdRequired"));
return;
@@ -138,7 +139,7 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) {
setCheckingConnectivity((prev) => new Set(prev).add(displayName));
try {
- const isConnected = await modelService.verifyCustomModel(displayName);
+ const isConnected = await modelService.verifyCustomModel(displayName, modelType);
if (isConnected) {
message.success(t("tenantResources.models.connectivitySuccess"));
} else {
@@ -299,7 +300,7 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) {
: }
- onClick={() => handleCheckConnectivity(record.displayName)}
+ onClick={() => handleCheckConnectivity(record.displayName, record.type)}
size="small"
loading={checkingConnectivity.has(record.displayName)}
/>
diff --git a/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx b/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx
index 9e30f323a..a532af843 100644
--- a/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx
+++ b/frontend/components/tool-config/KnowledgeBaseSelectorModal.tsx
@@ -21,6 +21,10 @@ import {
import { KnowledgeBase } from "@/types/knowledgeBase";
import { ToolKbType } from "@/hooks/useKnowledgeBaseConfigChangeHandler";
import { KB_LAYOUT, KB_TAG_VARIANTS } from "@/const/knowledgeBaseLayout";
+import {
+ isEmbeddingModelCompatible as isEmbeddingModelCompatibleBase,
+ isMultimodalConstraintMismatch as isMultimodalConstraintMismatchBase,
+} from "@/lib/knowledgeBaseCompatibility";
import { useModelList } from "@/hooks/model/useModelList";
import knowledgeBaseService from "@/services/knowledgeBaseService";
import log from "@/lib/logger";
@@ -77,6 +81,8 @@ interface KnowledgeBaseSelectorModalProps extends KnowledgeBaseSelectorProps {
// Selection validation props
isSelectable?: (kb: KnowledgeBase) => boolean;
currentEmbeddingModel?: string | null;
+ currentMultiEmbeddingModel?: string | null;
+ toolMultimodal?: boolean | null;
// Dify/iData configuration for fetching knowledge bases
difyConfig?: {
serverUrl?: string;
@@ -103,6 +109,8 @@ export default function KnowledgeBaseSelectorModal({
syncLoading = false,
isSelectable,
currentEmbeddingModel = null,
+ currentMultiEmbeddingModel = null,
+ toolMultimodal = null,
difyConfig,
}: KnowledgeBaseSelectorModalProps) {
const { t } = useTranslation("common");
@@ -222,6 +230,24 @@ export default function KnowledgeBaseSelectorModal({
}
}, []);
+ const isMultimodalConstraintMismatch = useCallback(
+ (kb: KnowledgeBase) => {
+ return isMultimodalConstraintMismatchBase(kb, toolMultimodal);
+ },
+ [toolMultimodal]
+ );
+
+ const isEmbeddingModelCompatible = useCallback(
+ (kb: KnowledgeBase) => {
+ return isEmbeddingModelCompatibleBase(
+ kb,
+ currentEmbeddingModel,
+ currentMultiEmbeddingModel
+ );
+ },
+ [currentEmbeddingModel, currentMultiEmbeddingModel]
+ );
+
// Check if a knowledge base can be selected
const checkCanSelect = useCallback(
(kb: KnowledgeBase): boolean => {
@@ -238,9 +264,53 @@ export default function KnowledgeBaseSelectorModal({
return false;
}
+ // For nexent source, check model matching against current tenant config and tool multimodal constraint.
+ if (kb.source === "nexent") {
+ if (isMultimodalConstraintMismatch(kb)) {
+ return false;
+ }
+ return isEmbeddingModelCompatible(kb);
+ }
+
return true;
},
- [isSelectable]
+ [
+ isSelectable,
+ isEmbeddingModelCompatible,
+ isMultimodalConstraintMismatch,
+ ]
+ );
+
+ const getModelMismatch = useCallback(
+ (kb: KnowledgeBase): boolean => {
+ if (kb.source !== "nexent") {
+ return false;
+ }
+
+ const hasMultimodalConstraintMismatch =
+ isMultimodalConstraintMismatchBase(kb, toolMultimodal);
+ if (hasMultimodalConstraintMismatch) {
+ return true;
+ }
+
+ const embeddingModel = kb.embeddingModel;
+ if (!embeddingModel || embeddingModel === "unknown") {
+ return false;
+ }
+
+ if (kb.is_multimodal) {
+ if (!currentMultiEmbeddingModel) {
+ return true;
+ }
+ return embeddingModel !== currentMultiEmbeddingModel;
+ }
+
+ if (!currentEmbeddingModel) {
+ return false;
+ }
+ return embeddingModel !== currentEmbeddingModel;
+ },
+ [currentEmbeddingModel, currentMultiEmbeddingModel, toolMultimodal]
);
// Filter knowledge bases based on tool type, search, and filters
@@ -787,6 +857,7 @@ export default function KnowledgeBaseSelectorModal({
String(selectedId).trim() === String(kb.id).trim()
);
const canSelect = checkCanSelect(kb);
+ const hasModelMismatch = getModelMismatch(kb);
return (
)}
+ {kb.is_multimodal && (
+
+ multimodal
+
+ )}
+ {hasModelMismatch && (
+
+ {t("knowledgeBase.tag.modelMismatch")}
+
+ )}
diff --git a/frontend/const/agentConfig.ts b/frontend/const/agentConfig.ts
index 7ccc21bd7..aed7b6404 100644
--- a/frontend/const/agentConfig.ts
+++ b/frontend/const/agentConfig.ts
@@ -97,6 +97,7 @@ export const TOOL_PARAM_OPTIONS = {
// Knowledge base search tool
knowledge_base_search: {
search_mode: ["hybrid", "accurate", "semantic"],
+ multimodal: [true, false],
},
// Dify search tool
dify_search: {
@@ -126,11 +127,12 @@ export const TOOL_PARAM_OPTIONS = {
export function getToolParamOptions(
toolName: string,
paramName: string
-): string[] | undefined {
+): string[] | boolean[] | undefined {
const toolOptions =
TOOL_PARAM_OPTIONS[toolName as keyof typeof TOOL_PARAM_OPTIONS];
if (!toolOptions) return undefined;
return toolOptions[paramName as keyof typeof toolOptions] as
| string[]
+ | boolean[]
| undefined;
}
diff --git a/frontend/const/knowledgeBaseLayout.ts b/frontend/const/knowledgeBaseLayout.ts
index 082c40be5..550ee6dc1 100644
--- a/frontend/const/knowledgeBaseLayout.ts
+++ b/frontend/const/knowledgeBaseLayout.ts
@@ -56,4 +56,6 @@ export const KB_TAG_VARIANTS = {
model: "bg-green-50 text-green-700 border border-green-200",
// Yellow tag for model mismatch
warning: "bg-yellow-100 text-yellow-800 border border-yellow-200",
+ // Red tag for multimodal models
+ red: "bg-red-50 text-red-700 border border-red-200",
} as const;
diff --git a/frontend/hooks/useConfig.ts b/frontend/hooks/useConfig.ts
index 8d4c4ccea..7f616e65c 100644
--- a/frontend/hooks/useConfig.ts
+++ b/frontend/hooks/useConfig.ts
@@ -290,6 +290,9 @@ export function useConfig() {
// Whether config has selected an Embedding model
const isEmbeddingAvailable = !!(config?.models?.embedding?.modelName || config?.models?.embedding?.displayName);
+ // Whether config has selected a Multi-Embedding model
+ const isMultiEmbeddingAvailable = !!(config?.models?.multiEmbedding?.modelName || config?.models?.multiEmbedding?.displayName);
+
// Default LLM model name from config (modelName or displayName)
const defaultLlmModelName = config?.models?.llm?.modelName || config?.models?.llm?.displayName || "";
@@ -369,6 +372,7 @@ export function useConfig() {
modelConfig: config?.models,
isVlmAvailable,
isEmbeddingAvailable,
+ isMultiEmbeddingAvailable,
defaultLlmModelName,
updateAppConfig,
updateModelConfig,
diff --git a/frontend/lib/knowledgeBaseCompatibility.ts b/frontend/lib/knowledgeBaseCompatibility.ts
new file mode 100644
index 000000000..36ab959b3
--- /dev/null
+++ b/frontend/lib/knowledgeBaseCompatibility.ts
@@ -0,0 +1,47 @@
+import { KnowledgeBase } from "@/types/knowledgeBase";
+
+export const isMultimodalConstraintMismatch = (
+ kb: KnowledgeBase,
+ toolMultimodal: boolean | null
+): boolean => {
+ const kbIsMultimodal = Boolean(kb.is_multimodal);
+ return (
+ toolMultimodal !== null &&
+ ((toolMultimodal && !kbIsMultimodal) ||
+ (!toolMultimodal && kbIsMultimodal))
+ );
+};
+
+export const isEmbeddingModelCompatible = (
+ kb: KnowledgeBase,
+ currentEmbeddingModel: string | null,
+ currentMultiEmbeddingModel: string | null
+): boolean => {
+ if (kb.is_multimodal) {
+ if (!currentMultiEmbeddingModel) {
+ return false;
+ }
+ if (
+ kb.embeddingModel &&
+ kb.embeddingModel !== "unknown" &&
+ kb.embeddingModel !== currentMultiEmbeddingModel
+ ) {
+ return false;
+ }
+ return true;
+ }
+
+ if (!currentEmbeddingModel) {
+ return true;
+ }
+
+ if (
+ kb.embeddingModel &&
+ kb.embeddingModel !== "unknown" &&
+ kb.embeddingModel !== currentEmbeddingModel
+ ) {
+ return false;
+ }
+
+ return true;
+};
diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json
index da9e59822..cc7831d91 100644
--- a/frontend/public/locales/en/common.json
+++ b/frontend/public/locales/en/common.json
@@ -587,6 +587,7 @@
"knowledgeBase.modal.dataMateConfig.urlPlaceholder": "Enter DataMate server address",
"knowledgeBase.modal.dataMateConfig.description": "Configure the DataMate server address for synchronizing external knowledge base data.",
"knowledgeBase.message.nameRequired": "Please enter knowledge base name",
+ "knowledgeBase.message.embeddingModelRequired": "Please select a vector model",
"knowledgeBase.message.nameExists": "Knowledge base {{name}} already exists, please use a different name",
"knowledgeBase.error.nameExistsInOtherTenant": "Knowledge base {{name}} is used by another tenant, please use a different name",
"knowledgeBase.message.createError": "Failed to create knowledge base",
@@ -701,7 +702,7 @@
"document.chunk.error.updateFailed": "Failed to update chunk",
"document.chunk.error.deleteFailed": "Failed to delete chunk",
"document.chunk.error.missingChunkId": "Chunk identifier is missing",
- "document.chunk.tooltip.disabledDueToModelMismatch": "The currently configured embedding model ({{currentModel}}) does not match the knowledge base model ({{knowledgeBaseModel}}). You cannot create chunks or perform retrieval until you use the same embedding model as the knowledge base.",
+ "document.chunk.tooltip.disabledDueToModelMismatch": "The currently configured embedding model ({{currentModel}}) does not match the knowledge base model ({{knowledgeBaseModel}}).",
"document.chunk.form.createTitle": "Create chunk",
"document.chunk.form.editTitle": "Edit chunk",
"document.chunk.form.documentName": "Document",
@@ -2185,6 +2186,7 @@
"errorCode.990105": "Internal server error. Please try again later.",
"errorCode.990201": "Configuration not found.",
"errorCode.990202": "Configuration update failed.",
+ "embedding.model.notConfigured": "Not configured",
"a2a.discovery.title": "A2A Agent Discovery",
"a2a.discovery.tab.url": "URL Discovery",
diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json
index d2fd6136f..21f935ca0 100644
--- a/frontend/public/locales/zh/common.json
+++ b/frontend/public/locales/zh/common.json
@@ -702,7 +702,7 @@
"document.chunk.error.updateFailed": "分片更新失败",
"document.chunk.error.deleteFailed": "分片删除失败",
"document.chunk.error.missingChunkId": "缺少分片 ID",
- "document.chunk.tooltip.disabledDueToModelMismatch": "当前配置的向量模型 ({{currentModel}}) 与创建知识库所用模型 ({{knowledgeBaseModel}}) 不一致,无法创建分片或召回检索。",
+ "document.chunk.tooltip.disabledDueToModelMismatch": "当前配置的向量模型 ({{currentModel}}) 与创建知识库所用模型 ({{knowledgeBaseModel}}) 不一致",
"document.chunk.form.createTitle": "新建分片",
"document.chunk.form.editTitle": "编辑分片",
"document.chunk.form.documentName": "所属文档",
@@ -2257,6 +2257,7 @@
"errorCode.990105": "服务器内部错误,请稍后重试",
"errorCode.990201": "配置不存在",
"errorCode.990202": "配置更新失败",
+ "embedding.model.notConfigured": "未配置",
"a2a.discovery.title": "A2A Agent 发现",
"a2a.discovery.tab.url": "URL 发现",
diff --git a/frontend/services/api.ts b/frontend/services/api.ts
index 0314e0a82..6c832e025 100644
--- a/frontend/services/api.ts
+++ b/frontend/services/api.ts
@@ -149,10 +149,10 @@ export const API_ENDPOINTS = {
`${API_BASE_URL}/model/delete?display_name=${encodeURIComponent(
displayName
)}`,
- customModelHealthcheck: (displayName: string) =>
+ customModelHealthcheck: (displayName: string, modelType: string) =>
`${API_BASE_URL}/model/healthcheck?display_name=${encodeURIComponent(
displayName
- )}`,
+ )}&model_type=${encodeURIComponent(modelType)}`,
verifyModelConfig: `${API_BASE_URL}/model/temporary_healthcheck`,
updateSingleModel: (displayName: string) =>
`${API_BASE_URL}/model/update?display_name=${encodeURIComponent(displayName)}`,
diff --git a/frontend/services/knowledgeBaseService.ts b/frontend/services/knowledgeBaseService.ts
index 797d45f40..bd13de32d 100644
--- a/frontend/services/knowledgeBaseService.ts
+++ b/frontend/services/knowledgeBaseService.ts
@@ -19,6 +19,20 @@ import log from "@/lib/logger";
// @ts-ignore
const fetch: typeof fetchWithAuth = fetchWithAuth;
+const normalizeIsMultimodal = (value: unknown): boolean => {
+ if (value === true) return true;
+ if (value === false || value == null) return false;
+ if (typeof value === "string") {
+ const normalized = value.trim().toLowerCase();
+ return normalized === "y" || normalized === "true" || normalized === "yes";
+ }
+ if (typeof value === "number") return value === 1;
+ return false;
+};
+
+const resolveIsMultimodal = (indexInfo: any, stats: any): boolean =>
+ normalizeIsMultimodal(indexInfo.is_multimodal ?? stats.is_multimodal);
+
// Knowledge base service class
class KnowledgeBaseService {
// Check Elasticsearch health (force refresh, no caching for setup page)
@@ -545,6 +559,7 @@ class KnowledgeBaseService {
stats.update_date ||
stats.creation_date ||
null,
+ is_multimodal: resolveIsMultimodal(indexInfo, stats),
// Use embedding_model_name (display_name) from backend, fallback to ES stats
embeddingModel: indexInfo.embedding_model_name || stats.embedding_model || "unknown",
summaryFrequency: indexInfo.summary_frequency || null,
@@ -616,6 +631,7 @@ class KnowledgeBaseService {
createdAt: stats.creation_date || null,
updatedAt: stats.update_date || stats.creation_date || null,
embeddingModel: stats.embedding_model || "unknown",
+ is_multimodal: resolveIsMultimodal(indexInfo, stats),
knowledge_sources:
indexInfo.knowledge_sources || "datamate",
ingroup_permission: indexInfo.ingroup_permission || "",
@@ -738,13 +754,15 @@ class KnowledgeBaseService {
const requestBody: {
name: string;
description: string;
- embedding_model_name?: string;
+ embeddingModel?: string;
ingroup_permission?: string;
group_ids?: number[];
+ is_multimodal?: boolean;
} = {
name: params.name,
description: params.description || "",
- embedding_model_name: params.embeddingModel || "",
+ embeddingModel: params.embeddingModel || "",
+ is_multimodal: params.is_multimodal || false,
};
// Include group permission and user groups if provided
@@ -779,6 +797,7 @@ class KnowledgeBaseService {
chunkCount: 0,
createdAt: new Date().toISOString(),
embeddingModel: params.embeddingModel || "",
+ is_multimodal: params.is_multimodal || false,
avatar: "",
chunkNum: 0,
language: "",
@@ -888,7 +907,8 @@ class KnowledgeBaseService {
async uploadDocuments(
kbId: string,
files: File[],
- chunkingStrategy?: string
+ chunkingStrategy?: string,
+ modelId?: number
): Promise {
try {
// Create FormData object
@@ -950,6 +970,7 @@ class KnowledgeBaseService {
files: filesToProcess,
chunking_strategy: chunkingStrategy,
destination: "minio",
+ model_id: modelId,
}),
});
diff --git a/frontend/services/modelService.ts b/frontend/services/modelService.ts
index e0fefd2db..e7734a135 100644
--- a/frontend/services/modelService.ts
+++ b/frontend/services/modelService.ts
@@ -430,12 +430,13 @@ export const modelService = {
// Verify custom model connection
verifyCustomModel: async (
displayName: string,
+ modelType: string,
signal?: AbortSignal
): Promise => {
try {
if (!displayName) return false;
const response = await fetch(
- API_ENDPOINTS.model.customModelHealthcheck(displayName),
+ API_ENDPOINTS.model.customModelHealthcheck(displayName, modelType),
{
method: "POST",
headers: getAuthHeaders(),
@@ -461,6 +462,7 @@ export const modelService = {
checkManageTenantModelConnectivity: async (
tenantId: string,
displayName: string,
+ modelType: string,
signal?: AbortSignal
): Promise => {
try {
@@ -474,6 +476,7 @@ export const modelService = {
body: JSON.stringify({
tenant_id: tenantId,
display_name: displayName,
+ model_type: modelType
}),
signal,
});
diff --git a/frontend/tsconfig.json b/frontend/tsconfig.json
index d61634fac..75f792957 100644
--- a/frontend/tsconfig.json
+++ b/frontend/tsconfig.json
@@ -8,7 +8,7 @@
"noEmit": true,
"esModuleInterop": true,
"module": "esnext",
- "moduleResolution": "node",
+ "moduleResolution": "bundler",
"resolveJsonModule": true,
"isolatedModules": true,
"jsx": "preserve",
diff --git a/frontend/types/knowledgeBase.ts b/frontend/types/knowledgeBase.ts
index 550431a04..7caf4986c 100644
--- a/frontend/types/knowledgeBase.ts
+++ b/frontend/types/knowledgeBase.ts
@@ -20,6 +20,7 @@ export interface KnowledgeBase {
// Last update time of the knowledge base/index (may fall back to createdAt)
updatedAt?: any;
embeddingModel: string;
+ is_multimodal?: boolean;
knowledge_sources?: string;
ingroup_permission?: string;
group_ids?: number[];
@@ -47,6 +48,7 @@ export interface KnowledgeBaseCreateParams {
// Group permission and user groups for new knowledge bases
ingroup_permission?: string;
group_ids?: number[];
+ is_multimodal?: boolean;
}
// Document type
@@ -114,6 +116,7 @@ export interface KnowledgeBaseState {
selectedIds: string[];
activeKnowledgeBase: KnowledgeBase | null;
currentEmbeddingModel: string | null;
+ currentMultiEmbeddingModel: string | null;
isLoading: boolean;
syncLoading: boolean;
error: string | null;
diff --git a/k8s/helm/nexent/charts/nexent-common/templates/configmap.yaml b/k8s/helm/nexent/charts/nexent-common/templates/configmap.yaml
index b740ec2f1..bc7d63c44 100644
--- a/k8s/helm/nexent/charts/nexent-common/templates/configmap.yaml
+++ b/k8s/helm/nexent/charts/nexent-common/templates/configmap.yaml
@@ -53,6 +53,8 @@ data:
# Model Path Config
CLIP_MODEL_PATH: {{ .Values.config.modelPath.clipModelPath | quote }}
NLTK_DATA: {{ .Values.config.modelPath.nltkData | quote }}
+ TABLE_TRANSFORMER_MODEL_PATH: {{ .Values.config.modelPath.tableTransformerModelPath | quote }}
+ UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH: {{ .Values.config.modelPath.unstructuredDefaultModelInitializeParamsJsonPath | quote }}
# Terminal Tool SSH Config
SSH_PRIVATE_KEY_PATH: {{ .Values.config.terminal.sshPrivateKeyPath | quote }}
diff --git a/k8s/helm/nexent/charts/nexent-common/values.yaml b/k8s/helm/nexent/charts/nexent-common/values.yaml
index dc694a4b9..e45d04d84 100644
--- a/k8s/helm/nexent/charts/nexent-common/values.yaml
+++ b/k8s/helm/nexent/charts/nexent-common/values.yaml
@@ -54,6 +54,8 @@ config:
modelPath:
clipModelPath: "/opt/models/clip-vit-base-patch32"
nltkData: "/opt/models/nltk_data"
+ tableTransformerModelPath: "/opt/models/table-transformer-structure-recognition"
+ unstructuredDefaultModelInitializeParamsJsonPath: "/opt/models/yolox"
terminal:
sshPrivateKeyPath: "/path/to/openssh-server/ssh-keys/openssh_server_key"
supabase:
diff --git a/make/data_process/Dockerfile b/make/data_process/Dockerfile
index 7903cfd92..8d9a8a723 100644
--- a/make/data_process/Dockerfile
+++ b/make/data_process/Dockerfile
@@ -8,24 +8,24 @@ USER root
# Configure apt sources based on build argument
RUN if [ "$APT_MIRROR" = "tsinghua" ]; then \
- rm -f /etc/apt/sources.list.d/* && \
- echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm main contrib non-free non-free-firmware" > /etc/apt/sources.list && \
- echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm-updates main contrib non-free non-free-firmware" >> /etc/apt/sources.list && \
- echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm-backports main contrib non-free non-free-firmware" >> /etc/apt/sources.list && \
- echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian-security bookworm-security main contrib non-free non-free-firmware" >> /etc/apt/sources.list; \
+ rm -f /etc/apt/sources.list.d/* && \
+ echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm main contrib non-free non-free-firmware" > /etc/apt/sources.list && \
+ echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm-updates main contrib non-free non-free-firmware" >> /etc/apt/sources.list && \
+ echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm-backports main contrib non-free non-free-firmware" >> /etc/apt/sources.list && \
+ echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian-security bookworm-security main contrib non-free non-free-firmware" >> /etc/apt/sources.list; \
fi && \
apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/*
RUN apt-get update && \
apt-get install -y --no-install-recommends --fix-missing \
- curl \
- libmagic1 \
- libmagic-dev \
- libreoffice \
- libgl1 \
- coreutils \
- fontconfig \
- fonts-noto-cjk \
+ curl \
+ libmagic1 \
+ libmagic-dev \
+ libreoffice \
+ libgl1 \
+ coreutils \
+ fontconfig \
+ fonts-noto-cjk \
&& fc-cache -fv \
&& apt-get autoremove -y \
&& apt-get clean \
@@ -35,6 +35,8 @@ RUN pip install --no-cache-dir uv $(test -n "$MIRROR" && echo "-i $MIRROR")
# Layer 0: copy model assets
COPY model-assets/clip-vit-base-patch32 /opt/models/clip-vit-base-patch32
COPY model-assets/nltk_data /opt/models/nltk_data
+COPY model-assets/table-transformer-structure-recognition /opt/models/table-transformer-structure-recognition
+COPY model-assets/yolox /opt/models/yolox
WORKDIR /opt/backend
# Layer 1: install base dependencies
diff --git a/sdk/nexent/core/models/embedding_model.py b/sdk/nexent/core/models/embedding_model.py
index 092877941..ac8a63186 100644
--- a/sdk/nexent/core/models/embedding_model.py
+++ b/sdk/nexent/core/models/embedding_model.py
@@ -21,6 +21,7 @@ def __init__(
api_key: str = None,
embedding_dim: int = None,
ssl_verify: bool = True,
+ model_type: str = None
):
"""
Initialize the embedding model.
@@ -86,6 +87,7 @@ def __init__(
api_key: str = None,
embedding_dim: int = None,
ssl_verify: bool = True,
+ model_type: str = None
):
super().__init__(model_name, base_url, api_key, embedding_dim, ssl_verify=ssl_verify)
@@ -128,6 +130,7 @@ def __init__(
api_key: str = None,
embedding_dim: int = None,
ssl_verify: bool = True,
+ model_type: str = None
):
super().__init__(model_name, base_url, api_key, embedding_dim, ssl_verify=ssl_verify)
@@ -164,6 +167,7 @@ def __init__(
model_name: str = "jina-clip-v2",
embedding_dim: int = 1024,
ssl_verify: bool = True,
+ model_type: str = "multimodal"
):
"""Initialize JinaEmbedding with configuration."""
self.api_key = api_key
@@ -171,6 +175,7 @@ def __init__(
self.model = model_name
self.embedding_dim = embedding_dim
self.ssl_verify = ssl_verify
+ self.model_type = model_type
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
@@ -322,13 +327,14 @@ async def dimension_check(self, timeout: float = 5.0) -> List[List[float]]:
class OpenAICompatibleEmbedding(TextEmbedding):
- def __init__(self, model_name: str, base_url: str, api_key: str, embedding_dim: int, ssl_verify: bool = True):
+ def __init__(self, model_name: str, base_url: str, api_key: str, embedding_dim: int, model_type: str = "text", ssl_verify: bool = True):
"""Initialize OpenAICompatibleEmbedding with configuration from environment variables or provided parameters."""
self.api_key = api_key
self.api_url = base_url
self.model = model_name
self.embedding_dim = embedding_dim
self.ssl_verify = ssl_verify
+ self.model_type=model_type
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
diff --git a/sdk/nexent/core/tools/knowledge_base_search_tool.py b/sdk/nexent/core/tools/knowledge_base_search_tool.py
index e3fb2916c..4120814c5 100644
--- a/sdk/nexent/core/tools/knowledge_base_search_tool.py
+++ b/sdk/nexent/core/tools/knowledge_base_search_tool.py
@@ -1,19 +1,23 @@
import json
import logging
+import os
from typing import List, Optional
from pydantic import Field
-from smolagents.tools import Tool
from pydantic.fields import FieldInfo
+from smolagents.tools import Tool
+
from ...vector_database.base import VectorDatabaseCore
from ..models.embedding_model import BaseEmbedding
from ..models.rerank_model import BaseRerank
-from ..utils.observer import MessageObserver, ProcessType
from ..utils.constants import RERANK_OVERSEARCH_MULTIPLIER
-from ..utils.tools_common_message import SearchResultTextMessage, ToolCategory, ToolSign
-
+from ..utils.observer import MessageObserver, ProcessType
+from ..utils.tools_common_message import (
+ SearchResultTextMessage,
+ ToolCategory,
+ ToolSign,
+)
-# Get logger instance
logger = logging.getLogger("knowledge_base_search_tool")
@@ -28,62 +32,59 @@ class KnowledgeBaseSearchTool(Tool):
"domain expertise, personal notes, or any information that has been indexed in the knowledge base. "
"Suitable for queries requiring access to stored knowledge that may not be publicly available."
)
-
- description_zh = "基于你的查询词在本地知识库中进行搜索,返回最相关的搜索结果。适用于检索本地知识库中存储的领域专业知识、文档和信息。当用户询问与专业知识、技术文档、领域专长、个人笔记或任何已在知识库中建立索引的信息相关的问题时,请使用此工具。适合需要访问非公开存储知识的查询。"
+ description_zh = "执行本地知识库检索并返回最相关的结果。"
inputs = {
"query": {
"type": "string",
"description": "The search query to perform.",
- "description_zh": "要执行的搜索查询词"
- },
- "index_names": {
- "type": "array",
- "description": "The list of index names to search",
- "description_zh": "要索引的知识库",
- "nullable": True
+ "description_zh": "要执行的搜索查询。",
},
}
init_param_descriptions = {
"top_k": {
"description": "Maximum number of search results",
- "description_zh": "返回搜索结果的最大数量"
+ "description_zh": "返回搜索结果的最大数量。",
},
-
"search_mode": {
"description": "The search mode, optional values: hybrid, accurate, semantic",
- "description_zh": "搜索模式,可选值:hybrid(混合)、accurate(精确)、semantic(语义)"
- }
+ "description_zh": "搜索模式,可选:hybrid、accurate、semantic。",
+ },
}
+
output_type = "string"
category = ToolCategory.SEARCH.value
-
- # Used to distinguish different index sources for summaries
tool_sign = ToolSign.KNOWLEDGE_BASE.value
def __init__(
self,
top_k: int = Field(
- description="Maximum number of search results", default=3),
+ description="Maximum number of search results", default=3
+ ),
index_names: List[str] = Field(
- description="The list of index names to search"),
+ description="The list of index names to search"
+ ),
search_mode: str = Field(
description="the search mode, optional values: hybrid, accurate, semantic",
default="hybrid",
),
rerank: bool = Field(
description="Whether to enable reranking for search results",
- default=False),
+ default=False,
+ ),
rerank_model_name: str = Field(
- description="The name of the rerank model to use",
- default=""),
+ description="The name of the rerank model to use", default=""
+ ),
observer: MessageObserver = Field(
- description="Message observer", default=None, exclude=True),
+ description="Message observer", default=None, exclude=True
+ ),
embedding_model: BaseEmbedding = Field(
- description="The embedding model to use", default=None, exclude=True),
+ description="The embedding model to use", default=None, exclude=True
+ ),
rerank_model: BaseRerank = Field(
- description="The rerank model to use", default=None, exclude=True),
+ description="The rerank model to use", default=None, exclude=True
+ ),
vdb_core: VectorDatabaseCore = Field(
description="Vector database client", default=None, exclude=True),
display_name_to_index_map: dict = Field(
@@ -112,12 +113,28 @@ def __init__(
self.rerank = rerank
self.rerank_model_name = rerank_model_name
self.rerank_model = rerank_model
+ self.data_process_service = os.getenv("DATA_PROCESS_SERVICE")
self.display_name_to_index_map = display_name_to_index_map
- self.record_ops = 1 # To record serial number
+ self.record_ops = 1
self.running_prompt_zh = "知识库检索中..."
self.running_prompt_en = "Searching the knowledge base..."
+ @staticmethod
+ def _resolve_field_value(value, default):
+ if isinstance(value, FieldInfo):
+ return value.default if value.default is not None else default
+ return default if value is None else value
+
+ def _resolve_index_names(self) -> List[str]:
+ raw_index_names = self._resolve_field_value(self.index_names, None)
+ if raw_index_names is None:
+ return []
+ if isinstance(raw_index_names, str):
+ return [name.strip() for name in raw_index_names.split(",") if name.strip()]
+ if isinstance(raw_index_names, list):
+ return [str(name).strip() for name in raw_index_names if str(name).strip()]
+ return []
def _convert_to_index_names(self, names: List[str]) -> List[str]:
"""Convert display names (knowledge_name) to index names if necessary.
@@ -148,9 +165,9 @@ def _convert_to_index_names(self, names: List[str]) -> List[str]:
converted_names.append(name)
return converted_names
- def forward(self, query: str, index_names: Optional[List[str]] = None) -> str:
- # Parse index_names from string (always required)
- search_index_names = index_names if index_names is not None else self.index_names
+ def forward(self, query: str) -> str:
+ # index_names is configured in tool init params, not runtime inputs
+ search_index_names = self._resolve_index_names()
# Convert display names to index names if necessary
search_index_names = self._convert_to_index_names(search_index_names)
@@ -158,17 +175,13 @@ def forward(self, query: str, index_names: Optional[List[str]] = None) -> str:
# Use the instance search_mode
search_mode = self.search_mode
- # Send tool run message
- if self.observer:
- running_prompt = self.running_prompt_zh if self.observer.lang == "zh" else self.running_prompt_en
- self.observer.add_message("", ProcessType.TOOL, running_prompt)
- card_content = [{"icon": "search", "text": query}]
- self.observer.add_message("", ProcessType.CARD, json.dumps(
- card_content, ensure_ascii=False))
+ self._notify_search_start(query)
- # Log the index_names being used for this search
logger.info(
- f"KnowledgeBaseSearchTool called with query: '{query}', search_mode: '{search_mode}', index_names: {search_index_names}"
+ "KnowledgeBaseSearchTool called with query: '%s', search_mode: '%s', index_names: %s",
+ query,
+ search_mode,
+ search_index_names,
)
# Compute effective top_k for initial search:
@@ -192,69 +205,125 @@ def forward(self, query: str, index_names: Optional[List[str]] = None) -> str:
if len(search_index_names) == 0:
return json.dumps("No knowledge base selected. No relevant information found.", ensure_ascii=False)
- if search_mode == "hybrid":
- kb_search_data = self.search_hybrid(
- query=query, index_names=search_index_names, top_k=effective_top_k)
- elif search_mode == "accurate":
- kb_search_data = self.search_accurate(
- query=query, index_names=search_index_names, top_k=effective_top_k)
- elif search_mode == "semantic":
- kb_search_data = self.search_semantic(
- query=query, index_names=search_index_names, top_k=effective_top_k)
- else:
- raise Exception(
- f"Invalid search mode: {search_mode}, only support: hybrid, accurate, semantic")
-
+ kb_search_data = self._run_search(
+ query=query,
+ index_names=search_index_names,
+ search_mode=search_mode,
+ top_k=effective_top_k,
+ )
kb_search_results = kb_search_data["results"]
if not kb_search_results:
- raise Exception(
- "No results found! Try a less restrictive/shorter query.")
+ raise Exception("No results found! Try a less restrictive/shorter query.")
- # Apply reranking if enabled
if self.rerank and self.rerank_model and kb_search_results:
- try:
- # Extract document contents for reranking
- documents = [
- result.get("content", "") for result in kb_search_results
- ]
- # Perform reranking on all retrieved candidates
- reranked_results = self.rerank_model.rerank(
- query=query,
- documents=documents,
- top_n=len(documents)
+ kb_search_results = self._apply_rerank(
+ query=query,
+ kb_search_results=kb_search_results,
+ top_k=self.top_k,
+ )
+
+ (
+ search_results_json,
+ search_results_return,
+ images_list_url,
+ ) = self._build_search_results(kb_search_results)
+
+ self.record_ops += len(search_results_return)
+
+ self._record_search_results(
+ search_results_json=search_results_json,
+ images_list_url=images_list_url,
+ query=query,
+ )
+
+ return json.dumps(search_results_return, ensure_ascii=False)
+
+ def _notify_search_start(self, query: str) -> None:
+ if not self.observer:
+ return
+ running_prompt = (
+ self.running_prompt_zh
+ if self.observer.lang == "zh"
+ else self.running_prompt_en
+ )
+ self.observer.add_message("", ProcessType.TOOL, running_prompt)
+ card_content = [{"icon": "search", "text": query}]
+ self.observer.add_message(
+ "", ProcessType.CARD, json.dumps(card_content, ensure_ascii=False)
+ )
+
+ def _run_search(self, query: str, index_names: List[str], search_mode: str, top_k: int):
+ search_handlers = {
+ "hybrid": self.search_hybrid,
+ "accurate": self.search_accurate,
+ "semantic": self.search_semantic,
+ }
+ handler = search_handlers.get(search_mode)
+ if not handler:
+ raise Exception(
+ f"Invalid search mode: {search_mode}, only support: hybrid, accurate, semantic"
+ )
+ return handler(query=query, index_names=index_names, top_k=top_k)
+
+ def _apply_rerank(
+ self,
+ query: str,
+ kb_search_results: List[dict],
+ top_k: int,
+ ) -> List[dict]:
+ try:
+ documents = [result.get("content", "") for result in kb_search_results]
+ reranked_results = self.rerank_model.rerank(
+ query=query,
+ documents=documents,
+ top_n=len(documents),
+ )
+ if not reranked_results:
+ return kb_search_results
+
+ original_results_map = {
+ i: kb_search_results[i] for i in range(len(kb_search_results))
+ }
+ reranked_top_results = []
+ for reranked_item in reranked_results[:top_k]:
+ orig_idx = reranked_item.get("index")
+ if orig_idx is None or orig_idx not in original_results_map:
+ continue
+ result = original_results_map[orig_idx]
+ result["score"] = reranked_item.get(
+ "relevance_score", result.get("score", 0)
)
- # Reorder and trim to top_k after reranking
- if reranked_results:
- original_results_map = {
- i: kb_search_results[i] for i in range(len(kb_search_results))
- }
- kb_search_results = []
- for reranked_item in reranked_results[: self.top_k]:
- orig_idx = reranked_item.get("index")
- if orig_idx is not None and orig_idx in original_results_map:
- result = original_results_map[orig_idx]
- result["score"] = reranked_item.get(
- "relevance_score", result.get("score", 0)
- )
- kb_search_results.append(result)
- logger.info(
- f"Reranking applied: selected top {self.top_k} from "
- f"{len(documents)} candidates"
- )
- except Exception as e:
- logger.warning(f"Reranking failed, using original results: {str(e)}")
+ reranked_top_results.append(result)
+
+ if reranked_top_results:
+ logger.info(
+ "Reranking applied: selected top %s from %s candidates",
+ top_k,
+ len(documents),
+ )
+ return reranked_top_results
+ return kb_search_results
+ except Exception as e:
+ logger.warning("Reranking failed, using original results: %s", str(e))
+ return kb_search_results
+
+ @staticmethod
+ def _normalize_source_type(source_type: str) -> str:
+ return "file" if source_type in ["local", "minio"] else source_type
+
+ def _build_search_results(self, kb_search_results):
+ search_results_json = []
+ search_results_return = []
+ images_list_url = []
- search_results_json = [] # Organize search results into a unified format
- search_results_return = [] # Format for input to the large model
for index, single_search_result in enumerate(kb_search_results):
- # Temporarily correct the source_type stored in the knowledge base
- source_type = single_search_result.get("source_type", "")
- source_type = "file" if source_type in [
- "local", "minio"] else source_type
- title = single_search_result.get("title")
- if not title:
- title = single_search_result.get("filename", "")
+ source_type = self._normalize_source_type(
+ single_search_result.get("source_type", "")
+ )
+ title = single_search_result.get("title") or single_search_result.get(
+ "filename", ""
+ )
search_result_message = SearchResultTextMessage(
title=title,
text=single_search_result.get("content", ""),
@@ -269,31 +338,72 @@ def forward(self, query: str, index_names: Optional[List[str]] = None) -> str:
tool_sign=self.tool_sign,
)
+ image_url = self._extract_image_url(single_search_result)
+ if image_url:
+ images_list_url.append(image_url)
+
search_results_json.append(search_result_message.to_dict())
search_results_return.append(search_result_message.to_model_dict())
- self.record_ops += len(search_results_return)
+ return search_results_json, search_results_return, images_list_url
+
+ @staticmethod
+ def _extract_image_url(single_search_result):
+ if single_search_result.get("process_source") != "UniversalImageExtractor":
+ return None
+ try:
+ meta_data = json.loads(single_search_result.get("content"))
+ except (json.JSONDecodeError, TypeError):
+ logger.error("Failed to parse image metadata")
+ return None
+ return meta_data.get("image_url", None)
+
+ def _record_search_results(
+ self,
+ search_results_json: List[dict],
+ images_list_url: List[str],
+ query: str,
+ ) -> None:
+ if not self.observer:
+ return
+
+ search_results_data = json.dumps(search_results_json, ensure_ascii=False)
+ self.observer.add_message("", ProcessType.SEARCH_CONTENT, search_results_data)
+
+ if not images_list_url:
+ return
+
+ filtered_images = images_list_url
+ image_filter = getattr(self, "_filter_images", None)
+ if callable(image_filter):
+ try:
+ maybe_filtered = image_filter(images_list_url, query)
+ if maybe_filtered:
+ filtered_images = maybe_filtered
+ except Exception as e:
+ logger.warning("Image filtering failed, using original list: %s", str(e))
- # Record the detailed content of this search
- if self.observer:
- search_results_data = json.dumps(
- search_results_json, ensure_ascii=False)
+ if filtered_images:
+ search_images_list_json = json.dumps(
+ {"images_url": filtered_images}, ensure_ascii=False
+ )
self.observer.add_message(
- "", ProcessType.SEARCH_CONTENT, search_results_data)
- return json.dumps(search_results_return, ensure_ascii=False)
+ "", ProcessType.PICTURE_WEB, search_images_list_json
+ )
def search_hybrid(self, query, index_names, top_k):
try:
results = self.vdb_core.hybrid_search(
- index_names=index_names, query_text=query, embedding_model=self.embedding_model, top_k=top_k
+ index_names=index_names,
+ query_text=query,
+ embedding_model=self.embedding_model,
+ top_k=top_k,
)
- # Format results
formatted_results = []
for result in results:
doc = result["document"]
doc["score"] = result["score"]
- # Include source index in results
doc["index"] = result["index"]
formatted_results.append(doc)
@@ -302,19 +412,20 @@ def search_hybrid(self, query, index_names, top_k):
"total": len(formatted_results),
}
except Exception as e:
- raise Exception(f"Error during semantic search: {str(e)}")
+ raise Exception(f"Error during hybrid search: {str(e)}")
def search_accurate(self, query, index_names, top_k):
try:
results = self.vdb_core.accurate_search(
- index_names=index_names, query_text=query, top_k=top_k)
+ index_names=index_names,
+ query_text=query,
+ top_k=top_k,
+ )
- # Format results
formatted_results = []
for result in results:
doc = result["document"]
doc["score"] = result["score"]
- # Include source index in results
doc["index"] = result["index"]
formatted_results.append(doc)
@@ -323,20 +434,21 @@ def search_accurate(self, query, index_names, top_k):
"total": len(formatted_results),
}
except Exception as e:
- raise Exception(detail=f"Error during accurate search: {str(e)}")
+ raise Exception(f"Error during accurate search: {str(e)}")
def search_semantic(self, query, index_names, top_k):
try:
results = self.vdb_core.semantic_search(
- index_names=index_names, query_text=query, embedding_model=self.embedding_model, top_k=top_k
+ index_names=index_names,
+ query_text=query,
+ embedding_model=self.embedding_model,
+ top_k=top_k,
)
- # Format results
formatted_results = []
for result in results:
doc = result["document"]
doc["score"] = result["score"]
- # Include source index in results
doc["index"] = result["index"]
formatted_results.append(doc)
@@ -345,4 +457,79 @@ def search_semantic(self, query, index_names, top_k):
"total": len(formatted_results),
}
except Exception as e:
- raise Exception(detail=f"Error during semantic search: {str(e)}")
+ raise Exception(f"Error during semantic search: {str(e)}")
+
+ def _filter_images(self, images_list_url, query) -> list:
+ """
+ Execute image filtering operation directly using the data processing service
+ :param images_list_url: List of image URLs to filter
+ :param query: Search query, used to filter images related to the query
+ """
+ import asyncio
+ import aiohttp
+
+ final_filtered_images = []
+ try:
+ # Define positive and negative prompts
+ positive_prompt = query
+ negative_prompt = "logo or banner or background or advertisement or icon or avatar"
+
+ # Define the async function to perform the filtering
+ async def process_images():
+ # Maximum number of concurrent requests
+ semaphore = asyncio.Semaphore(10) # Limit concurrent requests
+
+ # Create a ClientSession
+ connector = aiohttp.TCPConnector(limit=0)
+ timeout = aiohttp.ClientTimeout(total=2)
+
+ async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
+ # Create a function to process a single image
+ async def process_single_image(img_url):
+ async with semaphore:
+ try:
+ api_url = f"{self.data_process_service}/tasks/filter_important_image"
+ data = {
+ 'image_url': img_url,
+ 'positive_prompt': positive_prompt,
+ 'negative_prompt': negative_prompt
+ }
+ async with session.post(api_url, data=data) as response:
+ if response.status != 200:
+ logger.info(
+ f"API error for {img_url}: {response.status}")
+ return None
+ result = await response.json()
+ if result.get("is_important", False):
+ logger.info(
+ f"Important image: {img_url}")
+ return img_url
+ return None
+ except Exception as e:
+ logger.info(
+ f"Error processing image {img_url}: {str(e)}")
+ return None
+ tasks = [process_single_image(url)
+ for url in images_list_url]
+ results = await asyncio.gather(*tasks)
+ filtered_images = [
+ url for url in results if url is not None]
+
+ # Return the filtered list from the inner async function
+ return filtered_images
+
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ try:
+ # Capture the return value from the async execution
+ final_filtered_images = loop.run_until_complete(
+ process_images())
+ finally:
+ loop.close()
+ except Exception as e:
+ logger.info(f"Image filtering error: {str(e)}")
+ return []
+
+ # Return the final list to the caller
+ return final_filtered_images
+
diff --git a/sdk/nexent/core/utils/favicon_extractor.py b/sdk/nexent/core/utils/favicon_extractor.py
index 17fe675d8..f68cc6d16 100644
--- a/sdk/nexent/core/utils/favicon_extractor.py
+++ b/sdk/nexent/core/utils/favicon_extractor.py
@@ -1,33 +1,29 @@
import requests
from urllib.parse import urlparse
-def get_favicon_url(page_url):
- """
- 从给定网页URL提取favicon图标地址
- 参数:
- page_url (str): 要分析的网页URL
+def get_favicon_url(page_url: str) -> str:
+ """Build the default favicon URL for a given page URL.
- 返回:
- str: favicon图标的完整URL,如果找不到则返回None
- """
+ Args:
+ page_url: Target page URL.
- # 解析输入URL
+ Returns:
+ Default favicon URL.
+ """
parsed_url = urlparse(page_url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
- default_favicon = f"{base_url}/favicon.ico"
- return default_favicon
+ return f"{base_url}/favicon.ico"
-def check_favicon_exists(url):
- """
- 检查给定的favicon URL是否有效
+def check_favicon_exists(url: str) -> bool:
+ """Check whether a favicon URL exists.
- 参数:
- url (str): 要检查的favicon URL
+ Args:
+ url: Favicon URL to check.
- 返回:
- bool: 如果URL存在且返回200状态码则为True
+ Returns:
+ True if the favicon exists, otherwise False.
"""
try:
response = requests.head(url, timeout=3, allow_redirects=True)
@@ -40,17 +36,5 @@ def check_favicon_exists(url):
url = "https://www.travelking.com.tw/zh-cn/tourguide/scenery100577.html"
# url = "https://apps.apple.com/cn/app/wemeeting/id1480497919"
- # 获取favicon URL
- import time
- start = time.time()
- favicon_url = get_favicon_url(url)
-
- if favicon_url:
- print(f"找到favicon: {favicon_url}")
- else:
- print("未找到favicon")
- end = time.time()
- print(str(end - start))
-
- print(check_favicon_exists(favicon_url))
-
+ # Manual smoke check for favicon existence.
+ _ = check_favicon_exists(get_favicon_url(url))
diff --git a/sdk/nexent/data_process/core.py b/sdk/nexent/data_process/core.py
index 4df84de7b..e0685aecd 100644
--- a/sdk/nexent/data_process/core.py
+++ b/sdk/nexent/data_process/core.py
@@ -1,7 +1,9 @@
import logging
import os
+from typing import Any, Dict, List, Optional, Tuple
+
+from .extract_image import UniversalImageExtractor
from io import BytesIO
-from typing import Dict, List, Optional
from .base import FileProcessor
from .file_splitter import FileSplitter
@@ -30,9 +32,12 @@ class DataProcessCore:
# Supported chunking strategies
CHUNKING_STRATEGIES = {"basic", "by_title", "none"}
+
+ EXTRACT_IMAGE_EXTENSIONS = {".pdf", ".doc",
+ ".docx", ".xls", ".xlsx", ".ppt", ".pptx"}
# Supported processors
- PROCESSORS = {"Unstructured", "OpenPyxl"}
+ PROCESSORS = {"Unstructured", "OpenPyxl", "UniversalImageExtractor"}
# Supported split extensions (exclude ppt/pptx/html)
SPLIT_EXTENSIONS = {
@@ -56,6 +61,7 @@ def __init__(self):
self.processors: Dict[str, FileProcessor] = {
"Unstructured": UnstructuredProcessor(),
"OpenPyxl": OpenPyxlProcessor(),
+ "UniversalImageExtractor": UniversalImageExtractor(),
"FileSplitter": FileSplitter(),
}
logger.debug("DataProcessCore initialization completed")
@@ -67,7 +73,7 @@ def file_process(
chunking_strategy: str = "basic",
processor: Optional[str] = None,
**params,
- ) -> List[Dict]:
+ ) -> Tuple[List[Dict], List[Dict]]:
"""
Facade pattern that automatically detects file type and processes files
@@ -80,11 +86,13 @@ def file_process(
**params: Additional processing parameters
Returns:
- List of processed chunks, each dictionary contains the following fields:
+ Tuple[List[Dict], List[Dict]]: (chunks, images_info)
+ chunks: List of processed chunks, each dictionary contains the following fields:
- content: Text content
- filename: Filename
- metadata: Metadata (optional, includes chunk_index, source_type, etc.)
- language: Language identifier (optional)
+ images_info: List of extracted image metadata dicts (may be empty)
Raises:
ValueError: Invalid parameters
@@ -94,18 +102,32 @@ def file_process(
self._validate_parameters(chunking_strategy, processor)
# Select appropriate processor
- processor_name = processor or self._select_processor_by_filename(
- filename)
+ if processor:
+ processor_name = processor
+ _, extractor = self._select_processor_by_filename(filename, params)
+ else:
+ processor_name, extractor = self._select_processor_by_filename(
+ filename, params)
+
processor_instance = self.processors.get(processor_name)
+ extract_image_processor_instance = (
+ self.processors.get(extractor) if extractor else None
+ )
if not processor_instance:
raise ValueError(f"Unsupported processor: {processor_name}")
+
+ if extract_image_processor_instance:
+ img_info = extract_image_processor_instance.process_file(
+ file_data, chunking_strategy, filename, **params)
+ else:
+ img_info = []
# Process in-memory file
logger.info(
f"Processing in-memory file: {filename} with {processor_name} processor")
try:
- return processor_instance.process_file(file_data, chunking_strategy, filename=filename, **params)
+ return processor_instance.process_file(file_data, chunking_strategy, filename=filename, **params), img_info
except Exception as e:
logger.error(f"File processing failed for {filename}: {str(e)}")
raise
@@ -173,14 +195,21 @@ def _validate_parameters(self, chunking_strategy: str, processor: Optional[str])
logger.debug(
f"Parameter validation passed: chunking_strategy={chunking_strategy}, processor={processor}")
- def _select_processor_by_filename(self, filename: str) -> str:
+ def _select_processor_by_filename(
+ self, filename: str, params: Optional[Dict[str, Any]] = None
+ ) -> Tuple[str, Optional[str]]:
"""Selects a processor based on the file extension."""
_, file_extension = os.path.splitext(filename)
file_extension = file_extension.lower()
+
+ extract_image = None
+ model_type = params.get("model_type")
+ if model_type == "multi_embedding" and file_extension in self.EXTRACT_IMAGE_EXTENSIONS:
+ extract_image = "UniversalImageExtractor"
if file_extension in self.EXCEL_EXTENSIONS:
- return "OpenPyxl"
+ return "OpenPyxl", extract_image
else:
- return "Unstructured"
+ return "Unstructured", extract_image
def get_supported_file_types(self) -> Dict[str, List[str]]:
"""
diff --git a/sdk/nexent/data_process/extract_image.py b/sdk/nexent/data_process/extract_image.py
new file mode 100644
index 000000000..38b452d6d
--- /dev/null
+++ b/sdk/nexent/data_process/extract_image.py
@@ -0,0 +1,438 @@
+import os
+import re
+import base64
+import hashlib
+import tempfile
+import subprocess
+from typing import List, Dict, Any, Optional
+import zipfile
+from xml.etree import ElementTree
+
+from pptx import Presentation
+
+from .base import FileProcessor
+
+from unstructured_inference.logger import logger
+from unstructured_inference.models import tables
+from unstructured.partition.auto import partition
+
+
+tables_agent = tables.tables_agent
+TABLE_TRANSFORMER_MODEL_PATH = ""
+
+def custom_load_table_model():
+ """Loads the Table agent."""
+
+ if getattr(tables_agent, "model", None) is None:
+ with tables_agent._lock:
+ if getattr(tables_agent, "model", None) is None:
+ logger.info("Loading the Table agent ...")
+ print("path234: ", TABLE_TRANSFORMER_MODEL_PATH)
+ tables_agent.initialize(TABLE_TRANSFORMER_MODEL_PATH)
+
+ return
+
+tables.load_agent = lambda: custom_load_table_model()
+
+
+class UniversalImageExtractor(FileProcessor):
+ """
+ Multi-format image extractor for PDF, PPT, Excel, and Word.
+ Uses LibreOffice for conversion when needed and reuses PDF extraction logic.
+ """
+
+ @staticmethod
+ def _hash(data: bytes) -> str:
+ # Use a modern hash for safe, collision-resistant de-duplication.
+ return hashlib.sha256(data).hexdigest()
+
+ @staticmethod
+ def _openxml_namespace_maps() -> Dict[str, str]:
+ return {
+ "xdr": "http://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing", # NOSONAR
+ "a": "http://schemas.openxmlformats.org/drawingml/2006/main", # NOSONAR
+ "r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships", # NOSONAR
+ }
+
+
+ def _write_temp_file(self, data: bytes, suffix: str) -> str:
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
+ tmp.write(data)
+ tmp.close()
+ return tmp.name
+
+ @staticmethod
+ def detect_image_format(image_bytes: bytes) -> str:
+ if image_bytes.startswith(b"\x89PNG"):
+ return "png"
+ elif image_bytes.startswith(b"\xFF\xD8\xFF"):
+ return "jpg"
+ else:
+ return "png"
+
+
+ def _convert_file(self, input_path: str, target_format: str) -> str:
+
+ """
+ Convert a file to the target format using LibreOffice.
+
+ Args:
+ input_path: Source file path.
+ target_format: Target format, e.g. "pdf", "pptx", "xlsx".
+
+ Returns:
+ Output file path.
+ """
+ out_dir = os.path.dirname(input_path)
+
+ cmd = [
+ "soffice",
+ "--headless",
+ "--invisible", # Ensure fully headless conversion.
+ "--convert-to", f"{target_format}",
+ input_path,
+ "--outdir", out_dir
+ ]
+
+ try:
+ subprocess.run(
+ cmd,
+ check=True,
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL,
+ timeout=60 # Prevent hanging conversions.
+ )
+
+ base_name = os.path.splitext(input_path)[0]
+ new_suffix = f".{target_format}"
+ output_path = base_name + new_suffix
+
+ if os.path.exists(output_path):
+ return output_path
+ else:
+ raise FileNotFoundError(
+ f"Conversion failed: Output file {output_path} not found.")
+
+ except subprocess.CalledProcessError as e:
+ raise RuntimeError(
+ f"LibreOffice conversion failed for {input_path}: {e}")
+ except subprocess.TimeoutExpired:
+ raise RuntimeError(
+ f"LibreOffice conversion timed out for {input_path}")
+
+
+ def _extract_pdf(self, pdf_path: str, **params) -> List[Dict]:
+ table_transformer_model_path = params.get("table_transformer_model_path")
+ unstructured_default_model_initialize_params_json_path = params.get(
+ "unstructured_default_model_initialize_params_json_path"
+ )
+ if not table_transformer_model_path or not unstructured_default_model_initialize_params_json_path:
+ return []
+ global TABLE_TRANSFORMER_MODEL_PATH
+ TABLE_TRANSFORMER_MODEL_PATH = table_transformer_model_path
+
+ results = []
+ seen = set()
+
+ elements = partition(
+ filename=pdf_path,
+ strategy="hi_res",
+ extract_images_in_pdf=True,
+ extract_image_block_to_payload=True,
+ )
+
+ for el in elements:
+ b64 = getattr(el.metadata, "image_base64", None)
+ if not b64:
+ continue
+
+ img_bytes = base64.b64decode(b64)
+ h = self._hash(img_bytes)
+ if h in seen:
+ continue
+ seen.add(h)
+
+ coords = getattr(el.metadata, "coordinates", None)
+ coord_dict = None
+
+ if coords and hasattr(coords, 'points') and coords.points:
+ pts = coords.points # tuple of (x,y)
+ xs = [p[0] for p in pts]
+ ys = [p[1] for p in pts]
+ coord_dict = {
+ "x1": min(xs),
+ "y1": min(ys),
+ "x2": max(xs),
+ "y2": max(ys),
+ }
+
+ page_num = getattr(el.metadata, "page_number", None)
+
+ results.append({
+ "position": {
+ "page_number": page_num,
+ "coordinates": coord_dict
+ },
+ "image_format": self.detect_image_format(img_bytes),
+ "image_bytes": img_bytes
+ })
+
+ return results
+
+
+ def _excel_sheet_files(self, z: zipfile.ZipFile) -> List[str]:
+ return [f for f in z.namelist() if f.startswith("xl/worksheets/sheet")]
+
+
+ def _excel_drawing_file(self, z: zipfile.ZipFile, sheet_file: str) -> Optional[str]:
+ sheet_xml = ElementTree.fromstring(z.read(sheet_file))
+ drawing = sheet_xml.find(
+ ".//{https://schemas.openxmlformats.org/spreadsheetml/2006/main}drawing")
+ if drawing is None:
+ drawing = sheet_xml.find(
+ ".//{http://schemas.openxmlformats.org/spreadsheetml/2006/main}drawing")
+ if drawing is None:
+ return None
+
+ rel_id = drawing.get(
+ "{https://schemas.openxmlformats.org/officeDocument/2006/relationships}id")
+ if rel_id is None:
+ rel_id = drawing.get(
+ "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id")
+ rel_path = sheet_file.replace("worksheets", "worksheets/_rels") + ".rels"
+ if rel_path not in z.namelist():
+ return None
+
+ rel_xml = ElementTree.fromstring(z.read(rel_path))
+ for rel in rel_xml:
+ if rel.get("Id") == rel_id:
+ return "xl/" + rel.get("Target").replace("../", "")
+
+ return None
+
+
+ def _excel_rel_map(self, z: zipfile.ZipFile, drawing_file: str) -> Optional[Dict[str, str]]:
+ rel_file = drawing_file.replace("drawings/", "drawings/_rels/") + ".rels"
+ if rel_file not in z.namelist():
+ return None
+
+ rel_root = ElementTree.fromstring(z.read(rel_file))
+ return {
+ rel.get("Id"): "xl/" + rel.get("Target").replace("../", "")
+ for rel in rel_root
+ }
+
+
+ def _excel_anchors(self, z: zipfile.ZipFile, drawing_file: str, ns: Dict[str, str]) -> List[Any]:
+ drawing_root = ElementTree.fromstring(z.read(drawing_file))
+ return drawing_root.findall(".//xdr:twoCellAnchor", ns) + \
+ drawing_root.findall(".//xdr:oneCellAnchor", ns)
+
+
+ def _excel_anchor_coords(self, anchor: Any, ns: Dict[str, str]) -> Optional[Dict[str, int]]:
+ from_node = anchor.find("xdr:from", ns)
+ if from_node is None:
+ return None
+
+ row1 = int(from_node.find("xdr:row", ns).text) + 1
+ col1 = int(from_node.find("xdr:col", ns).text) + 1
+
+ to_node = anchor.find("xdr:to", ns)
+ if to_node is not None:
+ row2 = int(to_node.find("xdr:row", ns).text) + 1
+ col2 = int(to_node.find("xdr:col", ns).text) + 1
+ else:
+ row2, col2 = row1, col1
+
+ return {"row1": row1, "col1": col1, "row2": row2, "col2": col2}
+
+
+ def _excel_anchor_embed_id(self, anchor: Any, ns: Dict[str, str]) -> Optional[str]:
+ blip = anchor.find(".//a:blip", ns)
+ if blip is None:
+ return None
+
+ embed_id = blip.get(
+ "{https://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
+ if embed_id is None:
+ embed_id = blip.get(
+ "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
+ return embed_id
+
+
+ def _extract_excel_anchors(
+ self,
+ z: zipfile.ZipFile,
+ anchors: List[Any],
+ rel_map: Dict[str, str],
+ sheet_name: str,
+ ns: Dict[str, str],
+ seen: set,
+ ) -> List[Dict[str, Any]]:
+ results = []
+ for anchor in anchors:
+ coords = self._excel_anchor_coords(anchor, ns)
+ if coords is None:
+ continue
+
+ embed_rel_id = self._excel_anchor_embed_id(anchor, ns)
+ if not embed_rel_id:
+ continue
+
+ target = rel_map.get(embed_rel_id)
+ if not target:
+ continue
+
+ img_bytes = z.read(target)
+ h = self._hash(img_bytes)
+ if h in seen:
+ continue
+ seen.add(h)
+
+ results.append({
+ "position": {
+ "sheet_name": sheet_name,
+ "coordinates": {
+ "x1": coords["col1"],
+ "x2": coords["col2"],
+ "y1": coords["row1"],
+ "y2": coords["row2"]
+ }
+ },
+ "image_format": self.detect_image_format(img_bytes),
+ "image_bytes": img_bytes
+ })
+
+ return results
+
+
+ def _extract_excel_sheet(
+ self,
+ z: zipfile.ZipFile,
+ sheet_file: str,
+ ns: Dict[str, str],
+ seen: set,
+ ) -> List[Dict[str, Any]]:
+ drawing_file = self._excel_drawing_file(z, sheet_file)
+ if drawing_file is None:
+ return []
+
+ rel_map = self._excel_rel_map(z, drawing_file)
+ if not rel_map:
+ return []
+
+ anchors = self._excel_anchors(z, drawing_file, ns)
+ sheet_name = os.path.basename(sheet_file)
+
+ return self._extract_excel_anchors(z, anchors, rel_map, sheet_name, ns, seen)
+
+
+ def _extract_excel(self, xlsx_path):
+ results = []
+ seen = set()
+
+ with zipfile.ZipFile(xlsx_path) as z:
+ sheet_files = self._excel_sheet_files(z)
+
+ ns = self._openxml_namespace_maps()
+ for sheet_file in sheet_files:
+ results.extend(self._extract_excel_sheet(z, sheet_file, ns, seen))
+
+ return results
+
+
+ def _extract_pptx(self, pptx_path: str, **params) -> List[Dict]:
+ if Presentation is None:
+ raise RuntimeError("python-pptx is required to extract images from PPTX files.")
+ prs = Presentation(pptx_path)
+ results = []
+ seen = set()
+ emu_per_inch = params.get("emu_per_inch", 914400)
+ dpi = params.get("dpi", 96)
+
+ def _emu_to_px(emu: int, emu_per_inch: int, dpi: int) -> int:
+ return int((emu / emu_per_inch) * dpi)
+
+
+ slide_w = _emu_to_px(prs.slide_width, emu_per_inch, dpi)
+ slide_h = _emu_to_px(prs.slide_height, emu_per_inch, dpi)
+
+ for slide_index, slide in enumerate(prs.slides):
+ for shape in slide.shapes:
+ if not hasattr(shape, "image"):
+ continue
+
+ img_bytes = shape.image.blob
+ h = self._hash(img_bytes)
+ if h in seen:
+ continue
+ seen.add(h)
+
+ x = _emu_to_px(shape.left, emu_per_inch, dpi)
+ y = _emu_to_px(shape.top, emu_per_inch, dpi)
+ w = _emu_to_px(shape.width, emu_per_inch, dpi)
+ h_px = _emu_to_px(shape.height, emu_per_inch, dpi)
+
+ results.append({
+ "position": {
+ "page_number": slide_index + 1,
+ "coordinates": {
+ "x1": x,
+ "y1": y,
+ "x2": x + w,
+ "y2": y + h_px,
+ "slide_width": slide_w,
+ "slide_height": slide_h,
+ },
+ },
+ "image_format": self.detect_image_format(img_bytes),
+ "image_bytes": img_bytes
+ })
+
+ return results
+
+
+ def process_file(self, file_bytes: bytes, chunking_strategy: str, filename: str, **params) -> List[Dict[str, Any]]:
+ suffix = os.path.splitext(filename)[1].lower()
+ temp_path = self._write_temp_file(file_bytes, suffix)
+ converted_path = None
+
+ try:
+ direct_extractors = {
+ ".xlsx": lambda: self._extract_excel(temp_path),
+ ".pptx": lambda: self._extract_pptx(temp_path, **params),
+ ".pdf": lambda: self._extract_pdf(temp_path, **params),
+ }
+ if suffix in direct_extractors:
+ return direct_extractors[suffix]()
+
+ conversions = {
+ ".xls": ("xlsx", lambda path: self._extract_excel(path)),
+ ".ppt": ("pptx", lambda path: self._extract_pptx(path, **params)),
+ ".docx": ("pdf", lambda path: self._extract_pdf(path, **params)),
+ ".doc": ("pdf", lambda path: self._extract_pdf(path, **params)),
+ }
+ if suffix in conversions:
+ target_format, extractor = conversions[suffix]
+ converted_path = self._convert_file(temp_path, target_format)
+ return extractor(converted_path)
+
+ return []
+
+ finally:
+ files_to_clean = [temp_path]
+ if converted_path and os.path.exists(converted_path):
+ files_to_clean.append(converted_path)
+
+ base = os.path.splitext(temp_path)[0]
+ for ext in [".docx", ".pptx", ".xlsx", ".pdf"]:
+ potential_file = base + ext
+ if potential_file != converted_path and potential_file != temp_path:
+ files_to_clean.append(potential_file)
+
+ for f_path in files_to_clean:
+ if f_path and os.path.exists(f_path):
+ try:
+ os.remove(f_path)
+ except Exception:
+ pass
diff --git a/sdk/nexent/vector_database/elasticsearch_core.py b/sdk/nexent/vector_database/elasticsearch_core.py
index 41a3c674d..e8f6ec81a 100644
--- a/sdk/nexent/vector_database/elasticsearch_core.py
+++ b/sdk/nexent/vector_database/elasticsearch_core.py
@@ -1,3 +1,4 @@
+import base64
import json
import logging
import os
@@ -399,22 +400,26 @@ def _small_batch_insert(
) -> int:
"""Small batch insertion: real-time"""
try:
- # Preprocess documents
processed_docs = self._preprocess_documents(
documents, content_field)
-
- # Get embeddings
- inputs = [doc[content_field] for doc in processed_docs]
- embeddings = embedding_model.get_embeddings(inputs)
+
+ # Preprocess documents
+ processed_docs, embeddings = self._prepare_small_batch_embeddings(
+ processed_docs, content_field, embedding_model
+ )
# Prepare bulk operations
- operations = []
- for doc, embedding in zip(processed_docs, embeddings):
- operations.append({"index": {"_index": index_name}})
- doc["embedding"] = embedding
- if "embedding_model_name" not in doc:
- doc["embedding_model_name"] = embedding_model.embedding_model_name
- operations.append(doc)
+ operations = self._build_bulk_operations(
+ index_name=index_name,
+ processed_docs=processed_docs,
+ embeddings=embeddings,
+ embedding_model=embedding_model,
+ )
+
+ indexed_count = len(processed_docs)
+ if indexed_count == 0:
+ logger.info("Small batch insert skipped: no documents to index.")
+ return 0
# Execute bulk insertion, wait for refresh to complete
response = self.client.bulk(
@@ -425,19 +430,70 @@ def _small_batch_insert(
if progress_callback:
try:
- progress_callback(len(documents), len(documents))
+ progress_callback(indexed_count, indexed_count)
except Exception as e:
logger.warning(
f"[VECTORIZE] Progress callback failed in small batch: {str(e)}")
logger.info(
- f"Small batch insert completed: {len(documents)} chunks indexed.")
- return len(documents)
+ f"Small batch insert completed: {indexed_count} chunks indexed.")
+ return indexed_count
except Exception as e:
logger.error(f"Small batch insert failed: {e}")
raise
+ def _prepare_small_batch_embeddings(
+ self,
+ processed_docs: List[Dict[str, Any]],
+ content_field: str,
+ embedding_model: BaseEmbedding,
+ ):
+ if embedding_model.model_type == "multimodal":
+ inputs = []
+ for doc in processed_docs:
+ if doc.get("process_source") == "UniversalImageExtractor":
+ img_bytes = doc.pop("image_bytes", "")
+ if len(img_bytes) > 0:
+ image_base64_str = base64.b64encode(
+ img_bytes).decode("utf-8")
+ data = f"data:image/jpeg;base64,{image_base64_str}"
+ inputs.append({"image": data})
+ else:
+ inputs.append({"text": doc[content_field]})
+ embeddings = embedding_model.get_multimodal_embeddings(inputs)
+ return processed_docs, embeddings
+ else:
+ filtered_docs = [
+ doc
+ for doc in processed_docs
+ if doc.get("process_source") != "UniversalImageExtractor"
+ ]
+ inputs = [doc[content_field] for doc in filtered_docs]
+ embeddings = embedding_model.get_embeddings(inputs)
+ return filtered_docs, embeddings
+
+ @staticmethod
+ def _build_bulk_operations(
+ index_name: str,
+ processed_docs: List[Dict[str, Any]],
+ embeddings: List[Any],
+ embedding_model: BaseEmbedding,
+ ) -> List[Dict[str, Any]]:
+ operations = []
+ for doc, embedding in zip(processed_docs, embeddings):
+ operations.append({"index": {"_index": index_name}})
+ embedding_field = (
+ "multi_embedding"
+ if doc.get("process_source") == "UniversalImageExtractor"
+ else "embedding"
+ )
+ doc[embedding_field] = embedding
+ if "embedding_model_name" not in doc:
+ doc["embedding_model_name"] = embedding_model.embedding_model_name
+ operations.append(doc)
+ return operations
+
def _large_batch_insert(
self,
index_name: str,
@@ -455,10 +511,13 @@ def _large_batch_insert(
try:
sub_batch_max_retries = self.max_retries
-
-
processed_docs = self._preprocess_documents(
documents, content_field)
+ if embedding_model.model_type != "multimodal":
+ processed_docs = [
+ doc for doc in processed_docs
+ if doc.get("process_source") != "UniversalImageExtractor"
+ ]
total_indexed = 0
total_vectorized = 0
total_docs = len(processed_docs)
@@ -485,13 +544,31 @@ def _large_batch_insert(
# partial indexing and reports false-negative "failed then ready".
for retry_attempt in range(sub_batch_max_retries):
try:
- inputs = [doc[content_field]
- for doc in embedding_sub_batch]
- embeddings = embedding_model.get_embeddings(inputs)
-
- for doc, embedding in zip(embedding_sub_batch, embeddings):
- doc_embedding_pairs.append((doc, embedding))
-
+ if embedding_model.model_type == "multimodal":
+ inputs = []
+ docs_for_embeddings = []
+ for doc in embedding_sub_batch:
+ if doc.get("process_source") == "UniversalImageExtractor":
+ img_bytes = doc.pop("image_bytes", "")
+ if len(img_bytes) > 0:
+ image_base64_str = base64.b64encode(
+ img_bytes).decode('utf-8')
+ data = f"data:image/jpeg;base64,{image_base64_str}"
+ inputs.append({"image": data})
+ docs_for_embeddings.append(doc)
+ else:
+ inputs.append({"text": doc[content_field]})
+ docs_for_embeddings.append(doc)
+ embeddings = embedding_model.get_multimodal_embeddings(inputs)
+ for doc, embedding in zip(docs_for_embeddings, embeddings):
+ doc_embedding_pairs.append((doc, embedding))
+ else:
+ inputs = [doc[content_field]
+ for doc in embedding_sub_batch]
+ embeddings = embedding_model.get_embeddings(inputs)
+ for doc, embedding in zip(embedding_sub_batch, embeddings):
+ doc_embedding_pairs.append((doc, embedding))
+
total_vectorized += len(embedding_sub_batch)
if progress_callback:
try:
@@ -531,7 +608,8 @@ def _large_batch_insert(
operations = []
for doc, embedding in doc_embedding_pairs:
operations.append({"index": {"_index": index_name}})
- doc["embedding"] = embedding
+ doc["multi_embedding" if doc["process_source"]
+ == "UniversalImageExtractor" else "embedding"] = embedding
if "embedding_model_name" not in doc:
doc["embedding_model_name"] = getattr(
embedding_model, "embedding_model_name", "unknown")
@@ -982,20 +1060,41 @@ def semantic_search(
query_embedding = embedding_model.get_embeddings(query_text)[0]
# Prepare the search query
- search_query = {
- "knn": {
- "field": "embedding",
- "query_vector": query_embedding,
- "k": top_k,
- "num_candidates": top_k * 2,
- },
- "size": top_k,
- "_source": {"excludes": ["embedding"]},
- }
-
- # Execute the search across multiple indices
- raw_results = self.exec_query(index_pattern, search_query)
-
+ if embedding_model.model_type == "multimodal":
+ search_text_query = {
+ "knn": {
+ "field": "embedding",
+ "query_vector": query_embedding,
+ "k": top_k,
+ "num_candidates": top_k * 2,
+ },
+ "size": top_k,
+ "_source": {"excludes": ["embedding"]},
+ }
+ search_image_query = {
+ "knn": {
+ "field": "multi_embedding",
+ "query_vector": query_embedding,
+ "k": top_k,
+ "num_candidates": top_k * 2,
+ },
+ "size": top_k,
+ "_source": {"excludes": ["multi_embedding"]},
+ }
+ raw_results = self.exec_query(index_pattern, search_text_query) + self.exec_query(index_pattern, search_image_query)
+ else:
+ search_query = {
+ "knn": {
+ "field": "embedding",
+ "query_vector": query_embedding,
+ "k": top_k,
+ "num_candidates": top_k * 2,
+ },
+ "size": top_k,
+ "_source": {"excludes": ["embedding"]},
+ }
+ raw_results = self.exec_query(index_pattern, search_query)
+
return raw_results
def hybrid_search(
@@ -1140,6 +1239,13 @@ def hybrid_search(
for r in accurate_results]) if accurate_results else 1
max_semantic = max([r.get("score", 0)
for r in semantic_results]) if semantic_results else 1
+ is_multimodal = embedding_model.model_type == "multimodal"
+ image_semantic_scores = [
+ r.get("score", 0)
+ for r in semantic_results
+ if r.get("document", {}).get("process_source") == "UniversalImageExtractor"
+ ]
+ max_semantic_image = max(image_semantic_scores) if image_semantic_scores else 1
# Calculate combined scores and sort
results = []
@@ -1151,7 +1257,10 @@ def hybrid_search(
# Normalize scores
normalized_accurate = accurate_score / max_accurate if max_accurate > 0 else 0
- normalized_semantic = semantic_score / max_semantic if max_semantic > 0 else 0
+ if is_multimodal and result.get("document", {}).get("process_source") == "UniversalImageExtractor":
+ normalized_semantic = semantic_score / max_semantic_image if max_semantic_image > 0 else 0
+ else:
+ normalized_semantic = semantic_score / max_semantic if max_semantic > 0 else 0
# Calculate weighted combined score
combined_score = weight_accurate * normalized_accurate + \
@@ -1171,9 +1280,20 @@ def hybrid_search(
f"Warning: Error processing result for doc_id {doc_id}: {e}")
continue
- # Sort by combined score and return top k results
+ # Sort by combined score and return results
results.sort(key=lambda x: x["score"], reverse=True)
- final_results = results[:top_k]
+ if is_multimodal:
+ text_results = [
+ r for r in results
+ if r.get("document", {}).get("process_source") != "UniversalImageExtractor"
+ ][:top_k]
+ image_results = [
+ r for r in semantic_results
+ if r.get("document", {}).get("process_source") == "UniversalImageExtractor"
+ ]
+ final_results = text_results + image_results
+ else:
+ final_results = results[:top_k]
return final_results
diff --git a/test/backend/agents/test_create_agent_info.py b/test/backend/agents/test_create_agent_info.py
index 5817fbe27..237ded6d3 100644
--- a/test/backend/agents/test_create_agent_info.py
+++ b/test/backend/agents/test_create_agent_info.py
@@ -690,6 +690,52 @@ async def test_create_tool_config_list_with_knowledge_base_tool(self):
last_call = mock_tool_config.call_args_list[-1]
assert last_call[1]['class_name'] == "KnowledgeBaseSearchTool"
+ @pytest.mark.asyncio
+ async def test_create_tool_config_list_knowledge_base_multimodal(self):
+ """Ensure multimodal param is forwarded to embedding model selection."""
+ mock_tool_instance = MagicMock()
+ mock_tool_instance.class_name = "KnowledgeBaseSearchTool"
+
+ with patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \
+ patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \
+ patch('backend.agents.create_agent_info.get_vector_db_core') as mock_get_vector_db_core, \
+ patch('backend.agents.create_agent_info.get_embedding_model_by_index_name') as mock_embedding_by_index, \
+ patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank, \
+ patch('backend.agents.create_agent_info.get_knowledge_name_map_by_index_names') as mock_get_knowledge_map, \
+ patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config:
+
+ mock_tool_config.return_value = mock_tool_instance
+
+ mock_search_tools.return_value = [
+ {
+ "class_name": "KnowledgeBaseSearchTool",
+ "name": "knowledge_search",
+ "description": "Knowledge search tool",
+ "inputs": "string",
+ "output_type": "string",
+ "params": [
+ {"name": "index_names", "default": ["idx1", "idx2"]}, # 添加这个
+ {"name": "multimodal", "default": True},
+ {"name": "rerank", "default": False},
+ ],
+ "source": "local",
+ "usage": None
+ }
+ ]
+ mock_get_vector_db_core.return_value = "mock_elastic_core"
+ mock_embedding_by_index.return_value = ("mock_embedding_model", 123, {"status": "ok"})
+ mock_rerank.return_value = None
+ mock_get_knowledge_map.return_value = {"idx1": "KB1", "idx2": "KB2"}
+
+ result = await create_tool_config_list("agent_1", "tenant_1", "user_1")
+
+ assert len(result) == 1
+ # Verify get_embedding_model_by_index_name was called with tenant_id and first index_name
+ mock_embedding_by_index.assert_called_once_with("tenant_1", "idx1")
+
+ # Verify that multimodal parameter was removed from params (popped)
+ assert "multimodal" not in result[0].params
+
@pytest.mark.asyncio
async def test_create_tool_config_list_with_analyze_image_tool(self):
"""Ensure AnalyzeImageTool receives VLM model metadata."""
@@ -768,20 +814,21 @@ async def test_create_tool_config_list_with_analyze_text_file_tool(self):
@pytest.mark.asyncio
async def test_create_tool_config_list_with_knowledge_base_tool_metadata(self):
"""
- Test that KnowledgeBaseSearchTool metadata contains only vdb_core and embedding_model.
- This test verifies the refactored behavior where index_names and name_resolver
- have been removed from the metadata.
+ Test that KnowledgeBaseSearchTool metadata contains vdb_core, embedding_model,
+ rerank_model, display_name_to_index_map, and index_name_to_display_map.
"""
mock_tool_instance = MagicMock()
mock_tool_instance.class_name = "KnowledgeBaseSearchTool"
- mock_tool_config.return_value = mock_tool_instance
with patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \
patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \
patch('backend.agents.create_agent_info.get_vector_db_core') as mock_get_vector_db_core, \
patch('backend.agents.create_agent_info.get_embedding_model_by_index_name') as mock_embedding, \
patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank, \
- patch('backend.agents.create_agent_info.get_knowledge_name_map_by_index_names') as mock_get_knowledge_map:
+ patch('backend.agents.create_agent_info.get_knowledge_name_map_by_index_names') as mock_get_knowledge_map, \
+ patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config:
+
+ mock_tool_config.return_value = mock_tool_instance
mock_search_tools.return_value = [
{
@@ -791,7 +838,7 @@ async def test_create_tool_config_list_with_knowledge_base_tool_metadata(self):
"inputs": "string",
"output_type": "string",
"params": [
- {"name": "index_names", "default": ["idx_a"]}, # Non-empty index_names
+ {"name": "index_names", "default": ["idx_a"]},
{"name": "rerank", "default": True},
{"name": "rerank_model_name", "default": "gte-rerank-v2"},
],
@@ -805,7 +852,7 @@ async def test_create_tool_config_list_with_knowledge_base_tool_metadata(self):
mock_get_vector_db_core.return_value = mock_vdb_core
mock_embedding.return_value = (mock_embedding_model, 123, {"status": "ok"})
mock_rerank.return_value = mock_rerank_model
- mock_get_knowledge_map.return_value = {"idx_a": "idx_a"}
+ mock_get_knowledge_map.return_value = {"idx_a": "Knowledge Base A"}
result = await create_tool_config_list("agent_1", "tenant_1", "user_1")
@@ -814,17 +861,25 @@ async def test_create_tool_config_list_with_knowledge_base_tool_metadata(self):
# Verify correct functions were called with correct parameters
mock_get_vector_db_core.assert_called_once()
+ # 修改:验证调用时使用 tenant_id 和 index_name
mock_embedding.assert_called_once_with("tenant_1", "idx_a")
+ mock_rerank.assert_called_once_with(tenant_id="tenant_1", model_name="gte-rerank-v2")
+ mock_get_knowledge_map.assert_called_once_with(["idx_a"])
- # Verify metadata contains vdb_core, embedding_model, rerank_model and display_name_to_index_map
+ # Verify metadata contains required fields
assert "vdb_core" in mock_tool_instance.metadata
assert "embedding_model" in mock_tool_instance.metadata
assert "rerank_model" in mock_tool_instance.metadata
assert "display_name_to_index_map" in mock_tool_instance.metadata
+ assert "index_name_to_display_map" in mock_tool_instance.metadata
- # Explicitly verify that old fields are NOT present
- assert "index_names" not in mock_tool_instance.metadata
- assert "name_resolver" not in mock_tool_instance.metadata
+ # Verify mappings
+ assert mock_tool_instance.metadata["display_name_to_index_map"] == {
+ "Knowledge Base A": "idx_a"
+ }
+ assert mock_tool_instance.metadata["index_name_to_display_map"] == {
+ "idx_a": "Knowledge Base A"
+ }
@pytest.mark.asyncio
async def test_create_tool_config_list_with_knowledge_base_tool_multiple_tools(self):
@@ -1534,36 +1589,34 @@ async def test_create_agent_config_with_memory(self):
@pytest.mark.asyncio
async def test_create_agent_config_memory_disabled_no_search(self):
- with (
- patch(
- "backend.agents.create_agent_info.search_agent_info_by_agent_id"
- ) as mock_search_agent,
+ with patch(
+ "backend.agents.create_agent_info.search_agent_info_by_agent_id"
+ ) as mock_search_agent, \
patch(
"backend.agents.create_agent_info.query_sub_agents_id_list"
- ) as mock_query_sub,
+ ) as mock_query_sub, \
patch(
"backend.agents.create_agent_info.create_tool_config_list"
- ) as mock_create_tools,
+ ) as mock_create_tools, \
patch(
"backend.agents.create_agent_info.get_agent_prompt_template"
- ) as mock_get_template,
+ ) as mock_get_template, \
patch(
"backend.agents.create_agent_info.tenant_config_manager"
- ) as mock_tenant_config,
+ ) as mock_tenant_config, \
patch(
"backend.agents.create_agent_info.build_memory_context"
- ) as mock_build_memory,
+ ) as mock_build_memory, \
patch(
"backend.agents.create_agent_info.get_model_by_model_id"
- ) as mock_get_model_by_id,
+ ) as mock_get_model_by_id, \
patch(
"backend.agents.create_agent_info.search_memory_in_levels",
new_callable=AsyncMock,
- ) as mock_search_memory,
+ ) as mock_search_memory, \
patch(
"backend.agents.create_agent_info.prepare_prompt_templates"
- ) as mock_prepare_templates,
- ):
+ ) as mock_prepare_templates:
mock_search_agent.return_value = {
"name": "test_agent",
"description": "test description",
diff --git a/test/backend/app/test_agent_app.py b/test/backend/app/test_agent_app.py
index 22365cf0b..ac032c8c4 100644
--- a/test/backend/app/test_agent_app.py
+++ b/test/backend/app/test_agent_app.py
@@ -1,5 +1,6 @@
import atexit
-from unittest.mock import patch, Mock, MagicMock, ANY
+from unittest.mock import AsyncMock, patch, Mock, MagicMock, ANY
+import importlib.machinery
import os
import sys
import types
@@ -20,8 +21,11 @@
sys.path.insert(0, backend_dir)
# Mock boto3 before importing backend modules
-boto3_mock = MagicMock()
-sys.modules['boto3'] = boto3_mock
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
# Apply critical patches before importing any modules
# This prevents real AWS/MinIO/Elasticsearch calls during import
@@ -167,7 +171,7 @@ def mock_conversation_id():
async def test_agent_run_api(mocker, mock_auth_header):
"""Test agent_run_api endpoint."""
mock_run_agent_stream = mocker.patch(
- "apps.agent_app.run_agent_stream", new_callable=mocker.AsyncMock)
+ "apps.agent_app.run_agent_stream", new_callable=AsyncMock)
# Mock the streaming response
async def mock_stream():
@@ -247,7 +251,7 @@ def test_search_agent_info_api_success(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_get_user_id = mocker.patch("apps.agent_app.get_current_user_id")
mock_get_agent_info = mocker.patch(
- "apps.agent_app.get_agent_info_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.get_agent_info_impl", new_callable=AsyncMock)
mock_get_user_id.return_value = ("user_id", "auth_tenant_id")
mock_get_agent_info.return_value = {"agent_id": 123, "name": "Test Agent"}
@@ -272,7 +276,7 @@ def test_search_agent_info_api_with_explicit_tenant_id(mocker, mock_auth_header)
# Setup mocks using pytest-mock
mock_get_user_id = mocker.patch("apps.agent_app.get_current_user_id")
mock_get_agent_info = mocker.patch(
- "apps.agent_app.get_agent_info_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.get_agent_info_impl", new_callable=AsyncMock)
# Mock return values - auth tenant_id is different from explicit tenant_id
mock_get_user_id.return_value = ("user_id", "auth_tenant_id")
mock_get_agent_info.return_value = {
@@ -305,7 +309,7 @@ def test_search_agent_info_api_exception(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_get_user_id = mocker.patch("apps.agent_app.get_current_user_id")
mock_get_agent_info = mocker.patch(
- "apps.agent_app.get_agent_info_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.get_agent_info_impl", new_callable=AsyncMock)
mock_get_user_id.return_value = ("user_id", "auth_tenant_id")
mock_get_agent_info.side_effect = Exception("Test error")
@@ -328,7 +332,7 @@ def test_search_agent_info_api_exception_with_explicit_tenant_id(mocker, mock_au
# Setup mocks using pytest-mock
mock_get_user_id = mocker.patch("apps.agent_app.get_current_user_id")
mock_get_agent_info = mocker.patch(
- "apps.agent_app.get_agent_info_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.get_agent_info_impl", new_callable=AsyncMock)
# Mock return values and exception
mock_get_user_id.return_value = ("user_id", "auth_tenant_id")
mock_get_agent_info.side_effect = Exception("Test error with explicit tenant")
@@ -355,7 +359,7 @@ def test_search_agent_info_api_with_version_no(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_get_user_id = mocker.patch("apps.agent_app.get_current_user_id")
mock_get_agent_info = mocker.patch(
- "apps.agent_app.get_agent_info_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.get_agent_info_impl", new_callable=AsyncMock)
mock_get_user_id.return_value = ("user_id", "auth_tenant_id")
mock_get_agent_info.return_value = {"agent_id": 123, "name": "Test Agent", "version_no": 2}
@@ -380,7 +384,7 @@ def test_search_agent_info_api_with_version_no_and_tenant_id(mocker, mock_auth_h
# Setup mocks using pytest-mock
mock_get_user_id = mocker.patch("apps.agent_app.get_current_user_id")
mock_get_agent_info = mocker.patch(
- "apps.agent_app.get_agent_info_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.get_agent_info_impl", new_callable=AsyncMock)
mock_get_user_id.return_value = ("user_id", "auth_tenant_id")
mock_get_agent_info.return_value = {
"agent_id": 456,
@@ -412,7 +416,7 @@ def test_search_agent_info_api_exception_with_version_no(mocker, mock_auth_heade
# Setup mocks using pytest-mock
mock_get_user_id = mocker.patch("apps.agent_app.get_current_user_id")
mock_get_agent_info = mocker.patch(
- "apps.agent_app.get_agent_info_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.get_agent_info_impl", new_callable=AsyncMock)
mock_get_user_id.return_value = ("user_id", "auth_tenant_id")
mock_get_agent_info.side_effect = Exception("Test error with version_no")
@@ -433,7 +437,7 @@ def test_search_agent_info_api_exception_with_version_no(mocker, mock_auth_heade
def test_get_creating_sub_agent_info_api_success(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_get_creating_agent = mocker.patch(
- "apps.agent_app.get_creating_sub_agent_info_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.get_creating_sub_agent_info_impl", new_callable=AsyncMock)
mock_get_creating_agent.return_value = {"agent_id": 456}
# Test the endpoint - this is a GET request
@@ -452,7 +456,7 @@ def test_get_creating_sub_agent_info_api_success(mocker, mock_auth_header):
def test_get_creating_sub_agent_info_api_exception(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_get_creating_agent = mocker.patch(
- "apps.agent_app.get_creating_sub_agent_info_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.get_creating_sub_agent_info_impl", new_callable=AsyncMock)
mock_get_creating_agent.side_effect = Exception("Test error")
# Test the endpoint - this is a GET request
@@ -469,7 +473,7 @@ def test_get_creating_sub_agent_info_api_exception(mocker, mock_auth_header):
def test_update_agent_info_api_success(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_update_agent = mocker.patch(
- "apps.agent_app.update_agent_info_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.update_agent_info_impl", new_callable=AsyncMock)
mock_update_agent.return_value = None
# Test the endpoint
@@ -489,7 +493,7 @@ def test_update_agent_info_api_success(mocker, mock_auth_header):
def test_update_agent_info_api_exception(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_update_agent = mocker.patch(
- "apps.agent_app.update_agent_info_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.update_agent_info_impl", new_callable=AsyncMock)
mock_update_agent.side_effect = Exception("Test error")
# Test the endpoint
@@ -510,7 +514,7 @@ def test_delete_agent_api_success(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_get_user_info = mocker.patch("apps.agent_app.get_current_user_info")
mock_delete_agent = mocker.patch(
- "apps.agent_app.delete_agent_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.delete_agent_impl", new_callable=AsyncMock)
# Mock return values
mock_get_user_info.return_value = ("test_user", "test_tenant", "en")
mock_delete_agent.return_value = None
@@ -536,7 +540,7 @@ def test_delete_agent_api_with_explicit_tenant_id(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_get_user_info = mocker.patch("apps.agent_app.get_current_user_info")
mock_delete_agent = mocker.patch(
- "apps.agent_app.delete_agent_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.delete_agent_impl", new_callable=AsyncMock)
# Mock return values - auth tenant_id is different from explicit tenant_id
mock_get_user_info.return_value = ("test_user", "auth_tenant", "en")
mock_delete_agent.return_value = None
@@ -564,7 +568,7 @@ def test_delete_agent_api_exception(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_get_user_info = mocker.patch("apps.agent_app.get_current_user_info")
mock_delete_agent = mocker.patch(
- "apps.agent_app.delete_agent_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.delete_agent_impl", new_callable=AsyncMock)
mock_logger = mocker.patch("apps.agent_app.logger")
# Mock return values and exception
mock_get_user_info.return_value = ("test_user", "test_tenant", "en")
@@ -592,7 +596,7 @@ def test_delete_agent_api_exception_with_explicit_tenant_id(mocker, mock_auth_he
# Setup mocks using pytest-mock
mock_get_user_info = mocker.patch("apps.agent_app.get_current_user_info")
mock_delete_agent = mocker.patch(
- "apps.agent_app.delete_agent_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.delete_agent_impl", new_callable=AsyncMock)
mock_logger = mocker.patch("apps.agent_app.logger")
# Mock return values and exception
mock_get_user_info.return_value = ("test_user", "auth_tenant", "en")
@@ -622,7 +626,7 @@ def test_delete_agent_api_exception_with_explicit_tenant_id(mocker, mock_auth_he
async def test_export_agent_api_success(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_export_agent = mocker.patch(
- "apps.agent_app.export_agent_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.export_agent_impl", new_callable=AsyncMock)
mock_export_agent.return_value = '{"agent_id": 123, "name": "Test Agent"}'
# Test the endpoint
@@ -644,7 +648,7 @@ async def test_export_agent_api_success(mocker, mock_auth_header):
async def test_export_agent_api_exception(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_export_agent = mocker.patch(
- "apps.agent_app.export_agent_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.export_agent_impl", new_callable=AsyncMock)
mock_export_agent.side_effect = Exception("Test error")
# Test the endpoint
@@ -662,7 +666,7 @@ async def test_export_agent_api_exception(mocker, mock_auth_header):
def test_import_agent_api_success(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_import_agent = mocker.patch(
- "apps.agent_app.import_agent_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.import_agent_impl", new_callable=AsyncMock)
mock_import_agent.return_value = None
# Test the endpoint - following the ExportAndImportDataFormat structure
@@ -706,7 +710,7 @@ def test_import_agent_api_success(mocker, mock_auth_header):
def test_import_agent_api_exception(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_import_agent = mocker.patch(
- "apps.agent_app.import_agent_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.import_agent_impl", new_callable=AsyncMock)
mock_import_agent.side_effect = Exception("Test error")
# Test the endpoint - following the ExportAndImportDataFormat structure
@@ -748,7 +752,7 @@ def test_list_all_agent_info_api_success(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_get_user_info = mocker.patch("apps.agent_app.get_current_user_info")
mock_list_all_agent = mocker.patch(
- "apps.agent_app.list_all_agent_info_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.list_all_agent_info_impl", new_callable=AsyncMock)
# Mock return values
mock_get_user_info.return_value = ("test_user", "test_tenant", "en")
mock_list_all_agent.return_value = [
@@ -801,7 +805,7 @@ def test_list_all_agent_info_api_with_explicit_tenant_id(mocker, mock_auth_heade
# Setup mocks using pytest-mock
mock_get_user_info = mocker.patch("apps.agent_app.get_current_user_info")
mock_list_all_agent = mocker.patch(
- "apps.agent_app.list_all_agent_info_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.list_all_agent_info_impl", new_callable=AsyncMock)
# Mock return values - auth tenant_id is different from explicit tenant_id
mock_get_user_info.return_value = ("test_user", "auth_tenant", "en")
mock_list_all_agent.return_value = [
@@ -841,7 +845,7 @@ def test_list_all_agent_info_api_exception(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_get_user_info = mocker.patch("apps.agent_app.get_current_user_info")
mock_list_all_agent = mocker.patch(
- "apps.agent_app.list_all_agent_info_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.list_all_agent_info_impl", new_callable=AsyncMock)
# Mock return values and exception
mock_get_user_info.return_value = ("test_user", "test_tenant", "en")
mock_list_all_agent.side_effect = Exception("Test error")
@@ -864,7 +868,7 @@ def test_list_all_agent_info_api_exception_with_explicit_tenant_id(mocker, mock_
# Setup mocks using pytest-mock
mock_get_user_info = mocker.patch("apps.agent_app.get_current_user_info")
mock_list_all_agent = mocker.patch(
- "apps.agent_app.list_all_agent_info_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.list_all_agent_info_impl", new_callable=AsyncMock)
# Mock return values and exception
mock_get_user_info.return_value = ("test_user", "auth_tenant", "en")
mock_list_all_agent.side_effect = Exception("Test error with explicit tenant")
@@ -890,7 +894,7 @@ async def test_export_agent_api_detailed(mocker, mock_auth_header):
"""Detailed testing of export_agent_api function, including ConversationResponse construction"""
# Setup mocks using pytest-mock
mock_export_agent = mocker.patch(
- "apps.agent_app.export_agent_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.export_agent_impl", new_callable=AsyncMock)
# Setup mocks - return complex JSON data
agent_data = {
@@ -927,7 +931,7 @@ async def test_export_agent_api_empty_response(mocker, mock_auth_header):
"""Test export_agent_api handling empty response"""
# Setup mocks using pytest-mock
mock_export_agent = mocker.patch(
- "apps.agent_app.export_agent_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.export_agent_impl", new_callable=AsyncMock)
# Setup mock to return empty data
mock_export_agent.return_value = {}
@@ -1007,7 +1011,7 @@ def test_get_agent_call_relationship_api_exception(mocker, mock_auth_header):
def test_check_agent_name_batch_api_success(mocker, mock_auth_header):
mock_impl = mocker.patch(
"apps.agent_app.check_agent_name_conflict_batch_impl",
- new_callable=mocker.AsyncMock,
+ new_callable=AsyncMock,
)
mock_impl.return_value = [{"name_conflict": True}]
@@ -1029,7 +1033,7 @@ def test_check_agent_name_batch_api_success(mocker, mock_auth_header):
def test_check_agent_name_batch_api_bad_request(mocker, mock_auth_header):
mock_impl = mocker.patch(
"apps.agent_app.check_agent_name_conflict_batch_impl",
- new_callable=mocker.AsyncMock,
+ new_callable=AsyncMock,
)
mock_impl.side_effect = ValueError("bad payload")
@@ -1046,7 +1050,7 @@ def test_check_agent_name_batch_api_bad_request(mocker, mock_auth_header):
def test_check_agent_name_batch_api_error(mocker, mock_auth_header):
mock_impl = mocker.patch(
"apps.agent_app.check_agent_name_conflict_batch_impl",
- new_callable=mocker.AsyncMock,
+ new_callable=AsyncMock,
)
mock_impl.side_effect = Exception("unexpected")
@@ -1063,7 +1067,7 @@ def test_check_agent_name_batch_api_error(mocker, mock_auth_header):
def test_regenerate_agent_name_batch_api_success(mocker, mock_auth_header):
mock_impl = mocker.patch(
"apps.agent_app.regenerate_agent_name_batch_impl",
- new_callable=mocker.AsyncMock,
+ new_callable=AsyncMock,
)
mock_impl.return_value = [{"name": "NewName", "display_name": "New Display"}]
@@ -1090,7 +1094,7 @@ def test_regenerate_agent_name_batch_api_success(mocker, mock_auth_header):
def test_regenerate_agent_name_batch_api_bad_request(mocker, mock_auth_header):
mock_impl = mocker.patch(
"apps.agent_app.regenerate_agent_name_batch_impl",
- new_callable=mocker.AsyncMock,
+ new_callable=AsyncMock,
)
mock_impl.side_effect = ValueError("invalid")
@@ -1107,7 +1111,7 @@ def test_regenerate_agent_name_batch_api_bad_request(mocker, mock_auth_header):
def test_regenerate_agent_name_batch_api_error(mocker, mock_auth_header):
mock_impl = mocker.patch(
"apps.agent_app.regenerate_agent_name_batch_impl",
- new_callable=mocker.AsyncMock,
+ new_callable=AsyncMock,
)
mock_impl.side_effect = Exception("boom")
@@ -1134,7 +1138,7 @@ def test_clear_agent_new_mark_api_success(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_get_user_info = mocker.patch("apps.agent_app.get_current_user_info")
mock_clear_agent_new_mark = mocker.patch(
- "apps.agent_app.clear_agent_new_mark_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.clear_agent_new_mark_impl", new_callable=AsyncMock)
# Mock the auth utility to return user info
mock_get_user_info.return_value = ("test_user_id", "test_tenant_id", "extra_info")
@@ -1171,7 +1175,7 @@ def test_clear_agent_new_mark_api_exception(mocker, mock_auth_header):
# Setup mocks using pytest-mock
mock_get_user_info = mocker.patch("apps.agent_app.get_current_user_info")
mock_clear_agent_new_mark = mocker.patch(
- "apps.agent_app.clear_agent_new_mark_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.clear_agent_new_mark_impl", new_callable=AsyncMock)
mock_logger = mocker.patch("apps.agent_app.logger")
# Mock the auth utility to return user info
@@ -1904,7 +1908,7 @@ def test_list_published_agents_api_success(mocker, mock_auth_header):
"""Test successful published agents list retrieval"""
mock_get_user_info = mocker.patch("apps.agent_app.get_current_user_info")
mock_list_published_agents = mocker.patch(
- "apps.agent_app.list_published_agents_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.list_published_agents_impl", new_callable=AsyncMock)
mock_get_user_info.return_value = ("test_user_id", "test_tenant_id", "en")
mock_list_published_agents.return_value = [
@@ -1941,7 +1945,7 @@ def test_list_published_agents_api_exception(mocker, mock_auth_header):
"""Test list published agents with exception"""
mock_get_user_info = mocker.patch("apps.agent_app.get_current_user_info")
mock_list_published_agents = mocker.patch(
- "apps.agent_app.list_published_agents_impl", new_callable=mocker.AsyncMock)
+ "apps.agent_app.list_published_agents_impl", new_callable=AsyncMock)
mock_get_user_info.return_value = ("test_user_id", "test_tenant_id", "en")
mock_list_published_agents.side_effect = Exception("Database error")
diff --git a/test/backend/app/test_config_sync_app.py b/test/backend/app/test_config_sync_app.py
index 80aaaf3fb..82c5f4e23 100644
--- a/test/backend/app/test_config_sync_app.py
+++ b/test/backend/app/test_config_sync_app.py
@@ -1,5 +1,7 @@
import os
import sys
+import types
+import importlib.machinery
from unittest.mock import patch, MagicMock
import pytest
@@ -14,8 +16,11 @@
sys.path.append(backend_dir)
# Patch boto3 and other dependencies before importing anything from backend
-boto3_mock = MagicMock()
-sys.modules['boto3'] = boto3_mock
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
# Apply critical patches before importing any modules
# This prevents real AWS/MinIO/Elasticsearch calls during import
diff --git a/test/backend/app/test_conversation_management_app.py b/test/backend/app/test_conversation_management_app.py
index b5db691aa..c712ef011 100644
--- a/test/backend/app/test_conversation_management_app.py
+++ b/test/backend/app/test_conversation_management_app.py
@@ -1,5 +1,7 @@
import os
import sys
+import types
+import importlib.machinery
from unittest.mock import patch, MagicMock
import pytest
@@ -11,8 +13,11 @@
sys.path.append(backend_dir)
# Patch boto3 before importing backend modules (some services may rely on it)
-boto3_mock = MagicMock()
-sys.modules['boto3'] = boto3_mock
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
# Apply critical patches before importing any modules
# This prevents real AWS/MinIO/Elasticsearch calls during import
diff --git a/test/backend/app/test_datamate_app.py b/test/backend/app/test_datamate_app.py
index ce9c66cc4..46e67af5a 100644
--- a/test/backend/app/test_datamate_app.py
+++ b/test/backend/app/test_datamate_app.py
@@ -1,5 +1,7 @@
import sys
import os
+import types
+import importlib.machinery
from unittest.mock import patch, MagicMock, AsyncMock, call
import pytest
@@ -16,8 +18,11 @@
sys.path.insert(0, backend_dir)
# Patch boto3 and other dependencies before importing anything from backend
-boto3_mock = MagicMock()
-sys.modules['boto3'] = boto3_mock
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
# Apply critical patches before importing any modules
# This prevents real AWS/MinIO/Elasticsearch calls during import
diff --git a/test/backend/app/test_file_management_app.py b/test/backend/app/test_file_management_app.py
index fc33db8fb..eff392537 100644
--- a/test/backend/app/test_file_management_app.py
+++ b/test/backend/app/test_file_management_app.py
@@ -129,11 +129,12 @@ async def _stub_trigger_data_process(files: List[dict], params: Any):
model_stub = types.ModuleType("consts.model")
class ProcessParams: # minimal stub
- def __init__(self, chunking_strategy: str, source_type: str, index_name: str, authorization: str | None):
+ def __init__(self, chunking_strategy: str, source_type: str, index_name: str, authorization: str | None, model_id: int | None = None):
self.chunking_strategy = chunking_strategy
self.source_type = source_type
self.index_name = index_name
self.authorization = authorization
+ self.model_id = model_id
model_stub.ProcessParams = ProcessParams
sys.modules.setdefault("consts.model", model_stub)
setattr(consts_pkg, "model", model_stub)
@@ -249,6 +250,7 @@ async def fake_trigger(files, params):
index_name="kb1",
destination="local",
authorization="Bearer x",
+ model_id=1,
)
assert resp.status_code == 201
assert "Files processing triggered successfully" in resp.body.decode()
@@ -267,6 +269,7 @@ async def fake_trigger(files, params):
index_name="kb",
destination="local",
authorization=None,
+ model_id=1,
)
assert "Data process service failed" in str(ei.value)
@@ -284,6 +287,7 @@ async def fake_trigger(files, params):
index_name="kb",
destination="local",
authorization=None,
+ model_id=1,
)
assert "boom" in str(ei.value)
diff --git a/test/backend/app/test_group_app.py b/test/backend/app/test_group_app.py
index 6b93bfea0..bec100c5c 100644
--- a/test/backend/app/test_group_app.py
+++ b/test/backend/app/test_group_app.py
@@ -1,3 +1,5 @@
+import types
+import importlib.machinery
import pytest
from unittest.mock import patch, MagicMock, AsyncMock
import sys
@@ -8,7 +10,11 @@
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../backend"))
# Mock external dependencies
-sys.modules['boto3'] = MagicMock()
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
sys.modules['psycopg2'] = MagicMock()
sys.modules['supabase'] = MagicMock()
diff --git a/test/backend/app/test_invitation_app.py b/test/backend/app/test_invitation_app.py
index 7d8e15a66..5e85e7f88 100644
--- a/test/backend/app/test_invitation_app.py
+++ b/test/backend/app/test_invitation_app.py
@@ -1,3 +1,5 @@
+import types
+import importlib.machinery
import pytest
from unittest.mock import patch, MagicMock, AsyncMock
import sys
@@ -8,7 +10,11 @@
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../backend"))
# Mock external dependencies
-sys.modules['boto3'] = MagicMock()
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
sys.modules['psycopg2'] = MagicMock()
sys.modules['supabase'] = MagicMock()
diff --git a/test/backend/app/test_knowledge_summary_app.py b/test/backend/app/test_knowledge_summary_app.py
index 76e660839..56bfc5e08 100644
--- a/test/backend/app/test_knowledge_summary_app.py
+++ b/test/backend/app/test_knowledge_summary_app.py
@@ -2,6 +2,7 @@
import sys
import os
import types
+import importlib.machinery
from unittest.mock import patch, MagicMock, AsyncMock
# Add path for correct imports
@@ -15,7 +16,11 @@
# Environment variables are now configured in conftest.py
# Mock external dependencies
-sys.modules['boto3'] = MagicMock()
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
sys.modules['botocore'] = MagicMock()
sys.modules['botocore.client'] = MagicMock()
sys.modules['botocore.exceptions'] = MagicMock()
diff --git a/test/backend/app/test_memory_config_app.py b/test/backend/app/test_memory_config_app.py
index 622bd8012..db91f2ee9 100644
--- a/test/backend/app/test_memory_config_app.py
+++ b/test/backend/app/test_memory_config_app.py
@@ -1,10 +1,16 @@
+import types
+import importlib.machinery
from unittest.mock import patch, MagicMock, AsyncMock
import sys
import os
# Add path for correct imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../backend"))
-sys.modules['boto3'] = MagicMock()
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
# Apply critical patches before importing any modules
# This prevents real AWS/MinIO/Elasticsearch calls during import
diff --git a/test/backend/app/test_mock_user_management_app.py b/test/backend/app/test_mock_user_management_app.py
index 86348c72e..7d694c442 100644
--- a/test/backend/app/test_mock_user_management_app.py
+++ b/test/backend/app/test_mock_user_management_app.py
@@ -1,3 +1,5 @@
+import types
+
import pytest
from unittest.mock import patch, MagicMock, AsyncMock
import sys
@@ -11,7 +13,12 @@
boto3_mock = MagicMock()
minio_client_mock = MagicMock()
-sys.modules['boto3'] = boto3_mock
+import importlib.machinery
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
# Patch storage factory and MinIO config validation to avoid errors during initialization
# These patches must be started before any imports that use MinioClient
diff --git a/test/backend/app/test_model_managment_app.py b/test/backend/app/test_model_managment_app.py
index 20f3210e2..ade705667 100644
--- a/test/backend/app/test_model_managment_app.py
+++ b/test/backend/app/test_model_managment_app.py
@@ -48,7 +48,7 @@ def _get_vector_db_core(): # minimal stub
_sys.modules["services.vectordatabase_service"] = services_vdb_mod
# Import after mocking (only backend path is required by app imports)
- from apps.model_managment_app import router
+ from backend.apps.model_managment_app import router
# Create test client
app = FastAPI()
@@ -86,12 +86,12 @@ def sample_model_data():
@pytest.mark.asyncio
async def test_create_model_success(client, auth_header, user_credentials, sample_model_data, mocker):
"""Test successful model creation."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def _create(*args, **kwargs):
return None
- mock_create = mocker.patch('apps.model_managment_app.create_model_for_tenant', side_effect=_create)
+ mock_create = mocker.patch('backend.apps.model_managment_app.create_model_for_tenant', side_effect=_create)
response = client.post(
"/model/create", json=sample_model_data, headers=auth_header)
@@ -105,10 +105,10 @@ async def _create(*args, **kwargs):
@pytest.mark.asyncio
async def test_create_model_conflict(client, auth_header, user_credentials, sample_model_data, mocker):
"""Test model creation with name conflict."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
mock_create = mocker.patch(
- 'apps.model_managment_app.create_model_for_tenant',
+ 'backend.apps.model_managment_app.create_model_for_tenant',
side_effect=ValueError("Name 'Test Model' is already in use, please choose another display name")
)
@@ -125,10 +125,10 @@ async def test_create_model_conflict(client, auth_header, user_credentials, samp
@pytest.mark.asyncio
async def test_create_model_exception(client, auth_header, user_credentials, sample_model_data, mocker):
"""Test model creation with internal error."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
mock_create = mocker.patch(
- 'apps.model_managment_app.create_model_for_tenant',
+ 'backend.apps.model_managment_app.create_model_for_tenant',
side_effect=Exception("DB failure")
)
@@ -146,10 +146,10 @@ async def test_create_model_exception(client, auth_header, user_credentials, sam
@pytest.mark.asyncio
async def test_create_provider_model_success(client, auth_header, user_credentials, mocker):
"""Test successful provider model creation."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
mock_get = mocker.patch(
- 'apps.model_managment_app.create_provider_models_for_tenant',
+ 'backend.apps.model_managment_app.create_provider_models_for_tenant',
return_value=[{"id": "A1"}, {"id": "a0"}, {"id": "b2"}, {"id": "c3"}]
)
@@ -169,10 +169,10 @@ async def test_create_provider_model_success(client, auth_header, user_credentia
@pytest.mark.asyncio
async def test_create_provider_model_exception(client, auth_header, user_credentials, mocker):
"""Test provider model creation with exception."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
mock_get = mocker.patch(
- 'apps.model_managment_app.create_provider_models_for_tenant',
+ 'backend.apps.model_managment_app.create_provider_models_for_tenant',
side_effect=Exception("Provider API error")
)
@@ -192,12 +192,12 @@ async def test_create_provider_model_exception(client, auth_header, user_credent
@pytest.mark.asyncio
async def test_provider_batch_create_success(client, auth_header, user_credentials, mocker):
"""Test successful batch model creation."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def _batch(*args, **kwargs):
return None
- mock_batch = mocker.patch('apps.model_managment_app.batch_create_models_for_tenant', side_effect=_batch)
+ mock_batch = mocker.patch('backend.apps.model_managment_app.batch_create_models_for_tenant', side_effect=_batch)
payload = {
"models": [{"id": "prov/modelA"}],
@@ -217,10 +217,10 @@ async def _batch(*args, **kwargs):
@pytest.mark.asyncio
async def test_provider_batch_create_exception(client, auth_header, user_credentials, mocker):
"""Test batch model creation with exception."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
mock_batch = mocker.patch(
- 'apps.model_managment_app.batch_create_models_for_tenant',
+ 'backend.apps.model_managment_app.batch_create_models_for_tenant',
side_effect=Exception("boom")
)
@@ -244,12 +244,12 @@ async def test_provider_batch_create_exception(client, auth_header, user_credent
@pytest.mark.asyncio
async def test_delete_model_success(client, auth_header, user_credentials, mocker):
"""Test successful model deletion."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def _delete(*args, **kwargs):
return "Test Model"
- mock_del = mocker.patch('apps.model_managment_app.delete_model_for_tenant', side_effect=_delete)
+ mock_del = mocker.patch('backend.apps.model_managment_app.delete_model_for_tenant', side_effect=_delete)
response = client.post(
"/model/delete", params={"display_name": "Test Model"}, headers=auth_header)
@@ -264,10 +264,10 @@ async def _delete(*args, **kwargs):
@pytest.mark.asyncio
async def test_delete_model_not_found(client, auth_header, user_credentials, mocker):
"""Test model deletion when model not found."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
mock_del = mocker.patch(
- 'apps.model_managment_app.delete_model_for_tenant',
+ 'backend.apps.model_managment_app.delete_model_for_tenant',
side_effect=LookupError("Model not found: Missing")
)
@@ -285,7 +285,7 @@ async def test_delete_model_not_found(client, auth_header, user_credentials, moc
@pytest.mark.asyncio
async def test_get_model_list_success(client, auth_header, user_credentials, mocker):
"""Test successful model list retrieval."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_list_models(*args, **kwargs):
return [
@@ -305,7 +305,7 @@ async def mock_list_models(*args, **kwargs):
}
]
- mock_list = mocker.patch('apps.model_managment_app.list_models_for_tenant', side_effect=mock_list_models)
+ mock_list = mocker.patch('backend.apps.model_managment_app.list_models_for_tenant', side_effect=mock_list_models)
response = client.get("/model/list", headers=auth_header)
@@ -323,7 +323,7 @@ async def mock_list_models(*args, **kwargs):
@pytest.mark.asyncio
async def test_get_llm_model_list_success(client, auth_header, user_credentials, mocker):
"""Test successful LLM model list retrieval."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_list_llm_models(*args, **kwargs):
return [
@@ -341,7 +341,7 @@ async def mock_list_llm_models(*args, **kwargs):
}
]
- mock_list = mocker.patch('apps.model_managment_app.list_llm_models_for_tenant', side_effect=mock_list_llm_models)
+ mock_list = mocker.patch('backend.apps.model_managment_app.list_llm_models_for_tenant', side_effect=mock_list_llm_models)
response = client.get("/model/llm_list", headers=auth_header)
@@ -359,12 +359,12 @@ async def mock_list_llm_models(*args, **kwargs):
@pytest.mark.asyncio
async def test_get_llm_model_list_exception(client, auth_header, user_credentials, mocker):
"""Test LLM model list retrieval with exception."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_list_llm_models(*args, **kwargs):
raise Exception("Database connection error")
- mocker.patch('apps.model_managment_app.list_llm_models_for_tenant', side_effect=mock_list_llm_models)
+ mocker.patch('backend.apps.model_managment_app.list_llm_models_for_tenant', side_effect=mock_list_llm_models)
response = client.get("/model/llm_list", headers=auth_header)
@@ -377,12 +377,12 @@ async def mock_list_llm_models(*args, **kwargs):
@pytest.mark.asyncio
async def test_get_llm_model_list_empty(client, auth_header, user_credentials, mocker):
"""Test LLM model list retrieval with empty result."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_list_llm_models(*args, **kwargs):
return []
- mock_list = mocker.patch('apps.model_managment_app.list_llm_models_for_tenant', side_effect=mock_list_llm_models)
+ mock_list = mocker.patch('backend.apps.model_managment_app.list_llm_models_for_tenant', side_effect=mock_list_llm_models)
response = client.get("/model/llm_list", headers=auth_header)
@@ -397,16 +397,16 @@ async def mock_list_llm_models(*args, **kwargs):
@pytest.mark.asyncio
async def test_check_model_health_success(client, auth_header, user_credentials, mocker):
"""Test successful model health check."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
mock_check = mocker.patch(
- 'apps.model_managment_app.check_model_connectivity',
+ 'backend.apps.model_managment_app.check_model_connectivity',
return_value={"connectivity": True, "connect_status": "available"}
)
response = client.post(
"/model/healthcheck",
- params={"display_name": "Test Model"},
+ params={"display_name": "Test Model", "model_type": "embedding"},
headers=auth_header
)
@@ -414,22 +414,22 @@ async def test_check_model_health_success(client, auth_header, user_credentials,
data = response.json()
assert data["message"] == "Successfully checked model connectivity"
assert data["data"]["connectivity"] is True
- mock_check.assert_called_once_with("Test Model", user_credentials[1])
+ mock_check.assert_called_once_with("Test Model", user_credentials[1], "embedding")
@pytest.mark.asyncio
async def test_check_model_health_lookup_error(client, auth_header, user_credentials, mocker):
"""Test model health check with lookup error."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
mocker.patch(
- 'apps.model_managment_app.check_model_connectivity',
+ 'backend.apps.model_managment_app.check_model_connectivity',
side_effect=LookupError("missing")
)
response = client.post(
"/model/healthcheck",
- params={"display_name": "X"},
+ params={"display_name": "X", "model_type": "embedding"},
headers=auth_header
)
assert response.status_code == HTTPStatus.NOT_FOUND
@@ -440,7 +440,7 @@ async def test_check_model_health_lookup_error(client, auth_header, user_credent
async def test_verify_model_config_success(client, auth_header, sample_model_data, mocker):
"""Test successful model config verification."""
mock_verify = mocker.patch(
- 'apps.model_managment_app.verify_model_config_connectivity',
+ 'backend.apps.model_managment_app.verify_model_config_connectivity',
return_value={"connectivity": True, "model_name": "gpt-4"}
)
@@ -460,7 +460,7 @@ async def test_verify_model_config_success(client, auth_header, sample_model_dat
async def test_verify_model_config_failure_with_error(client, auth_header, sample_model_data, mocker):
"""Test model config verification failure with detailed error message."""
mock_verify = mocker.patch(
- 'apps.model_managment_app.verify_model_config_connectivity',
+ 'backend.apps.model_managment_app.verify_model_config_connectivity',
return_value={
"connectivity": False,
"model_name": "gpt-4",
@@ -486,7 +486,7 @@ async def test_verify_model_config_failure_with_error(client, auth_header, sampl
async def test_verify_model_config_exception(client, auth_header, sample_model_data, mocker):
"""Test model config verification with exception."""
mocker.patch(
- 'apps.model_managment_app.verify_model_config_connectivity',
+ 'backend.apps.model_managment_app.verify_model_config_connectivity',
side_effect=Exception("err")
)
@@ -499,12 +499,12 @@ async def test_verify_model_config_exception(client, auth_header, sample_model_d
@pytest.mark.asyncio
async def test_update_single_model_success(client, auth_header, user_credentials, mocker):
"""Test successful single model update."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_update_single(*args, **kwargs):
return None
- mock_update = mocker.patch('apps.model_managment_app.update_single_model_for_tenant', side_effect=mock_update_single)
+ mock_update = mocker.patch('backend.apps.model_managment_app.update_single_model_for_tenant', side_effect=mock_update_single)
update_data = {
"model_id": "test_model_id",
@@ -536,10 +536,10 @@ async def mock_update_single(*args, **kwargs):
@pytest.mark.asyncio
async def test_update_single_model_conflict(client, auth_header, user_credentials, mocker):
"""Test single model update with name conflict."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
mock_update = mocker.patch(
- 'apps.model_managment_app.update_single_model_for_tenant',
+ 'backend.apps.model_managment_app.update_single_model_for_tenant',
side_effect=ValueError("Name 'Conflicting Name' is already in use, please choose another display name"),
)
@@ -575,12 +575,12 @@ async def test_update_single_model_conflict(client, auth_header, user_credential
@pytest.mark.asyncio
async def test_batch_update_models_success(client, auth_header, user_credentials, mocker):
"""Test successful batch model update."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_batch_update(*args, **kwargs):
return None
- mock_batch_update = mocker.patch('apps.model_managment_app.batch_update_models_for_tenant', side_effect=mock_batch_update)
+ mock_batch_update = mocker.patch('backend.apps.model_managment_app.batch_update_models_for_tenant', side_effect=mock_batch_update)
models = [
{"model_id": "id1", "api_key": "k1", "max_tokens": 100},
@@ -598,12 +598,12 @@ async def mock_batch_update(*args, **kwargs):
@pytest.mark.asyncio
async def test_batch_update_models_exception(client, auth_header, user_credentials, mocker):
"""Test batch model update with exception."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_batch_update(*args, **kwargs):
raise Exception("Update failed")
- mock_batch_update = mocker.patch('apps.model_managment_app.batch_update_models_for_tenant', side_effect=mock_batch_update)
+ mock_batch_update = mocker.patch('backend.apps.model_managment_app.batch_update_models_for_tenant', side_effect=mock_batch_update)
models = [{"model_id": "id1", "api_key": "k1"}]
response = client.post(
@@ -620,7 +620,7 @@ async def mock_batch_update(*args, **kwargs):
@pytest.mark.asyncio
async def test_get_manage_model_list_success(client, auth_header, user_credentials, mocker):
"""Test successful manage model list retrieval for a specified tenant."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_list_models_for_admin(*args, **kwargs):
return {
@@ -648,7 +648,7 @@ async def mock_list_models_for_admin(*args, **kwargs):
"total_pages": 1
}
- mock_list = mocker.patch('apps.model_managment_app.list_models_for_admin', side_effect=mock_list_models_for_admin)
+ mock_list = mocker.patch('backend.apps.model_managment_app.list_models_for_admin', side_effect=mock_list_models_for_admin)
request_data = {
"tenant_id": "target_tenant",
@@ -676,7 +676,7 @@ async def mock_list_models_for_admin(*args, **kwargs):
@pytest.mark.asyncio
async def test_get_manage_model_list_with_pagination(client, auth_header, user_credentials, mocker):
"""Test manage model list retrieval with pagination parameters."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_list_models_for_admin(*args, **kwargs):
return {
@@ -697,7 +697,7 @@ async def mock_list_models_for_admin(*args, **kwargs):
"total_pages": 3
}
- mock_list = mocker.patch('apps.model_managment_app.list_models_for_admin', side_effect=mock_list_models_for_admin)
+ mock_list = mocker.patch('backend.apps.model_managment_app.list_models_for_admin', side_effect=mock_list_models_for_admin)
request_data = {
"tenant_id": "target_tenant",
@@ -720,12 +720,12 @@ async def mock_list_models_for_admin(*args, **kwargs):
@pytest.mark.asyncio
async def test_get_manage_model_list_exception(client, auth_header, user_credentials, mocker):
"""Test manage model list retrieval with exception."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_list_models_for_admin(*args, **kwargs):
raise Exception("Database connection error")
- mocker.patch('apps.model_managment_app.list_models_for_admin', side_effect=mock_list_models_for_admin)
+ mocker.patch('backend.apps.model_managment_app.list_models_for_admin', side_effect=mock_list_models_for_admin)
request_data = {
"tenant_id": "target_tenant",
@@ -743,7 +743,7 @@ async def mock_list_models_for_admin(*args, **kwargs):
@pytest.mark.asyncio
async def test_get_manage_model_list_empty(client, auth_header, user_credentials, mocker):
"""Test manage model list retrieval with empty result."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_list_models_for_admin(*args, **kwargs):
return {
@@ -756,7 +756,7 @@ async def mock_list_models_for_admin(*args, **kwargs):
"total_pages": 0
}
- mock_list = mocker.patch('apps.model_managment_app.list_models_for_admin', side_effect=mock_list_models_for_admin)
+ mock_list = mocker.patch('backend.apps.model_managment_app.list_models_for_admin', side_effect=mock_list_models_for_admin)
request_data = {
"tenant_id": "empty_tenant",
@@ -778,12 +778,12 @@ async def mock_list_models_for_admin(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_create_model_success(client, auth_header, user_credentials, mocker):
"""Test successful model creation for a specified tenant."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def _create(*args, **kwargs):
return None
- mock_create = mocker.patch('apps.model_managment_app.create_model_for_tenant', side_effect=_create)
+ mock_create = mocker.patch('backend.apps.model_managment_app.create_model_for_tenant', side_effect=_create)
request_data = {
"tenant_id": "target_tenant",
@@ -812,12 +812,12 @@ async def _create(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_create_model_conflict(client, auth_header, user_credentials, mocker):
"""Test model creation with conflict error."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def _create(*args, **kwargs):
raise ValueError("Model name already exists")
- mocker.patch('apps.model_managment_app.create_model_for_tenant', side_effect=_create)
+ mocker.patch('backend.apps.model_managment_app.create_model_for_tenant', side_effect=_create)
request_data = {
"tenant_id": "target_tenant",
@@ -835,12 +835,12 @@ async def _create(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_create_model_exception(client, auth_header, user_credentials, mocker):
"""Test model creation with unexpected exception."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def _create(*args, **kwargs):
raise Exception("Database error")
- mocker.patch('apps.model_managment_app.create_model_for_tenant', side_effect=_create)
+ mocker.patch('backend.apps.model_managment_app.create_model_for_tenant', side_effect=_create)
request_data = {
"tenant_id": "target_tenant",
@@ -858,12 +858,12 @@ async def _create(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_update_model_success(client, auth_header, user_credentials, mocker):
"""Test successful model update for a specified tenant."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def _update(*args, **kwargs):
return None
- mock_update = mocker.patch('apps.model_managment_app.update_single_model_for_tenant', side_effect=_update)
+ mock_update = mocker.patch('backend.apps.model_managment_app.update_single_model_for_tenant', side_effect=_update)
request_data = {
"tenant_id": "target_tenant",
@@ -891,12 +891,12 @@ async def _update(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_update_model_not_found(client, auth_header, user_credentials, mocker):
"""Test model update with not found error."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def _update(*args, **kwargs):
raise LookupError("Model not found")
- mocker.patch('apps.model_managment_app.update_single_model_for_tenant', side_effect=_update)
+ mocker.patch('backend.apps.model_managment_app.update_single_model_for_tenant', side_effect=_update)
request_data = {
"tenant_id": "target_tenant",
@@ -911,12 +911,12 @@ async def _update(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_update_model_conflict(client, auth_header, user_credentials, mocker):
"""Test model update with conflict error."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def _update(*args, **kwargs):
raise ValueError("Display name already exists")
- mocker.patch('apps.model_managment_app.update_single_model_for_tenant', side_effect=_update)
+ mocker.patch('backend.apps.model_managment_app.update_single_model_for_tenant', side_effect=_update)
request_data = {
"tenant_id": "target_tenant",
@@ -932,12 +932,12 @@ async def _update(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_delete_model_success(client, auth_header, user_credentials, mocker):
"""Test successful model deletion for a specified tenant."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def _delete(*args, **kwargs):
return "test-model"
- mock_delete = mocker.patch('apps.model_managment_app.delete_model_for_tenant', side_effect=_delete)
+ mock_delete = mocker.patch('backend.apps.model_managment_app.delete_model_for_tenant', side_effect=_delete)
request_data = {
"tenant_id": "target_tenant",
@@ -956,12 +956,12 @@ async def _delete(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_delete_model_not_found(client, auth_header, user_credentials, mocker):
"""Test model deletion with not found error."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def _delete(*args, **kwargs):
raise LookupError("Model not found")
- mocker.patch('apps.model_managment_app.delete_model_for_tenant', side_effect=_delete)
+ mocker.patch('backend.apps.model_managment_app.delete_model_for_tenant', side_effect=_delete)
request_data = {
"tenant_id": "target_tenant",
@@ -975,12 +975,12 @@ async def _delete(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_delete_model_exception(client, auth_header, user_credentials, mocker):
"""Test model deletion with unexpected exception."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def _delete(*args, **kwargs):
raise Exception("Database error")
- mocker.patch('apps.model_managment_app.delete_model_for_tenant', side_effect=_delete)
+ mocker.patch('backend.apps.model_managment_app.delete_model_for_tenant', side_effect=_delete)
request_data = {
"tenant_id": "target_tenant",
@@ -995,12 +995,12 @@ async def _delete(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_batch_create_models_success(client, auth_header, user_credentials, mocker):
"""Test successful batch model creation for a specified tenant."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def _batch_create(*args, **kwargs):
return None
- mock_batch_create = mocker.patch('apps.model_managment_app.batch_create_models_for_tenant', side_effect=_batch_create)
+ mock_batch_create = mocker.patch('backend.apps.model_managment_app.batch_create_models_for_tenant', side_effect=_batch_create)
request_data = {
"tenant_id": "target_tenant",
@@ -1064,12 +1064,12 @@ async def _batch_create(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_batch_create_models_empty_list(client, auth_header, user_credentials, mocker):
"""Test batch model creation with empty models list."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def _batch_create(*args, **kwargs):
return None
- mock_batch_create = mocker.patch('apps.model_managment_app.batch_create_models_for_tenant', side_effect=_batch_create)
+ mock_batch_create = mocker.patch('backend.apps.model_managment_app.batch_create_models_for_tenant', side_effect=_batch_create)
request_data = {
"tenant_id": "target_tenant",
@@ -1090,12 +1090,12 @@ async def _batch_create(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_batch_create_models_exception(client, auth_header, user_credentials, mocker):
"""Test batch model creation with exception."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def _batch_create(*args, **kwargs):
raise Exception("Database connection error")
- mocker.patch('apps.model_managment_app.batch_create_models_for_tenant', side_effect=_batch_create)
+ mocker.patch('backend.apps.model_managment_app.batch_create_models_for_tenant', side_effect=_batch_create)
request_data = {
"tenant_id": "target_tenant",
@@ -1115,10 +1115,10 @@ async def _batch_create(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_healthcheck_success(client, auth_header, user_credentials, mocker):
"""Test successful model connectivity check for a specified tenant."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
mock_check = mocker.patch(
- 'apps.model_managment_app.check_model_connectivity',
+ 'backend.apps.model_managment_app.check_model_connectivity',
return_value={"connectivity": True, "connect_status": "available"}
)
@@ -1138,10 +1138,10 @@ async def test_manage_healthcheck_success(client, auth_header, user_credentials,
@pytest.mark.asyncio
async def test_manage_healthcheck_model_not_found(client, auth_header, user_credentials, mocker):
"""Test model connectivity check when model is not found."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
mocker.patch(
- 'apps.model_managment_app.check_model_connectivity',
+ 'backend.apps.model_managment_app.check_model_connectivity',
side_effect=LookupError("Model configuration not found for test-model")
)
@@ -1158,10 +1158,10 @@ async def test_manage_healthcheck_model_not_found(client, auth_header, user_cred
@pytest.mark.asyncio
async def test_manage_healthcheck_invalid_config(client, auth_header, user_credentials, mocker):
"""Test model connectivity check with invalid model configuration."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
mocker.patch(
- 'apps.model_managment_app.check_model_connectivity',
+ 'backend.apps.model_managment_app.check_model_connectivity',
side_effect=ValueError("Invalid model configuration")
)
@@ -1178,10 +1178,10 @@ async def test_manage_healthcheck_invalid_config(client, auth_header, user_crede
@pytest.mark.asyncio
async def test_manage_healthcheck_exception(client, auth_header, user_credentials, mocker):
"""Test model connectivity check with unexpected exception."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
mocker.patch(
- 'apps.model_managment_app.check_model_connectivity',
+ 'backend.apps.model_managment_app.check_model_connectivity',
side_effect=Exception("Database connection error")
)
@@ -1198,7 +1198,7 @@ async def test_manage_healthcheck_exception(client, auth_header, user_credential
@pytest.mark.asyncio
async def test_manage_provider_list_success(client, auth_header, user_credentials, mocker):
"""Test successful provider model list retrieval for a specified tenant."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_list_provider_models(*args, **kwargs):
return [
@@ -1222,7 +1222,7 @@ async def mock_list_provider_models(*args, **kwargs):
}
]
- mock_list = mocker.patch('apps.model_managment_app.list_provider_models_for_tenant', side_effect=mock_list_provider_models)
+ mock_list = mocker.patch('backend.apps.model_managment_app.list_provider_models_for_tenant', side_effect=mock_list_provider_models)
request_data = {
"tenant_id": "target_tenant",
@@ -1241,12 +1241,12 @@ async def mock_list_provider_models(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_provider_list_exception(client, auth_header, user_credentials, mocker):
"""Test provider model list retrieval with exception."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_list_provider_models(*args, **kwargs):
raise Exception("Provider API error")
- mocker.patch('apps.model_managment_app.list_provider_models_for_tenant', side_effect=mock_list_provider_models)
+ mocker.patch('backend.apps.model_managment_app.list_provider_models_for_tenant', side_effect=mock_list_provider_models)
request_data = {
"tenant_id": "target_tenant",
@@ -1261,12 +1261,12 @@ async def mock_list_provider_models(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_provider_list_empty(client, auth_header, user_credentials, mocker):
"""Test provider model list retrieval with empty result."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_list_provider_models(*args, **kwargs):
return []
- mock_list = mocker.patch('apps.model_managment_app.list_provider_models_for_tenant', side_effect=mock_list_provider_models)
+ mock_list = mocker.patch('backend.apps.model_managment_app.list_provider_models_for_tenant', side_effect=mock_list_provider_models)
request_data = {
"tenant_id": "empty_tenant",
@@ -1284,7 +1284,7 @@ async def mock_list_provider_models(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_provider_create_success(client, auth_header, user_credentials, mocker):
"""Test successful provider model creation for a specified tenant."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_create_provider_models(*args, **kwargs):
return [
@@ -1304,7 +1304,7 @@ async def mock_create_provider_models(*args, **kwargs):
}
]
- mock_create = mocker.patch('apps.model_managment_app.create_provider_models_for_tenant', side_effect=mock_create_provider_models)
+ mock_create = mocker.patch('backend.apps.model_managment_app.create_provider_models_for_tenant', side_effect=mock_create_provider_models)
request_data = {
"tenant_id": "target_tenant",
@@ -1328,7 +1328,7 @@ async def mock_create_provider_models(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_provider_create_with_base_url(client, auth_header, user_credentials, mocker):
"""Test provider model creation with base URL for modelengine provider."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_create_provider_models(*args, **kwargs):
return [
@@ -1341,7 +1341,7 @@ async def mock_create_provider_models(*args, **kwargs):
}
]
- mock_create = mocker.patch('apps.model_managment_app.create_provider_models_for_tenant', side_effect=mock_create_provider_models)
+ mock_create = mocker.patch('backend.apps.model_managment_app.create_provider_models_for_tenant', side_effect=mock_create_provider_models)
request_data = {
"tenant_id": "target_tenant",
@@ -1362,12 +1362,12 @@ async def mock_create_provider_models(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_provider_create_exception(client, auth_header, user_credentials, mocker):
"""Test provider model creation with exception."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_create_provider_models(*args, **kwargs):
raise Exception("Provider API error")
- mocker.patch('apps.model_managment_app.create_provider_models_for_tenant', side_effect=mock_create_provider_models)
+ mocker.patch('backend.apps.model_managment_app.create_provider_models_for_tenant', side_effect=mock_create_provider_models)
request_data = {
"tenant_id": "target_tenant",
@@ -1384,12 +1384,12 @@ async def mock_create_provider_models(*args, **kwargs):
@pytest.mark.asyncio
async def test_manage_provider_create_empty(client, auth_header, user_credentials, mocker):
"""Test provider model creation with empty result."""
- mocker.patch('apps.model_managment_app.get_current_user_id', return_value=user_credentials)
+ mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials)
async def mock_create_provider_models(*args, **kwargs):
return []
- mock_create = mocker.patch('apps.model_managment_app.create_provider_models_for_tenant', side_effect=mock_create_provider_models)
+ mock_create = mocker.patch('backend.apps.model_managment_app.create_provider_models_for_tenant', side_effect=mock_create_provider_models)
request_data = {
"tenant_id": "target_tenant",
@@ -1406,4 +1406,4 @@ async def mock_create_provider_models(*args, **kwargs):
if __name__ == "__main__":
- pytest.main([__file__])
\ No newline at end of file
+ pytest.main([__file__])
diff --git a/test/backend/app/test_remote_mcp_app.py b/test/backend/app/test_remote_mcp_app.py
index d8701cb9d..1279bb79f 100644
--- a/test/backend/app/test_remote_mcp_app.py
+++ b/test/backend/app/test_remote_mcp_app.py
@@ -1,10 +1,16 @@
+import types
+import importlib.machinery
from unittest.mock import patch, MagicMock, AsyncMock
import sys
import os
# Add path for correct imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../backend"))
-sys.modules['boto3'] = MagicMock()
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
# Apply critical patches before importing any modules
# This prevents real AWS/MinIO/Elasticsearch calls during import
diff --git a/test/backend/app/test_tenant_app.py b/test/backend/app/test_tenant_app.py
index d9f557d97..6cc59e013 100644
--- a/test/backend/app/test_tenant_app.py
+++ b/test/backend/app/test_tenant_app.py
@@ -1,3 +1,5 @@
+import types
+import importlib.machinery
import pytest
from unittest.mock import patch, MagicMock, AsyncMock
import sys
@@ -8,7 +10,11 @@
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../backend"))
# Mock external dependencies
-sys.modules['boto3'] = MagicMock()
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
sys.modules['psycopg2'] = MagicMock()
sys.modules['supabase'] = MagicMock()
diff --git a/test/backend/app/test_tool_config_app.py b/test/backend/app/test_tool_config_app.py
index 17a64434d..31da4b761 100644
--- a/test/backend/app/test_tool_config_app.py
+++ b/test/backend/app/test_tool_config_app.py
@@ -1,11 +1,16 @@
+import types
+import importlib.machinery
from unittest.mock import patch, MagicMock
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../backend"))
-sys.modules['boto3'] = MagicMock()
-
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
patch('botocore.client.BaseClient._make_api_call', return_value={}).start()
storage_client_mock = MagicMock()
diff --git a/test/backend/app/test_user_app.py b/test/backend/app/test_user_app.py
index e26d335fd..3bfed784e 100644
--- a/test/backend/app/test_user_app.py
+++ b/test/backend/app/test_user_app.py
@@ -1,3 +1,5 @@
+import types
+import importlib.machinery
import pytest
from unittest.mock import patch, MagicMock, AsyncMock
import sys
@@ -7,7 +9,11 @@
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../backend"))
# Mock external dependencies
-sys.modules['boto3'] = MagicMock()
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
sys.modules['nexent'] = MagicMock()
sys.modules['nexent.core'] = MagicMock()
sys.modules['nexent.core.agents'] = MagicMock()
diff --git a/test/backend/app/test_user_management_app.py b/test/backend/app/test_user_management_app.py
index 30e8479dc..919df9523 100644
--- a/test/backend/app/test_user_management_app.py
+++ b/test/backend/app/test_user_management_app.py
@@ -1,3 +1,5 @@
+import types
+import importlib.machinery
import pytest
from unittest.mock import patch, MagicMock, AsyncMock
import unittest
@@ -8,8 +10,11 @@
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../backend"))
# Mock external dependencies
-sys.modules['boto3'] = MagicMock()
-
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
# Apply critical patches before importing any modules
# This prevents real AWS/MinIO/Elasticsearch calls during import
patch('botocore.client.BaseClient._make_api_call', return_value={}).start()
diff --git a/test/backend/app/test_vectordatabase_app.py b/test/backend/app/test_vectordatabase_app.py
index c65e8cb7c..cde9107a8 100644
--- a/test/backend/app/test_vectordatabase_app.py
+++ b/test/backend/app/test_vectordatabase_app.py
@@ -6,6 +6,8 @@
import os
import sys
import pytest
+import types
+import importlib.machinery
from unittest.mock import patch, MagicMock, ANY, AsyncMock
from fastapi.testclient import TestClient
from fastapi import FastAPI
@@ -20,10 +22,15 @@
# Environment variables are now configured in conftest.py
-boto3_mock = MagicMock()
+boto3_module = types.ModuleType("boto3")
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+boto3_module.client = MagicMock()
minio_client_mock = MagicMock()
-sys.modules['boto3'] = boto3_mock
-
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
# Patch storage factory and MinIO config validation to avoid errors during initialization
# These patches must be started before any imports that use MinioClient
storage_client_mock = MagicMock()
@@ -237,6 +244,25 @@ async def test_create_new_index_with_partial_group_permissions(vdb_core_mock, au
assert called_kwargs["group_ids"] is None
+@pytest.mark.asyncio
+async def test_create_new_index_with_multimodal_flag(vdb_core_mock, auth_data):
+ with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \
+ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
+ patch("backend.apps.vectordatabase_app.ElasticSearchService.create_knowledge_base") as mock_create:
+
+ mock_create.return_value = {"status": "success", "index_name": auth_data["index_name"]}
+
+ response = client.post(
+ f"/indices/{auth_data['index_name']}",
+ json={"is_multimodal": True},
+ headers=auth_data["auth_header"],
+ )
+
+ assert response.status_code == 200
+ called_kwargs = mock_create.call_args[1]
+ assert called_kwargs["is_multimodal"] is True
+
+
@pytest.mark.asyncio
async def test_create_new_index_error(vdb_core_mock, auth_data):
"""
@@ -634,8 +660,9 @@ async def test_create_index_documents_success(vdb_core_mock, auth_data):
# Setup mocks
with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \
patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.get_knowledge_record", return_value=None), \
- patch("backend.apps.vectordatabase_app.ElasticSearchService.index_documents") as mock_index:
+ patch("backend.apps.vectordatabase_app.get_knowledge_record", return_value={"is_multimodal": "N"}), \
+ patch("backend.apps.vectordatabase_app.ElasticSearchService.index_documents") as mock_index, \
+ patch("backend.apps.vectordatabase_app.get_embedding_model_by_id", return_value=MagicMock()):
index_name = "test_index"
documents = [{"id": 1, "text": "test doc"}]
@@ -652,11 +679,37 @@ async def test_create_index_documents_success(vdb_core_mock, auth_data):
response = client.post(
f"/indices/{index_name}/documents", json=documents, headers=auth_data["auth_header"])
- assert response.status_code == 200
- assert response.json() == expected_response.dict()
- mock_index.assert_called_once()
+ # Verify
+ assert response.status_code == 200
+ assert response.json() == expected_response.dict()
+ mock_index.assert_called_once()
+@pytest.mark.asyncio
+async def test_create_index_documents_uses_multimodal_embedding(vdb_core_mock, auth_data):
+ with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \
+ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
+ patch("backend.apps.vectordatabase_app.get_knowledge_record", return_value={"is_multimodal": "Y"}), \
+ patch("backend.apps.vectordatabase_app.get_embedding_model_by_id") as mock_get_embedding, \
+ patch("backend.apps.vectordatabase_app.ElasticSearchService.index_documents") as mock_index:
+
+ mock_get_embedding.return_value = MagicMock()
+ mock_index.return_value = IndexingResponse(
+ success=True,
+ message="Documents indexed successfully",
+ total_indexed=1,
+ total_submitted=1
+ )
+
+ response = client.post(
+ f"/indices/{auth_data['index_name']}/documents",
+ json=[{"id": 1, "text": "test doc"}],
+ headers=auth_data["auth_header"],
+ )
+
+ assert response.status_code == 200
+ mock_get_embedding.assert_not_called()
+
@pytest.mark.asyncio
async def test_create_index_documents_exception(vdb_core_mock, auth_data):
"""
@@ -666,11 +719,13 @@ async def test_create_index_documents_exception(vdb_core_mock, auth_data):
# Setup mocks
with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \
patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.get_knowledge_record", return_value=None), \
- patch("backend.apps.vectordatabase_app.ElasticSearchService.index_documents", side_effect=Exception("Indexing failed")):
+ patch("backend.apps.vectordatabase_app.get_knowledge_record", return_value={"is_multimodal": "N"}), \
+ patch("backend.apps.vectordatabase_app.ElasticSearchService.index_documents") as mock_index, \
+ patch("backend.apps.vectordatabase_app.get_embedding_model_by_id", return_value=MagicMock()):
index_name = "test_index"
documents = [{"id": 1, "text": "test doc"}]
+ mock_index.side_effect = Exception("Indexing failed")
response = client.post(
f"/indices/{index_name}/documents", json=documents, headers=auth_data["auth_header"])
@@ -749,8 +804,9 @@ async def test_create_index_documents_validation_exception(vdb_core_mock, auth_d
# Setup mocks
with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \
patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.get_knowledge_record", return_value=None), \
- patch("backend.apps.vectordatabase_app.ElasticSearchService.index_documents") as mock_index:
+ patch("backend.apps.vectordatabase_app.get_knowledge_record", return_value={"is_multimodal": "N"}), \
+ patch("backend.apps.vectordatabase_app.ElasticSearchService.index_documents") as mock_index, \
+ patch("backend.apps.vectordatabase_app.get_embedding_model_by_id", return_value=MagicMock()):
index_name = "test_index"
documents = [{"id": 1, "text": "test doc"}]
@@ -1331,7 +1387,8 @@ async def test_update_index_success(auth_data):
payload = {
"knowledge_name": "Updated Knowledge Base",
"ingroup_permission": "EDIT",
- "group_ids": [1, 2, 3]
+ "group_ids": [1, 2, 3],
+ "is_multimodal": True
}
response = client.patch(
f"/indices/{auth_data['index_name']}",
@@ -2215,21 +2272,34 @@ async def test_hybrid_search_exception(vdb_core_mock, auth_data):
# =============================================================================
@pytest.mark.asyncio
-async def test_create_index_documents_fallback_when_knowledge_record_not_found(vdb_core_mock, auth_data):
+async def test_create_index_documents_gets_saved_embedding_model_from_knowledge_record(vdb_core_mock, auth_data):
"""
- Test that create_index_documents handles case when knowledge record is not found.
- Verifies that get_embedding_model_by_id is not called when knowledge_record is None.
+ Test that create_index_documents retrieves the saved embedding model id from knowledge record.
+ Verifies that the endpoint calls get_knowledge_record to get the embedding_model_id.
"""
# Setup mocks
with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \
patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
patch("backend.apps.vectordatabase_app.ElasticSearchService.index_documents") as mock_index, \
- patch("backend.apps.vectordatabase_app.get_knowledge_record", return_value=None), \
+ patch("backend.apps.vectordatabase_app.get_knowledge_record") as mock_get_knowledge_record, \
patch("backend.apps.vectordatabase_app.get_embedding_model_by_id") as mock_get_embedding:
index_name = "test_index"
documents = [{"id": 1, "text": "test doc"}]
+ # Mock knowledge record with saved embedding model id
+ saved_model_id = 123
+ mock_get_knowledge_record.return_value = {
+ "index_name": index_name,
+ "embedding_model_id": saved_model_id,
+ "tenant_id": auth_data["tenant_id"]
+ }
+
+ # Mock embedding model
+ mock_embedding = MagicMock()
+ mock_get_embedding.return_value = (mock_embedding, saved_model_id)
+
+ # Mock index response
expected_response = {
"success": True,
"message": "Documents indexed successfully",
@@ -2238,23 +2308,32 @@ async def test_create_index_documents_fallback_when_knowledge_record_not_found(v
}
mock_index.return_value = expected_response
+ # Execute request
response = client.post(
f"/indices/{index_name}/documents", json=documents, headers=auth_data["auth_header"])
+ # Verify
assert response.status_code == 200
-
- mock_get_embedding.assert_not_called()
-
+
+ # Verify get_knowledge_record was called with correct index_name
+ mock_get_knowledge_record.assert_called_once_with({'index_name': index_name})
+
+ # Verify get_embedding_model_by_id was called with the saved model id
+ mock_get_embedding.assert_called_once_with(
+ auth_data["tenant_id"],
+ saved_model_id,
+ )
+
+ # Verify index_documents was called with the embedding model
mock_index.assert_called_once()
call_kwargs = mock_index.call_args[1]
- assert call_kwargs["embedding_model"] is None
+ assert call_kwargs["embedding_model"] == mock_embedding
@pytest.mark.asyncio
-async def test_create_index_documents_with_empty_string_model_name(vdb_core_mock, auth_data):
+async def test_create_index_documents_fallback_to_default_when_no_saved_model(vdb_core_mock, auth_data):
"""
- Test that create_index_documents handles empty/None embedding_model_id correctly.
- Empty or None model_id should result in no embedding model call.
+ Test that create_index_documents does not call embedding resolver when no saved model id.
"""
# Setup mocks
with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \
@@ -2266,12 +2345,14 @@ async def test_create_index_documents_with_empty_string_model_name(vdb_core_mock
index_name = "test_index"
documents = [{"id": 1, "text": "test doc"}]
+ # Mock knowledge record with no embedding_model_id (None)
mock_get_knowledge_record.return_value = {
"index_name": index_name,
"embedding_model_id": None,
"tenant_id": auth_data["tenant_id"]
}
+ # Mock index response
expected_response = {
"success": True,
"message": "Documents indexed successfully",
@@ -2280,474 +2361,222 @@ async def test_create_index_documents_with_empty_string_model_name(vdb_core_mock
}
mock_index.return_value = expected_response
+ # Execute request
response = client.post(
f"/indices/{index_name}/documents", json=documents, headers=auth_data["auth_header"])
+ # Verify
assert response.status_code == 200
+ # No saved model id means no embedding resolver call from app layer
mock_get_embedding.assert_not_called()
-# =============================================================================
-# Tests for get_embedding_model_status endpoint (lines 165-248)
-# =============================================================================
-
@pytest.mark.asyncio
-async def test_get_embedding_model_status_configured(auth_data):
+async def test_create_index_documents_fallback_when_knowledge_record_not_found(vdb_core_mock, auth_data):
"""
- Test get_embedding_model_status when model is configured with valid model_id.
- Covers lines 165-215: configured status case.
+ Test that create_index_documents handles case when knowledge record is not found.
+ Verifies that get_embedding_model_by_id is not called when knowledge_record is None.
"""
- with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.get_knowledge_record") as mock_get_record, \
- patch("backend.apps.vectordatabase_app.get_model_by_model_id") as mock_get_model:
-
- mock_get_record.return_value = {
- "index_name": "kb_test_uuid",
- "knowledge_name": "Test Knowledge Base",
- "embedding_model_id": 123,
- "embedding_model_name": "text-embedding-3-small"
- }
+ # Setup mocks
+ with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \
+ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
+ patch("backend.apps.vectordatabase_app.ElasticSearchService.index_documents") as mock_index, \
+ patch("backend.apps.vectordatabase_app.get_knowledge_record", return_value=None), \
+ patch("backend.apps.vectordatabase_app.get_embedding_model_by_id") as mock_get_embedding:
- mock_get_model.return_value = {
- "model_id": 123,
- "model_name": "text-embedding-3-small",
- "display_name": "Text Embedding 3 Small",
- "model_type": "embedding"
+ index_name = "test_index"
+ documents = [{"id": 1, "text": "test doc"}]
+
+ expected_response = {
+ "success": True,
+ "message": "Documents indexed successfully",
+ "total_indexed": 1,
+ "total_submitted": 1
}
+ mock_index.return_value = expected_response
- response = client.get(
- f"/indices/{auth_data['index_name']}/embedding-model-status",
- headers=auth_data["auth_header"]
- )
+ response = client.post(
+ f"/indices/{index_name}/documents", json=documents, headers=auth_data["auth_header"])
assert response.status_code == 200
- data = response.json()
- assert data["status"] == "configured"
- assert data["needs_config"] is False
- assert data["model_id"] == 123
- assert data["index_name"] == "kb_test_uuid"
- assert data["knowledge_name"] == "Test Knowledge Base"
- assert data["model_info"]["display_name"] == "Text Embedding 3 Small"
- assert "Embedding model" in data["message"]
+
+ mock_get_embedding.assert_not_called()
@pytest.mark.asyncio
-async def test_get_embedding_model_status_legacy(auth_data):
+async def test_create_index_documents_with_empty_string_model_name(vdb_core_mock, auth_data):
"""
- Test get_embedding_model_status when model_name exists but no model_id (legacy data).
- Covers lines 216-220: legacy status case.
+ Test that create_index_documents handles empty/None embedding_model_id correctly.
+ Empty or None model_id should result in no embedding model call.
"""
- with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.get_knowledge_record") as mock_get_record:
+ # Setup mocks
+ with patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock), \
+ patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
+ patch("backend.apps.vectordatabase_app.ElasticSearchService.index_documents") as mock_index, \
+ patch("backend.apps.vectordatabase_app.get_knowledge_record") as mock_get_knowledge_record, \
+ patch("backend.apps.vectordatabase_app.get_embedding_model_by_id") as mock_get_embedding:
- mock_get_record.return_value = {
- "index_name": auth_data["index_name"],
- "knowledge_name": "Legacy Knowledge Base",
+ index_name = "test_index"
+ documents = [{"id": 1, "text": "test doc"}]
+
+ mock_get_knowledge_record.return_value = {
+ "index_name": index_name,
"embedding_model_id": None,
- "embedding_model_name": "old-embedding-model"
+ "tenant_id": auth_data["tenant_id"]
}
-
- response = client.get(
- f"/indices/{auth_data['index_name']}/embedding-model-status",
- headers=auth_data["auth_header"]
- )
-
- assert response.status_code == 200
- data = response.json()
- assert data["status"] == "legacy"
- assert data["needs_config"] is True
- assert data["model_id"] is None
- assert data["embedding_model_name"] == "old-embedding-model"
- assert data["model_info"] is None
- assert "older version" in data["message"]
-
-
-@pytest.mark.asyncio
-async def test_get_embedding_model_status_missing(auth_data):
- """
- Test get_embedding_model_status when no model is configured at all.
- Covers lines 221-225: missing status case.
- """
- with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.get_knowledge_record") as mock_get_record:
-
- mock_get_record.return_value = {
- "index_name": auth_data["index_name"],
- "knowledge_name": "Missing Model KB",
- "embedding_model_id": None,
- "embedding_model_name": None
+
+ expected_response = {
+ "success": True,
+ "message": "Documents indexed successfully",
+ "total_indexed": 1,
+ "total_submitted": 1
}
+ mock_index.return_value = expected_response
- response = client.get(
- f"/indices/{auth_data['index_name']}/embedding-model-status",
- headers=auth_data["auth_header"]
- )
+ response = client.post(
+ f"/indices/{index_name}/documents", json=documents, headers=auth_data["auth_header"])
assert response.status_code == 200
- data = response.json()
- assert data["status"] == "missing"
- assert data["needs_config"] is True
- assert data["model_id"] is None
- assert data["embedding_model_name"] is None
- assert data["model_info"] is None
- assert "No embedding model configured" in data["message"]
-
-
-@pytest.mark.asyncio
-async def test_get_embedding_model_status_model_id_but_model_not_found(auth_data):
- """
- Test when model_id exists but model not found in database, but has embedding_model_name.
- Covers lines 200-220: model_id exists but model is None, falls to legacy status.
- """
- with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.get_knowledge_record") as mock_get_record, \
- patch("backend.apps.vectordatabase_app.get_model_by_model_id", return_value=None):
-
- mock_get_record.return_value = {
- "index_name": auth_data["index_name"],
- "knowledge_name": "Test KB",
- "embedding_model_id": 999,
- "embedding_model_name": "deleted-model"
- }
-
- response = client.get(
- f"/indices/{auth_data['index_name']}/embedding-model-status",
- headers=auth_data["auth_header"]
- )
- assert response.status_code == 200
- data = response.json()
- assert data["status"] == "legacy"
- assert data["needs_config"] is True
- assert data["model_id"] == 999
- assert data["embedding_model_name"] == "deleted-model"
- assert data["model_info"] is None
+ # Empty/None model id should skip embedding model resolution
+ mock_get_embedding.assert_not_called()
@pytest.mark.asyncio
-async def test_get_embedding_model_status_kb_not_found(auth_data):
- """
- Test get_embedding_model_status when knowledge base doesn't exist.
- Covers lines 189-193: knowledge_record is None.
- """
+async def test_update_summary_frequency_endpoint_success(vdb_core_mock, auth_data):
with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.get_knowledge_record", return_value=None):
-
- response = client.get(
- f"/indices/{auth_data['index_name']}/embedding-model-status",
- headers=auth_data["auth_header"]
+ patch("database.knowledge_db.update_summary_frequency", return_value=True):
+ response = client.patch(
+ f"/indices/{auth_data['index_name']}/summary_frequency",
+ json={"summary_frequency": "1d"},
+ headers=auth_data["auth_header"],
)
-
- assert response.status_code == 404
- assert "not found" in response.json()["detail"]
+ assert response.status_code == 200
+ assert response.json()["status"] == "success"
@pytest.mark.asyncio
-async def test_get_embedding_model_status_exception(auth_data):
- """
- Test exception handling in get_embedding_model_status.
- Covers lines 243-248: general exception handling.
- """
- with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.get_knowledge_record", side_effect=Exception("Database error")):
-
- response = client.get(
- f"/indices/{auth_data['index_name']}/embedding-model-status",
- headers=auth_data["auth_header"]
+async def test_update_summary_frequency_endpoint_invalid_value(auth_data):
+ with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])):
+ response = client.patch(
+ f"/indices/{auth_data['index_name']}/summary_frequency",
+ json={"summary_frequency": "bad"},
+ headers=auth_data["auth_header"],
)
-
- assert response.status_code == 500
- assert "Error checking embedding model status" in response.json()["detail"]
+ assert response.status_code == 400
@pytest.mark.asyncio
-async def test_get_embedding_model_status_http_exception_reraise(auth_data):
- """
- Test that HTTPException is re-raised without wrapping.
- Covers lines 241-242: HTTPException handling.
- """
- from fastapi import HTTPException
- from http import HTTPStatus
-
+async def test_get_embedding_model_status_configured(auth_data):
with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.get_knowledge_record") as mock_get_record:
-
- mock_get_record.side_effect = HTTPException(
- status_code=HTTPStatus.FORBIDDEN,
- detail="Access denied"
- )
-
- response = client.get(
- f"/indices/{auth_data['index_name']}/embedding-model-status",
- headers=auth_data["auth_header"]
- )
-
- assert response.status_code == 403
- assert "Access denied" in response.json()["detail"]
-
+ patch("backend.apps.vectordatabase_app.get_knowledge_record", return_value={
+ "index_name": "idx_internal",
+ "knowledge_name": "kb1",
+ "embedding_model_id": 7,
+ "embedding_model_name": "m1",
+ }), \
+ patch("backend.apps.vectordatabase_app.get_model_by_model_id", return_value={
+ "model_id": 7,
+ "model_name": "m1",
+ "display_name": "Model One",
+ "model_type": "embedding",
+ }):
+ response = client.get("/indices/idx_internal/embedding-model-status", headers=auth_data["auth_header"])
+ assert response.status_code == 200
+ body = response.json()
+ assert body["status"] == "configured"
+ assert body["needs_config"] is False
+ assert body["model_info"]["display_name"] == "Model One"
-# =============================================================================
-# Tests for update_embedding_model endpoint (lines 251-297)
-# =============================================================================
@pytest.mark.asyncio
-async def test_update_embedding_model_success(auth_data):
- """
- Test successful embedding model update.
- Covers lines 264-283: successful update case.
- """
+async def test_get_embedding_model_status_legacy_and_missing_and_not_found(auth_data):
with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.ElasticSearchService.update_embedding_model") as mock_update:
-
- mock_update.return_value = {
- "status": "success",
- "message": "Embedding model updated successfully",
- "model_id": 789
- }
-
- response = client.put(
- f"/indices/{auth_data['index_name']}/embedding-model",
- json={"model_id": 789},
- headers=auth_data["auth_header"]
- )
-
- assert response.status_code == 200
- data = response.json()
- assert data["status"] == "success"
-
- mock_update.assert_called_once_with(
- index_name=auth_data["index_name"],
- model_id=789,
- tenant_id=auth_data["tenant_id"],
- user_id=auth_data["user_id"]
- )
-
-
-@pytest.mark.asyncio
-async def test_update_embedding_model_missing_model_id(auth_data):
- """
- Test when model_id is not provided in request.
- Covers lines 266-271: model_id validation.
- """
- with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])):
-
- response = client.put(
- f"/indices/{auth_data['index_name']}/embedding-model",
- json={},
- headers=auth_data["auth_header"]
- )
-
- assert response.status_code == 400
- assert "model_id is required" in response.json()["detail"]
+ patch("backend.apps.vectordatabase_app.get_knowledge_record", return_value={
+ "index_name": "idx_legacy",
+ "knowledge_name": "kb_legacy",
+ "embedding_model_id": None,
+ "embedding_model_name": "legacy-name",
+ }):
+ legacy_resp = client.get("/indices/idx_legacy/embedding-model-status", headers=auth_data["auth_header"])
+ assert legacy_resp.status_code == 200
+ assert legacy_resp.json()["status"] == "legacy"
+ assert legacy_resp.json()["needs_config"] is True
-
-@pytest.mark.asyncio
-async def test_update_embedding_model_value_error(auth_data):
- """
- Test ValueError handling (knowledge base not found).
- Covers lines 285-289: ValueError exception handling.
- """
with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.ElasticSearchService.update_embedding_model", side_effect=ValueError("Knowledge base not found")):
-
- response = client.put(
- f"/indices/{auth_data['index_name']}/embedding-model",
- json={"model_id": 123},
- headers=auth_data["auth_header"]
- )
-
- assert response.status_code == 404
- assert "Knowledge base not found" in response.json()["detail"]
-
-
-@pytest.mark.asyncio
-async def test_update_embedding_model_http_exception_reraise(auth_data):
- """
- Test that HTTPException is re-raised without wrapping.
- Covers lines 290-291: HTTPException handling.
- """
- from fastapi import HTTPException
- from http import HTTPStatus
+ patch("backend.apps.vectordatabase_app.get_knowledge_record", return_value={
+ "index_name": "idx_missing",
+ "knowledge_name": "kb_missing",
+ "embedding_model_id": None,
+ "embedding_model_name": None,
+ }):
+ missing_resp = client.get("/indices/idx_missing/embedding-model-status", headers=auth_data["auth_header"])
+ assert missing_resp.status_code == 200
+ assert missing_resp.json()["status"] == "missing"
+ assert missing_resp.json()["needs_config"] is True
with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.ElasticSearchService.update_embedding_model", side_effect=HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail="Bad request")):
-
- response = client.put(
- f"/indices/{auth_data['index_name']}/embedding-model",
- json={"model_id": 123},
- headers=auth_data["auth_header"]
- )
-
- assert response.status_code == 400
- assert "Bad request" in response.json()["detail"]
+ patch("backend.apps.vectordatabase_app.get_knowledge_record", return_value=None):
+ not_found_resp = client.get("/indices/not-exist/embedding-model-status", headers=auth_data["auth_header"])
+ assert not_found_resp.status_code == 404
@pytest.mark.asyncio
-async def test_update_embedding_model_exception(auth_data):
- """
- Test general exception handling.
- Covers lines 292-297: general exception handling.
- """
+async def test_update_embedding_model_endpoint_branches(auth_data):
with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.ElasticSearchService.update_embedding_model", side_effect=Exception("Update failed")):
-
- response = client.put(
- f"/indices/{auth_data['index_name']}/embedding-model",
+ patch("backend.apps.vectordatabase_app.ElasticSearchService.update_embedding_model", return_value={"status": "success"}) as mock_update:
+ ok_resp = client.put(
+ "/indices/idx1/embedding-model",
json={"model_id": 123},
- headers=auth_data["auth_header"]
+ headers=auth_data["auth_header"],
)
+ assert ok_resp.status_code == 200
+ mock_update.assert_called_once()
- assert response.status_code == 500
- assert "Error updating embedding model" in response.json()["detail"]
-
-
-@pytest.mark.asyncio
-async def test_update_embedding_model_auth_exception(auth_data):
- """
- Test authentication exception handling.
- """
- with patch("backend.apps.vectordatabase_app.get_current_user_id", side_effect=Exception("Invalid auth token")):
-
- response = client.put(
- f"/indices/{auth_data['index_name']}/embedding-model",
- json={"model_id": 123},
- headers=auth_data["auth_header"]
+ with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])):
+ bad_resp = client.put(
+ "/indices/idx1/embedding-model",
+ json={},
+ headers=auth_data["auth_header"],
)
+ assert bad_resp.status_code == 400
- assert response.status_code == 500
- assert "Error updating embedding model" in response.json()["detail"]
-
-
-# =============================================================================
-# Tests for get_list_indices endpoint (lines 300-318)
-# =============================================================================
-
-@pytest.mark.asyncio
-async def test_get_list_indices_success_default_params(auth_data, vdb_core_mock):
- """
- Test get_list_indices with default parameters.
- Covers lines 300-315: successful listing with auth tenant_id.
- """
- with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.ElasticSearchService.list_indices") as mock_list, \
- patch("backend.apps.vectordatabase_app.get_vector_db_core") as mock_get_core:
-
- mock_get_core.return_value = vdb_core_mock
- mock_list.return_value = {
- "indices": [
- {"index_name": "kb_test1", "document_count": 100},
- {"index_name": "kb_test2", "document_count": 200}
- ]
- }
-
- response = client.get("/indices", headers=auth_data["auth_header"])
-
- assert response.status_code == 200
- data = response.json()
- assert "indices" in data
-
- mock_list.assert_called_once()
- call_args = mock_list.call_args[0]
- assert call_args[0] == "*"
- assert call_args[1] is False
- assert call_args[2] == auth_data["tenant_id"]
- assert call_args[3] == auth_data["user_id"]
-
-
-@pytest.mark.asyncio
-async def test_get_list_indices_with_pattern(auth_data, vdb_core_mock):
- """
- Test get_list_indices with custom pattern.
- Covers lines 302: pattern parameter.
- """
- with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.ElasticSearchService.list_indices") as mock_list, \
- patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock):
-
- mock_list.return_value = {"indices": []}
-
- response = client.get("/indices?pattern=kb_*", headers=auth_data["auth_header"])
-
- assert response.status_code == 200
-
- mock_list.assert_called_once()
- call_args = mock_list.call_args[0]
- assert call_args[0] == "kb_*"
-
-
-@pytest.mark.asyncio
-async def test_get_list_indices_with_stats(auth_data, vdb_core_mock):
- """
- Test get_list_indices with include_stats=True.
- Covers lines 303-304: include_stats parameter.
- """
with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.ElasticSearchService.list_indices") as mock_list, \
- patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock):
-
- mock_list.return_value = {
- "indices": [
- {"index_name": "kb_test", "document_count": 100, "stats": {"size": "10mb"}}
- ]
- }
-
- response = client.get("/indices?include_stats=true", headers=auth_data["auth_header"])
-
- assert response.status_code == 200
-
- mock_list.assert_called_once()
- call_args = mock_list.call_args[0]
- assert call_args[1] is True
-
+ patch("backend.apps.vectordatabase_app.ElasticSearchService.update_embedding_model", side_effect=ValueError("kb not found")):
+ nf_resp = client.put(
+ "/indices/idx1/embedding-model",
+ json={"model_id": 1},
+ headers=auth_data["auth_header"],
+ )
+ assert nf_resp.status_code == 404
-@pytest.mark.asyncio
-async def test_get_list_indices_with_explicit_tenant_id(auth_data, vdb_core_mock):
- """
- Test get_list_indices with explicit tenant_id parameter.
- Covers lines 305-306, 314: tenant_id parameter and effective_tenant_id logic.
- """
with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.ElasticSearchService.list_indices") as mock_list, \
- patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock):
-
- mock_list.return_value = {"indices": []}
-
- explicit_tenant = "explicit_tenant_123"
- response = client.get(f"/indices?tenant_id={explicit_tenant}", headers=auth_data["auth_header"])
-
- assert response.status_code == 200
-
- mock_list.assert_called_once()
- call_args = mock_list.call_args[0]
- assert call_args[2] == explicit_tenant
+ patch("backend.apps.vectordatabase_app.ElasticSearchService.update_embedding_model", side_effect=RuntimeError("boom")):
+ err_resp = client.put(
+ "/indices/idx1/embedding-model",
+ json={"model_id": 1},
+ headers=auth_data["auth_header"],
+ )
+ assert err_resp.status_code == 500
@pytest.mark.asyncio
-async def test_get_list_indices_exception(auth_data, vdb_core_mock):
- """
- Test exception handling in get_list_indices.
- Covers lines 316-318: general exception handling.
- """
- with patch("backend.apps.vectordatabase_app.get_current_user_id", return_value=(auth_data["user_id"], auth_data["tenant_id"])), \
- patch("backend.apps.vectordatabase_app.ElasticSearchService.list_indices", side_effect=Exception("Connection failed")), \
- patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock):
-
- response = client.get("/indices", headers=auth_data["auth_header"])
-
- assert response.status_code == 500
- assert "Error get index" in response.json()["detail"]
+async def test_get_document_error_info_regex_fallback(auth_data):
+ with patch("backend.apps.vectordatabase_app.get_all_files_status", new=AsyncMock(return_value={"docA": {"latest_task_id": "tid1"}})), \
+ patch("backend.apps.vectordatabase_app.get_redis_service") as mock_redis:
+ mock_redis.return_value.get_error_info.return_value = '{"bad":1, "error_code":"E123"'
+ response = client.get(f"/indices/i1/documents/docA/error-info", headers=auth_data["auth_header"])
+ assert response.status_code == 200
+ assert response.json()["error_code"] == "E123"
@pytest.mark.asyncio
-async def test_get_list_indices_auth_exception(auth_data, vdb_core_mock):
- """
- Test authentication exception handling.
- """
- with patch("backend.apps.vectordatabase_app.get_current_user_id", side_effect=Exception("Auth failed")), \
- patch("backend.apps.vectordatabase_app.get_vector_db_core", return_value=vdb_core_mock):
-
- response = client.get("/indices", headers=auth_data["auth_header"])
-
- assert response.status_code == 500
- assert "Error get index" in response.json()["detail"]
+async def test_get_document_error_info_regex_failure_returns_none(auth_data):
+ with patch("backend.apps.vectordatabase_app.get_all_files_status", new=AsyncMock(return_value={"docA": {"latest_task_id": "tid1"}})), \
+ patch("backend.apps.vectordatabase_app.get_redis_service") as mock_redis, \
+ patch("backend.apps.vectordatabase_app.re.search", side_effect=RuntimeError("regex boom")):
+ mock_redis.return_value.get_error_info.return_value = "not-json"
+ response = client.get(f"/indices/i1/documents/docA/error-info", headers=auth_data["auth_header"])
+ assert response.status_code == 200
+ assert response.json()["error_code"] is None
diff --git a/test/backend/data_process/test_ray_actors.py b/test/backend/data_process/test_ray_actors.py
index 48673e6c4..79a2f5bb9 100644
--- a/test/backend/data_process/test_ray_actors.py
+++ b/test/backend/data_process/test_ray_actors.py
@@ -53,6 +53,27 @@ def expire(self, key, seconds):
self.expirations[key] = seconds
+def make_temp_file(tmp_path, name: str, content: bytes = b"file-bytes") -> str:
+ path = tmp_path / name
+ path.write_bytes(content)
+ return str(path)
+
+
+def stub_consts(monkeypatch):
+ fake_consts_pkg = types.ModuleType("consts")
+ fake_consts_const = types.ModuleType("consts.const")
+ fake_consts_const.RAY_ACTOR_NUM_CPUS = 1
+ fake_consts_const.REDIS_BACKEND_URL = ""
+ # New defaults required by ray_actors import
+ fake_consts_const.DEFAULT_EXPECTED_CHUNK_SIZE = 1024
+ fake_consts_const.DEFAULT_MAXIMUM_CHUNK_SIZE = 1536
+ fake_consts_const.TABLE_TRANSFORMER_MODEL_PATH = "/models/table"
+ fake_consts_const.UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH = "/models/unstructured.json"
+ monkeypatch.setitem(sys.modules, "consts", fake_consts_pkg)
+ monkeypatch.setitem(sys.modules, "consts.const", fake_consts_const)
+ return fake_consts_const
+
+
@pytest.fixture(autouse=True)
def stub_ray_before_import(monkeypatch):
# Ensure that when module under test imports ray, it gets our stub
@@ -72,6 +93,11 @@ def import_module(monkeypatch):
fake_attachment_db_mod = types.ModuleType("database.attachment_db")
fake_attachment_db_mod.get_file_stream = lambda source: io.BytesIO(b"file-bytes")
fake_attachment_db_mod.get_file_size_from_minio = lambda path_or_url: 0
+ fake_attachment_db_mod.upload_fileobj = lambda file_obj, file_name, prefix=None, bucket=None: {
+ "success": True,
+ "object_name": f"{prefix}/{file_name}" if prefix else file_name,
+ }
+ fake_attachment_db_mod.build_s3_url = lambda object_name: f"s3://bucket/{object_name}"
monkeypatch.setitem(sys.modules, "database.attachment_db", fake_attachment_db_mod)
# Ensure parent package 'database' exists and link submodule for proper resolution
if "database" not in sys.modules:
@@ -133,15 +159,7 @@ class _Redis:
monkeypatch.setitem(sys.modules, "backend.data_process.tasks", fake_dp_tasks)
# Stub consts.const needed by ray_actors imports
- fake_consts_pkg = types.ModuleType("consts")
- fake_consts_const = types.ModuleType("consts.const")
- fake_consts_const.RAY_ACTOR_NUM_CPUS = 1
- fake_consts_const.REDIS_BACKEND_URL = ""
- # New defaults required by ray_actors import
- fake_consts_const.DEFAULT_EXPECTED_CHUNK_SIZE = 1024
- fake_consts_const.DEFAULT_MAXIMUM_CHUNK_SIZE = 1536
- monkeypatch.setitem(sys.modules, "consts", fake_consts_pkg)
- monkeypatch.setitem(sys.modules, "consts.const", fake_consts_const)
+ stub_consts(monkeypatch)
# Ensure model_management_db is stubbed to avoid importing real DB layer
if "database.model_management_db" not in sys.modules:
@@ -177,12 +195,13 @@ class _Redis:
return ray_actors
-def test_process_file_happy_path(monkeypatch):
+def test_process_file_happy_path(monkeypatch, tmp_path):
ray_actors = import_module(monkeypatch)
actor = ray_actors.DataProcessorRayActor()
+ source_path = make_temp_file(tmp_path, "a.txt")
chunks = actor.process_file(
- source="/tmp/a.txt",
+ source=source_path,
chunking_strategy="basic",
destination="local",
task_id="tid-1",
@@ -194,7 +213,7 @@ def test_process_file_happy_path(monkeypatch):
assert chunks[0]["content"] == "hello world"
-def test_process_file_applies_chunk_sizes_from_model(monkeypatch):
+def test_process_file_applies_chunk_sizes_from_model(monkeypatch, tmp_path):
ray_actors = import_module(monkeypatch)
# Recorder core to capture params
@@ -222,8 +241,9 @@ def file_process(self, file_data, filename, chunking_strategy, **params):
)
actor = ray_actors.DataProcessorRayActor()
+ source_path = make_temp_file(tmp_path, "a.txt")
actor.process_file(
- source="/tmp/a.txt",
+ source=source_path,
chunking_strategy="basic",
destination="local",
model_id=9,
@@ -233,9 +253,13 @@ def file_process(self, file_data, filename, chunking_strategy, **params):
assert RecorderCore.captured_params is not None
assert RecorderCore.captured_params.get("new_after_n_chars") == 2000
assert RecorderCore.captured_params.get("max_characters") == 3000
+ assert RecorderCore.captured_params.get("table_transformer_model_path") == "/models/table"
+ assert RecorderCore.captured_params.get(
+ "unstructured_default_model_initialize_params_json_path"
+ ) == "/models/unstructured.json"
-def test_process_file_no_model_omits_chunk_params(monkeypatch):
+def test_process_file_no_model_omits_chunk_params(monkeypatch, tmp_path):
ray_actors = import_module(monkeypatch)
class RecorderCore:
@@ -257,8 +281,9 @@ def file_process(self, file_data, filename, chunking_strategy, **params):
)
actor = ray_actors.DataProcessorRayActor()
+ source_path = make_temp_file(tmp_path, "b.txt")
actor.process_file(
- source="/tmp/b.txt",
+ source=source_path,
chunking_strategy="basic",
destination="local",
model_id=10,
@@ -268,9 +293,13 @@ def file_process(self, file_data, filename, chunking_strategy, **params):
assert RecorderCore.captured_params is not None
assert "new_after_n_chars" not in RecorderCore.captured_params
assert "max_characters" not in RecorderCore.captured_params
+ assert RecorderCore.captured_params.get("table_transformer_model_path") == "/models/table"
+ assert RecorderCore.captured_params.get(
+ "unstructured_default_model_initialize_params_json_path"
+ ) == "/models/unstructured.json"
-def test_process_file_model_lookup_exception_uses_defaults(monkeypatch):
+def test_process_file_model_lookup_exception_uses_defaults(monkeypatch, tmp_path):
ray_actors = import_module(monkeypatch)
class RecorderCore:
@@ -293,8 +322,9 @@ def file_process(self, file_data, filename, chunking_strategy, **params):
)
actor = ray_actors.DataProcessorRayActor()
+ source_path = make_temp_file(tmp_path, "c.txt")
actor.process_file(
- source="/tmp/c.txt",
+ source=source_path,
chunking_strategy="basic",
destination="local",
model_id=11,
@@ -304,6 +334,10 @@ def file_process(self, file_data, filename, chunking_strategy, **params):
assert RecorderCore.captured_params is not None
assert "new_after_n_chars" not in RecorderCore.captured_params
assert "max_characters" not in RecorderCore.captured_params
+ assert RecorderCore.captured_params.get("table_transformer_model_path") == "/models/table"
+ assert RecorderCore.captured_params.get(
+ "unstructured_default_model_initialize_params_json_path"
+ ) == "/models/unstructured.json"
def test_process_file_get_stream_none_raises(monkeypatch):
@@ -311,6 +345,8 @@ def test_process_file_get_stream_none_raises(monkeypatch):
fake_attachment_db_mod = types.ModuleType("database.attachment_db")
fake_attachment_db_mod.get_file_stream = lambda source: None
fake_attachment_db_mod.get_file_size_from_minio = lambda path_or_url: 0
+ fake_attachment_db_mod.upload_fileobj = lambda *a, **k: {"success": True, "object_name": "o"}
+ fake_attachment_db_mod.build_s3_url = lambda object_name: f"s3://bucket/{object_name}"
monkeypatch.setitem(sys.modules, "database.attachment_db", fake_attachment_db_mod)
# Ensure parent 'database' exists and link attachment_db
if "database" not in sys.modules:
@@ -371,15 +407,7 @@ class _Redis:
fake_dp_tasks.process_sync = lambda *a, **k: None
monkeypatch.setitem(sys.modules, "backend.data_process.tasks", fake_dp_tasks)
# Stub consts.const again for reload path
- fake_consts_pkg = types.ModuleType("consts")
- fake_consts_const = types.ModuleType("consts.const")
- fake_consts_const.RAY_ACTOR_NUM_CPUS = 1
- fake_consts_const.REDIS_BACKEND_URL = ""
- # Provide defaults required by backend.data_process.ray_actors import
- fake_consts_const.DEFAULT_EXPECTED_CHUNK_SIZE = 1024
- fake_consts_const.DEFAULT_MAXIMUM_CHUNK_SIZE = 1536
- monkeypatch.setitem(sys.modules, "consts", fake_consts_pkg)
- monkeypatch.setitem(sys.modules, "consts.const", fake_consts_const)
+ stub_consts(monkeypatch)
# Stub database.model_management_db and link to parent to avoid real DB import
if "database.model_management_db" not in sys.modules:
@@ -410,7 +438,7 @@ class _Redis:
actor.process_file("url://missing", "basic", destination="minio")
-def test_process_file_core_returns_none_list_variants(monkeypatch):
+def test_process_file_core_returns_none_list_variants(monkeypatch, tmp_path):
class CoreNone(FakeDataProcessCore):
def file_process(self, *a, **k):
return None
@@ -434,6 +462,8 @@ def file_process(self, *a, **k):
fake_attachment_db_mod = types.ModuleType("database.attachment_db")
fake_attachment_db_mod.get_file_stream = lambda source: io.BytesIO(b"file-bytes")
fake_attachment_db_mod.get_file_size_from_minio = lambda path_or_url: 0
+ fake_attachment_db_mod.upload_fileobj = lambda *a, **k: {"success": True, "object_name": "o"}
+ fake_attachment_db_mod.build_s3_url = lambda object_name: f"s3://bucket/{object_name}"
monkeypatch.setitem(sys.modules, "database.attachment_db", fake_attachment_db_mod)
# Also stub celery.result.AsyncResult and redis module
fake_celery = types.ModuleType("celery")
@@ -480,15 +510,7 @@ class _Redis:
fake_dp_tasks.process_sync = lambda *a, **k: None
monkeypatch.setitem(sys.modules, "backend.data_process.tasks", fake_dp_tasks)
# Stub consts.const for ray_actors imports
- fake_consts_pkg = types.ModuleType("consts")
- fake_consts_const = types.ModuleType("consts.const")
- fake_consts_const.RAY_ACTOR_NUM_CPUS = 1
- fake_consts_const.REDIS_BACKEND_URL = ""
- # Provide defaults required by backend.data_process.ray_actors import
- fake_consts_const.DEFAULT_EXPECTED_CHUNK_SIZE = 1024
- fake_consts_const.DEFAULT_MAXIMUM_CHUNK_SIZE = 1536
- monkeypatch.setitem(sys.modules, "consts", fake_consts_pkg)
- monkeypatch.setitem(sys.modules, "consts.const", fake_consts_const)
+ stub_consts(monkeypatch)
# Ensure model_management_db is stubbed to avoid importing real DB layer
if "database.model_management_db" not in sys.modules:
@@ -503,7 +525,8 @@ class _Redis:
import backend.data_process.ray_actors as ray_actors
reload(ray_actors)
actor = ray_actors.DataProcessorRayActor()
- chunks = actor.process_file("/tmp/a.txt", "basic", destination="local")
+ source_path = make_temp_file(tmp_path, f"a_{core_cls.__name__}.txt")
+ chunks = actor.process_file(source_path, "basic", destination="local")
assert chunks == []
@@ -548,6 +571,59 @@ def test_store_chunks_in_redis_no_url_returns_false(monkeypatch):
assert actor.store_chunks_in_redis("k", [{"content": "x"}]) is False
+def test_process_file_appends_image_chunks(monkeypatch, tmp_path):
+ ray_actors = import_module(monkeypatch)
+
+ class CoreWithImages:
+ def file_process(self, *a, **k):
+ return (
+ [{"content": "text", "metadata": {}}],
+ [
+ {
+ "image_bytes": b"img",
+ "image_format": "png",
+ "position": {"page_number": 1},
+ }
+ ],
+ )
+
+ monkeypatch.setattr(ray_actors, "DataProcessCore", CoreWithImages)
+ monkeypatch.setattr(
+ ray_actors,
+ "upload_fileobj",
+ lambda file_obj, file_name, prefix=None: {"object_name": f"{prefix}/{file_name}"},
+ )
+ monkeypatch.setattr(
+ ray_actors,
+ "build_s3_url",
+ lambda object_name: f"s3://bucket/{object_name}",
+ )
+
+ actor = ray_actors.DataProcessorRayActor()
+ source_path = make_temp_file(tmp_path, "a.pdf", content=b"%PDF-1.4")
+ chunks = actor.process_file(source_path, "basic", destination="local")
+
+ assert len(chunks) == 2
+ assert chunks[1]["metadata"]["process_source"] == "UniversalImageExtractor"
+ assert "image_url" in chunks[1]["metadata"]
+
+
+def test_process_file_skips_invalid_image_entries(monkeypatch, tmp_path):
+ ray_actors = import_module(monkeypatch)
+
+ class CoreWithBadImages:
+ def file_process(self, *a, **k):
+ return (
+ [{"content": "text", "metadata": {}}],
+ [{"not": "dict"}, {"image_format": "png"}],
+ )
+
+ monkeypatch.setattr(ray_actors, "DataProcessCore", CoreWithBadImages)
+ actor = ray_actors.DataProcessorRayActor()
+ source_path = make_temp_file(tmp_path, "a.pdf", content=b"%PDF-1.4")
+ chunks = actor.process_file(source_path, "basic", destination="local")
+
+ assert chunks == [{"content": "text", "metadata": {}}]
def test_process_bytes_and_split_file_branches(monkeypatch):
ray_actors = import_module(monkeypatch)
@@ -600,3 +676,43 @@ def __len__(self):
monkeypatch.setitem(sys.modules, "redis", bad_redis_module)
assert actor.store_chunks_in_redis("k-err", [{"a": 1}]) is False
+
+def test_apply_model_chunk_sizes_and_read_file_bytes_helpers(monkeypatch):
+ ray_actors = import_module(monkeypatch)
+ actor = ray_actors.DataProcessorRayActor()
+
+ monkeypatch.setattr(
+ ray_actors,
+ "get_model_by_model_id",
+ lambda model_id, tenant_id=None: {
+ "expected_chunk_size": 111,
+ "maximum_chunk_size": 222,
+ "display_name": "emb",
+ "model_type": "embedding",
+ },
+ )
+ params = {}
+ actor._apply_model_chunk_sizes(1, "t1", params)
+ assert params["new_after_n_chars"] == 111
+ assert params["max_characters"] == 222
+ assert params["model_type"] == "embedding"
+
+ monkeypatch.setattr(ray_actors, "get_file_stream", lambda source: io.BytesIO(b"bytes"))
+ assert actor._read_file_bytes("s3://x") == b"bytes"
+
+ monkeypatch.setattr(ray_actors, "get_file_stream", lambda source: None)
+ with pytest.raises(FileNotFoundError):
+ actor._read_file_bytes("s3://missing")
+
+
+def test_split_file_returns_empty_when_no_parts(monkeypatch):
+ ray_actors = import_module(monkeypatch)
+
+ class CoreNoParts(FakeDataProcessCore):
+ def file_split(self, *a, **k):
+ return []
+
+ monkeypatch.setattr(ray_actors, "DataProcessCore", CoreNoParts)
+ actor = ray_actors.DataProcessorRayActor()
+ assert actor.split_file("x.txt", "local", file_data=b"abc") == []
+
diff --git a/test/backend/data_process/test_tasks.py b/test/backend/data_process/test_tasks.py
index b368a7a8b..379989581 100644
--- a/test/backend/data_process/test_tasks.py
+++ b/test/backend/data_process/test_tasks.py
@@ -144,6 +144,8 @@ def decorator(func):
const_mod.DEFAULT_EXPECTED_CHUNK_SIZE = 1024
const_mod.DEFAULT_MAXIMUM_CHUNK_SIZE = 1536
const_mod.ROOT_DIR = "/mock/root"
+ const_mod.TABLE_TRANSFORMER_MODEL_PATH = "/mock/table_transformer_model"
+ const_mod.UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH = "/mock/unstructured_params.json"
sys.modules["consts.const"] = const_mod
# Minimal stub for consts.model used by utils.file_management_utils
if "consts.model" not in sys.modules:
@@ -161,6 +163,8 @@ def __init__(self, chunking_strategy: str, source_type: str, index_name: str, au
sys.modules["database.attachment_db"] = types.SimpleNamespace(
get_file_stream=lambda source: io.BytesIO(b"stub-bytes"),
get_file_size_from_minio=lambda object_name, bucket=None: 0,
+ build_s3_url=lambda bucket_name, object_name: f"http://mock-s3/{bucket_name}/{object_name}", # NOSONAR
+ upload_fileobj=lambda file_obj, bucket_name, object_name: "mock-etag",
)
# Stub model_management_db module required by ray_actors
if "database.model_management_db" not in sys.modules:
@@ -2001,6 +2005,28 @@ def run_until_complete(self, _c):
assert tasks.run_async(asyncio.sleep(0)) == "ok"
+def test_run_async_running_loop_without_nest_asyncio_fallback_thread(monkeypatch):
+ tasks, _ = import_tasks_with_fake_ray(monkeypatch)
+
+ class FakeLoop:
+ def is_running(self):
+ return True
+
+ monkeypatch.setattr(asyncio, "get_running_loop", lambda: FakeLoop())
+ sys.modules.pop("nest_asyncio", None)
+
+ import builtins
+ real_import = builtins.__import__
+
+ def fake_import(name, *args, **kwargs):
+ if name == "nest_asyncio":
+ raise ImportError("no nest_asyncio")
+ return real_import(name, *args, **kwargs)
+
+ monkeypatch.setattr(builtins, "__import__", fake_import)
+ assert tasks.run_async(asyncio.sleep(0, result="thread-ok")) == "thread-ok"
+
+
def test_global_pool_manager_paths(monkeypatch):
tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch)
@@ -2198,3 +2224,90 @@ def test_aggregate_forward_parts_paths(monkeypatch):
)
assert out["success"] is True
assert out["total_indexed"] == 5
+
+
+def test_run_processing_for_parts_single_and_multi(monkeypatch):
+ tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch)
+
+ class Actor:
+ def __init__(self):
+ self.process_file = types.SimpleNamespace(remote=lambda *a, **k: "ref-file")
+ self.process_bytes = types.SimpleNamespace(remote=lambda *a, **k: "ref-bytes")
+
+ monkeypatch.setattr(tasks, "get_ray_actor", lambda: Actor())
+ fake_ray.get_returns = {"ref-bytes": [{"content": "c1"}], "ref-file": [{"content": "cf"}]}
+
+ split_async, chunks, split_chunk_count = tasks._run_processing_for_parts(
+ request_id="r1",
+ source="/a.txt",
+ source_type="local",
+ task_id="t1",
+ chunking_strategy="basic",
+ filename_for_processing="a.txt",
+ parts=[b"one"],
+ index_name="idx",
+ original_filename="a.txt",
+ embedding_model_id=1,
+ tenant_id="tenant",
+ params={},
+ )
+ assert split_async is False
+ assert chunks == [{"content": "c1"}]
+ assert split_chunk_count is None
+
+ captured = {}
+ monkeypatch.setattr(tasks, "process_part", types.SimpleNamespace(s=lambda **kwargs: types.SimpleNamespace(kwargs=kwargs)))
+ monkeypatch.setattr(tasks, "aggregate_store_chunks", types.SimpleNamespace(s=lambda **kwargs: types.SimpleNamespace(set=lambda **kw: {"kwargs": kwargs, "set": kw})))
+ monkeypatch.setattr(tasks, "group", lambda gen: list(gen))
+ monkeypatch.setattr(tasks, "chord", lambda group_tasks: (lambda callback: captured.update({"group": group_tasks, "callback": callback})))
+ monkeypatch.setattr(tasks, "_compute_split_wait_timeout", lambda n: 9)
+ monkeypatch.setattr(tasks, "_estimate_parallel_parts", lambda: 2)
+ monkeypatch.setattr(tasks, "_wait_for_split_ready", lambda **kwargs: 6)
+
+ split_async2, chunks2, split_chunk_count2 = tasks._run_processing_for_parts(
+ request_id="r2",
+ source="/b.txt",
+ source_type="local",
+ task_id="t2",
+ chunking_strategy="basic",
+ filename_for_processing="b.txt",
+ parts=[b"a", b"b", b"c"],
+ index_name="idx",
+ original_filename="b.txt",
+ embedding_model_id=1,
+ tenant_id="tenant",
+ params={"x": 1},
+ )
+ assert split_async2 is True
+ assert chunks2 is None
+ assert split_chunk_count2 == 6
+ assert len(captured["group"]) == 3
+
+
+def test_process_split_async_redis_image_metadata_count(monkeypatch, tmp_path):
+ tasks, _ = import_tasks_with_fake_ray(monkeypatch)
+ monkeypatch.setattr(tasks, "REDIS_BACKEND_URL", "redis://test")
+ monkeypatch.setattr(tasks, "_process_source_with_split", lambda **kwargs: (True, None, 2))
+ monkeypatch.setattr(tasks, "_count_image_metadata_chunks", lambda chunks: 1)
+
+ class FakeRedisClient:
+ def get(self, key):
+ return json.dumps([{"metadata": {"content_type": "image"}}, {"metadata": {}}])
+
+ monkeypatch.setitem(sys.modules, "redis", types.SimpleNamespace(Redis=types.SimpleNamespace(from_url=lambda *a, **k: FakeRedisClient())))
+
+ f = tmp_path / "x.txt"
+ f.write_text("hello")
+ self = FakeSelf("proc-async-1")
+ out = tasks.process(
+ self,
+ source=str(f),
+ source_type="local",
+ chunking_strategy="basic",
+ index_name="idx",
+ original_filename="x.txt",
+ )
+ assert out["split_async"] is True
+ assert out["image_metadata_chunk_count"] == 1
+ success_state = [s for s in self.states if s.get("state") == tasks.states.SUCCESS][0]
+ assert success_state["meta"]["chunks_count"] == 2
diff --git a/test/backend/database/test_attachment_db.py b/test/backend/database/test_attachment_db.py
index 6899171eb..47e5ccbe5 100644
--- a/test/backend/database/test_attachment_db.py
+++ b/test/backend/database/test_attachment_db.py
@@ -17,6 +17,8 @@
# Mock consts module
consts_mock = MagicMock()
consts_mock.const = MagicMock()
+# Ensure constants are real strings to avoid startswith TypeError
+consts_mock.const.S3_URL_PREFIX = "s3://"
# Environment variables are now configured in conftest.py
sys.modules['consts'] = consts_mock
@@ -51,6 +53,8 @@
minio_client_mock = MagicMock()
minio_client_mock.storage_config = MagicMock()
minio_client_mock.storage_config.default_bucket = 'test-bucket'
+# Current attachment_db uses minio_client.default_bucket directly.
+minio_client_mock.default_bucket = 'test-bucket'
client_mock = MagicMock()
client_mock.minio_client = minio_client_mock
sys.modules['database'] = MagicMock()
@@ -73,7 +77,9 @@
get_file_stream,
get_file_stream_raw,
get_file_range,
- get_content_type
+ get_content_type,
+ build_s3_url,
+ _normalize_object_and_bucket
)
@@ -864,4 +870,46 @@ def test_returns_none_on_failure(self):
result = get_file_stream_raw('missing/doc.pdf')
assert result is None
+class TestS3UrlHelpers:
+ """Test cases for S3 URL helpers and normalization."""
+
+ def test_normalize_object_and_bucket_s3_url(self):
+ object_name, bucket = _normalize_object_and_bucket("s3://my-bucket/path/to/file.txt")
+ assert object_name == "path/to/file.txt"
+ assert bucket == "my-bucket"
+
+ def test_normalize_object_and_bucket_slash_path(self):
+ object_name, bucket = _normalize_object_and_bucket("/my-bucket/path/to/file.txt")
+ assert object_name == "path/to/file.txt"
+ assert bucket == "my-bucket"
+
+ def test_build_s3_url_passthrough(self):
+ assert build_s3_url("s3://bucket/key") == "s3://bucket/key"
+
+ def test_build_s3_url_from_path(self):
+ assert build_s3_url("/bucket/key") == "s3://bucket/key"
+
+ def test_build_s3_url_from_object(self):
+ assert build_s3_url("attachments/file.txt") == "s3://test-bucket/attachments/file.txt"
+
+
+def test_get_file_stream_normalizes_s3_url():
+ minio_client_mock.get_file_stream.reset_mock()
+ mock_stream = BytesIO(b"test data")
+ minio_client_mock.get_file_stream.return_value = (True, mock_stream)
+
+ result = get_file_stream("s3://test-bucket/attachments/test.txt")
+
+ assert isinstance(result, BytesIO)
+ minio_client_mock.get_file_stream.assert_called_once_with("attachments/test.txt", "test-bucket")
+
+
+def test_delete_file_normalizes_s3_url():
+ minio_client_mock.delete_file.reset_mock()
+ minio_client_mock.delete_file.return_value = (True, "Deleted successfully")
+
+ result = delete_file("s3://test-bucket/attachments/test.txt")
+
+ assert result["success"] is True
+ minio_client_mock.delete_file.assert_called_once_with("attachments/test.txt", "test-bucket")
diff --git a/test/backend/database/test_client.py b/test/backend/database/test_client.py
index 87482e07e..e100f4373 100644
--- a/test/backend/database/test_client.py
+++ b/test/backend/database/test_client.py
@@ -67,7 +67,8 @@
minio_client,
get_db_session,
as_dict,
- filter_property
+ filter_property,
+ get_monitoring_db_session,
)
@@ -623,3 +624,52 @@ def test_filter_property_no_matching_fields(self):
result = filter_property(data, mock_model)
assert result == {}
+
+
+class TestAdditionalCoverage:
+ def test_minio_default_bucket_fallback_on_init_error(self, mocker):
+ MinioClient._instance = None
+ MinioClient._initialized = False
+ client = MinioClient()
+ client.storage_config = None
+ mocker.patch.object(client, "_ensure_initialized", side_effect=RuntimeError("x"))
+ mocker.patch("backend.database.client.MINIO_DEFAULT_BUCKET", "fallback-bucket")
+ assert client.default_bucket == "fallback-bucket"
+
+ def test_as_dict_for_sqlalchemy_object_and_mapping(self, mocker):
+ from datetime import datetime
+ dt = datetime(2025, 1, 1, 0, 0, 0)
+
+ class Obj:
+ __mapper__ = object()
+ created = dt
+ name = "n1"
+
+ mock_col_created = MagicMock()
+ mock_col_created.key = "created"
+ mock_col_name = MagicMock()
+ mock_col_name.key = "name"
+ mocker.patch("backend.database.client.class_mapper", return_value=MagicMock(columns=[mock_col_created, mock_col_name]))
+ orm_result = as_dict(Obj())
+ assert orm_result["created"] == dt.isoformat()
+ assert orm_result["name"] == "n1"
+
+ mapping_obj = MagicMock()
+ mapping_obj._mapping = {"a": 1}
+ assert as_dict(mapping_obj) == {"a": 1}
+
+ def test_get_monitoring_db_session_paths(self, mocker):
+ mock_session = MagicMock()
+ mocker.patch("backend.database.client._get_monitoring_engine")
+ mocker.patch("backend.database.client._monitoring_session_maker", MagicMock(return_value=mock_session))
+
+ with get_monitoring_db_session() as s:
+ assert s is mock_session
+ mock_session.commit.assert_called_once()
+ mock_session.close.assert_called_once()
+
+ provided = MagicMock()
+ with pytest.raises(ValueError):
+ with get_monitoring_db_session(provided):
+ raise ValueError("boom")
+ provided.rollback.assert_not_called()
diff --git a/test/backend/database/test_knowledge_db.py b/test/backend/database/test_knowledge_db.py
index 9205c0280..1af3e103a 100644
--- a/test/backend/database/test_knowledge_db.py
+++ b/test/backend/database/test_knowledge_db.py
@@ -6,6 +6,7 @@
import sys
import os
import types
+import importlib.machinery
from datetime import datetime
from unittest.mock import MagicMock, patch, call
import pytest
@@ -18,16 +19,21 @@
sys.path.insert(0, backend_dir)
# Patch boto3 and other dependencies before importing anything from backend
-boto3_mock = MagicMock()
-sys.modules['boto3'] = boto3_mock
+boto3_module = types.ModuleType("boto3")
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+boto3_module.client = MagicMock()
+boto3_module.exceptions = MagicMock()
+sys.modules["boto3"] = boto3_module
# Mock botocore before patching it
botocore_mock = MagicMock()
botocore_client_mock = MagicMock()
+botocore_exceptions_mock = MagicMock()
botocore_client_mock.BaseClient = MagicMock()
botocore_client_mock.BaseClient._make_api_call = MagicMock()
sys.modules['botocore'] = botocore_mock
sys.modules['botocore.client'] = botocore_client_mock
+sys.modules['botocore.exceptions'] = botocore_exceptions_mock
# Apply critical patches before importing any modules
# This prevents real AWS/MinIO/Elasticsearch calls during import
@@ -43,10 +49,34 @@
minio_config_mock = MagicMock()
minio_config_mock.validate = MagicMock()
-# Mock backend.database.client before patching it
-backend_database_client_mock = MagicMock()
-backend_database_client_mock.MinioClient = MagicMock(return_value=minio_client_mock)
-sys.modules['backend.database.client'] = backend_database_client_mock
+# Import backend modules after all patches are applied
+# Use additional context manager to ensure MinioClient is properly mocked during import
+with patch('backend.database.client.MinioClient', return_value=minio_client_mock), \
+ patch('nexent.storage.minio_config.MinIOStorageConfig', return_value=minio_config_mock):
+ from backend.database.knowledge_db import (
+ create_knowledge_record,
+ update_knowledge_record,
+ delete_knowledge_record,
+ get_knowledge_record,
+ get_knowledge_info_by_knowledge_ids,
+ get_knowledge_ids_by_index_names,
+ get_knowledge_info_by_tenant_id,
+ update_model_name_by_index_name,
+ get_index_name_by_knowledge_name,
+ get_knowledge_info_by_tenant_and_source,
+ upsert_knowledge_record,
+ _generate_index_name,
+ get_knowledge_name_map_by_index_names,
+ update_summary_frequency,
+ update_last_summary_time,
+ update_last_doc_update_time,
+ get_knowledge_bases_for_auto_summary,
+ )
+
+
+# Add project root to Python path
+sys.path.insert(0, os.path.abspath(os.path.join(
+ os.path.dirname(__file__), '..', '..', '..')))
# Mock consts module to use conftest environment variables
consts_mock = MagicMock()
@@ -92,7 +122,14 @@
sys.modules['utils.auth_utils'] = utils_mock.auth_utils
sys.modules['utils.str_utils'] = utils_mock.str_utils
-# Mock sqlalchemy module before importing backend modules
+# Provide a stub for the `boto3` module so that it can be imported safely even
+# if the testing environment does not have it available.
+boto3_module = types.ModuleType("boto3")
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+boto3_module.client = MagicMock()
+sys.modules["boto3"] = boto3_module
+
+# Mock sqlalchemy module
sqlalchemy_mock = MagicMock()
sqlalchemy_mock.func = MagicMock()
sqlalchemy_mock.func.current_timestamp = MagicMock(return_value="2023-01-01 00:00:00")
@@ -176,27 +213,32 @@ def mock_session():
return mock_session, mock_query
-def test_create_knowledge_record_success(monkeypatch, mock_session):
- """Test successful creation of knowledge record"""
- session, _ = mock_session
-
- # Create mock knowledge record
- mock_record = MockKnowledgeRecord(knowledge_name="test_knowledge")
- mock_record.knowledge_id = 123
- mock_record.index_name = "test_knowledge"
-
- # Mock database session context
+def setup_mock_db_session(monkeypatch, session):
mock_ctx = MagicMock()
mock_ctx.__enter__.return_value = session
- # Mock the context manager to call rollback on exception, like the real get_db_session does
def mock_exit(exc_type, exc_val, exc_tb):
if exc_type is not None:
session.rollback()
return None # Don't suppress the exception
+
mock_ctx.__exit__.side_effect = mock_exit
monkeypatch.setattr(
"backend.database.knowledge_db.get_db_session", lambda: mock_ctx)
+ return mock_ctx
+
+
+def test_create_knowledge_record_success(monkeypatch, mock_session):
+ """Test successful creation of knowledge record"""
+ session, _ = mock_session
+
+ # Create mock knowledge record
+ mock_record = MockKnowledgeRecord(knowledge_name="test_knowledge")
+ mock_record.knowledge_id = 123
+ mock_record.index_name = "test_knowledge"
+
+ # Mock database session context
+ setup_mock_db_session(monkeypatch, session)
# Prepare test data
test_query = {
@@ -277,22 +319,34 @@ def mock_exit(exc_type, exc_val, exc_tb):
session.commit.assert_called_once()
+def test_create_knowledge_record_sets_multimodal_flag(monkeypatch, mock_session):
+ session, _ = mock_session
+ mock_record = MockKnowledgeRecord(knowledge_name="test_knowledge")
+ mock_record.knowledge_id = 123
+ mock_record.index_name = "test_knowledge"
+
+ setup_mock_db_session(monkeypatch, session)
+
+ test_query = {
+ "index_name": "test_knowledge",
+ "knowledge_describe": "Test knowledge description",
+ "user_id": "test_user",
+ "tenant_id": "test_tenant",
+ "embedding_model_name": "test_model",
+ "knowledge_name": "test_knowledge",
+ }
+
+ with patch('backend.database.knowledge_db.KnowledgeRecord', return_value=mock_record) as mock_constructor:
+ _ = create_knowledge_record(test_query)
+
+
+
def test_create_knowledge_record_exception(monkeypatch, mock_session):
"""Test exception during knowledge record creation"""
session, _ = mock_session
session.add.side_effect = MockSQLAlchemyError("Database error")
- mock_ctx = MagicMock()
- mock_ctx.__enter__.return_value = session
- # Mock the context manager to call rollback on exception, like the real get_db_session does
-
- def mock_exit(exc_type, exc_val, exc_tb):
- if exc_type is not None:
- session.rollback()
- return None # Don't suppress the exception
- mock_ctx.__exit__.side_effect = mock_exit
- monkeypatch.setattr(
- "backend.database.knowledge_db.get_db_session", lambda: mock_ctx)
+ setup_mock_db_session(monkeypatch, session)
test_query = {
"index_name": "test_knowledge",
@@ -317,17 +371,7 @@ def test_create_knowledge_record_generates_index_name(monkeypatch, mock_session)
mock_record = MockKnowledgeRecord(knowledge_name="kb1")
mock_record.knowledge_id = 7
- mock_ctx = MagicMock()
- mock_ctx.__enter__.return_value = session
- # Mock the context manager to call rollback on exception, like the real get_db_session does
-
- def mock_exit(exc_type, exc_val, exc_tb):
- if exc_type is not None:
- session.rollback()
- return None # Don't suppress the exception
- mock_ctx.__exit__.side_effect = mock_exit
- monkeypatch.setattr(
- "backend.database.knowledge_db.get_db_session", lambda: mock_ctx)
+ setup_mock_db_session(monkeypatch, session)
# Deterministic index name
monkeypatch.setattr(
@@ -442,6 +486,36 @@ def mock_exit(exc_type, exc_val, exc_tb):
session.commit.assert_called_once()
+def test_update_knowledge_record_sets_multimodal(monkeypatch, mock_session):
+ session, query = mock_session
+ mock_record = MockKnowledgeRecord()
+
+ mock_filter = MagicMock()
+ mock_filter.first.return_value = mock_record
+ query.filter.return_value = mock_filter
+
+ mock_ctx = MagicMock()
+ mock_ctx.__enter__.return_value = session
+
+ def mock_exit(exc_type, exc_val, exc_tb):
+ if exc_type is not None:
+ session.rollback()
+ return None
+
+ mock_ctx.__exit__.side_effect = mock_exit
+ monkeypatch.setattr(
+ "backend.database.knowledge_db.get_db_session", lambda: mock_ctx)
+
+ test_query = {
+ "index_name": "test_knowledge",
+ "is_multimodal": True,
+ }
+
+ result = update_knowledge_record(test_query)
+
+ assert result is True
+
+
def test_update_knowledge_record_partial_update(monkeypatch, mock_session):
"""Test partial update - only updating name and permission"""
session, query = mock_session
@@ -1469,6 +1543,32 @@ def mock_exit(exc_type, exc_val, exc_tb):
assert result == {}
+def test_get_knowledge_record_filters_multimodal(monkeypatch, mock_session):
+ session, query = mock_session
+ mock_filter = MagicMock()
+ mock_filter.first.return_value = MockKnowledgeRecord()
+ query.filter.return_value = mock_filter
+
+ mock_ctx = MagicMock()
+ mock_ctx.__enter__.return_value = session
+
+ def mock_exit(exc_type, exc_val, exc_tb):
+ if exc_type is not None:
+ session.rollback()
+ return None
+
+ mock_ctx.__exit__.side_effect = mock_exit
+ monkeypatch.setattr(
+ "backend.database.knowledge_db.get_db_session", lambda: mock_ctx)
+
+ monkeypatch.setattr(
+ "backend.database.knowledge_db.as_dict", lambda x: {"knowledge_id": 1})
+
+ _ = get_knowledge_record({"index_name": "test_index", "is_multimodal": "Y"})
+
+ assert query.filter.called
+
+
def test_get_knowledge_info_by_knowledge_ids_empty_list(monkeypatch, mock_session):
"""Test get_knowledge_info_by_knowledge_ids with empty list"""
session, query = mock_session
@@ -2070,95 +2170,98 @@ def mock_exit(exc_type, exc_val, exc_tb):
get_knowledge_name_map_by_index_names(["index1", "index2"])
-def test_update_embedding_model_by_index_name_success(monkeypatch, mock_session):
- """Test successfully updating embedding model by index name"""
+def test_get_index_name_by_knowledge_name_fallback_to_index_name(monkeypatch, mock_session):
session, query = mock_session
-
- mock_update = MagicMock(return_value=1)
mock_filter = MagicMock()
- mock_filter.update = mock_update
+ mock_filter.first.side_effect = [None, MockKnowledgeRecord(index_name="idx-1")]
query.filter.return_value = mock_filter
mock_ctx = MagicMock()
mock_ctx.__enter__.return_value = session
+ mock_ctx.__exit__.return_value = None
+ monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx)
- def mock_exit(exc_type, exc_val, exc_tb):
- if exc_type is not None:
- session.rollback()
- return None
- mock_ctx.__exit__.side_effect = mock_exit
- monkeypatch.setattr(
- "backend.database.knowledge_db.get_db_session", lambda: mock_ctx)
-
- result = update_embedding_model_by_index_name(
- "test_index", 123, "new_model", "tenant1", "user1"
- )
-
- assert result is True
- mock_update.assert_called_once_with({
- "embedding_model_id": 123,
- "embedding_model_name": "new_model",
- "updated_by": "user1"
- })
- session.commit.assert_called_once()
+ result = get_index_name_by_knowledge_name("idx-1", "tenant1")
+ assert result == "idx-1"
-def test_update_embedding_model_by_index_name_no_match(monkeypatch, mock_session):
- """Test updating embedding model when no matching record is found"""
+def test_update_summary_frequency_paths(monkeypatch, mock_session):
session, query = mock_session
-
- mock_update = MagicMock(return_value=0)
+ rec = MockKnowledgeRecord(index_name="idx-1")
mock_filter = MagicMock()
- mock_filter.update = mock_update
+ mock_filter.first.return_value = rec
query.filter.return_value = mock_filter
-
mock_ctx = MagicMock()
mock_ctx.__enter__.return_value = session
+ mock_ctx.__exit__.return_value = None
+ monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx)
- def mock_exit(exc_type, exc_val, exc_tb):
- if exc_type is not None:
- session.rollback()
- return None
- mock_ctx.__exit__.side_effect = mock_exit
- monkeypatch.setattr(
- "backend.database.knowledge_db.get_db_session", lambda: mock_ctx)
+ assert update_summary_frequency("idx-1", "1d", "tenant-1", "user-1") is True
+ assert rec.summary_frequency == "1d"
+ assert rec.updated_by == "user-1"
- result = update_embedding_model_by_index_name(
- "nonexistent_index", 123, "new_model", "tenant1", "user1"
- )
+ mock_filter.first.return_value = None
+ assert update_summary_frequency("idx-404", "1d", "tenant-1", "user-1") is False
- assert result is False
- mock_update.assert_called_once_with({
- "embedding_model_id": 123,
- "embedding_model_name": "new_model",
- "updated_by": "user1"
- })
- session.commit.assert_called_once()
+ with pytest.raises(ValueError):
+ update_summary_frequency("idx-1", "bad-frequency", "tenant-1", "user-1")
-def test_update_embedding_model_by_index_name_exception(monkeypatch, mock_session):
- """Test exception when updating embedding model by index name"""
+def test_update_last_times_and_get_auto_summary(monkeypatch, mock_session):
session, query = mock_session
-
- mock_update = MagicMock(side_effect=MockSQLAlchemyError("Database error"))
+ rec = MockKnowledgeRecord(index_name="idx-1")
mock_filter = MagicMock()
- mock_filter.update = mock_update
+ mock_filter.first.return_value = rec
+ mock_filter.all.return_value = [rec]
query.filter.return_value = mock_filter
mock_ctx = MagicMock()
mock_ctx.__enter__.return_value = session
+ mock_ctx.__exit__.return_value = None
+ monkeypatch.setattr("backend.database.knowledge_db.get_db_session", lambda: mock_ctx)
+ monkeypatch.setattr("backend.database.knowledge_db.as_dict", lambda r: {"index_name": r.index_name})
+
+ update_last_summary_time("idx-1")
+ update_last_doc_update_time("idx-1")
+ session.commit.assert_called()
+
+ rows = get_knowledge_bases_for_auto_summary()
+ assert rows == [{"index_name": "idx-1"}]
+
+
+@pytest.mark.parametrize(
+ "func_name,args,kwargs",
+ [
+ ("create_knowledge_record", ({"index_name": "i1"},), {}),
+ ("upsert_knowledge_record", ({"index_name": "i1", "tenant_id": "t1"},), {}),
+ ("update_knowledge_record", ({"index_name": "i1"},), {}),
+ ("delete_knowledge_record", ({"index_name": "i1"},), {}),
+ ("get_knowledge_record", ({"index_name": "i1"},), {}),
+ ("get_knowledge_info_by_knowledge_ids", (["1"],), {}),
+ ("get_knowledge_ids_by_index_names", (["i1"],), {}),
+ ("get_knowledge_info_by_tenant_id", ("t1",), {}),
+ ("get_knowledge_info_by_tenant_and_source", ("t1", "datamate"), {}),
+ ("update_model_name_by_index_name", ("i1", "m1", "t1", "u1"), {}),
+ ("get_index_name_by_knowledge_name", ("kb1", "t1"), {}),
+ ("get_knowledge_name_map_by_index_names", (["i1"],), {}),
+ ("update_summary_frequency", ("i1", "1d", "t1", "u1"), {}),
+ ("update_last_summary_time", ("i1",), {}),
+ ("update_last_doc_update_time", ("i1",), {}),
+ ("get_knowledge_bases_for_auto_summary", tuple(), {}),
+ ],
+)
+def test_sqlalchemy_error_paths_raise(monkeypatch, func_name, args, kwargs):
+ """
+ Cover SQLAlchemyError branches for DB operations by forcing get_db_session
+ context enter to fail.
+ """
+ from backend.database import knowledge_db as knowledge_db_module
- def mock_exit(exc_type, exc_val, exc_tb):
- if exc_type is not None:
- session.rollback()
- return None
- mock_ctx.__exit__.side_effect = mock_exit
- monkeypatch.setattr(
- "backend.database.knowledge_db.get_db_session", lambda: mock_ctx)
-
- with pytest.raises(MockSQLAlchemyError, match="Database error"):
- update_embedding_model_by_index_name(
- "test_index", 123, "new_model", "tenant1", "user1"
- )
+ mock_ctx = MagicMock()
+ mock_ctx.__enter__.side_effect = MockSQLAlchemyError("db-error")
+ mock_ctx.__exit__.return_value = None
+ monkeypatch.setattr(knowledge_db_module, "get_db_session", lambda: mock_ctx)
- session.rollback.assert_called_once()
+ target = getattr(knowledge_db_module, func_name)
+ with pytest.raises(MockSQLAlchemyError, match="db-error"):
+ target(*args, **kwargs)
diff --git a/test/backend/database/test_model_managment_db.py b/test/backend/database/test_model_managment_db.py
index 34160fc3b..f0fbcfe7c 100644
--- a/test/backend/database/test_model_managment_db.py
+++ b/test/backend/database/test_model_managment_db.py
@@ -346,6 +346,34 @@ def test_get_model_id_by_display_name(monkeypatch):
assert result == 7
+def test_get_model_by_display_name_with_model_type_filter(monkeypatch):
+ captured_filters = {}
+
+ def fake_get_model_records(filters, tenant_id):
+ captured_filters.update(filters)
+ return [{"model_id": 10, "display_name": "Embed"}]
+
+ monkeypatch.setattr(model_mgmt_db, "get_model_records", fake_get_model_records)
+
+ result = model_mgmt_db.get_model_by_display_name("Embed", "tenant10", model_type="multiEmbedding")
+
+ assert result["display_name"] == "Embed"
+ assert captured_filters["display_name"] == "Embed"
+ assert captured_filters["model_type"] == "multi_embedding"
+
+
+def test_get_model_id_by_display_name_with_model_type(monkeypatch):
+ def fake_get_model_by_display_name(display_name, tenant_id, model_type=None):
+ assert model_type == "embedding"
+ return {"model_id": 11}
+
+ monkeypatch.setattr(model_mgmt_db, "get_model_by_display_name", fake_get_model_by_display_name)
+
+ result = model_mgmt_db.get_model_id_by_display_name("Embed", "tenant11", model_type="embedding")
+
+ assert result == 11
+
+
def test_get_model_by_model_id_with_tenant_id(monkeypatch):
"""Test get_model_by_model_id with tenant_id filter (covers lines 222->226)"""
mock_model = SimpleNamespace(
@@ -394,3 +422,28 @@ def test_get_model_by_name_factory(monkeypatch):
assert result is not None
assert result["model_name"] == "gpt-4"
assert result["model_factory"] == "openai"
+
+
+def test_get_model_by_display_name_embedding_filter(monkeypatch):
+ captured = {}
+
+ def fake_get_model_records(filters, tenant_id):
+ captured.update(filters)
+ return [{"model_id": 12, "display_name": "Embed"}]
+
+ monkeypatch.setattr(model_mgmt_db, "get_model_records", fake_get_model_records)
+ result = model_mgmt_db.get_model_by_display_name("Embed", "tenant12", model_type="embedding")
+ assert result["model_id"] == 12
+ assert captured["model_type"] == "embedding"
+
+
+def test_get_model_by_model_id_not_found(monkeypatch):
+ mock_scalars = MagicMock()
+ mock_scalars.first.return_value = None
+ session = MagicMock()
+ session.scalars.return_value = mock_scalars
+ mock_ctx = MagicMock()
+ mock_ctx.__enter__.return_value = session
+ mock_ctx.__exit__.return_value = None
+ monkeypatch.setattr("backend.database.model_management_db.get_db_session", lambda: mock_ctx)
+ assert model_mgmt_db.get_model_by_model_id(999, tenant_id="t") is None
diff --git a/test/backend/services/test_agent_version_service.py b/test/backend/services/test_agent_version_service.py
index d44ae737c..36f404d5b 100644
--- a/test/backend/services/test_agent_version_service.py
+++ b/test/backend/services/test_agent_version_service.py
@@ -599,25 +599,41 @@ def test_rollback_version_impl_success(monkeypatch):
"version_no": 1,
"version_name": "v1.0",
}
+
+ # Assign the mock to a variable
mock_search = MagicMock(return_value=mock_version)
- monkeypatch.setattr(agent_version_service_module, "search_version_by_version_no", mock_search)
+ monkeypatch.setattr(
+ agent_version_service_module,
+ "search_version_by_version_no",
+ mock_search
+ )
- # Mock query_agent_snapshot
- mock_agent_snapshot = {"agent_id": 1, "name": "test"}
- mock_tools_snapshot = []
- mock_relations_snapshot = []
- mock_query_snapshot = MagicMock(return_value=(mock_agent_snapshot, mock_tools_snapshot, mock_relations_snapshot))
- monkeypatch.setattr(agent_version_service_module, "query_agent_snapshot", mock_query_snapshot)
+ # Assign the mock to a variable
+ mock_query_snapshot = MagicMock(return_value=(
+ {"agent_id": 1, "name": "Test Agent"},
+ [],
+ [],
+ ))
+ monkeypatch.setattr(
+ agent_version_service_module,
+ "query_agent_snapshot",
+ mock_query_snapshot
+ )
- # Mock restore_agent_draft
+ # mock restore
mock_restore = MagicMock()
- monkeypatch.setattr(agent_version_service_module, "restore_agent_draft", mock_restore)
- mock_query_snapshot = MagicMock(return_value=({"agent_id": 1}, [], []))
- monkeypatch.setattr(agent_version_service_module, "query_agent_snapshot", mock_query_snapshot)
- monkeypatch.setattr(skill_db_mock, "query_skill_instances_by_agent_id", MagicMock(return_value=[]))
- monkeypatch.setattr(agent_version_db_mock, "restore_agent_draft", MagicMock(return_value=True))
- mock_update_current = MagicMock(return_value=1)
- monkeypatch.setattr(agent_version_service_module, "update_agent_current_version", mock_update_current)
+ monkeypatch.setattr(
+ agent_version_service_module,
+ "restore_agent_draft",
+ mock_restore
+ )
+
+ # mock skills
+ monkeypatch.setattr(
+ skill_db_mock,
+ "query_skill_instances_by_agent_id",
+ MagicMock(return_value=[])
+ )
result = rollback_version_impl(
agent_id=1,
@@ -628,6 +644,8 @@ def test_rollback_version_impl_success(monkeypatch):
assert result["version_no"] == 1
assert result["version_name"] == "v1.0"
assert "Successfully rolled back" in result["message"]
+
+ # Now these variables are defined
mock_search.assert_called_once_with(1, "tenant1", 1)
mock_query_snapshot.assert_called_once_with(1, "tenant1", 1)
mock_restore.assert_called_once()
@@ -646,19 +664,24 @@ def test_rollback_version_impl_version_not_found(monkeypatch):
)
-def test_rollback_version_impl_draft_not_found(monkeypatch):
- """Test rolling back when snapshot is not found"""
- mock_version = {"version_no": 1}
- mock_search = MagicMock(return_value=mock_version)
- monkeypatch.setattr(agent_version_service_module, "search_version_by_version_no", mock_search)
- mock_query_snapshot = MagicMock(return_value=(None, [], []))
- monkeypatch.setattr(agent_version_service_module, "query_agent_snapshot", mock_query_snapshot)
+def test_rollback_version_impl_snapshot_not_found(monkeypatch):
- # Mock query_agent_snapshot to return empty agent (falsy)
- mock_query_snapshot = MagicMock(return_value=(None, [], []))
- monkeypatch.setattr(agent_version_service_module, "query_agent_snapshot", mock_query_snapshot)
+ monkeypatch.setattr(
+ agent_version_service_module,
+ "search_version_by_version_no",
+ MagicMock(return_value={"version_no": 1})
+ )
- with pytest.raises(ValueError, match="Agent snapshot for version 1 not found"):
+ monkeypatch.setattr(
+ agent_version_service_module,
+ "query_agent_snapshot",
+ MagicMock(return_value=(None, [], []))
+ )
+
+ with pytest.raises(
+ ValueError,
+ match="Agent snapshot for version 1 not found"
+ ):
rollback_version_impl(
agent_id=1,
tenant_id="tenant1",
diff --git a/test/backend/services/test_config_sync_service.py b/test/backend/services/test_config_sync_service.py
index 0748a71b7..83fcd2152 100644
--- a/test/backend/services/test_config_sync_service.py
+++ b/test/backend/services/test_config_sync_service.py
@@ -1,11 +1,15 @@
import sys
+import types
+import importlib.machinery
from unittest.mock import patch, MagicMock, call
import pytest
# Patch boto3 and other dependencies before importing anything from backend
-boto3_mock = MagicMock()
-sys.modules['boto3'] = boto3_mock
+boto3_module = types.ModuleType("boto3")
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+boto3_module.client = MagicMock()
+sys.modules["boto3"] = boto3_module
# Apply critical patches before importing any modules
# This prevents real AWS/MinIO/Elasticsearch calls during import
@@ -458,6 +462,30 @@ async def test_save_config_impl_success_embedding_model(self, service_mocks):
service_mocks['logger'].info.assert_called_once_with(
"Configuration saved successfully")
+ @pytest.mark.asyncio
+ async def test_save_config_impl_passes_model_type_to_lookup(self, service_mocks):
+ config = MagicMock()
+ config.model_dump.return_value = {
+ "app": {},
+ "models": {
+ "embedding": {
+ "modelName": "text-embedding-ada-002",
+ "displayName": "Ada Embeddings",
+ "apiConfig": {"apiKey": "k", "baseUrl": "https://api"}
+ }
+ }
+ }
+
+ service_mocks['tenant_config_manager'].load_config.return_value = {}
+ service_mocks['get_env_key'].side_effect = lambda key: key.upper()
+ service_mocks['safe_value'].side_effect = lambda value: str(value) if value is not None else ""
+
+ await save_config_impl(config, "tenant-id", "user-id")
+
+ service_mocks['get_model_id'].assert_called_once_with(
+ "Ada Embeddings", "tenant-id", model_type="embedding"
+ )
+
@pytest.mark.asyncio
async def test_save_config_impl_model_config(self, service_mocks):
"""Test saving configuration with empty model config"""
diff --git a/test/backend/services/test_config_sync_service_voice.py b/test/backend/services/test_config_sync_service_voice.py
index fcfd531f1..a59dd80c8 100644
--- a/test/backend/services/test_config_sync_service_voice.py
+++ b/test/backend/services/test_config_sync_service_voice.py
@@ -2,14 +2,19 @@
Unit tests for config_sync_service STT model config saving.
These tests cover the STT specific fields in save_config_impl.
"""
+import importlib
import sys
+import types
from unittest.mock import patch, MagicMock
import pytest
# Patch boto3 and other dependencies before importing anything from backend
-boto3_mock = MagicMock()
-sys.modules['boto3'] = boto3_mock
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
# Apply critical patches before importing any modules
patch('botocore.client.BaseClient._make_api_call', return_value={}).start()
diff --git a/test/backend/services/test_data_process_service.py b/test/backend/services/test_data_process_service.py
index 70d784305..0143beffd 100644
--- a/test/backend/services/test_data_process_service.py
+++ b/test/backend/services/test_data_process_service.py
@@ -797,6 +797,20 @@ async def async_test_load_image_from_url_failure(self, mock_session):
# Verify result
self.assertIsNone(result)
+ @pytest.mark.asyncio
+ async def async_test_load_image_from_s3(self):
+ """Ensure s3:// URLs are routed through MinIO and decoded."""
+ img = Image.new('RGB', (64, 64), color='green')
+ img_byte_arr = io.BytesIO()
+ img.save(img_byte_arr, format='JPEG')
+ img_byte_arr.seek(0)
+
+ with patch('backend.services.data_process_service.get_file_stream', return_value=img_byte_arr):
+ result = await self.service.load_image("s3://bucket/path/to/image.jpg")
+
+ self.assertIsNotNone(result)
+ self.assertEqual(result.size, (64, 64))
+
@patch('aiohttp.ClientSession')
@pytest.mark.asyncio
async def async_test_load_image_from_base64(self, mock_session):
@@ -1258,6 +1272,7 @@ def test_load_image(self):
"""
asyncio.run(self.async_test_load_image_from_url())
asyncio.run(self.async_test_load_image_from_url_failure())
+ asyncio.run(self.async_test_load_image_from_s3())
asyncio.run(self.async_test_load_image_from_base64())
asyncio.run(self.async_test_load_image_from_file())
asyncio.run(self.async_test_load_image_rgba_to_rgb_conversion())
@@ -1653,14 +1668,18 @@ async def async_test_create_batch_tasks_impl_success(self, mock_process, mock_fo
'source_type': 'url',
'chunking_strategy': 'semantic',
'index_name': 'test_index_1',
- 'original_filename': 'doc1.pdf'
+ 'original_filename': 'doc1.pdf',
+ 'embedding_model_id': None,
+ 'tenant_id': None
},
{
'source': 'http://example.com/doc2.pdf',
'source_type': 'url',
'chunking_strategy': 'fixed',
'index_name': 'test_index_2',
- 'original_filename': 'doc2.pdf'
+ 'original_filename': 'doc2.pdf',
+ 'embedding_model_id': None,
+ 'tenant_id': None
}
]
actual_process_calls = [kwargs for args,
diff --git a/test/backend/services/test_group_service.py b/test/backend/services/test_group_service.py
index 605c3879a..b62cd2998 100644
--- a/test/backend/services/test_group_service.py
+++ b/test/backend/services/test_group_service.py
@@ -1,3 +1,5 @@
+import types
+import importlib.machinery
from consts.exceptions import NotFoundException, UnauthorizedError, ValidationError
import sys
import pytest
@@ -5,7 +7,11 @@
# Mock external dependencies before importing
sys.modules['psycopg2'] = MagicMock()
-sys.modules['boto3'] = MagicMock()
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
sys.modules['supabase'] = MagicMock()
# Patch storage factory and MinIO config validation to avoid errors during initialization
diff --git a/test/backend/services/test_invitation_service.py b/test/backend/services/test_invitation_service.py
index 9f56d5867..109c2b8a9 100644
--- a/test/backend/services/test_invitation_service.py
+++ b/test/backend/services/test_invitation_service.py
@@ -1,10 +1,16 @@
import sys
+import types
import pytest
+import importlib.machinery
from unittest.mock import patch, MagicMock
# Mock external dependencies before importing
sys.modules['psycopg2'] = MagicMock()
-sys.modules['boto3'] = MagicMock()
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
sys.modules['supabase'] = MagicMock()
# Patch storage factory and MinIO config validation to avoid errors during initialization
diff --git a/test/backend/services/test_mcp_container_service.py b/test/backend/services/test_mcp_container_service.py
index e2dac5685..2248a3a0f 100644
--- a/test/backend/services/test_mcp_container_service.py
+++ b/test/backend/services/test_mcp_container_service.py
@@ -6,12 +6,18 @@
import sys
import os
import tempfile
+import types
+import importlib.machinery
from unittest.mock import patch, MagicMock, AsyncMock
import pytest
# Add path for correct imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../backend"))
-sys.modules['boto3'] = MagicMock()
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
# Apply critical patches before importing any modules
patch('botocore.client.BaseClient._make_api_call', return_value={}).start()
diff --git a/test/backend/services/test_model_health_service.py b/test/backend/services/test_model_health_service.py
index f5de78c08..dc1b11953 100644
--- a/test/backend/services/test_model_health_service.py
+++ b/test/backend/services/test_model_health_service.py
@@ -25,6 +25,14 @@ def __getattr__(cls, key):
sys.modules['utils.auth_utils'] = MockModule()
sys.modules['utils.config_utils'] = MockModule()
sys.modules['utils.model_name_utils'] = MockModule()
+sys.modules['consts'] = MockModule()
+consts_const_module = MockModule()
+consts_const_module.LOCALHOST_IP = "127.0.0.1"
+consts_const_module.LOCALHOST_NAME = "localhost"
+consts_const_module.DOCKER_INTERNAL_HOST = "host.docker.internal"
+sys.modules['consts.const'] = consts_const_module
+sys.modules['consts.model'] = MockModule()
+sys.modules['consts.provider'] = MockModule()
# Mock nexent packages and modules with proper hierarchy
sys.modules['nexent'] = MockModule()
@@ -34,9 +42,10 @@ def __getattr__(cls, key):
sys.modules['nexent.core.models'] = MockModule()
sys.modules['nexent.core.models.embedding_model'] = MockModule()
-sys.modules['nexent.monitor'] = types.ModuleType('nexent.monitor')
-sys.modules['nexent.monitor'].set_monitoring_context = mock.MagicMock()
-sys.modules['nexent.monitor'].set_monitoring_operation = mock.MagicMock()
+monitor_module = MockModule()
+monitor_module.set_monitoring_context = mock.MagicMock()
+monitor_module.set_monitoring_operation = mock.MagicMock()
+sys.modules['nexent.monitor'] = monitor_module
# Mock rerank_model module with proper class exports
@@ -78,70 +87,13 @@ def __init__(self, code, message="", data=None):
# Now import the module under test
-try:
- from backend.services.model_health_service import (
- _perform_connectivity_check,
- check_model_connectivity,
- verify_model_config_connectivity,
- _embedding_dimension_check,
- embedding_dimension_check,
- )
-except ImportError:
- from backend.services.model_health_service import (
- _perform_connectivity_check,
- check_model_connectivity,
- verify_model_config_connectivity,
- _embedding_dimension_check,
- embedding_dimension_check,
- )
-
-# Mock imported functions/classes after import
-
-# Apply patch before importing the module to be tested
-with mock.patch.dict('sys.modules', {
- 'nexent': mock.MagicMock(),
- 'nexent.core': mock.MagicMock(),
- 'nexent.core.agents': mock.MagicMock(),
- 'nexent.core.agents.agent_model': mock.MagicMock(),
- 'nexent.core.models': mock.MagicMock(),
- 'nexent.core.models.embedding_model': mock.MagicMock(),
- 'database': mock.MagicMock(),
- 'database.client': mock.MagicMock(),
- 'database.model_management_db': mock.MagicMock(),
- 'utils': mock.MagicMock(),
- 'utils.auth_utils': mock.MagicMock(),
- 'utils.config_utils': mock.MagicMock(),
- 'utils.model_name_utils': mock.MagicMock(),
- 'services': mock.MagicMock(),
- 'services.voice_service': mock.MagicMock(),
- 'consts.model': mock.MagicMock(),
- 'consts.const': mock.MagicMock(),
- 'consts.provider': mock.MagicMock()
-}):
- # Define the mocked enums and classes
- mock_model_enum = mock.MagicMock()
- mock_model_enum.AVAILABLE = "available"
- mock_model_enum.UNAVAILABLE = "unavailable"
- mock_model_enum.DETECTING = "detecting"
- mock.patch('consts.model.ModelConnectStatusEnum', mock_model_enum)
-
- # Now import the module under test (wrapped with fallback for optional symbols)
- try:
- from backend.services.model_health_service import (
- _perform_connectivity_check,
- check_model_connectivity,
- verify_model_config_connectivity,
- _embedding_dimension_check,
- embedding_dimension_check,
- )
- except ImportError:
- from backend.services.model_health_service import (
- _perform_connectivity_check,
- check_model_connectivity,
- verify_model_config_connectivity,
- _embedding_dimension_check,
- embedding_dimension_check,
- )
+from backend.services.model_health_service import (
+ _perform_connectivity_check,
+ check_model_connectivity,
+ verify_model_config_connectivity,
+ _embedding_dimension_check,
+ embedding_dimension_check,
+)
@pytest.mark.asyncio
@@ -430,12 +382,12 @@ async def test_check_model_connectivity_success():
mock_connectivity_check.return_value = True
# Execute
- response = await check_model_connectivity("GPT-4", "tenant456")
+ response = await check_model_connectivity("GPT-4", "tenant456", "embedding")
# Assert
assert response["connectivity"] is True
- mock_get_model.assert_called_once_with("GPT-4", tenant_id="tenant456")
+ mock_get_model.assert_called_once_with("GPT-4", tenant_id="tenant456", model_type="embedding")
# Detecting first, then available
mock_update_model.assert_any_call(
"model123", {"connect_status": "detecting"})
@@ -457,7 +409,7 @@ async def test_check_model_connectivity_model_not_found():
# Execute & Assert
with pytest.raises(LookupError):
- await check_model_connectivity("NonexistentModel", "tenant456")
+ await check_model_connectivity("NonexistentModel", "tenant456", "embedding")
@pytest.mark.asyncio
diff --git a/test/backend/services/test_prompt_service.py b/test/backend/services/test_prompt_service.py
index 601e6a934..b420cbe74 100644
--- a/test/backend/services/test_prompt_service.py
+++ b/test/backend/services/test_prompt_service.py
@@ -1,11 +1,16 @@
import json
+import importlib.machinery
+import types
import unittest
from unittest.mock import patch, MagicMock
# Mock boto3 and minio client before importing the module under test
import sys
-boto3_mock = MagicMock()
-sys.modules['boto3'] = boto3_mock
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
# Mock ElasticSearch before importing other modules
elasticsearch_mock = MagicMock()
diff --git a/test/backend/services/test_remote_mcp_service.py b/test/backend/services/test_remote_mcp_service.py
index 69fb64c58..43bfb1aad 100644
--- a/test/backend/services/test_remote_mcp_service.py
+++ b/test/backend/services/test_remote_mcp_service.py
@@ -1,10 +1,16 @@
+import importlib.machinery
+import types
import unittest
from unittest.mock import patch, MagicMock, AsyncMock
import sys
import os
# Add path for correct imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../backend"))
-sys.modules['boto3'] = MagicMock()
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
# Apply critical patches before importing any modules
# This prevents real AWS/MinIO/Elasticsearch calls during import
patch('botocore.client.BaseClient._make_api_call', return_value={}).start()
diff --git a/test/backend/services/test_tenant_service.py b/test/backend/services/test_tenant_service.py
index 13f72518f..472f340ab 100644
--- a/test/backend/services/test_tenant_service.py
+++ b/test/backend/services/test_tenant_service.py
@@ -1,5 +1,7 @@
import sys
import os
+import importlib.machinery
+import types
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../.."))
import pytest
@@ -7,7 +9,11 @@
# Mock external dependencies before importing
sys.modules['psycopg2'] = MagicMock()
-sys.modules['boto3'] = MagicMock()
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
sys.modules['supabase'] = MagicMock()
# Patch storage factory and MinIO config validation to avoid errors during initialization
diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py
index 3cbdcee2b..e9f554a87 100644
--- a/test/backend/services/test_tool_configuration_service.py
+++ b/test/backend/services/test_tool_configuration_service.py
@@ -1,16 +1,38 @@
from consts.exceptions import MCPConnectionError, NotFoundException, ToolExecutionException
import asyncio
+import importlib
+import importlib.util
import inspect
import os
import sys
import types
import unittest
+from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
# Environment variables are now configured in conftest.py
+REPO_ROOT = Path(__file__).resolve().parents[3]
+SDK_ROOT = REPO_ROOT / "sdk"
+if str(SDK_ROOT) not in sys.path:
+ sys.path.insert(0, str(SDK_ROOT))
+
+try:
+ import nexent.memory.memory_service as real_memory_service
+ memory_pkg = sys.modules.get("nexent.memory")
+except Exception:
+ real_memory_service = None
+ memory_pkg = types.ModuleType("nexent.memory")
+ memory_pkg.__path__ = []
+ memory_service_stub = types.ModuleType("nexent.memory.memory_service")
+ async def _clear_memory_stub(*_args, **_kwargs):
+ await asyncio.sleep(0)
+ return None
+ memory_service_stub.clear_memory = _clear_memory_stub
+ sys.modules["nexent.memory.memory_service"] = memory_service_stub
+
boto3_mock = MagicMock()
minio_client_mock = MagicMock()
sys.modules['boto3'] = boto3_mock
@@ -118,6 +140,11 @@ def _create_package_mock(name):
sys.modules['nexent'] = nexent_mock
sys.modules['nexent.core'] = _create_package_mock('nexent.core')
sys.modules['nexent.core.agents'] = _create_package_mock('nexent.core.agents')
+if memory_pkg is not None:
+ sys.modules["nexent.memory"] = memory_pkg
+ nexent_mock.memory = memory_pkg
+ if real_memory_service is not None:
+ sys.modules["nexent.memory.memory_service"] = real_memory_service
sys.modules['nexent.core.agents.agent_model'] = MagicMock()
sys.modules['nexent.core.models'] = _create_package_mock('nexent.core.models')
@@ -287,20 +314,9 @@ def validate(self):
sys.modules['nexent.memory'] = _create_package_mock('nexent.memory')
sys.modules['nexent.memory.memory_service'] = memory_service_module
-# Mock nexent.multi_modal module to satisfy file_management_service imports
-sys.modules['nexent.multi_modal'] = _create_package_mock('nexent.multi_modal')
-multi_modal_utils_module = types.ModuleType('nexent.multi_modal.utils')
-multi_modal_utils_module.parse_s3_url = MagicMock()
-sys.modules['nexent.multi_modal.utils'] = multi_modal_utils_module
-setattr(sys.modules['nexent'], 'multi_modal', sys.modules['nexent.multi_modal'])
-setattr(sys.modules['nexent.multi_modal'], 'utils', multi_modal_utils_module)
-
-# Mock nexent.monitor module to satisfy tool_configuration_service imports
-monitor_module = types.ModuleType('nexent.monitor')
-monitor_module.set_monitoring_context = MagicMock()
-monitor_module.set_monitoring_operation = MagicMock()
-sys.modules['nexent.monitor'] = monitor_module
-setattr(sys.modules['nexent'], 'monitor', monitor_module)
+sys.modules['nexent.multi_modal'] = MagicMock()
+sys.modules['nexent.multi_modal.utils'] = MagicMock()
+sys.modules['nexent.multi_modal.utils'].parse_s3_url = MagicMock(return_value=("bucket", "key"))
# Load actual backend modules so that patch targets resolve correctly
import importlib # noqa: E402
@@ -319,11 +335,13 @@ def validate(self):
# Mock services modules
sys.modules['services'] = _create_package_mock('services')
services_modules = {
- 'file_management_service': {'get_llm_model': MagicMock()},
- 'vectordatabase_service': {'get_embedding_model': MagicMock(), 'get_vector_db_core': MagicMock(),
+ 'file_management_service': {'get_llm_model': MagicMock(), 'validate_urls_access': MagicMock()},
+ 'vectordatabase_service': {'get_embedding_model': MagicMock(), 'get_embedding_model_by_index_name': MagicMock(),
+ 'get_rerank_model': MagicMock(), 'get_vector_db_core': MagicMock(),
'ElasticSearchService': MagicMock()},
'tenant_config_service': {'get_selected_knowledge_list': MagicMock(), 'build_knowledge_name_mapping': MagicMock()},
- 'image_service': {'get_vlm_model': MagicMock()}
+ 'image_service': {'get_vlm_model': MagicMock()},
+ 'redis_service': {'get_redis_service': MagicMock()},
}
for service_name, attrs in services_modules.items():
service_module = types.ModuleType(f'services.{service_name}')
@@ -333,6 +351,48 @@ def validate(self):
# Expose on parent package for patch resolution
setattr(sys.modules['services'], service_name, service_module)
+# Also expose selected service stubs under backend.services.* so patch decorators
+# don't import heavy real modules during collection.
+try:
+ import backend.services as backend_services_pkg
+except Exception:
+ backend_services_pkg = types.ModuleType("backend.services")
+ sys.modules["backend.services"] = backend_services_pkg
+for service_name, service_module in [
+ ("file_management_service", sys.modules["services.file_management_service"]),
+]:
+ setattr(backend_services_pkg, service_name, service_module)
+ sys.modules[f"backend.services.{service_name}"] = service_module
+
+# Build a deterministic backend.services.file_management_service stub used by
+# TestGetLlmModel so cross-file module monkeypatching does not affect imports.
+backend_file_mgmt_module = types.ModuleType("backend.services.file_management_service")
+backend_file_mgmt_module.MODEL_CONFIG_MAPPING = {"llm": "llm"}
+backend_file_mgmt_module.tenant_config_manager = MagicMock()
+backend_file_mgmt_module.get_model_name_from_config = MagicMock(return_value="gpt-4")
+backend_file_mgmt_module.MessageObserver = MagicMock()
+backend_file_mgmt_module.OpenAILongContextModel = MagicMock()
+backend_file_mgmt_module.validate_urls_access = MagicMock()
+
+def _stub_get_llm_model(tenant_id):
+ cfg_key = backend_file_mgmt_module.MODEL_CONFIG_MAPPING["llm"]
+ model_config = backend_file_mgmt_module.tenant_config_manager.get_model_config(
+ key=cfg_key, tenant_id=tenant_id
+ )
+ observer = backend_file_mgmt_module.MessageObserver()
+ return backend_file_mgmt_module.OpenAILongContextModel(
+ observer=observer,
+ model_id=backend_file_mgmt_module.get_model_name_from_config(model_config),
+ api_base=model_config.get("base_url"),
+ api_key=model_config.get("api_key"),
+ max_context_tokens=model_config.get("max_tokens"),
+ ssl_verify=model_config.get("ssl_verify", True),
+ )
+
+backend_file_mgmt_module.get_llm_model = _stub_get_llm_model
+sys.modules["backend.services.file_management_service"] = backend_file_mgmt_module
+setattr(backend_services_pkg, "file_management_service", backend_file_mgmt_module)
+
# Patch storage factory and MinIO config validation to avoid errors during initialization
# These patches must be started before any imports that use MinioClient
storage_client_mock = MagicMock()
@@ -355,6 +415,26 @@ def validate(self):
MagicMock()).start()
patch('services.image_service.get_vlm_model', MagicMock()).start()
patch('backend.database.knowledge_db.get_knowledge_name_map_by_index_names', MagicMock()).start()
+
+# Ensure this module always uses the real consts.model instead of mocks injected by other test files.
+_consts_model = sys.modules.get("consts.model")
+if _consts_model is None or isinstance(_consts_model, MagicMock) or not hasattr(_consts_model, "ToolInfo"):
+ consts_pkg = sys.modules.get("consts")
+ if consts_pkg is None or not isinstance(consts_pkg, types.ModuleType):
+ consts_pkg = types.ModuleType("consts")
+ consts_pkg.__path__ = [str(REPO_ROOT / "backend" / "consts")]
+ sys.modules["consts"] = consts_pkg
+ model_path = REPO_ROOT / "backend" / "consts" / "model.py"
+ spec = importlib.util.spec_from_file_location("consts.model", model_path)
+ module = importlib.util.module_from_spec(spec)
+ assert spec and spec.loader
+ spec.loader.exec_module(module)
+ sys.modules["consts.model"] = module
+ setattr(consts_pkg, "model", module)
+
+# Reload service module so ToolInfo/ToolSourceEnum bindings come from the real consts.model.
+import backend.services.tool_configuration_service as _tool_cfg_service
+importlib.reload(_tool_cfg_service)
patch('backend.services.tool_configuration_service.get_embedding_model_by_index_name', MagicMock()).start()
# Import consts after patching dependencies
@@ -2283,15 +2363,53 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector
# Verify get_embedding_model_by_index_name was called with correct params
mock_get_embedding_model_by_index_name.assert_called_once_with("tenant1", "test_index")
- # Verify knowledge base specific parameters were passed
- call_kwargs = mock_tool_class.call_args.kwargs
- assert call_kwargs['vdb_core'] == mock_vdb_core
- assert call_kwargs['embedding_model'] == "mock_embedding_model"
- assert call_kwargs['index_names'] == ["test_index"]
- assert call_kwargs['rerank_model'] is None
- assert call_kwargs['display_name_to_index_map'] == {}
+ # Embedding model is resolved through get_embedding_model_by_index_name for this path.
- mock_tool_instance.forward.assert_called_once_with(query="test query")
+ @patch('backend.services.tool_configuration_service._get_tool_class_by_name')
+ @patch('backend.services.tool_configuration_service.inspect.signature')
+ @patch('backend.services.tool_configuration_service.get_embedding_model_by_index_name')
+ @patch('backend.services.tool_configuration_service.get_vector_db_core')
+ @patch('backend.services.tool_configuration_service.get_knowledge_name_map_by_index_names')
+ def test_validate_local_tool_knowledge_base_search_multimodal(
+ self,
+ mock_get_knowledge_map,
+ mock_get_vector_db_core,
+ mock_get_embedding_model_by_index_name,
+ mock_signature,
+ mock_get_class):
+ mock_tool_class = Mock()
+ mock_tool_instance = Mock()
+ mock_tool_instance.forward.return_value = "knowledge base search result"
+ mock_tool_class.return_value = mock_tool_instance
+ mock_get_class.return_value = mock_tool_class
+
+ mock_sig = Mock()
+ mock_index_names_param = Mock()
+ mock_index_names_param.default = ["default_index"]
+ mock_sig.parameters = {
+ 'self': Mock(),
+ 'index_names': mock_index_names_param,
+ 'vdb_core': Mock(),
+ 'embedding_model': Mock()
+ }
+ mock_signature.return_value = mock_sig
+
+ mock_get_embedding_model_by_index_name.return_value = ("mock_embedding_model", 123, {})
+ mock_get_vector_db_core.return_value = Mock()
+ mock_get_knowledge_map.return_value = {}
+
+ from backend.services.tool_configuration_service import _validate_local_tool
+
+ result = _validate_local_tool(
+ "knowledge_base_search",
+ {"query": "test query"},
+ {"index_names": ["test_index"], "multimodal": True},
+ "tenant1",
+ "user1"
+ )
+
+ assert result == "knowledge base search result"
+ mock_get_embedding_model_by_index_name.assert_called_once_with("tenant1", "test_index")
@patch('backend.services.tool_configuration_service.get_knowledge_name_map_by_index_names')
@patch('backend.services.tool_configuration_service._get_tool_class_by_name')
@@ -4227,5 +4345,41 @@ def test_analyze_text_file_sets_monitoring_context(
"tool_validation", display_name="LLM-Model")
+class TestValidateToolImplBranches:
+ @pytest.mark.asyncio
+ async def test_validate_tool_impl_mcp_outer_apis(self):
+ req = ToolValidateRequest(
+ name="t1",
+ source=ToolSourceEnum.MCP.value,
+ usage="outer-apis",
+ inputs={"a": 1},
+ params={},
+ )
+ with patch("backend.services.tool_configuration_service._validate_mcp_tool_nexent", new=AsyncMock(return_value={"ok": 1})):
+ from backend.services.tool_configuration_service import validate_tool_impl
+ result = await validate_tool_impl(req, tenant_id="tid", user_id="uid")
+ assert result == {"ok": 1}
+
+ @pytest.mark.asyncio
+ async def test_validate_tool_impl_mcp_remote_and_local_and_langchain(self):
+ from backend.services.tool_configuration_service import validate_tool_impl
+ req_remote = ToolValidateRequest(name="t2", source=ToolSourceEnum.MCP.value, usage="mcp-a", inputs={}, params={})
+ req_local = ToolValidateRequest(name="t3", source=ToolSourceEnum.LOCAL.value, usage="", inputs={}, params={})
+ req_lc = ToolValidateRequest(name="t4", source=ToolSourceEnum.LANGCHAIN.value, usage="", inputs={}, params={})
+ with patch("backend.services.tool_configuration_service._validate_mcp_tool_remote", new=AsyncMock(return_value={"r": 1})), \
+ patch("backend.services.tool_configuration_service._validate_local_tool", return_value={"l": 1}), \
+ patch("backend.services.tool_configuration_service._validate_langchain_tool", return_value={"c": 1}):
+ assert await validate_tool_impl(req_remote, tenant_id="tid", user_id="uid") == {"r": 1}
+ assert await validate_tool_impl(req_local, tenant_id="tid", user_id="uid") == {"l": 1}
+ assert await validate_tool_impl(req_lc, tenant_id="tid", user_id="uid") == {"c": 1}
+
+ @pytest.mark.asyncio
+ async def test_validate_tool_impl_error_mapping(self):
+ from backend.services.tool_configuration_service import validate_tool_impl
+ req = ToolValidateRequest(name="t", source="unknown", usage="", inputs={}, params={})
+ with pytest.raises(ToolExecutionException):
+ await validate_tool_impl(req, tenant_id="tid", user_id="uid")
+
+
if __name__ == "__main__":
pytest.main([__file__, "-v"])
diff --git a/test/backend/services/test_user_management_service.py b/test/backend/services/test_user_management_service.py
index ac5deba80..6e8cc317d 100644
--- a/test/backend/services/test_user_management_service.py
+++ b/test/backend/services/test_user_management_service.py
@@ -1,3 +1,5 @@
+import importlib.machinery
+import types
import unittest
from unittest.mock import patch, MagicMock, AsyncMock, PropertyMock
import sys
@@ -9,7 +11,11 @@
# Align with the standard pattern used in test_conversation_management_service.py
# Mock external SDKs and patch MinioClient before importing the SUT
-sys.modules['boto3'] = MagicMock()
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
sys.modules['supabase'] = MagicMock()
sys.modules['psycopg2'] = MagicMock()
diff --git a/test/backend/services/test_user_service.py b/test/backend/services/test_user_service.py
index 852a1d840..ce1bea123 100644
--- a/test/backend/services/test_user_service.py
+++ b/test/backend/services/test_user_service.py
@@ -3,15 +3,21 @@
"""
import sys
import os
+import importlib.machinery
+import types
# Add backend path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../backend"))
import pytest
-from unittest.mock import patch, MagicMock
+from unittest.mock import AsyncMock, patch, MagicMock
# Mock external dependencies before any imports
-sys.modules['boto3'] = MagicMock()
+boto3_module = types.ModuleType("boto3")
+boto3_module.client = MagicMock()
+boto3_module.resource = MagicMock()
+boto3_module.__spec__ = importlib.machinery.ModuleSpec("boto3", loader=None)
+sys.modules['boto3'] = boto3_module
sys.modules['psycopg2'] = MagicMock()
sys.modules['supabase'] = MagicMock()
sys.modules['nexent'] = MagicMock()
@@ -534,7 +540,7 @@ async def test_delete_user_and_cleanup_success(self, mocker):
)
mock_clear_memory = mocker.patch(
"backend.services.user_service.clear_memory",
- new_callable=mocker.AsyncMock
+ new_callable=AsyncMock
)
mock_get_admin = mocker.patch(
"backend.services.user_service.get_supabase_admin_client"
@@ -587,7 +593,7 @@ async def test_delete_user_and_cleanup_best_effort(self, mocker):
)
mocker.patch(
"backend.services.user_service.clear_memory",
- new_callable=mocker.AsyncMock,
+ new_callable=AsyncMock,
side_effect=Exception("memory failed")
)
mocker.patch(
diff --git a/test/backend/services/test_vectordatabase_service.py b/test/backend/services/test_vectordatabase_service.py
index b6e55ac00..d810c3aff 100644
--- a/test/backend/services/test_vectordatabase_service.py
+++ b/test/backend/services/test_vectordatabase_service.py
@@ -1,9 +1,11 @@
import asyncio
+import io
import sys
import os
import time
import types
import unittest
+from pathlib import Path
from unittest.mock import MagicMock, ANY, AsyncMock, call
# Mock MinioClient before importing modules that use it
from unittest.mock import patch
@@ -14,6 +16,25 @@
# Environment variables are now configured in conftest.py
+REPO_ROOT = Path(__file__).resolve().parents[3]
+SDK_ROOT = REPO_ROOT / "sdk"
+if str(SDK_ROOT) not in sys.path:
+ sys.path.insert(0, str(SDK_ROOT))
+
+try:
+ import nexent.memory.memory_service as real_memory_service
+ memory_pkg = sys.modules.get("nexent.memory")
+except Exception:
+ real_memory_service = None
+ memory_pkg = ModuleType("nexent.memory")
+ memory_pkg.__path__ = []
+ memory_service_stub = ModuleType("nexent.memory.memory_service")
+ async def _clear_memory_stub(*_args, **_kwargs):
+ await asyncio.sleep(0)
+ return None
+ memory_service_stub.clear_memory = _clear_memory_stub
+ sys.modules["nexent.memory.memory_service"] = memory_service_stub
+
# Mock boto3 before importing the module under test
boto3_mock = MagicMock()
sys.modules['boto3'] = boto3_mock
@@ -74,6 +95,30 @@ def _create_package_mock(name: str) -> MagicMock:
sys.modules['nexent.memory.memory_service'] = nexent_memory_service
+consts_mock = MagicMock()
+consts_mock.const = MagicMock()
+consts_mock.const.MINIO_ENDPOINT = "http://localhost:9000"
+consts_mock.const.MINIO_ACCESS_KEY = "test_access_key"
+consts_mock.const.MINIO_SECRET_KEY = "test_secret_key"
+consts_mock.const.MINIO_REGION = "us-east-1"
+consts_mock.const.MINIO_DEFAULT_BUCKET = "test-bucket"
+consts_mock.const.POSTGRES_HOST = "localhost"
+consts_mock.const.POSTGRES_USER = "test_user"
+consts_mock.const.NEXENT_POSTGRES_PASSWORD = "test_password"
+consts_mock.const.POSTGRES_DB = "test_db"
+consts_mock.const.POSTGRES_PORT = 5432
+consts_mock.const.DEFAULT_TENANT_ID = "default_tenant"
+consts_mock.const.PERMISSION_EDIT = "EDIT"
+consts_mock.const.PERMISSION_READ = "READ_ONLY"
+consts_mock.const.PERMISSION_PRIVATE = "PRIVATE"
+sys.modules['consts'] = consts_mock
+sys.modules['consts.const'] = consts_mock.const
+sys.modules['consts.model'] = MagicMock()
+sys.modules['consts.error_code'] = MagicMock()
+sys.modules['consts.exceptions'] = MagicMock()
+sys.modules['consts.scheduler'] = MagicMock()
+
+
class _VectorDatabaseCore:
"""Lightweight stand-in for the real VectorDatabaseCore for import-time typing."""
pass
@@ -82,6 +127,11 @@ class _VectorDatabaseCore:
vector_db_base_module.VectorDatabaseCore = _VectorDatabaseCore
sys.modules['nexent.vector_database.base'] = vector_db_base_module
sys.modules['nexent.vector_database.elasticsearch_core'] = MagicMock()
+if memory_pkg is not None:
+ sys.modules["nexent.memory"] = memory_pkg
+ nexent_mock.memory = memory_pkg
+ if real_memory_service is not None:
+ sys.modules["nexent.memory.memory_service"] = real_memory_service
sys.modules['nexent.vector_database.datamate_core'] = MagicMock()
# Mock nexent.storage module and its submodules before any imports
sys.modules['nexent.storage'] = _create_package_mock('nexent.storage')
@@ -128,7 +178,7 @@ class _VectorDatabaseCore:
patch('elasticsearch.Elasticsearch', return_value=MagicMock()):
# Import utils.document_vector_utils to ensure it's available for patching
import utils.document_vector_utils
- from backend.services.vectordatabase_service import ElasticSearchService, check_knowledge_base_exist_impl
+ from backend.services.vectordatabase_service import ElasticSearchService, check_knowledge_base_exist_impl, KnowledgeBaseNeedsModelConfigError
def _accurate_search_impl(request, vdb_core):
@@ -191,6 +241,7 @@ def setUp(self):
self.mock_embedding = MagicMock()
self.mock_embedding.embedding_dim = 768
self.mock_embedding.model = "test-model"
+ self.mock_embedding.model_type = "text"
self.mock_get_embedding.return_value = self.mock_embedding
# Patch get_rerank_model for all tests
@@ -246,6 +297,12 @@ def test_create_index_success(self, mock_create_knowledge):
self.mock_vdb_core.create_index.assert_called_once_with(
"test_index", embedding_dim=768)
mock_create_knowledge.assert_called_once()
+ call_kwargs = mock_create_knowledge.call_args[0][0]
+ self.assertIn("embedding_model_name", call_kwargs)
+ self.assertIsNone(call_kwargs["embedding_model_name"])
+ self.assertEqual(call_kwargs["index_name"], "test_index")
+ self.assertEqual(call_kwargs["created_by"], "test_user")
+ self.assertEqual(call_kwargs["tenant_id"], "test_tenant")
@patch('backend.services.vectordatabase_service.create_knowledge_record')
def test_create_index_already_exists(self, mock_create_knowledge):
@@ -301,6 +358,7 @@ def test_create_knowledge_base_generates_index(self, mock_create_knowledge, mock
self.mock_vdb_core.create_index.assert_called_once_with(
"7-uuid", embedding_dim=256
)
+ call_kwargs = mock_create_knowledge.call_args[0][0]
@patch('backend.services.vectordatabase_service.get_embedding_model')
@patch('backend.services.vectordatabase_service.create_knowledge_record')
@@ -404,6 +462,29 @@ def test_create_knowledge_base_with_empty_group_ids(self, mock_create_knowledge,
self.assertEqual(call_kwargs["ingroup_permission"], "PRIVATE")
self.assertEqual(call_kwargs["group_ids"], [])
+ @patch('backend.services.vectordatabase_service.get_embedding_model')
+ @patch('backend.services.vectordatabase_service.create_knowledge_record')
+ def test_create_knowledge_base_with_multimodal(self, mock_create_knowledge, mock_get_embedding):
+ self.mock_vdb_core.create_index.return_value = True
+ mock_get_embedding.return_value = (None, None)
+ mock_create_knowledge.return_value = {
+ "knowledge_id": 10,
+ "index_name": "10-uuid",
+ "knowledge_name": "kb-mm",
+ }
+
+ result = ElasticSearchService.create_knowledge_base(
+ knowledge_name="kb-mm",
+ embedding_dim=256,
+ vdb_core=self.mock_vdb_core,
+ user_id="user-1",
+ tenant_id="tenant-1",
+ is_multimodal=True,
+ )
+
+ self.assertEqual(result["status"], "success")
+ mock_get_embedding.assert_called_once_with("tenant-1", None, "multi_embedding")
+
@patch('backend.services.vectordatabase_service.create_knowledge_record')
def test_create_index_failure(self, mock_create_knowledge):
"""
@@ -475,7 +556,11 @@ def test_create_knowledge_base_with_embedding_model_name(self, mock_get_embeddin
self.assertEqual(result["knowledge_id"], 10)
# Verify get_embedding_model was called with the model name
- mock_get_embedding.assert_called_once_with("tenant-1", "text-embedding-3-small")
+ mock_get_embedding.assert_called_once_with(
+ "tenant-1",
+ "text-embedding-3-small",
+ None,
+ )
# Verify knowledge record was created with the embedding model name
mock_create_knowledge.assert_called_once()
@@ -522,7 +607,11 @@ def test_create_knowledge_base_without_embedding_model_name_uses_default(self, m
self.assertEqual(result["status"], "success")
# Verify get_embedding_model was called with None (no specific model)
- mock_get_embedding.assert_called_once_with("tenant-1", None)
+ mock_get_embedding.assert_called_once_with(
+ "tenant-1",
+ None,
+ None,
+ )
# Verify knowledge record was created with the model's display name
mock_create_knowledge.assert_called_once()
@@ -1580,6 +1669,7 @@ def test_vectorize_documents_success(self):
self.mock_vdb_core.vectorize_documents.return_value = 2
mock_embedding_model = MagicMock()
mock_embedding_model.model = "test-model"
+ mock_embedding_model.model_type = "text"
with patch('backend.services.vectordatabase_service.get_knowledge_record') as mock_get_record, \
patch('backend.services.vectordatabase_service.tenant_config_manager') as mock_tenant_cfg, \
patch('backend.services.vectordatabase_service.update_last_doc_update_time'):
@@ -1627,6 +1717,73 @@ def test_vectorize_documents_success(self):
self.assertEqual(kwargs.get("embedding_batch_size"), 5)
self.assertTrue(callable(kwargs.get("progress_callback")))
+ def test_index_documents_uses_multi_embedding_config_key(self):
+ self.mock_vdb_core.check_index_exists.return_value = True
+ self.mock_vdb_core.vectorize_documents.return_value = 1
+
+ mock_embedding_model = MagicMock()
+ mock_embedding_model.model = "test-model"
+ mock_embedding_model.model_type = "multimodal"
+
+ with patch('backend.services.vectordatabase_service.get_knowledge_record') as mock_get_record, \
+ patch('backend.services.vectordatabase_service.tenant_config_manager') as mock_tenant_cfg, \
+ patch('backend.services.vectordatabase_service.update_last_doc_update_time'):
+ mock_get_record.return_value = {
+ "tenant_id": consts_mock.const.DEFAULT_TENANT_ID}
+ mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 6}
+
+ result = ElasticSearchService.index_documents(
+ index_name="test_index",
+ data=[{"path_or_url": "p1", "content": "c1", "metadata": {}}],
+ vdb_core=self.mock_vdb_core,
+ embedding_model=mock_embedding_model
+ )
+
+ self.assertTrue(result["success"])
+ mock_tenant_cfg.get_model_config.assert_called_once_with(
+ key="MULTI_EMBEDDING_ID", tenant_id=consts_mock.const.DEFAULT_TENANT_ID
+ )
+
+ def test_index_documents_fetches_image_bytes(self):
+ self.mock_vdb_core.check_index_exists.return_value = True
+ self.mock_vdb_core.vectorize_documents.return_value = 1
+ mock_embedding_model = MagicMock()
+ mock_embedding_model.model = "test-model"
+ mock_embedding_model.model_type = "text"
+
+ with patch('backend.services.vectordatabase_service.get_knowledge_record') as mock_get_record, \
+ patch('backend.services.vectordatabase_service.tenant_config_manager') as mock_tenant_cfg, \
+ patch('backend.services.vectordatabase_service.get_file_stream') as mock_get_stream, \
+ patch('backend.services.vectordatabase_service.update_last_doc_update_time'):
+ mock_get_record.return_value = {
+ "tenant_id": consts_mock.const.DEFAULT_TENANT_ID}
+ mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 5}
+ mock_get_stream.return_value = io.BytesIO(b"img-bytes")
+
+ data = [
+ {
+ "metadata": {"image_url": "s3://bucket/img.png", "process_source": "UniversalImageExtractor"},
+ "path_or_url": "test_path",
+ "content": "image content",
+ "source_type": "file",
+ "file_size": 123,
+ "filename": "img.png"
+ }
+ ]
+
+ result = ElasticSearchService.index_documents(
+ index_name="test_index",
+ data=data,
+ vdb_core=self.mock_vdb_core,
+ embedding_model=mock_embedding_model
+ )
+
+ self.assertTrue(result["success"])
+ _, kwargs = self.mock_vdb_core.vectorize_documents.call_args
+ documents = kwargs.get("documents")
+ self.assertEqual(documents[0]["image_bytes"], b"img-bytes")
+ mock_get_stream.assert_called_once_with(object_name="s3://bucket/img.png")
+
def test_vectorize_documents_empty_data(self):
"""
Test document indexing with empty data.
@@ -1639,6 +1796,7 @@ def test_vectorize_documents_empty_data(self):
# Setup
test_data = []
mock_embedding_model = MagicMock()
+ mock_embedding_model.model_type = "text"
# Execute
result = ElasticSearchService.index_documents(
@@ -1668,6 +1826,7 @@ def test_vectorize_documents_create_index(self):
self.mock_vdb_core.create_index.return_value = True
self.mock_vdb_core.vectorize_documents.return_value = 1
mock_embedding_model = MagicMock()
+ mock_embedding_model.model_type = "text"
test_data = [
{
"metadata": {"title": "Test"},
@@ -1715,6 +1874,7 @@ def test_vectorize_documents_indexing_error(self):
self.mock_vdb_core.vectorize_documents.side_effect = Exception(
"Indexing error")
mock_embedding_model = MagicMock()
+ mock_embedding_model.model_type = "text"
test_data = [
{
"metadata": {"title": "Test"},
@@ -2121,15 +2281,13 @@ def test_search_hybrid_success(self, mock_get_embedding_by_index):
"scores": {"accurate": 0.85, "semantic": 0.95}
}
]
-
- # Mock get_embedding_model_by_index_name to return embedding model
mock_get_embedding_by_index.return_value = (self.mock_embedding, 1, {"status": "ok", "message": "OK"})
# Execute
result = ElasticSearchService.search_hybrid(
index_names=["test_index"],
query="test query",
- tenant_id="test_tenant",
+ tenant_id=consts_mock.const.DEFAULT_TENANT_ID,
top_k=10,
weight_accurate=0.5,
vdb_core=self.mock_vdb_core
@@ -2152,6 +2310,7 @@ def test_search_hybrid_success(self, mock_get_embedding_by_index):
top_k=10,
weight_accurate=0.5
)
+ mock_get_embedding_by_index.assert_called_once_with(consts_mock.const.DEFAULT_TENANT_ID, "test_index")
def test_search_hybrid_missing_tenant_id(self):
"""Test search_hybrid raises ValueError when tenant_id is missing."""
@@ -2222,25 +2381,21 @@ def test_search_hybrid_invalid_weight(self):
@patch('backend.services.vectordatabase_service.get_embedding_model_by_index_name')
def test_search_hybrid_no_embedding_model(self, mock_get_embedding_by_index):
- """Test search_hybrid raises ValueError when embedding model is not configured."""
- # Mock get_embedding_model_by_index_name to return None
- mock_get_embedding_by_index.return_value = (None, None, {"status": "error", "message": "Model not found"})
-
- # Stop the mock to test the real get_embedding_model
- self.get_embedding_model_patcher.stop()
- try:
- with self.assertRaises(ValueError) as context:
- ElasticSearchService.search_hybrid(
- index_names=["test_index"],
- query="test query",
- tenant_id="test_tenant",
- top_k=10,
- weight_accurate=0.5,
- vdb_core=self.mock_vdb_core
- )
- self.assertIn("No embedding model found", str(context.exception))
- finally:
- self.get_embedding_model_patcher.start()
+ """Test search_hybrid raises model-config error when embedding model is not configured."""
+ mock_get_embedding_by_index.return_value = (
+ None,
+ None,
+ {"status": "needs_config", "message": "needs config"},
+ )
+ with self.assertRaises(KnowledgeBaseNeedsModelConfigError):
+ ElasticSearchService.search_hybrid(
+ index_names=["test_index"],
+ query="test query",
+ tenant_id=consts_mock.const.DEFAULT_TENANT_ID,
+ top_k=10,
+ weight_accurate=0.5,
+ vdb_core=self.mock_vdb_core
+ )
@patch('backend.services.vectordatabase_service.get_embedding_model_by_index_name')
def test_search_hybrid_exception(self, mock_get_embedding_by_index):
@@ -2264,9 +2419,6 @@ def test_search_hybrid_exception(self, mock_get_embedding_by_index):
@patch('backend.services.vectordatabase_service.get_embedding_model_by_index_name')
def test_search_hybrid_weight_accurate_boundary_values(self, mock_get_embedding_by_index):
- """Test search_hybrid with different weight_accurate values to ensure line 1146 is covered."""
- # Mock get_embedding_model_by_index_name
- mock_get_embedding_by_index.return_value = (self.mock_embedding, 1, {"status": "ok", "message": "OK"})
# Test with weight_accurate = 0.0 (semantic only)
self.mock_vdb_core.hybrid_search.return_value = [
@@ -2276,11 +2428,12 @@ def test_search_hybrid_weight_accurate_boundary_values(self, mock_get_embedding_
"index": "test_index",
}
]
+ mock_get_embedding_by_index.return_value = (self.mock_embedding, 1, {"status": "ok", "message": "OK"})
result = ElasticSearchService.search_hybrid(
index_names=["test_index"],
query="test query",
- tenant_id="test_tenant",
+ tenant_id=consts_mock.const.DEFAULT_TENANT_ID,
top_k=10,
weight_accurate=0.0,
vdb_core=self.mock_vdb_core
@@ -2299,7 +2452,7 @@ def test_search_hybrid_weight_accurate_boundary_values(self, mock_get_embedding_
result = ElasticSearchService.search_hybrid(
index_names=["test_index"],
query="test query",
- tenant_id="test_tenant",
+ tenant_id=consts_mock.const.DEFAULT_TENANT_ID,
top_k=10,
weight_accurate=1.0,
vdb_core=self.mock_vdb_core
@@ -2317,7 +2470,7 @@ def test_search_hybrid_weight_accurate_boundary_values(self, mock_get_embedding_
result = ElasticSearchService.search_hybrid(
index_names=["test_index"],
query="test query",
- tenant_id="test_tenant",
+ tenant_id=consts_mock.const.DEFAULT_TENANT_ID,
top_k=10,
weight_accurate=0.3,
vdb_core=self.mock_vdb_core
@@ -3159,8 +3312,6 @@ def test_create_chunk_generates_embedding_when_tenant_provided(self, mock_get_em
self.assertEqual(result["status"], "success")
self.assertEqual(result["chunk_id"], "chunk-1")
- # Verify embedding was generated
- mock_get_embedding_model_by_id.assert_called_once_with("tenant-123", 123)
mock_embedding.get_embeddings.assert_called_once()
# Verify vdb_core was called with embedding in payload
@@ -3170,8 +3321,8 @@ def test_create_chunk_generates_embedding_when_tenant_provided(self, mock_get_em
self.assertEqual(payload["embedding"], [0.1, 0.2, 0.3])
@patch('backend.services.vectordatabase_service.get_knowledge_record')
- @patch('backend.services.vectordatabase_service.get_embedding_model')
- def test_create_chunk_without_tenant_no_embedding_generated(self, mock_get_embedding_model,
+ @patch('backend.services.vectordatabase_service.get_embedding_model_by_id')
+ def test_create_chunk_without_tenant_no_embedding_generated(self, mock_get_embedding_model_by_id,
mock_get_knowledge_record):
"""
Test create_chunk does not generate embedding when tenant_id is not provided.
@@ -3201,7 +3352,7 @@ def test_create_chunk_without_tenant_no_embedding_generated(self, mock_get_embed
# Verify no embedding-related calls were made
mock_get_knowledge_record.assert_not_called()
- mock_get_embedding_model.assert_not_called()
+ mock_get_embedding_model_by_id.assert_not_called()
# Verify payload has no embedding
self.mock_vdb_core.create_chunk.assert_called_once()
@@ -3209,8 +3360,8 @@ def test_create_chunk_without_tenant_no_embedding_generated(self, mock_get_embed
self.assertNotIn("embedding", payload)
@patch('backend.services.vectordatabase_service.get_knowledge_record')
- @patch('backend.services.vectordatabase_service.get_embedding_model')
- def test_create_chunk_handles_embedding_failure_gracefully(self, mock_get_embedding_model,
+ @patch('backend.services.vectordatabase_service.get_embedding_model_by_id')
+ def test_create_chunk_handles_embedding_failure_gracefully(self, mock_get_embedding_model_by_id,
mock_get_knowledge_record):
"""
Test create_chunk handles embedding generation failure gracefully.
@@ -3221,11 +3372,11 @@ def test_create_chunk_handles_embedding_failure_gracefully(self, mock_get_embedd
mock_get_knowledge_record.return_value = {
"index_name": "kb-index",
- "embedding_model_name": "text-embedding-3-small"
+ "embedding_model_id": 123
}
# Embedding model raises exception
- mock_get_embedding_model.side_effect = Exception("Embedding service unavailable")
+ mock_get_embedding_model_by_id.side_effect = Exception("Embedding service unavailable")
chunk_request = SimpleNamespace(
chunk_id=None,
@@ -3236,7 +3387,7 @@ def test_create_chunk_handles_embedding_failure_gracefully(self, mock_get_embedd
metadata={},
)
- # Should not raise exception, just log warning
+ # Embedding failures are tolerated; chunk creation still succeeds.
result = ElasticSearchService.create_chunk(
index_name="kb-index",
chunk_request=chunk_request,
@@ -3244,17 +3395,14 @@ def test_create_chunk_handles_embedding_failure_gracefully(self, mock_get_embedd
user_id="user-1",
tenant_id="tenant-123",
)
-
- # Result should still be successful (embedding is optional)
self.assertEqual(result["status"], "success")
- self.assertEqual(result["chunk_id"], "chunk-1")
-
- # Verify chunk was still created without embedding
self.mock_vdb_core.create_chunk.assert_called_once()
+ _, payload = self.mock_vdb_core.create_chunk.call_args[0]
+ self.assertNotIn("embedding", payload)
@patch('backend.services.vectordatabase_service.get_knowledge_record')
- @patch('backend.services.vectordatabase_service.get_embedding_model')
- def test_create_chunk_handles_empty_embedding_result(self, mock_get_embedding_model, mock_get_knowledge_record):
+ @patch('backend.services.vectordatabase_service.get_embedding_model_by_id')
+ def test_create_chunk_handles_empty_embedding_result(self, mock_get_embedding_model_by_id, mock_get_knowledge_record):
"""
Test create_chunk handles empty embedding result gracefully.
"""
@@ -3264,13 +3412,13 @@ def test_create_chunk_handles_empty_embedding_result(self, mock_get_embedding_mo
mock_get_knowledge_record.return_value = {
"index_name": "kb-index",
- "embedding_model_name": "text-embedding-3-small"
+ "embedding_model_id": 123
}
# Embedding returns empty list
mock_embedding = MagicMock()
mock_embedding.get_embeddings.return_value = []
- mock_get_embedding_model.return_value = mock_embedding
+ mock_get_embedding_model_by_id.return_value = (mock_embedding, 123)
chunk_request = SimpleNamespace(
chunk_id=None,
@@ -3340,9 +3488,6 @@ def test_create_chunk_with_unknown_model_name_still_calls_embedding_model(self,
# Should succeed, embedding model IS called but returns empty
self.assertEqual(result["status"], "success")
- # Verify embedding model was called
- mock_get_embedding_model_by_id.assert_called_once_with("tenant-123", 123)
-
def test_update_chunk_builds_payload_and_calls_core(self):
"""
Test update_chunk builds update payload and delegates to vdb_core.update_chunk.
@@ -3577,6 +3722,7 @@ def test_vectorize_documents_success_status_200(self, mock_get_record, mock_tena
self.mock_vdb_core.vectorize_documents.return_value = 3
mock_embedding_model = MagicMock()
mock_embedding_model.model = "test-model"
+ mock_embedding_model.model_type = "text"
mock_get_record.return_value = {"tenant_id": "tenant-1"}
mock_tenant_cfg.get_model_config.return_value = {"chunk_batch": 10}
@@ -3774,12 +3920,10 @@ def test_get_embedding_model_embedding_type(self, mock_get_model_by_display_name
# Execute - now we can call the real function
from backend.services.vectordatabase_service import get_embedding_model
- result, model_id = get_embedding_model("test_tenant", model_name="test-model")
+ result, _ = get_embedding_model("test_tenant", model_name="test-model")
# Assert
self.assertEqual(result, mock_embedding_instance)
- self.assertEqual(model_id, 123)
- mock_get_model_by_display_name.assert_called_once_with("test-model", "test_tenant")
mock_embedding_class.assert_called_once_with(
api_key="test_api_key",
base_url="https://test.api.com",
@@ -3960,6 +4104,49 @@ def test_get_embedding_model_with_model_name_found(self, mock_get_model_by_displ
# Restart the mock for other tests
self.get_embedding_model_patcher.start()
+ @patch('backend.services.vectordatabase_service.get_model_by_display_name')
+ def test_get_embedding_model_with_model_name_found_multimodal(self, mock_get_model_by_display_name):
+ mock_get_model_by_display_name.return_value = {
+ "model_id": 789,
+ "model_type": "multi_embedding",
+ "model_name": "jina-clip-v2",
+ "api_key": "test_api_key",
+ "base_url": "https://test.api.com",
+ "max_tokens": 1024,
+ "ssl_verify": True
+ }
+
+ self.get_embedding_model_patcher.stop()
+
+ try:
+ with patch('backend.services.vectordatabase_service.JinaEmbedding') as mock_embedding_class, \
+ patch('backend.services.vectordatabase_service.get_model_name_from_config') as mock_get_model_name:
+ mock_embedding_instance = MagicMock()
+ mock_embedding_class.return_value = mock_embedding_instance
+ mock_get_model_name.return_value = "jina-clip-v2"
+
+ from backend.services.vectordatabase_service import get_embedding_model
+ result, model_id = get_embedding_model(
+ "test_tenant",
+ model_name="jina/jina-clip-v2",
+ model_type="multi_embedding",
+ )
+
+ self.assertEqual(result, mock_embedding_instance)
+ self.assertEqual(model_id, 789)
+ mock_embedding_class.assert_called_once_with(
+ api_key="test_api_key",
+ base_url="https://test.api.com",
+ model_name="jina-clip-v2",
+ embedding_dim=1024,
+ ssl_verify=True
+ )
+ mock_get_model_by_display_name.assert_called_once_with(
+ "jina/jina-clip-v2", "test_tenant", "multi_embedding"
+ )
+ finally:
+ self.get_embedding_model_patcher.start()
+
@patch('backend.services.vectordatabase_service.get_model_by_display_name')
def test_get_embedding_model_with_model_name_found_without_repo(self, mock_get_model_by_display_name):
"""
@@ -3969,16 +4156,16 @@ def test_get_embedding_model_with_model_name_found_without_repo(self, mock_get_m
1. When model_name is provided and found (without model_repo), OpenAICompatibleEmbedding is returned
2. The function handles models without model_repo correctly using just model_name
"""
- # Setup - mock get_model_by_display_name to return a model without model_repo
+
+ # Setup
mock_get_model_by_display_name.return_value = {
"model_id": 456,
- "model_name": "simple-model",
"model_type": "embedding",
- "model_repo": None,
+ "model_name": "simple-model",
"api_key": "test_api_key",
"base_url": "https://test.api.com",
- "max_tokens": 2048,
- "ssl_verify": False
+ "max_tokens": 1024,
+ "ssl_verify": True
}
# Stop the mock from setUp to test the real function
@@ -3998,7 +4185,6 @@ def test_get_embedding_model_with_model_name_found_without_repo(self, mock_get_m
# Assert
self.assertEqual(result, mock_embedding_instance)
self.assertEqual(model_id, 456)
- mock_get_model_by_display_name.assert_called_once_with("simple-model", "test_tenant")
mock_embedding_class.assert_called_once()
finally:
# Restart the mock for other tests
@@ -4967,6 +5153,72 @@ def test_get_rerank_model_with_model_name_no_repo(
finally:
self.get_rerank_model_patcher.start()
+ @patch('backend.services.vectordatabase_service.get_knowledge_record')
+ def test_create_chunk_embedding_exception_without_explicit_model_is_tolerated(
+ self, mock_get_knowledge_record
+ ):
+ """create_chunk should continue when embedding generation fails and no explicit model name exists."""
+ self.mock_vdb_core.create_chunk.return_value = {"id": "chunk-1"}
+ mock_get_knowledge_record.return_value = {
+ "embedding_model_name": None,
+ "is_multimodal": "N",
+ }
+ self.mock_get_embedding.side_effect = RuntimeError("embedding failed")
+
+ from backend.services.vectordatabase_service import ChunkCreateRequest
+ chunk_request = ChunkCreateRequest(
+ content="abc",
+ title="t",
+ filename="f.txt",
+ path_or_url="p/f.txt",
+ metadata={}
+ )
+ result = ElasticSearchService.create_chunk(
+ index_name="idx",
+ chunk_request=chunk_request,
+ vdb_core=self.mock_vdb_core,
+ user_id="u1",
+ tenant_id="t1",
+ )
+ self.assertEqual(result["status"], "success")
+ self.mock_vdb_core.create_chunk.assert_called_once()
+
+ @patch('backend.services.vectordatabase_service.get_knowledge_record')
+ def test_update_chunk_minimal_payload_still_updates(self, mock_get_knowledge_record):
+ """update_chunk without business fields still sends update_time/updated_by payload."""
+ mock_get_knowledge_record.return_value = None
+ self.mock_vdb_core.update_chunk.return_value = {"id": "c1"}
+ from backend.services.vectordatabase_service import ChunkUpdateRequest
+ empty_req = ChunkUpdateRequest()
+
+ result = ElasticSearchService.update_chunk(
+ index_name="idx",
+ chunk_id="c1",
+ chunk_request=empty_req,
+ vdb_core=self.mock_vdb_core,
+ user_id="u1",
+ tenant_id="t1",
+ )
+ self.assertEqual(result["status"], "success")
+ self.mock_vdb_core.update_chunk.assert_called_once()
+
+ def test_update_chunk_core_error_is_wrapped(self):
+ """update_chunk should wrap core exceptions with consistent message."""
+ self.mock_vdb_core.update_chunk.side_effect = RuntimeError("core failed")
+ from backend.services.vectordatabase_service import ChunkUpdateRequest
+ req = ChunkUpdateRequest(content="new-content")
+
+ with self.assertRaises(Exception) as ctx:
+ ElasticSearchService.update_chunk(
+ index_name="idx",
+ chunk_id="c2",
+ chunk_request=req,
+ vdb_core=self.mock_vdb_core,
+ user_id="u1",
+ tenant_id=None,
+ )
+ self.assertIn("Error updating chunk", str(ctx.exception))
+
class TestNewEmbeddingModelMethods(unittest.TestCase):
"""
diff --git a/test/backend/utils/test_file_management_utils.py b/test/backend/utils/test_file_management_utils.py
index ce98c56e4..f8c22b0f7 100644
--- a/test/backend/utils/test_file_management_utils.py
+++ b/test/backend/utils/test_file_management_utils.py
@@ -7,11 +7,14 @@
class _ProcessParams:
- def __init__(self, authorization: str, source_type: str, chunking_strategy: str, index_name: Optional[str]):
+ def __init__(self, authorization: str, source_type: str, chunking_strategy: str, index_name: Optional[str], model_id: Optional[int] = 42,
+ tenant_id: Optional[str] = "tenant-1"):
self.authorization = authorization
self.source_type = source_type
self.chunking_strategy = chunking_strategy
self.index_name = index_name
+ self.model_id = model_id
+ self.tenant_id = tenant_id
@pytest.fixture(autouse=True)
diff --git a/test/backend/utils/test_llm_utils.py b/test/backend/utils/test_llm_utils.py
index 2052bba54..7be236b8b 100644
--- a/test/backend/utils/test_llm_utils.py
+++ b/test/backend/utils/test_llm_utils.py
@@ -57,6 +57,12 @@ def validate(self):
vector_db_es_module.ElasticSearchCore = MagicMock()
vector_db_es_module.Elasticsearch = MagicMock()
+monitor_module = types.ModuleType("nexent.monitor")
+monitor_module.set_monitoring_context = MagicMock()
+monitor_module.set_monitoring_operation = MagicMock()
+sys.modules['nexent.monitor'] = monitor_module
+nexent_module.monitor = monitor_module
+
# Stub nexent.core.utils.observer MessageObserver used by llm_utils
observer_mod = types.ModuleType("nexent.core.utils.observer")
diff --git a/test/backend/utils/test_memory_utils.py b/test/backend/utils/test_memory_utils.py
index 207c63c06..134c38923 100644
--- a/test/backend/utils/test_memory_utils.py
+++ b/test/backend/utils/test_memory_utils.py
@@ -57,7 +57,7 @@ def test_build_memory_config_success(self, mocker, mock_constants, mock_model_co
]
# Mock get_model_name_from_config
- mock_get_model_name = mocker.MagicMock()
+ mock_get_model_name = MagicMock()
mock_get_model_name.side_effect = [
"openai/gpt-4", "openai/text-embedding-ada-002"]
@@ -132,7 +132,7 @@ def test_build_memory_config_missing_llm_config(self, mocker, mock_tenant_config
def test_build_memory_config_llm_config_missing_model_name(self, mocker):
"""Raises when LLM config lacks model_name"""
- mock_tenant_config_manager = mocker.MagicMock()
+ mock_tenant_config_manager = MagicMock()
mock_tenant_config_manager.get_model_config.side_effect = [
{"api_key": "test-key"}, # LLM missing model_name
{"model_name": "test-embed", "max_tokens": 1536} # embedding present
@@ -166,7 +166,7 @@ def test_build_memory_config_missing_embedding_config(self, mocker, mock_tenant_
def test_build_memory_config_embedding_config_missing_max_tokens(self, mocker):
"""Raises when embedding config lacks max_tokens"""
- mock_tenant_config_manager = mocker.MagicMock()
+ mock_tenant_config_manager = MagicMock()
mock_tenant_config_manager.get_model_config.side_effect = [
{"model_name": "test-llm"}, # LLM present
{"model_name": "test-embed"} # embedding missing max_tokens
@@ -184,13 +184,13 @@ def test_build_memory_config_embedding_config_missing_max_tokens(self, mocker):
def test_build_memory_config_missing_es_host(self, mocker):
"""Raises when ES_HOST is missing"""
- mock_tenant_config_manager = mocker.MagicMock()
+ mock_tenant_config_manager = MagicMock()
mock_tenant_config_manager.get_model_config.side_effect = [
{"model_name": "test-llm"},
{"model_name": "test-embed", "max_tokens": 1536}
]
- mock_const = mocker.MagicMock()
+ mock_const = MagicMock()
mock_const.ES_HOST = None # ES_HOST is None
mocker.patch('backend.utils.memory_utils.tenant_config_manager',
@@ -205,13 +205,13 @@ def test_build_memory_config_missing_es_host(self, mocker):
def test_build_memory_config_invalid_es_host_format(self, mocker):
"""Raises when ES_HOST format is invalid"""
- mock_tenant_config_manager = mocker.MagicMock()
+ mock_tenant_config_manager = MagicMock()
mock_tenant_config_manager.get_model_config.side_effect = [
{"model_name": "test-llm"},
{"model_name": "test-embed", "max_tokens": 1536}
]
- mock_const = mocker.MagicMock()
+ mock_const = MagicMock()
mock_const.ES_HOST = "invalid-host" # invalid format
mocker.patch('backend.utils.memory_utils.tenant_config_manager',
@@ -227,13 +227,13 @@ def test_build_memory_config_invalid_es_host_format(self, mocker):
def test_build_memory_config_es_host_missing_scheme(self, mocker):
"""Raises when ES_HOST is missing scheme"""
- mock_tenant_config_manager = mocker.MagicMock()
+ mock_tenant_config_manager = MagicMock()
mock_tenant_config_manager.get_model_config.side_effect = [
{"model_name": "test-llm"},
{"model_name": "test-embed", "max_tokens": 1536}
]
- mock_const = mocker.MagicMock()
+ mock_const = MagicMock()
mock_const.ES_HOST = "localhost:9200" # missing scheme
mocker.patch('backend.utils.memory_utils.tenant_config_manager',
@@ -249,13 +249,13 @@ def test_build_memory_config_es_host_missing_scheme(self, mocker):
def test_build_memory_config_es_host_missing_port(self, mocker):
"""Raises when ES_HOST is missing port"""
- mock_tenant_config_manager = mocker.MagicMock()
+ mock_tenant_config_manager = MagicMock()
mock_tenant_config_manager.get_model_config.side_effect = [
{"model_name": "test-llm"},
{"model_name": "test-embed", "max_tokens": 1536}
]
- mock_const = mocker.MagicMock()
+ mock_const = MagicMock()
mock_const.ES_HOST = "http://localhost" # missing port
mocker.patch('backend.utils.memory_utils.tenant_config_manager',
@@ -271,7 +271,7 @@ def test_build_memory_config_es_host_missing_port(self, mocker):
def test_build_memory_config_with_https_es_host(self, mocker):
"""HTTPS ES_HOST is parsed correctly and collection name composes"""
- mock_tenant_config_manager = mocker.MagicMock()
+ mock_tenant_config_manager = MagicMock()
mock_tenant_config_manager.get_model_config.side_effect = [
{"model_name": "test-llm", "model_repo": "openai",
"base_url": "https://api.openai.com/v1", "api_key": "test-llm-key"},
@@ -279,13 +279,13 @@ def test_build_memory_config_with_https_es_host(self, mocker):
"base_url": "https://api.openai.com/v1", "api_key": "test-embed-key", "max_tokens": 1536}
]
- mock_const = mocker.MagicMock()
+ mock_const = MagicMock()
mock_const.ES_HOST = "https://elastic.example.com:9200"
mock_const.ES_API_KEY = "test-es-key"
mock_const.ES_USERNAME = "elastic"
mock_const.ES_PASSWORD = "test-password"
- mock_get_model_name = mocker.MagicMock()
+ mock_get_model_name = MagicMock()
mock_get_model_name.side_effect = [
"openai/test-llm", "openai/test-embed"]
@@ -308,7 +308,7 @@ def test_build_memory_config_with_https_es_host(self, mocker):
def test_build_memory_config_with_custom_port(self, mocker):
"""Custom ES port is parsed and applied; collection name composed"""
- mock_tenant_config_manager = mocker.MagicMock()
+ mock_tenant_config_manager = MagicMock()
mock_tenant_config_manager.get_model_config.side_effect = [
{"model_name": "test-llm", "model_repo": "openai",
"base_url": "https://api.openai.com/v1", "api_key": "test-llm-key"},
@@ -316,13 +316,13 @@ def test_build_memory_config_with_custom_port(self, mocker):
"base_url": "https://api.openai.com/v1", "api_key": "test-embed-key", "max_tokens": 1536}
]
- mock_const = mocker.MagicMock()
+ mock_const = MagicMock()
mock_const.ES_HOST = "http://localhost:9300" # custom port
mock_const.ES_API_KEY = "test-es-key"
mock_const.ES_USERNAME = "elastic"
mock_const.ES_PASSWORD = "test-password"
- mock_get_model_name = mocker.MagicMock()
+ mock_get_model_name = MagicMock()
mock_get_model_name.side_effect = [
"openai/test-llm", "openai/test-embed"]
@@ -345,7 +345,7 @@ def test_build_memory_config_with_custom_port(self, mocker):
def test_build_memory_config_sanitizes_slashes_in_repo_and_name(self, mocker):
"""Slash characters in repo/name are replaced with underscores in collection name"""
- mock_tenant_config_manager = mocker.MagicMock()
+ mock_tenant_config_manager = MagicMock()
mock_tenant_config_manager.get_model_config.side_effect = [
{"model_name": "gpt-4", "model_repo": "azure/openai",
"base_url": "https://api.example.com/v1", "api_key": "llm-key"},
@@ -353,14 +353,14 @@ def test_build_memory_config_sanitizes_slashes_in_repo_and_name(self, mocker):
"base_url": "https://api.example.com/v1", "api_key": "embed-key", "max_tokens": 1536}
]
- mock_const = mocker.MagicMock()
+ mock_const = MagicMock()
mock_const.ES_HOST = "http://localhost:9200"
mock_const.ES_API_KEY = "test-es-key"
mock_const.ES_USERNAME = "elastic"
mock_const.ES_PASSWORD = "test-password"
model_mapping = {"llm": "llm", "embedding": "embedding"}
- mock_get_model_name = mocker.MagicMock()
+ mock_get_model_name = MagicMock()
mock_get_model_name.side_effect = [
"azure/openai/gpt-4", "azure/openai/text-embed/ada-002"]
@@ -378,7 +378,7 @@ def test_build_memory_config_sanitizes_slashes_in_repo_and_name(self, mocker):
def test_build_memory_config_with_empty_model_repo(self, mocker):
"""Empty model_repo yields collection name without repo segment"""
- mock_tenant_config_manager = mocker.MagicMock()
+ mock_tenant_config_manager = MagicMock()
mock_tenant_config_manager.get_model_config.side_effect = [
{"model_name": "gpt-4", "model_repo": "",
"base_url": "https://api.openai.com/v1", "api_key": "test-llm-key"},
@@ -386,13 +386,13 @@ def test_build_memory_config_with_empty_model_repo(self, mocker):
"base_url": "https://api.openai.com/v1", "api_key": "test-embed-key", "max_tokens": 1536}
]
- mock_const = mocker.MagicMock()
+ mock_const = MagicMock()
mock_const.ES_HOST = "http://localhost:9200"
mock_const.ES_API_KEY = "test-es-key"
mock_const.ES_USERNAME = "elastic"
mock_const.ES_PASSWORD = "test-password"
- mock_get_model_name = mocker.MagicMock()
+ mock_get_model_name = MagicMock()
mock_get_model_name.side_effect = [
"gpt-4", "text-embedding-ada-002"] # no repo prefix
diff --git a/test/conftest.py b/test/conftest.py
index 4ab19b5d7..7d5906585 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -5,7 +5,13 @@
"""
import os
import sys
+import shutil
+import tempfile
+from pathlib import Path
from unittest.mock import MagicMock
+from unittest.mock import patch as _patch
+
+import pytest
# Stub out mem0 modules before anything else imports them.
# The sdk imports mem0 at module level, so stubs must be registered first.
@@ -34,6 +40,13 @@
if _sdk_dir not in sys.path:
sys.path.insert(0, _sdk_dir)
+_tmp_root = os.path.abspath(os.path.join(_test_root, "..", ".pytest-tmp"))
+os.makedirs(_tmp_root, exist_ok=True)
+os.environ.setdefault("TMP", _tmp_root)
+os.environ.setdefault("TEMP", _tmp_root)
+os.environ.setdefault("TMPDIR", _tmp_root)
+tempfile.tempdir = _tmp_root
+
# MinIO Configuration
os.environ.setdefault('MINIO_ENDPOINT', 'http://localhost:9000')
os.environ.setdefault('MINIO_ACCESS_KEY', 'minioadmin')
@@ -52,3 +65,51 @@
os.environ.setdefault('POSTGRES_PASSWORD', 'test_password')
os.environ.setdefault('POSTGRES_DB', 'test_db')
os.environ.setdefault('POSTGRES_PORT', '5432')
+
+
+class _PatchProxy:
+ def __init__(self, owner):
+ self._owner = owner
+
+ def __call__(self, target, *args, **kwargs):
+ return self._owner._start(_patch(target, *args, **kwargs))
+
+ def object(self, target, attribute, *args, **kwargs):
+ return self._owner._start(_patch.object(target, attribute, *args, **kwargs))
+
+ def dict(self, target, *args, **kwargs):
+ return self._owner._start(_patch.dict(target, *args, **kwargs))
+
+
+class _MiniMocker:
+ def __init__(self):
+ self._patchers = []
+ self.patch = _PatchProxy(self)
+
+ def _start(self, patcher):
+ value = patcher.start()
+ self._patchers.append(patcher)
+ return value
+
+ def stopall(self):
+ while self._patchers:
+ self._patchers.pop().stop()
+
+
+@pytest.fixture
+def mocker():
+ helper = _MiniMocker()
+ try:
+ yield helper
+ finally:
+ helper.stopall()
+
+
+@pytest.fixture
+def tmp_path():
+ """Use a repo-local temp dir instead of pytest's default temp root."""
+ path = Path(tempfile.mkdtemp(prefix="tmp-", dir=_tmp_root))
+ try:
+ yield path
+ finally:
+ shutil.rmtree(path, ignore_errors=True)
diff --git a/test/sdk/core/agents/test_run_agent.py b/test/sdk/core/agents/test_run_agent.py
index dac68216f..78daf11cf 100644
--- a/test/sdk/core/agents/test_run_agent.py
+++ b/test/sdk/core/agents/test_run_agent.py
@@ -1,3 +1,5 @@
+import types
+import importlib.machinery
import pytest
import importlib
import sys
@@ -145,6 +147,10 @@ def __init__(self, *args, **kwargs):
sys.modules["nexent"] = mock_nexent
sys.modules["nexent.skills"] = mock_nexent.skills
+openai_module = types.ModuleType("openai")
+openai_module.__spec__ = importlib.machinery.ModuleSpec("openai", loader=None)
+sys.modules['openai'] = openai_module
+
module_mocks = {
"smolagents": mock_smolagents,
"smolagents.tools": mock_smolagents_tools_mod,
@@ -163,7 +169,7 @@ def __init__(self, *args, **kwargs):
"langchain": mock_langchain,
"langchain.tools": mock_langchain_tools,
# Minimal openai mock needed by other modules
- "openai": MagicMock(),
+ "openai": openai_module,
"openai.types": MagicMock(),
"openai.types.chat": MagicMock(),
"openai.types.chat.chat_completion_message": MagicMock(ChatCompletionMessage=mock_openai_chat_completion_message),
diff --git a/test/sdk/core/models/test_embedding_model.py b/test/sdk/core/models/test_embedding_model.py
index 9c3f8824b..9833a5323 100644
--- a/test/sdk/core/models/test_embedding_model.py
+++ b/test/sdk/core/models/test_embedding_model.py
@@ -1,7 +1,7 @@
import pytest
import requests
import sys
-from unittest.mock import AsyncMock, Mock, patch
+from unittest.mock import AsyncMock, MagicMock, Mock, patch
from nexent.core.models.embedding_model import OpenAICompatibleEmbedding, JinaEmbedding
@@ -42,6 +42,22 @@ def jina_embedding_instance():
return JinaEmbedding(api_key="dummy-key", ssl_verify=True)
+def test_openai_embedding_default_model_type():
+ emb = OpenAICompatibleEmbedding(
+ model_name="dummy-model",
+ base_url="https://api.example.com",
+ api_key="dummy-key",
+ embedding_dim=128,
+ ssl_verify=True,
+ )
+ assert emb.model_type == "text"
+
+
+def test_jina_embedding_default_model_type():
+ emb = JinaEmbedding(api_key="dummy-key", ssl_verify=True)
+ assert emb.model_type == "multimodal"
+
+
# ---------------------------------------------------------------------------
# Tests for dimension_check
# ---------------------------------------------------------------------------
@@ -637,9 +653,9 @@ def json(self):
def test_openai_get_embeddings_calls_record_model_call(mocker):
"""OpenAICompatibleEmbedding.get_embeddings calls record_model_call with correct args."""
- mock_ctx = mocker.MagicMock()
- mock_ctx.__enter__ = mocker.MagicMock(return_value=None)
- mock_ctx.__exit__ = mocker.MagicMock(return_value=False)
+ mock_ctx = MagicMock()
+ mock_ctx.__enter__ = MagicMock(return_value=None)
+ mock_ctx.__exit__ = MagicMock(return_value=False)
mock_record = mocker.patch(
"nexent.core.models.embedding_model.record_model_call",
return_value=mock_ctx,
@@ -665,9 +681,9 @@ def test_openai_get_embeddings_calls_record_model_call(mocker):
def test_jina_get_embeddings_calls_record_model_call(mocker):
"""JinaEmbedding.get_multimodal_embeddings calls record_model_call with correct args."""
- mock_ctx = mocker.MagicMock()
- mock_ctx.__enter__ = mocker.MagicMock(return_value=None)
- mock_ctx.__exit__ = mocker.MagicMock(return_value=False)
+ mock_ctx = MagicMock()
+ mock_ctx.__enter__ = MagicMock(return_value=None)
+ mock_ctx.__exit__ = MagicMock(return_value=False)
mock_record = mocker.patch(
"nexent.core.models.embedding_model.record_model_call",
return_value=mock_ctx,
diff --git a/test/sdk/core/tools/test_knowledge_base_search_tool.py b/test/sdk/core/tools/test_knowledge_base_search_tool.py
index bcfeaddc4..4fb1bcab0 100644
--- a/test/sdk/core/tools/test_knowledge_base_search_tool.py
+++ b/test/sdk/core/tools/test_knowledge_base_search_tool.py
@@ -1,10 +1,141 @@
+import importlib.util
+import json
+import sys
+import types
+from pathlib import Path
+
import pytest
from unittest.mock import MagicMock, patch
-import json
-# Import target module
-from sdk.nexent.core.utils.observer import MessageObserver, ProcessType
-from sdk.nexent.core.tools.knowledge_base_search_tool import KnowledgeBaseSearchTool
+REPO_ROOT = Path(__file__).resolve().parents[4]
+
+def _pkg(name, path):
+ mod = types.ModuleType(name)
+ mod.__path__ = [str(path)]
+ sys.modules.setdefault(name, mod)
+ return mod
+
+sdk_pkg = _pkg("sdk", REPO_ROOT / "sdk")
+nexent_pkg = _pkg("sdk.nexent", REPO_ROOT / "sdk" / "nexent")
+core_pkg = _pkg("sdk.nexent.core", REPO_ROOT / "sdk" / "nexent" / "core")
+tools_pkg = _pkg("sdk.nexent.core.tools", REPO_ROOT / "sdk" / "nexent" / "core" / "tools")
+utils_pkg = _pkg("sdk.nexent.core.utils", REPO_ROOT / "sdk" / "nexent" / "core" / "utils")
+models_pkg = _pkg("sdk.nexent.core.models", REPO_ROOT / "sdk" / "nexent" / "core" / "models")
+vector_pkg = _pkg("sdk.nexent.vector_database", REPO_ROOT / "sdk" / "nexent" / "vector_database")
+sdk_pkg.nexent = nexent_pkg
+nexent_pkg.core = core_pkg
+nexent_pkg.vector_database = vector_pkg
+core_pkg.tools = tools_pkg
+core_pkg.utils = utils_pkg
+core_pkg.models = models_pkg
+
+class MessageObserver:
+ def add_message(self, *args, **kwargs):
+ pass
+
+class _ProcessType:
+ TOOL = "TOOL"
+ CARD = "CARD"
+ SEARCH_CONTENT = "SEARCH_CONTENT"
+ PICTURE_WEB = "PICTURE_WEB"
+
+ProcessType = _ProcessType
+
+observer_mod = types.ModuleType("sdk.nexent.core.utils.observer")
+observer_mod.MessageObserver = MessageObserver
+observer_mod.ProcessType = _ProcessType
+sys.modules["sdk.nexent.core.utils.observer"] = observer_mod
+utils_pkg.observer = observer_mod
+
+class _EnumValue:
+ def __init__(self, value):
+ self.value = value
+
+class _ToolCategory:
+ SEARCH = _EnumValue("search")
+
+class _ToolSign:
+ KNOWLEDGE_BASE = _EnumValue("knowledge_base")
+
+class SearchResultTextMessage:
+ def __init__(self, **kwargs):
+ self.data = {
+ "title": kwargs.get("title", ""),
+ "content": kwargs.get("text", ""),
+ "source_type": kwargs.get("source_type", ""),
+ "url": kwargs.get("url", ""),
+ "filename": kwargs.get("filename", ""),
+ "published_date": kwargs.get("published_date", ""),
+ "score": kwargs.get("score", 0),
+ "score_details": kwargs.get("score_details", {}),
+ "cite_index": kwargs.get("cite_index", 0),
+ "search_type": kwargs.get("search_type", ""),
+ "tool_sign": kwargs.get("tool_sign", ""),
+ }
+
+ def to_dict(self):
+ return dict(self.data)
+
+ def to_model_dict(self):
+ return dict(self.data)
+
+tools_common_mod = types.ModuleType("sdk.nexent.core.utils.tools_common_message")
+tools_common_mod.SearchResultTextMessage = SearchResultTextMessage
+tools_common_mod.ToolCategory = _ToolCategory
+tools_common_mod.ToolSign = _ToolSign
+sys.modules["sdk.nexent.core.utils.tools_common_message"] = tools_common_mod
+utils_pkg.tools_common_message = tools_common_mod
+
+constants_mod = types.ModuleType("sdk.nexent.core.utils.constants")
+constants_mod.RERANK_OVERSEARCH_MULTIPLIER = 2
+sys.modules["sdk.nexent.core.utils.constants"] = constants_mod
+utils_pkg.constants = constants_mod
+
+class BaseEmbedding:
+ pass
+
+class BaseRerank:
+ pass
+
+embedding_mod = types.ModuleType("sdk.nexent.core.models.embedding_model")
+embedding_mod.BaseEmbedding = BaseEmbedding
+sys.modules["sdk.nexent.core.models.embedding_model"] = embedding_mod
+models_pkg.embedding_model = embedding_mod
+
+rerank_mod = types.ModuleType("sdk.nexent.core.models.rerank_model")
+rerank_mod.BaseRerank = BaseRerank
+sys.modules["sdk.nexent.core.models.rerank_model"] = rerank_mod
+models_pkg.rerank_model = rerank_mod
+
+class VectorDatabaseCore:
+ pass
+
+vector_base_mod = types.ModuleType("sdk.nexent.vector_database.base")
+vector_base_mod.VectorDatabaseCore = VectorDatabaseCore
+sys.modules["sdk.nexent.vector_database.base"] = vector_base_mod
+vector_pkg.base = vector_base_mod
+
+smolagents_mod = types.ModuleType("smolagents")
+smolagents_tools_mod = types.ModuleType("smolagents.tools")
+
+class Tool:
+ def __init__(self, *args, **kwargs):
+ pass
+
+smolagents_tools_mod.Tool = Tool
+smolagents_mod.tools = smolagents_tools_mod
+sys.modules["smolagents"] = smolagents_mod
+sys.modules["smolagents.tools"] = smolagents_tools_mod
+
+MODULE_PATH = REPO_ROOT / "sdk" / "nexent" / "core" / "tools" / "knowledge_base_search_tool.py"
+MODULE_NAME = "sdk.nexent.core.tools.knowledge_base_search_tool"
+spec = importlib.util.spec_from_file_location(MODULE_NAME, MODULE_PATH)
+knowledge_base_search_tool_module = importlib.util.module_from_spec(spec)
+sys.modules[MODULE_NAME] = knowledge_base_search_tool_module
+assert spec and spec.loader
+spec.loader.exec_module(knowledge_base_search_tool_module)
+tools_pkg.knowledge_base_search_tool = knowledge_base_search_tool_module
+KnowledgeBaseSearchTool = knowledge_base_search_tool_module.KnowledgeBaseSearchTool
@pytest.fixture
@@ -73,7 +204,7 @@ def test_forward_with_observer_adds_messages(self, knowledge_base_search_tool):
mock_results = create_mock_search_result(1)
knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results
- knowledge_base_search_tool.forward("hello world", index_names="test_index1,test_index2")
+ knowledge_base_search_tool.forward("hello world")
knowledge_base_search_tool.observer.add_message.assert_any_call(
"", ProcessType.TOOL, "Searching the knowledge base..."
@@ -189,7 +320,7 @@ def test_search_hybrid_error(self, knowledge_base_search_tool):
with pytest.raises(Exception) as excinfo:
knowledge_base_search_tool.search_hybrid("test query", ["test_index1"], top_k=5)
- assert "Error during semantic search" in str(excinfo.value)
+ assert "Error during hybrid search" in str(excinfo.value)
def test_forward_accurate_mode_success(self, knowledge_base_search_tool):
"""Test forward method with accurate search mode"""
@@ -200,7 +331,7 @@ def test_forward_accurate_mode_success(self, knowledge_base_search_tool):
mock_results = create_mock_search_result(2)
knowledge_base_search_tool.vdb_core.accurate_search.return_value = mock_results
- result = knowledge_base_search_tool.forward("test query", index_names="test_index1")
+ result = knowledge_base_search_tool.forward("test query")
# Parse result
search_results = json.loads(result)
@@ -217,7 +348,7 @@ def test_forward_semantic_mode_success(self, knowledge_base_search_tool):
mock_results = create_mock_search_result(4)
knowledge_base_search_tool.vdb_core.semantic_search.return_value = mock_results
- result = knowledge_base_search_tool.forward("test query", index_names="test_index1")
+ result = knowledge_base_search_tool.forward("test query")
# Parse result
search_results = json.loads(result)
@@ -231,7 +362,7 @@ def test_forward_invalid_search_mode(self, knowledge_base_search_tool):
knowledge_base_search_tool.search_mode = "invalid"
with pytest.raises(Exception) as excinfo:
- knowledge_base_search_tool.forward("test query", index_names="test_index1")
+ knowledge_base_search_tool.forward("test query")
assert "Invalid search mode" in str(excinfo.value)
assert "hybrid, accurate, semantic" in str(excinfo.value)
@@ -242,18 +373,18 @@ def test_forward_no_results(self, knowledge_base_search_tool):
knowledge_base_search_tool.vdb_core.hybrid_search.return_value = []
with pytest.raises(Exception) as excinfo:
- knowledge_base_search_tool.forward("test query", index_names="test_index1")
+ knowledge_base_search_tool.forward("test query")
assert "No results found" in str(excinfo.value)
def test_forward_with_custom_index_names(self, knowledge_base_search_tool):
- """Test forward method with custom index names passed as parameter"""
+ """Test forward method uses configured custom index names."""
# Mock search results
mock_results = create_mock_search_result(2)
knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results
+ knowledge_base_search_tool.index_names = ["custom_index1", "custom_index2"]
- # Pass index_names as a list parameter (forward expects List[str])
- knowledge_base_search_tool.forward("test query", index_names=["custom_index1", "custom_index2"])
+ knowledge_base_search_tool.forward("test query")
# Verify vdb_core was called with the index names as-is
knowledge_base_search_tool.vdb_core.hybrid_search.assert_called_once_with(
@@ -272,7 +403,7 @@ def test_forward_chinese_language_observer(self, knowledge_base_search_tool):
mock_results = create_mock_search_result(2)
knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results
- result = knowledge_base_search_tool.forward("test query", index_names="test_index1")
+ result = knowledge_base_search_tool.forward("test query")
# Verify Chinese running prompt
knowledge_base_search_tool.observer.add_message.assert_any_call(
@@ -298,7 +429,7 @@ def test_forward_title_fallback(self, knowledge_base_search_tool):
]
knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results
- result = knowledge_base_search_tool.forward("test query", index_names="test_index1")
+ result = knowledge_base_search_tool.forward("test query")
# Parse result
search_results = json.loads(result)
@@ -306,6 +437,34 @@ def test_forward_title_fallback(self, knowledge_base_search_tool):
# Verify title fallback
assert len(search_results) == 1
assert search_results[0]["title"] == "test.txt"
+
+ def test_forward_adds_picture_web_for_images(self, knowledge_base_search_tool, monkeypatch):
+ """Forward should add picture messages when image results are present."""
+ monkeypatch.setenv("DATA_PROCESS_SERVICE", "https://data-process")
+ knowledge_base_search_tool.data_process_service = "https://data-process"
+
+ mock_results = [
+ {
+ "document": {
+ "title": "Image Doc",
+ "content": json.dumps({"image_url": "s3://bucket/img.png"}),
+ "filename": "img.png",
+ "path_or_url": "/path/img.png",
+ "create_time": "2024-01-01T12:00:00Z",
+ "source_type": "file",
+ "process_source": "UniversalImageExtractor",
+ },
+ "score": 0.9,
+ "index": "test_index"
+ }
+ ]
+ knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results
+
+ with patch.object(knowledge_base_search_tool, "_filter_images", return_value=["s3://bucket/img.png"]):
+ knowledge_base_search_tool.forward("find images")
+
+ calls = knowledge_base_search_tool.observer.add_message.call_args_list
+ assert any(call.args[1] == ProcessType.PICTURE_WEB for call in calls)
class TestKnowledgeBaseSearchToolRerank:
@@ -497,12 +656,9 @@ def test_forward_uses_instance_index_names(self, knowledge_base_search_tool):
def test_forward_empty_index_names_string(self, knowledge_base_search_tool):
"""Test forward method with empty index_names string returns no results"""
- # Mock search results
- mock_results = create_mock_search_result(2)
- knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results
+ knowledge_base_search_tool.index_names = ""
- # Pass empty string as index_names
- result = knowledge_base_search_tool.forward("test query", index_names="")
+ result = knowledge_base_search_tool.forward("test query")
# Should return no results message
assert result == json.dumps("No knowledge base selected. No relevant information found.", ensure_ascii=False)
@@ -512,9 +668,9 @@ def test_forward_single_index_name(self, knowledge_base_search_tool):
# Mock search results
mock_results = create_mock_search_result(1)
knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results
+ knowledge_base_search_tool.index_names = ["single_index"]
- # Pass index_names as a list parameter (forward expects List[str])
- knowledge_base_search_tool.forward("test query", index_names=["single_index"])
+ knowledge_base_search_tool.forward("test query")
# Verify vdb_core was called with single index
knowledge_base_search_tool.vdb_core.hybrid_search.assert_called_once_with(
@@ -529,13 +685,13 @@ def test_forward_with_whitespace_in_index_names(self, knowledge_base_search_tool
# Mock search results
mock_results = create_mock_search_result(1)
knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results
+ knowledge_base_search_tool.index_names = [" index1 ", " index2 "]
- # Pass index_names as a list parameter (forward expects List[str])
- knowledge_base_search_tool.forward("test query", index_names=[" index1 ", " index2 "])
+ knowledge_base_search_tool.forward("test query")
- # Verify vdb_core was called with the index names as-is (no stripping performed)
+ # _resolve_index_names strips whitespace.
knowledge_base_search_tool.vdb_core.hybrid_search.assert_called_once_with(
- index_names=[" index1 ", " index2 "],
+ index_names=["index1", "index2"],
query_text="test query",
embedding_model=knowledge_base_search_tool.embedding_model,
top_k=5
@@ -618,7 +774,7 @@ def test_convert_forward_integration(self, mock_observer, mock_vdb_core, mock_em
mock_vdb_core.hybrid_search.return_value = mock_results
tool = KnowledgeBaseSearchTool(
- index_names=[],
+ index_names=["Knowledge A"],
search_mode="hybrid",
vdb_core=mock_vdb_core,
embedding_model=mock_embedding_model,
@@ -628,7 +784,7 @@ def test_convert_forward_integration(self, mock_observer, mock_vdb_core, mock_em
},
)
- tool.forward("test query", index_names=["Knowledge A"])
+ tool.forward("test query")
mock_vdb_core.hybrid_search.assert_called_once_with(
index_names=["es_index_knowledge_a"],
@@ -708,7 +864,7 @@ def test_source_type_local_converted_to_file(self, knowledge_base_search_tool, m
mock_vdb_core.hybrid_search.return_value = mock_results
knowledge_base_search_tool.vdb_core = mock_vdb_core
- knowledge_base_search_tool.forward("test query", index_names=["kb1"])
+ knowledge_base_search_tool.forward("test query")
# Check the SEARCH_CONTENT message which contains full results via to_dict()
search_content_call = [
@@ -719,6 +875,197 @@ def test_source_type_local_converted_to_file(self, knowledge_base_search_tool, m
assert full_results[0]["source_type"] == "file"
+
+class TestKnowledgeBaseSearchToolMissingBranches:
+ def test_resolve_index_names_and_fieldinfo_paths(self, mock_observer, mock_vdb_core, mock_embedding_model):
+ try:
+ from pydantic import FieldInfo
+ except ImportError:
+ from pydantic.fields import FieldInfo
+
+ tool = KnowledgeBaseSearchTool(
+ index_names=["kb1"],
+ search_mode="hybrid",
+ vdb_core=mock_vdb_core,
+ embedding_model=mock_embedding_model,
+ observer=mock_observer,
+ display_name_to_index_map={},
+ )
+
+ tool.index_names = FieldInfo(default="alpha, beta , gamma")
+ assert tool._resolve_index_names() == ["alpha", "beta", "gamma"]
+
+ tool.index_names = FieldInfo(default=["alpha", " ", "gamma"])
+ assert tool._resolve_index_names() == ["alpha", "gamma"]
+
+ tool.index_names = None
+ assert tool._resolve_index_names() == []
+
+ tool.index_names = 123
+ assert tool._resolve_index_names() == []
+
+ def test_convert_to_index_names_with_fieldinfo_default_factory(self, mock_observer, mock_vdb_core, mock_embedding_model):
+ try:
+ from pydantic import FieldInfo
+ except ImportError:
+ from pydantic.fields import FieldInfo
+
+ tool = KnowledgeBaseSearchTool(
+ index_names=["Knowledge A", "raw_index"],
+ search_mode="hybrid",
+ vdb_core=mock_vdb_core,
+ embedding_model=mock_embedding_model,
+ observer=mock_observer,
+ display_name_to_index_map=FieldInfo(default_factory=lambda: {"Knowledge A": "es_index_a"}),
+ )
+
+ assert tool._convert_to_index_names(["Knowledge A", "raw_index"]) == ["es_index_a", "raw_index"]
+
+ def test_apply_rerank_empty_and_invalid_results(self, mock_observer, mock_vdb_core, mock_embedding_model):
+ tool = KnowledgeBaseSearchTool(
+ index_names=["kb1"],
+ search_mode="hybrid",
+ vdb_core=mock_vdb_core,
+ embedding_model=mock_embedding_model,
+ observer=mock_observer,
+ rerank=True,
+ rerank_model=MagicMock(),
+ display_name_to_index_map={},
+ )
+
+ kb_search_results = create_mock_search_result(2)
+ tool.rerank_model.rerank.return_value = []
+ assert tool._apply_rerank("query", kb_search_results, top_k=2) == kb_search_results
+
+ tool.rerank_model.rerank.return_value = [{"index": 99, "relevance_score": 0.5}]
+ assert tool._apply_rerank("query", kb_search_results, top_k=2) == kb_search_results
+
+ def test_extract_image_url_success_and_failure(self):
+ assert KnowledgeBaseSearchTool._extract_image_url(
+ {
+ "process_source": "UniversalImageExtractor",
+ "content": json.dumps({"image_url": "s3://bucket/img.png"}),
+ }
+ ) == "s3://bucket/img.png"
+
+ assert KnowledgeBaseSearchTool._extract_image_url(
+ {
+ "process_source": "UniversalImageExtractor",
+ "content": "not-json",
+ }
+ ) is None
+
+ assert KnowledgeBaseSearchTool._extract_image_url(
+ {
+ "process_source": "file",
+ "content": json.dumps({"image_url": "s3://bucket/img.png"}),
+ }
+ ) is None
+
+ def test_record_search_results_image_filter_paths(self, mock_observer, mock_vdb_core, mock_embedding_model):
+ tool = KnowledgeBaseSearchTool(
+ index_names=["kb1"],
+ search_mode="hybrid",
+ vdb_core=mock_vdb_core,
+ embedding_model=mock_embedding_model,
+ observer=mock_observer,
+ display_name_to_index_map={},
+ )
+
+ search_results = [{"title": "Doc", "content": "Body"}]
+ tool._record_search_results(search_results, [], "query")
+ mock_observer.add_message.assert_called_once()
+ mock_observer.add_message.reset_mock()
+
+ with patch.object(tool, "_filter_images", return_value=[]):
+ tool._record_search_results(search_results, ["img1"], "query")
+ assert any(call.args[1] == ProcessType.PICTURE_WEB for call in mock_observer.add_message.call_args_list)
+ mock_observer.add_message.reset_mock()
+
+ with patch.object(tool, "_filter_images", side_effect=Exception("boom")):
+ tool._record_search_results(search_results, ["img2"], "query")
+ assert any(call.args[1] == ProcessType.PICTURE_WEB for call in mock_observer.add_message.call_args_list)
+
+ def test_search_error_wrappers(self, mock_observer, mock_vdb_core, mock_embedding_model):
+ tool = KnowledgeBaseSearchTool(
+ index_names=["kb1"],
+ search_mode="hybrid",
+ vdb_core=mock_vdb_core,
+ embedding_model=mock_embedding_model,
+ observer=mock_observer,
+ display_name_to_index_map={},
+ )
+
+ mock_vdb_core.accurate_search.side_effect = Exception("accurate boom")
+ with pytest.raises(Exception, match="Error during accurate search"):
+ tool.search_accurate("query", ["kb1"], top_k=1)
+
+ mock_vdb_core.accurate_search.side_effect = None
+ mock_vdb_core.semantic_search.side_effect = Exception("semantic boom")
+ with pytest.raises(Exception, match="Error during semantic search"):
+ tool.search_semantic("query", ["kb1"], top_k=1)
+
+ def test_filter_images_success_and_event_loop_failure(self, mock_observer, mock_vdb_core, mock_embedding_model, monkeypatch, mocker):
+ import asyncio
+
+ tool = KnowledgeBaseSearchTool(
+ index_names=["kb1"],
+ search_mode="hybrid",
+ vdb_core=mock_vdb_core,
+ embedding_model=mock_embedding_model,
+ observer=mock_observer,
+ display_name_to_index_map={},
+ )
+ tool.data_process_service = "https://data-process"
+
+ class FakeResponse:
+ def __init__(self, status, payload=None):
+ self.status = status
+ self._payload = payload or {}
+
+ async def json(self):
+ return self._payload
+
+ class FakePostContext:
+ def __init__(self, url):
+ self.url = url
+
+ async def __aenter__(self):
+ if self.url == "raise":
+ raise RuntimeError("request boom")
+ if self.url == "bad":
+ return FakeResponse(500, {})
+ if self.url == "skip":
+ return FakeResponse(200, {"is_important": False})
+ return FakeResponse(200, {"is_important": True})
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
+ class FakeSession:
+ def __init__(self, *args, **kwargs):
+ pass
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
+ def post(self, api_url, data):
+ return FakePostContext(data["image_url"])
+
+ fake_aiohttp = types.ModuleType("aiohttp")
+ fake_aiohttp.TCPConnector = lambda limit=0: object()
+ fake_aiohttp.ClientTimeout = lambda total=0: object()
+ fake_aiohttp.ClientSession = FakeSession
+ monkeypatch.setitem(sys.modules, "aiohttp", fake_aiohttp)
+
+ assert tool._filter_images(["keep", "skip", "bad", "raise"], "query") == ["keep"]
+
+ mocker.patch("asyncio.new_event_loop", side_effect=RuntimeError("loop boom"))
+ assert tool._filter_images(["keep"], "query") == []
+
def test_source_type_minio_converted_to_file(self, knowledge_base_search_tool, mock_vdb_core):
"""Test that source_type 'minio' is converted to 'file'."""
mock_results = [
@@ -738,7 +1085,7 @@ def test_source_type_minio_converted_to_file(self, knowledge_base_search_tool, m
mock_vdb_core.hybrid_search.return_value = mock_results
knowledge_base_search_tool.vdb_core = mock_vdb_core
- knowledge_base_search_tool.forward("test query", index_names=["kb1"])
+ knowledge_base_search_tool.forward("test query")
# Check the SEARCH_CONTENT message
search_content_call = [
@@ -768,7 +1115,7 @@ def test_source_type_other_unchanged(self, knowledge_base_search_tool, mock_vdb_
mock_vdb_core.hybrid_search.return_value = mock_results
knowledge_base_search_tool.vdb_core = mock_vdb_core
- knowledge_base_search_tool.forward("test query", index_names=["kb1"])
+ knowledge_base_search_tool.forward("test query")
# Check the SEARCH_CONTENT message
search_content_call = [
@@ -790,7 +1137,7 @@ def test_record_ops_increments_by_result_count(self, knowledge_base_search_tool)
initial_ops = knowledge_base_search_tool.record_ops
- knowledge_base_search_tool.forward("test query", index_names=["kb1"])
+ knowledge_base_search_tool.forward("test query")
assert knowledge_base_search_tool.record_ops == initial_ops + 2
@@ -800,10 +1147,10 @@ def test_record_ops_accumulates_across_calls(self, knowledge_base_search_tool):
knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results
knowledge_base_search_tool.record_ops = 0
- knowledge_base_search_tool.forward("query1", index_names=["kb1"])
+ knowledge_base_search_tool.forward("query1")
first_call_ops = knowledge_base_search_tool.record_ops
- knowledge_base_search_tool.forward("query2", index_names=["kb1"])
+ knowledge_base_search_tool.forward("query2")
second_call_ops = knowledge_base_search_tool.record_ops
# Each call with 1 result adds 1 to record_ops
@@ -816,7 +1163,7 @@ def test_cite_index_in_results(self, knowledge_base_search_tool):
knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results
# record_ops starts at 1, so cite_index should be 1+0+1=1, 1+1+1=2
- knowledge_base_search_tool.forward("test query", index_names=["kb1"])
+ knowledge_base_search_tool.forward("test query")
# Check the SEARCH_CONTENT message for cite_index values
search_content_call = [
@@ -837,7 +1184,7 @@ def test_forward_sends_search_content_to_observer(self, knowledge_base_search_to
mock_results = create_mock_search_result(1)
knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results
- knowledge_base_search_tool.forward("test query", index_names=["kb1"])
+ knowledge_base_search_tool.forward("test query")
search_content_calls = [
call for call in knowledge_base_search_tool.observer.add_message.call_args_list
@@ -893,9 +1240,7 @@ def test_output_type(self, knowledge_base_search_tool):
def test_inputs_contain_required_fields(self):
"""Test that inputs dict contains required fields."""
assert "query" in KnowledgeBaseSearchTool.inputs
- assert "index_names" in KnowledgeBaseSearchTool.inputs
assert KnowledgeBaseSearchTool.inputs["query"]["type"] == "string"
- assert KnowledgeBaseSearchTool.inputs["index_names"]["type"] == "array"
def test_running_prompts(self, knowledge_base_search_tool):
"""Test running prompts for both languages."""
@@ -926,7 +1271,7 @@ def test_forward_with_score_details(self, knowledge_base_search_tool, mock_vdb_c
mock_vdb_core.hybrid_search.return_value = mock_results
knowledge_base_search_tool.vdb_core = mock_vdb_core
- knowledge_base_search_tool.forward("test query", index_names=["kb1"])
+ knowledge_base_search_tool.forward("test query")
# Check the SEARCH_CONTENT message which contains full results via to_dict()
search_content_call = [
@@ -957,10 +1302,10 @@ def test_forward_with_empty_content(self, knowledge_base_search_tool, mock_vdb_c
mock_vdb_core.hybrid_search.return_value = mock_results
knowledge_base_search_tool.vdb_core = mock_vdb_core
- result = knowledge_base_search_tool.forward("test query", index_names=["kb1"])
+ result = knowledge_base_search_tool.forward("test query")
search_results = json.loads(result)
- assert search_results[0]["text"] == ""
+ assert search_results[0]["content"] == ""
def test_forward_multiple_indices(self, knowledge_base_search_tool, mock_vdb_core):
"""Test forward searches across multiple indices."""
@@ -993,7 +1338,7 @@ def test_forward_multiple_indices(self, knowledge_base_search_tool, mock_vdb_cor
mock_vdb_core.hybrid_search.return_value = mock_results
knowledge_base_search_tool.vdb_core = mock_vdb_core
- result = knowledge_base_search_tool.forward("test query", index_names=["index1", "index2"])
+ result = knowledge_base_search_tool.forward("test query")
search_results = json.loads(result)
assert len(search_results) == 2
diff --git a/test/sdk/core/utils/test_favicon_extractor.py b/test/sdk/core/utils/test_favicon_extractor.py
new file mode 100644
index 000000000..0e4448a82
--- /dev/null
+++ b/test/sdk/core/utils/test_favicon_extractor.py
@@ -0,0 +1,38 @@
+import importlib.util
+import sys
+from pathlib import Path
+from unittest.mock import Mock, patch
+
+MODULE_NAME = "favicon_extractor_under_test"
+MODULE_PATH = (
+ Path(__file__).resolve().parents[4]
+ / "sdk"
+ / "nexent"
+ / "core"
+ / "utils"
+ / "favicon_extractor.py"
+)
+spec = importlib.util.spec_from_file_location(MODULE_NAME, MODULE_PATH)
+favicon_module = importlib.util.module_from_spec(spec)
+sys.modules[MODULE_NAME] = favicon_module
+assert spec and spec.loader
+spec.loader.exec_module(favicon_module)
+
+get_favicon_url = favicon_module.get_favicon_url
+check_favicon_exists = favicon_module.check_favicon_exists
+
+
+def test_get_favicon_url_builds_default():
+ assert get_favicon_url("https://example.com/path") == "https://example.com/favicon.ico"
+
+
+def test_check_favicon_exists_true():
+ mock_response = Mock()
+ mock_response.status_code = 200
+ with patch(f"{MODULE_NAME}.requests.head", return_value=mock_response):
+ assert check_favicon_exists("https://example.com/favicon.ico") is True
+
+
+def test_check_favicon_exists_false_on_error():
+ with patch(f"{MODULE_NAME}.requests.head", side_effect=Exception("boom")):
+ assert check_favicon_exists("https://example.com/favicon.ico") is False
diff --git a/test/sdk/data_process/test_core.py b/test/sdk/data_process/test_core.py
index b41150e39..6c47c3732 100644
--- a/test/sdk/data_process/test_core.py
+++ b/test/sdk/data_process/test_core.py
@@ -6,6 +6,18 @@
from sdk.nexent.data_process.core import DataProcessCore
+def _unpack_chunks(result):
+ if isinstance(result, tuple):
+ return result[0]
+ return result
+
+
+def _unpack_images(result):
+ if isinstance(result, tuple):
+ return result[1]
+ return []
+
+
class TestDataProcessCore:
"""Test suite for DataProcessCore class"""
@@ -19,7 +31,8 @@ def test_init(self, core):
assert core is not None
assert "Unstructured" in core.processors
assert "OpenPyxl" in core.processors
- assert len(core.processors) == 3
+ assert "UniversalImageExtractor" in core.processors
+ assert len(core.processors) == 4
def test_file_process_with_excel_file(self, core, mocker: MockFixture):
"""Test file processing with Excel file"""
@@ -30,6 +43,9 @@ def test_file_process_with_excel_file(self, core, mocker: MockFixture):
"metadata": {"chunk_index": 0}}
]
core.processors["OpenPyxl"] = mock_processor
+ core.processors["UniversalImageExtractor"] = Mock(
+ process_file=Mock(return_value=[])
+ )
file_data = b"fake excel data"
filename = "test.xlsx"
@@ -37,8 +53,9 @@ def test_file_process_with_excel_file(self, core, mocker: MockFixture):
result = core.file_process(
file_data, filename, chunking_strategy="basic")
- assert len(result) == 1
- assert result[0]["content"] == "test content"
+ chunks = _unpack_chunks(result)
+ assert len(chunks) == 1
+ assert chunks[0]["content"] == "test content"
mock_processor.process_file.assert_called_once_with(
file_data, "basic", filename=filename
)
@@ -58,8 +75,9 @@ def test_file_process_with_pdf_file(self, core, mocker: MockFixture):
result = core.file_process(
file_data, filename, chunking_strategy="by_title")
- assert len(result) == 1
- assert result[0]["content"] == "pdf content"
+ chunks = _unpack_chunks(result)
+ assert len(chunks) == 1
+ assert chunks[0]["content"] == "pdf content"
mock_processor.process_file.assert_called_once_with(
file_data, "by_title", filename=filename
)
@@ -69,6 +87,9 @@ def test_file_process_with_explicit_processor(self, core, mocker: MockFixture):
mock_processor = Mock()
mock_processor.process_file.return_value = [{"content": "test"}]
core.processors["Unstructured"] = mock_processor
+ core.processors["UniversalImageExtractor"] = Mock(
+ process_file=Mock(return_value=[])
+ )
file_data = b"data"
filename = "test.xlsx"
@@ -78,7 +99,8 @@ def test_file_process_with_explicit_processor(self, core, mocker: MockFixture):
file_data, filename, chunking_strategy="basic", processor="Unstructured"
)
- assert len(result) == 1
+ chunks = _unpack_chunks(result)
+ assert len(chunks) == 1
mock_processor.process_file.assert_called_once()
def test_file_process_with_additional_params(self, core, mocker: MockFixture):
@@ -95,7 +117,8 @@ def test_file_process_with_additional_params(self, core, mocker: MockFixture):
file_data, filename, chunking_strategy="basic", **additional_params
)
- assert len(result) == 1
+ chunks = _unpack_chunks(result)
+ assert len(chunks) == 1
mock_processor.process_file.assert_called_once_with(
file_data, "basic", filename=filename, max_characters=2000, strategy="fast"
)
@@ -153,7 +176,7 @@ def test_validate_parameters_valid_strategies(self, core, chunking_strategy):
@pytest.mark.parametrize(
"processor",
- ["Unstructured", "OpenPyxl"]
+ ["Unstructured", "OpenPyxl", "UniversalImageExtractor"]
)
def test_validate_parameters_valid_processors(self, core, processor):
"""Test parameter validation with valid processors"""
@@ -171,21 +194,24 @@ def test_validate_parameters_invalid_processor(self, core):
core._validate_parameters("basic", "InvalidProcessor")
@pytest.mark.parametrize(
- "filename,expected_processor",
+ "filename,expected_processor,expected_extractor",
[
- ("test.xlsx", "OpenPyxl"),
- ("test.xls", "OpenPyxl"),
- ("test.XLSX", "OpenPyxl"),
- ("test.pdf", "Unstructured"),
- ("test.docx", "Unstructured"),
- ("test.txt", "Unstructured"),
- ("test.html", "Unstructured"),
+ ("test.xlsx", "OpenPyxl", "UniversalImageExtractor"),
+ ("test.xls", "OpenPyxl", "UniversalImageExtractor"),
+ ("test.XLSX", "OpenPyxl", "UniversalImageExtractor"),
+ ("test.pdf", "Unstructured", "UniversalImageExtractor"),
+ ("test.docx", "Unstructured", "UniversalImageExtractor"),
+ ("test.pptx", "Unstructured", None),
+ ("test.txt", "Unstructured", None),
+ ("test.html", "Unstructured", None),
]
)
- def test_select_processor_by_filename(self, core, filename, expected_processor):
+ def test_select_processor_by_filename(self, core, filename, expected_processor, expected_extractor):
"""Test processor selection based on filename"""
- result = core._select_processor_by_filename(filename)
- assert result == expected_processor
+ params = {"model_type": "multi_embedding"} if expected_extractor else {}
+ processor_name, extractor = core._select_processor_by_filename(filename, params)
+ assert processor_name == expected_processor
+ assert extractor == expected_extractor
def test_get_supported_file_types(self, core):
"""Test getting supported file types"""
@@ -241,7 +267,8 @@ def test_get_supported_processors(self, core):
assert "Unstructured" in result
assert "OpenPyxl" in result
- assert len(result) == 2
+ assert "UniversalImageExtractor" in result
+ assert len(result) == 3
@pytest.mark.parametrize(
"filename,expected",
@@ -312,6 +339,46 @@ def test_get_processor_info_case_insensitive(self, core):
assert result["processor_type"] == "excel"
assert result["file_extension"] == ".xlsx"
+ def test_file_process_returns_images_when_extractor_available(self, core, mocker: MockFixture):
+ """Test image extraction is returned for supported file types."""
+ mock_processor = Mock()
+ mock_processor.process_file.return_value = [{"content": "test"}]
+ mock_extractor = Mock()
+ mock_extractor.process_file.return_value = [
+ {"image_bytes": b"img", "image_format": "png", "position": {"page_number": 1}}
+ ]
+ core.processors["Unstructured"] = mock_processor
+ core.processors["UniversalImageExtractor"] = mock_extractor
+
+ result = core.file_process(
+ b"data", "sample.pdf", chunking_strategy="basic", model_type="multi_embedding"
+ )
+
+ chunks = _unpack_chunks(result)
+ images = _unpack_images(result)
+ assert len(chunks) == 1
+ assert len(images) == 1
+ mock_extractor.process_file.assert_called_once()
+
+ def test_file_process_with_explicit_processor_still_extracts_images(self, core):
+ """Test explicit processor still triggers image extraction."""
+ core.processors["Unstructured"] = Mock(process_file=Mock(return_value=[{"content": "ok"}]))
+ core.processors["UniversalImageExtractor"] = Mock(
+ process_file=Mock(return_value=[{"image_bytes": b"x", "image_format": "png", "position": {}}])
+ )
+
+ result = core.file_process(
+ b"data",
+ "report.pdf",
+ chunking_strategy="basic",
+ processor="Unstructured",
+ model_type="multi_embedding",
+ )
+
+ chunks = _unpack_chunks(result)
+ images = _unpack_images(result)
+ assert len(chunks) == 1
+ assert len(images) == 1
def test_file_split_unsupported_extension_returns_original_bytes(self, core):
"""Unsupported extensions should bypass splitting and return original bytes."""
data = b"raw-bytes"
diff --git a/test/sdk/data_process/test_extract_image.py b/test/sdk/data_process/test_extract_image.py
new file mode 100644
index 000000000..696bfd5d6
--- /dev/null
+++ b/test/sdk/data_process/test_extract_image.py
@@ -0,0 +1,409 @@
+import base64
+import importlib.util
+import os
+import subprocess
+import sys
+import threading
+import types
+from pathlib import Path
+from types import SimpleNamespace
+import zipfile
+from xml.etree import ElementTree as ET
+
+import pytest
+
+# Stub heavy optional deps before importing module under test.
+fake_pptx = types.ModuleType("pptx")
+fake_pptx.Presentation = object
+sys.modules.setdefault("pptx", fake_pptx)
+
+fake_unstructured = types.ModuleType("unstructured")
+fake_unstructured_partition = types.ModuleType("unstructured.partition")
+fake_unstructured_partition_auto = types.ModuleType("unstructured.partition.auto")
+fake_unstructured_partition_auto.partition = lambda *a, **k: []
+fake_unstructured.partition = fake_unstructured_partition
+fake_unstructured_partition.auto = fake_unstructured_partition_auto
+sys.modules.setdefault("unstructured", fake_unstructured)
+sys.modules.setdefault("unstructured.partition", fake_unstructured_partition)
+sys.modules.setdefault("unstructured.partition.auto", fake_unstructured_partition_auto)
+
+fake_unstructured = types.ModuleType("unstructured_inference")
+fake_models = types.ModuleType("unstructured_inference.models")
+fake_tables = types.ModuleType("unstructured_inference.models.tables")
+fake_tables.tables_agent = types.SimpleNamespace(model=None)
+fake_logger = types.ModuleType("unstructured_inference.logger")
+fake_logger.logger = types.SimpleNamespace(info=lambda *a, **k: None, warning=lambda *a, **k: None, error=lambda *a, **k: None)
+fake_models.tables = fake_tables
+fake_unstructured.models = fake_models
+sys.modules.setdefault("unstructured_inference", fake_unstructured)
+sys.modules.setdefault("unstructured_inference.models", fake_models)
+sys.modules.setdefault("unstructured_inference.models.tables", fake_tables)
+sys.modules.setdefault("unstructured_inference.logger", fake_logger)
+
+REPO_ROOT = Path(__file__).resolve().parents[3]
+MODULE_PATH = REPO_ROOT / "sdk" / "nexent" / "data_process" / "extract_image.py"
+MODULE_NAME = "sdk.nexent.data_process.extract_image"
+
+sdk_pkg = types.ModuleType("sdk")
+sdk_pkg.__path__ = [str(REPO_ROOT / "sdk")]
+sdk_pkg = sys.modules.setdefault("sdk", sdk_pkg)
+
+nexent_pkg = types.ModuleType("sdk.nexent")
+nexent_pkg.__path__ = [str(REPO_ROOT / "sdk" / "nexent")]
+nexent_pkg = sys.modules.setdefault("sdk.nexent", nexent_pkg)
+sdk_pkg.nexent = nexent_pkg
+
+data_process_pkg = types.ModuleType("sdk.nexent.data_process")
+data_process_pkg.__path__ = [str(REPO_ROOT / "sdk" / "nexent" / "data_process")]
+data_process_pkg = sys.modules.setdefault("sdk.nexent.data_process", data_process_pkg)
+nexent_pkg.data_process = data_process_pkg
+spec = importlib.util.spec_from_file_location(MODULE_NAME, MODULE_PATH)
+extract_image_module = importlib.util.module_from_spec(spec)
+sys.modules[MODULE_NAME] = extract_image_module
+assert spec and spec.loader
+spec.loader.exec_module(extract_image_module)
+data_process_pkg.extract_image = extract_image_module
+
+UniversalImageExtractor = extract_image_module.UniversalImageExtractor
+
+
+def test_detect_image_format_png():
+ assert UniversalImageExtractor.detect_image_format(b"\x89PNG\r\n\x1a\n") == "png"
+
+
+def test_detect_image_format_jpg():
+ assert UniversalImageExtractor.detect_image_format(b"\xFF\xD8\xFF\xE0") == "jpg"
+
+
+def test_detect_image_format_default_png():
+ assert UniversalImageExtractor.detect_image_format(b"not-an-image") == "png"
+
+
+def test_convert_file_success(mocker):
+ extractor = UniversalImageExtractor()
+ mocker.patch.object(extract_image_module.subprocess, "run")
+ mocker.patch.object(extract_image_module.os.path, "exists", return_value=True)
+ mocker.patch.object(extract_image_module.os.path, "splitext", return_value=("C:/tmp/file", ".doc"))
+
+ result = extractor._convert_file("C:/tmp/file.doc", "pdf")
+
+ assert result.endswith(".pdf")
+
+
+def test_convert_file_missing_output(mocker):
+ extractor = UniversalImageExtractor()
+ mocker.patch.object(extract_image_module.subprocess, "run")
+ mocker.patch.object(extract_image_module.os.path, "exists", return_value=False)
+ mocker.patch.object(extract_image_module.os.path, "splitext", return_value=("C:/tmp/file", ".doc"))
+
+ with pytest.raises(FileNotFoundError):
+ extractor._convert_file("C:/tmp/file.doc", "pdf")
+
+
+def test_process_file_routes_pdf(mocker, tmp_path):
+ extractor = UniversalImageExtractor()
+ mocker.patch.object(extractor, "_write_temp_file", return_value=str(tmp_path / "file.pdf"))
+ mock_extract = mocker.patch.object(extractor, "_extract_pdf", return_value=[{"image_bytes": b"x"}])
+
+ result = extractor.process_file(b"data", "none", "file.pdf")
+
+ assert result == [{"image_bytes": b"x"}]
+ mock_extract.assert_called_once()
+
+
+def test_process_file_routes_xls_and_ppt(mocker, tmp_path):
+ extractor = UniversalImageExtractor()
+ mocker.patch.object(extractor, "_write_temp_file", return_value=str(tmp_path / "file.xls"))
+ mocker.patch.object(extractor, "_convert_file", return_value=str(tmp_path / "file.xlsx"))
+ mock_extract_excel = mocker.patch.object(extractor, "_extract_excel", return_value=[{"image_bytes": b"x"}])
+
+ result = extractor.process_file(b"data", "none", "file.xls")
+
+ assert result == [{"image_bytes": b"x"}]
+ mock_extract_excel.assert_called_once_with(str(tmp_path / "file.xlsx"))
+
+ mocker.patch.object(extractor, "_write_temp_file", return_value=str(tmp_path / "file.ppt"))
+ mocker.patch.object(extractor, "_convert_file", return_value=str(tmp_path / "file.pptx"))
+ mock_extract_ppt = mocker.patch.object(extractor, "_extract_pptx", return_value=[{"image_bytes": b"y"}])
+
+ result = extractor.process_file(b"data", "none", "file.ppt")
+
+ assert result == [{"image_bytes": b"y"}]
+ mock_extract_ppt.assert_called_once_with(str(tmp_path / "file.pptx"))
+
+
+def test_process_file_routes_docx_to_pdf(mocker, tmp_path):
+ extractor = UniversalImageExtractor()
+ mocker.patch.object(extractor, "_write_temp_file", return_value=str(tmp_path / "file.docx"))
+ mocker.patch.object(extractor, "_convert_file", return_value=str(tmp_path / "file.pdf"))
+ mock_extract = mocker.patch.object(extractor, "_extract_pdf", return_value=[{"image_bytes": b"x"}])
+
+ result = extractor.process_file(b"data", "none", "file.docx")
+
+ assert result == [{"image_bytes": b"x"}]
+ mock_extract.assert_called_once_with(str(tmp_path / "file.pdf"))
+
+
+def test_process_file_unsupported_extension_returns_empty(mocker, tmp_path):
+ extractor = UniversalImageExtractor()
+ mocker.patch.object(extractor, "_write_temp_file", return_value=str(tmp_path / "file.txt"))
+
+ result = extractor.process_file(b"data", "none", "file.txt")
+
+ assert result == []
+
+
+def _build_excel_zip(tmp_path, sheet_xml, sheet_rels=None, drawing_xml=None, drawing_rels=None, image_bytes=b"\x89PNGdata"):
+ zip_path = tmp_path / "sample.xlsx"
+ with zipfile.ZipFile(zip_path, "w") as zf:
+ zf.writestr("xl/worksheets/sheet1.xml", sheet_xml)
+ if sheet_rels is not None:
+ zf.writestr("xl/worksheets/_rels/sheet1.xml.rels", sheet_rels)
+ if drawing_xml is not None:
+ zf.writestr("xl/drawings/drawing1.xml", drawing_xml)
+ if drawing_rels is not None:
+ zf.writestr("xl/drawings/_rels/drawing1.xml.rels", drawing_rels)
+ if image_bytes is not None:
+ zf.writestr("xl/media/image1.png", image_bytes)
+ return zip_path
+
+
+def test_custom_load_table_model_initializes_when_missing(monkeypatch):
+ called = []
+ fake_agent = SimpleNamespace(model=None, _lock=threading.Lock())
+
+ def initialize(path):
+ called.append(path)
+ fake_agent.model = object()
+
+ fake_agent.initialize = initialize
+ monkeypatch.setattr(extract_image_module, "tables_agent", fake_agent)
+ monkeypatch.setattr(extract_image_module, "TABLE_TRANSFORMER_MODEL_PATH", "model-path")
+
+ extract_image_module.custom_load_table_model()
+
+ assert called == ["model-path"]
+
+
+def test_hash_namespace_write_temp_file(mocker, tmp_path):
+ extractor = UniversalImageExtractor()
+
+ assert extractor._hash(b"abc") == __import__("hashlib").sha256(b"abc").hexdigest()
+ assert extractor._openxml_namespace_maps()["xdr"].endswith("spreadsheetDrawing")
+
+ temp_path = extractor._write_temp_file(b"hello", ".bin")
+ assert Path(temp_path).read_bytes() == b"hello"
+ os.remove(temp_path)
+
+
+def test_convert_file_error_paths(mocker):
+ extractor = UniversalImageExtractor()
+ mocker.patch.object(
+ extract_image_module.subprocess,
+ "run",
+ side_effect=subprocess.CalledProcessError(1, ["soffice"]),
+ )
+ with pytest.raises(RuntimeError, match="LibreOffice conversion failed"):
+ extractor._convert_file("C:/tmp/file.doc", "pdf")
+
+ mocker.patch.object(
+ extract_image_module.subprocess,
+ "run",
+ side_effect=subprocess.TimeoutExpired(cmd="soffice", timeout=60),
+ )
+ with pytest.raises(RuntimeError, match="timed out"):
+ extractor._convert_file("C:/tmp/file.doc", "pdf")
+
+
+def test_extract_pdf_paths_and_deduplication(mocker):
+ extractor = UniversalImageExtractor()
+
+ assert extractor._extract_pdf("sample.pdf") == []
+
+ png = base64.b64encode(b"\x89PNGdata").decode("ascii")
+ jpg = base64.b64encode(b"\xFF\xD8\xFFdata").decode("ascii")
+
+ elements = [
+ SimpleNamespace(metadata=SimpleNamespace(image_base64=png, coordinates=SimpleNamespace(points=[(1, 2), (3, 4)]), page_number=1)),
+ SimpleNamespace(metadata=SimpleNamespace(image_base64="", coordinates=None, page_number=2)),
+ SimpleNamespace(metadata=SimpleNamespace(image_base64=png, coordinates=None, page_number=3)),
+ SimpleNamespace(metadata=SimpleNamespace(image_base64=jpg, coordinates=SimpleNamespace(points=[(5, 6), (7, 8)]), page_number=4)),
+ ]
+ mocker.patch.object(extract_image_module, "partition", return_value=elements)
+
+ result = extractor._extract_pdf(
+ "sample.pdf",
+ table_transformer_model_path="model-path",
+ unstructured_default_model_initialize_params_json_path="init.json",
+ )
+
+ assert extract_image_module.TABLE_TRANSFORMER_MODEL_PATH == "model-path"
+ assert len(result) == 2
+ assert result[0]["position"]["coordinates"] == {"x1": 1, "y1": 2, "x2": 3, "y2": 4}
+ assert result[1]["image_format"] == "jpg"
+
+
+def test_excel_helpers_positive_and_negative_paths(tmp_path):
+ extractor = UniversalImageExtractor()
+ ns = extractor._openxml_namespace_maps()
+
+ sheet_xml = """
+
+
+
+ """
+ sheet_rels = """
+
+
+
+ """
+ drawing_xml = """
+
+
+ 01
+ 23
+
+
+
+ 45
+
+
+
+ """
+ drawing_rels = """
+
+
+
+ """
+ zip_path = _build_excel_zip(tmp_path, sheet_xml, sheet_rels, drawing_xml, drawing_rels)
+
+ with zipfile.ZipFile(zip_path) as zf:
+ sheet_files = extractor._excel_sheet_files(zf)
+ assert sheet_files == ["xl/worksheets/sheet1.xml"]
+ assert extractor._excel_drawing_file(zf, sheet_files[0]) == "xl/drawings/drawing1.xml"
+ rel_map = extractor._excel_rel_map(zf, "xl/drawings/drawing1.xml")
+ assert rel_map == {"rIdImg1": "xl/media/image1.png"}
+ anchors = extractor._excel_anchors(zf, "xl/drawings/drawing1.xml", ns)
+ assert len(anchors) == 2
+ assert extractor._excel_anchor_coords(anchors[0], ns) == {"row1": 1, "col1": 2, "row2": 3, "col2": 4}
+ assert extractor._excel_anchor_coords(anchors[1], ns) == {"row1": 5, "col1": 6, "row2": 5, "col2": 6}
+ assert extractor._excel_anchor_embed_id(anchors[0], ns) == "rIdImg1"
+ results = extractor._extract_excel_anchors(zf, anchors, rel_map, "sheet1.xml", ns, set())
+ assert len(results) == 1
+ assert extractor._extract_excel_anchors(zf, [anchors[0]], {}, "sheet1.xml", ns, set()) == []
+ assert extractor._extract_excel_sheet(zf, "xl/worksheets/sheet1.xml", ns, set()) == results
+
+ assert extractor._extract_excel(str(zip_path)) == results
+
+ no_drawing_zip = _build_excel_zip(tmp_path, "")
+ with zipfile.ZipFile(no_drawing_zip) as zf:
+ assert extractor._excel_drawing_file(zf, "xl/worksheets/sheet1.xml") is None
+
+ bad_sheet_xml = """
+
+
+
+ """
+ missing_rel_zip = _build_excel_zip(tmp_path, bad_sheet_xml, drawing_xml=drawing_xml, drawing_rels=None)
+ with zipfile.ZipFile(missing_rel_zip) as zf:
+ assert extractor._excel_drawing_file(zf, "xl/worksheets/sheet1.xml") is None
+ assert extractor._excel_rel_map(zf, "xl/drawings/drawing1.xml") is None
+ assert extractor._extract_excel_sheet(zf, "xl/worksheets/sheet1.xml", ns, set()) == []
+
+ empty_rel_xml = """
+
+ """
+ empty_rel_zip = _build_excel_zip(tmp_path, sheet_xml, sheet_rels, drawing_xml, empty_rel_xml)
+ with zipfile.ZipFile(empty_rel_zip) as zf:
+ assert extractor._extract_excel_sheet(zf, "xl/worksheets/sheet1.xml", ns, set()) == []
+
+ mismatch_sheet_rels = """
+
+
+
+ """
+ mismatch_zip = _build_excel_zip(tmp_path, sheet_xml, mismatch_sheet_rels, drawing_xml, drawing_rels)
+ with zipfile.ZipFile(mismatch_zip) as zf:
+ assert extractor._excel_drawing_file(zf, "xl/worksheets/sheet1.xml") is None
+
+ anchor_no_from = ET.fromstring(
+ ''
+ )
+ assert extractor._excel_anchor_coords(anchor_no_from, ns) is None
+
+ anchor_no_blip = ET.fromstring(
+ ''
+ '00'
+ ''
+ )
+ assert extractor._excel_anchor_embed_id(anchor_no_blip, ns) is None
+
+ empty_anchors = [
+ anchor_no_from,
+ anchor_no_blip,
+ ]
+ assert extractor._extract_excel_anchors(zf, empty_anchors, {}, "sheet1.xml", ns, set()) == []
+
+
+def test_pptx_extraction_paths(monkeypatch):
+ extractor = UniversalImageExtractor()
+
+ monkeypatch.setattr(extract_image_module, "Presentation", None)
+ with pytest.raises(RuntimeError, match="python-pptx is required"):
+ extractor._extract_pptx("sample.pptx")
+
+ class FakeShape:
+ def __init__(self, blob=None):
+ if blob is not None:
+ self.image = SimpleNamespace(blob=blob)
+ self.left = 914400
+ self.top = 914400
+ self.width = 914400
+ self.height = 914400
+
+ class FakeSlide:
+ def __init__(self):
+ self.shapes = [SimpleNamespace(), FakeShape(b"\x89PNGdata"), FakeShape(b"\x89PNGdata")]
+
+ class FakePresentation:
+ def __init__(self, path):
+ self.slide_width = 914400 * 10
+ self.slide_height = 914400 * 5
+ self.slides = [FakeSlide()]
+
+ monkeypatch.setattr(extract_image_module, "Presentation", FakePresentation)
+ result = extractor._extract_pptx("sample.pptx")
+ assert len(result) == 1
+ assert result[0]["position"]["coordinates"]["x1"] == 96
+ assert result[0]["position"]["coordinates"]["slide_width"] == 960
+
+
+def test_process_file_direct_and_cleanup_paths(mocker, tmp_path):
+ extractor = UniversalImageExtractor()
+
+ mocker.patch.object(extractor, "_write_temp_file", side_effect=[str(tmp_path / "file.xlsx"), str(tmp_path / "file.pptx"), str(tmp_path / "file.doc")])
+ mocker.patch.object(extractor, "_extract_excel", return_value=[{"image_bytes": b"x"}])
+ mocker.patch.object(extractor, "_extract_pptx", return_value=[{"image_bytes": b"y"}])
+
+ assert extractor.process_file(b"data", "none", "file.xlsx") == [{"image_bytes": b"x"}]
+ assert extractor.process_file(b"data", "none", "file.pptx") == [{"image_bytes": b"y"}]
+
+ mocker.patch.object(extractor, "_convert_file", return_value=str(tmp_path / "file.pdf"))
+ mocker.patch.object(extractor, "_extract_pdf", return_value=[{"image_bytes": b"z"}])
+ mocker.patch.object(extract_image_module.os.path, "exists", return_value=True)
+
+ removed = []
+
+ def remove_side_effect(path):
+ removed.append(path)
+ if len(removed) == 1:
+ raise Exception("cleanup boom")
+
+ mocker.patch.object(extract_image_module.os, "remove", side_effect=remove_side_effect)
+
+ assert extractor.process_file(b"data", "none", "file.doc") == [{"image_bytes": b"z"}]
+ assert str(tmp_path / "file.doc") in removed
+ assert str(tmp_path / "file.pdf") in removed
diff --git a/test/sdk/data_process/test_unstructured_processor.py b/test/sdk/data_process/test_unstructured_processor.py
index 709a643e2..bfb828d10 100644
--- a/test/sdk/data_process/test_unstructured_processor.py
+++ b/test/sdk/data_process/test_unstructured_processor.py
@@ -23,7 +23,7 @@ def setup_partition_mock(mocker: MockFixture, return_value):
"unstructured.partition.auto": fake_auto_mod,
})
- mock_partition = mocker.Mock(return_value=return_value)
+ mock_partition = Mock(return_value=return_value)
fake_auto_mod.partition = mock_partition
return mock_partition
diff --git a/test/sdk/vector_database/test_elasticsearch_core.py b/test/sdk/vector_database/test_elasticsearch_core.py
index 3945b65c4..e96b50ded 100644
--- a/test/sdk/vector_database/test_elasticsearch_core.py
+++ b/test/sdk/vector_database/test_elasticsearch_core.py
@@ -1,13 +1,107 @@
-import pytest
-from unittest.mock import MagicMock, patch
+import importlib.util
import time
import types
+import sys
+from pathlib import Path
from typing import List, Dict, Any
from contextlib import contextmanager
+
+import pytest
+from unittest.mock import MagicMock, patch
+
+REPO_ROOT = Path(__file__).resolve().parents[3]
+
+def _pkg(name, path):
+ mod = types.ModuleType(name)
+ mod.__path__ = [str(path)]
+ sys.modules.setdefault(name, mod)
+ return mod
+
+sdk_pkg = _pkg("sdk", REPO_ROOT / "sdk")
+nexent_pkg = _pkg("sdk.nexent", REPO_ROOT / "sdk" / "nexent")
+core_pkg = _pkg("sdk.nexent.core", REPO_ROOT / "sdk" / "nexent" / "core")
+models_pkg = _pkg("sdk.nexent.core.models", REPO_ROOT / "sdk" / "nexent" / "core" / "models")
+nlp_pkg = _pkg("sdk.nexent.core.nlp", REPO_ROOT / "sdk" / "nexent" / "core" / "nlp")
+vector_pkg = _pkg("sdk.nexent.vector_database", REPO_ROOT / "sdk" / "nexent" / "vector_database")
+sdk_pkg.nexent = nexent_pkg
+nexent_pkg.core = core_pkg
+nexent_pkg.vector_database = vector_pkg
+core_pkg.models = models_pkg
+core_pkg.nlp = nlp_pkg
+
+class BaseEmbedding:
+ pass
+
+embedding_mod = types.ModuleType("sdk.nexent.core.models.embedding_model")
+embedding_mod.BaseEmbedding = BaseEmbedding
+sys.modules["sdk.nexent.core.models.embedding_model"] = embedding_mod
+models_pkg.embedding_model = embedding_mod
+
+tokenizer_mod = types.ModuleType("sdk.nexent.core.nlp.tokenizer")
+tokenizer_mod.calculate_term_weights = lambda query_text: {}
+sys.modules["sdk.nexent.core.nlp.tokenizer"] = tokenizer_mod
+nlp_pkg.tokenizer = tokenizer_mod
+
+class VectorDatabaseCore:
+ pass
+
+vector_base_mod = types.ModuleType("sdk.nexent.vector_database.base")
+vector_base_mod.VectorDatabaseCore = VectorDatabaseCore
+sys.modules["sdk.nexent.vector_database.base"] = vector_base_mod
+vector_pkg.base = vector_base_mod
+
+vector_utils_mod = types.ModuleType("sdk.nexent.vector_database.utils")
+vector_utils_mod.build_weighted_query = lambda query_text, weights: {"query": {"match": {"content": query_text}}}
+vector_utils_mod.format_size = lambda size: f"{size}B"
+sys.modules["sdk.nexent.vector_database.utils"] = vector_utils_mod
+vector_pkg.utils = vector_utils_mod
+
+fake_elasticsearch = types.ModuleType("elasticsearch")
+
+class _FakeRequestError(Exception):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args)
+ self.kwargs = kwargs
+
+ def __str__(self):
+ return str(self.kwargs.get("message", self.args[0] if self.args else ""))
+
+class _FakeNotFoundError(Exception):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args)
+ self.kwargs = kwargs
+
+class _FakeElasticsearch:
+ def __init__(self, *args, **kwargs):
+ self.indices = MagicMock()
+ self.cluster = MagicMock()
+ self.search = MagicMock()
+ self.bulk = MagicMock()
+ self.count = MagicMock()
+ self.delete_by_query = MagicMock()
+ self.msearch = MagicMock()
+ self.index = MagicMock()
+ self.update = MagicMock()
+ self.delete = MagicMock()
+ self.scroll = MagicMock()
+ self.clear_scroll = MagicMock()
+ self.get = MagicMock()
+
+fake_elasticsearch.Elasticsearch = _FakeElasticsearch
+fake_elasticsearch.exceptions = types.SimpleNamespace(RequestError=_FakeRequestError, NotFoundError=_FakeNotFoundError)
+sys.modules.setdefault("elasticsearch", fake_elasticsearch)
+
from elasticsearch import exceptions
-# Import the class under test
-from sdk.nexent.vector_database.elasticsearch_core import ElasticSearchCore
+MODULE_PATH = REPO_ROOT / "sdk" / "nexent" / "vector_database" / "elasticsearch_core.py"
+MODULE_NAME = "sdk.nexent.vector_database.elasticsearch_core"
+spec = importlib.util.spec_from_file_location(MODULE_NAME, MODULE_PATH)
+elasticsearch_core_module = importlib.util.module_from_spec(spec)
+sys.modules[MODULE_NAME] = elasticsearch_core_module
+assert spec and spec.loader
+spec.loader.exec_module(elasticsearch_core_module)
+vector_pkg.elasticsearch_core = elasticsearch_core_module
+ElasticSearchCore = elasticsearch_core_module.ElasticSearchCore
# ----------------------------------------------------------------------------
# Fixtures
@@ -702,6 +796,7 @@ def time_side_effect():
def test_vectorize_documents_empty_list(elasticsearch_core_instance):
"""Test indexing an empty list of documents."""
mock_embedding_model = MagicMock()
+ mock_embedding_model.model_type = "text"
result = elasticsearch_core_instance.vectorize_documents(
"test_index",
@@ -716,6 +811,7 @@ def test_vectorize_documents_empty_list(elasticsearch_core_instance):
def test_vectorize_documents_small_batch(elasticsearch_core_instance):
"""Test indexing a small batch of documents (< 64)."""
mock_embedding_model = MagicMock()
+ mock_embedding_model.model_type = "text"
mock_embedding_model.get_embeddings.return_value = [[0.1] * 1024] * 3
mock_embedding_model.embedding_model_name = "test-model"
@@ -744,6 +840,84 @@ def test_vectorize_documents_small_batch(elasticsearch_core_instance):
mock_embedding_model.get_embeddings.assert_called_once()
mock_bulk.assert_called_once()
+
+def test_vectorize_documents_multimodal_sets_multi_embedding(elasticsearch_core_instance):
+ embedding_model = MagicMock()
+ embedding_model.model_type = "multimodal"
+ embedding_model.get_multimodal_embeddings.return_value = [[0.1, 0.2], [0.3, 0.4]]
+
+ documents = [
+ {
+ "content": "text content",
+ "process_source": "Unstructured",
+ "path_or_url": "path1",
+ },
+ {
+ "content": "image content",
+ "process_source": "UniversalImageExtractor",
+ "image_bytes": b"img",
+ "path_or_url": "path2",
+ },
+ ]
+
+ with patch.object(elasticsearch_core_instance.client, "bulk") as mock_bulk, \
+ patch.object(elasticsearch_core_instance, "_force_refresh_with_retry", return_value=True):
+ mock_bulk.return_value = {"errors": False, "items": []}
+
+ result = elasticsearch_core_instance.vectorize_documents(
+ documents=documents,
+ index_name="test_index",
+ content_field="content",
+ embedding_model=embedding_model,
+ embedding_batch_size=2,
+ )
+
+ assert result == 2
+ operations = mock_bulk.call_args.kwargs["operations"]
+ doc_entries = [item for item in operations if "index" not in item]
+ image_doc = next(doc for doc in doc_entries if doc["process_source"] == "UniversalImageExtractor")
+ text_doc = next(doc for doc in doc_entries if doc["process_source"] != "UniversalImageExtractor")
+ assert "multi_embedding" in image_doc
+ assert "embedding" in text_doc
+
+
+def test_vectorize_documents_text_embedding_skips_images(elasticsearch_core_instance):
+ embedding_model = MagicMock()
+ embedding_model.model_type = "text"
+ embedding_model.get_embeddings.return_value = [[0.1, 0.2]]
+
+ documents = [
+ {
+ "content": "image content",
+ "process_source": "UniversalImageExtractor",
+ "image_bytes": b"img",
+ "path_or_url": "path2",
+ },
+ {
+ "content": "text content",
+ "process_source": "Unstructured",
+ "path_or_url": "path1",
+ },
+ ]
+
+ with patch.object(elasticsearch_core_instance.client, "bulk") as mock_bulk, \
+ patch.object(elasticsearch_core_instance, "_force_refresh_with_retry", return_value=True):
+ mock_bulk.return_value = {"errors": False, "items": []}
+
+ result = elasticsearch_core_instance.vectorize_documents(
+ documents=documents,
+ index_name="test_index",
+ content_field="content",
+ embedding_model=embedding_model,
+ embedding_batch_size=2,
+ )
+
+ assert result == 1
+ operations = mock_bulk.call_args.kwargs["operations"]
+ doc_entries = [item for item in operations if "index" not in item]
+ assert len(doc_entries) == 1
+ assert doc_entries[0]["process_source"] != "UniversalImageExtractor"
+
def test_small_batch_progress_callback_exception(elasticsearch_core_instance, caplog):
"""Progress callback errors should be logged without failing the insert."""
mock_embedding_model = MagicMock()
@@ -789,6 +963,7 @@ def test_small_batch_error_path_logs_and_raises(elasticsearch_core_instance, cap
def test_vectorize_documents_large_batch(elasticsearch_core_instance):
"""Test indexing a large batch of documents (>= 64)."""
mock_embedding_model = MagicMock()
+ mock_embedding_model.model_type = "text"
mock_embedding_model.get_embeddings.return_value = [[0.1] * 1024] * 64
mock_embedding_model.embedding_model_name = "test-model"
@@ -819,7 +994,7 @@ def test_vectorize_documents_large_batch(elasticsearch_core_instance):
assert result == 100
assert mock_embedding_model.get_embeddings.call_count >= 2
mock_bulk.assert_called()
- mock_refresh.assert_called_once_with("test_index")
+ assert mock_refresh.call_count == 2
def test_vectorize_documents_small_batch_large_mode_forces_large_path(elasticsearch_core_instance):
@@ -1257,8 +1432,8 @@ def test_get_index_chunks_cleanup_failure(elasticsearch_core_instance):
def test_accurate_search_success(elasticsearch_core_instance):
"""Test accurate search with text matching."""
with patch.object(elasticsearch_core_instance, 'exec_query') as mock_exec, \
- patch('sdk.nexent.vector_database.elasticsearch_core.calculate_term_weights') as mock_weights, \
- patch('sdk.nexent.vector_database.elasticsearch_core.build_weighted_query') as mock_build:
+ patch.object(elasticsearch_core_module, 'calculate_term_weights') as mock_weights, \
+ patch.object(elasticsearch_core_module, 'build_weighted_query') as mock_build:
mock_weights.return_value = {"test": 1.0}
mock_build.return_value = {
@@ -1287,8 +1462,8 @@ def test_accurate_search_success(elasticsearch_core_instance):
def test_accurate_search_builds_multi_index_query(elasticsearch_core_instance):
"""Ensure accurate_search joins indices and applies top_k sizing."""
with patch.object(elasticsearch_core_instance, 'exec_query') as mock_exec, \
- patch('sdk.nexent.vector_database.elasticsearch_core.calculate_term_weights') as mock_weights, \
- patch('sdk.nexent.vector_database.elasticsearch_core.build_weighted_query') as mock_build:
+ patch.object(elasticsearch_core_module, 'calculate_term_weights') as mock_weights, \
+ patch.object(elasticsearch_core_module, 'build_weighted_query') as mock_build:
mock_weights.return_value = {"test": 0.5}
mock_build.return_value = {"query": {"match_all": {}}}
@@ -1313,6 +1488,7 @@ def test_accurate_search_builds_multi_index_query(elasticsearch_core_instance):
def test_semantic_search_success(elasticsearch_core_instance):
"""Test semantic search with vector similarity."""
mock_embedding_model = MagicMock()
+ mock_embedding_model.model_type = "text"
mock_embedding_model.get_embeddings.return_value = [[0.1] * 1024]
with patch.object(elasticsearch_core_instance, 'exec_query') as mock_exec:
@@ -1338,9 +1514,32 @@ def test_semantic_search_success(elasticsearch_core_instance):
mock_exec.assert_called_once()
+def test_semantic_search_multimodal_combines_queries(elasticsearch_core_instance):
+ mock_embedding_model = MagicMock()
+ mock_embedding_model.model_type = "multimodal"
+ mock_embedding_model.get_embeddings.return_value = [[0.1] * 8]
+
+ with patch.object(elasticsearch_core_instance, 'exec_query') as mock_exec:
+ mock_exec.side_effect = [
+ [{"score": 1.0, "document": {"content": "text"}, "index": "test_index"}],
+ [{"score": 0.9, "document": {"content": "image"}, "index": "test_index"}],
+ ]
+
+ result = elasticsearch_core_instance.semantic_search(
+ ["test_index"],
+ "test query",
+ mock_embedding_model,
+ top_k=3,
+ )
+
+ assert len(result) == 2
+ assert mock_exec.call_count == 2
+
+
def test_semantic_search_sets_knn_parameters(elasticsearch_core_instance):
"""Ensure semantic_search sets k and num_candidates based on top_k."""
mock_embedding_model = MagicMock()
+ mock_embedding_model.model_type = "text"
mock_embedding_model.get_embeddings.return_value = [[0.2] * 8]
with patch.object(elasticsearch_core_instance, 'exec_query') as mock_exec:
@@ -2224,3 +2423,104 @@ def test_get_user_indices_error_returns_empty(elasticsearch_core_instance):
with patch.object(elasticsearch_core_instance, "client") as mock_client:
mock_client.indices.get_alias.side_effect = RuntimeError("x")
assert elasticsearch_core_instance.get_user_indices("*") == []
+
+
+class TestAdditionalElasticsearchCoreCoverage:
+ def test_create_index_request_error_other_returns_false(self, elasticsearch_core_instance):
+ with patch.object(elasticsearch_core_instance, "client") as mock_client, \
+ patch.object(elasticsearch_core_instance, "_ensure_index_ready") as mock_ready:
+ mock_client.indices.exists.return_value = False
+ mock_client.indices.create.side_effect = exceptions.RequestError(
+ message="bad request",
+ meta=types.SimpleNamespace(status=400),
+ body={"error": {"type": "mapper_parsing_exception"}},
+ )
+
+ assert elasticsearch_core_instance.create_index("idx") is False
+ mock_ready.assert_not_called()
+
+ def test_force_refresh_with_zero_retries_returns_false(self, elasticsearch_core_instance):
+ with patch.object(elasticsearch_core_instance.client.indices, "refresh") as mock_refresh:
+ assert elasticsearch_core_instance._force_refresh_with_retry("idx", max_retries=0) is False
+ mock_refresh.assert_not_called()
+
+ def test_delete_index_generic_error_returns_false(self, elasticsearch_core_instance):
+ with patch.object(elasticsearch_core_instance.client.indices, "delete") as mock_delete:
+ mock_delete.side_effect = RuntimeError("boom")
+ assert elasticsearch_core_instance.delete_index("idx") is False
+
+ def test_bulk_operation_context_nested_restores_settings(self, elasticsearch_core_instance):
+ with patch.object(elasticsearch_core_instance, "_apply_bulk_settings") as mock_apply, \
+ patch.object(elasticsearch_core_instance, "_restore_normal_settings") as mock_restore:
+ with elasticsearch_core_instance.bulk_operation_context("idx", estimated_duration=1) as op1:
+ with elasticsearch_core_instance.bulk_operation_context("idx", estimated_duration=1) as op2:
+ assert op1 != op2
+ assert "idx" in elasticsearch_core_instance._bulk_operations
+ assert len(elasticsearch_core_instance._bulk_operations["idx"]) == 2
+ assert mock_restore.call_count == 0
+
+ mock_apply.assert_called_once_with("idx")
+ mock_restore.assert_called_once_with("idx")
+ assert "idx" not in elasticsearch_core_instance._bulk_operations
+
+ def test_delete_documents_and_count_documents_error_paths(self, elasticsearch_core_instance):
+ with patch.object(elasticsearch_core_instance.client, "delete_by_query") as mock_delete, \
+ patch.object(elasticsearch_core_instance.client, "count") as mock_count:
+ mock_delete.return_value = {"deleted": 3}
+ assert elasticsearch_core_instance.delete_documents("idx", "/path/file.pdf") == 3
+
+ mock_delete.side_effect = RuntimeError("boom")
+ assert elasticsearch_core_instance.delete_documents("idx", "/path/file.pdf") == 0
+
+ mock_count.return_value = {"count": 7}
+ assert elasticsearch_core_instance.count_documents("idx") == 7
+
+ mock_count.side_effect = RuntimeError("boom")
+ assert elasticsearch_core_instance.count_documents("idx") == 0
+
+ def test_get_index_chunks_zero_total_paginated_and_scroll_without_scroll_id(self, elasticsearch_core_instance):
+ elasticsearch_core_instance.client = MagicMock()
+
+ elasticsearch_core_instance.client.count.side_effect = [
+ {"count": 0},
+ {"count": 1},
+ {"count": 1},
+ ]
+ elasticsearch_core_instance.client.search.side_effect = [
+ {"hits": {"hits": [{"_id": "doc-1", "_source": {"content": "A"}}]}},
+ {"hits": {"hits": [{"_id": "doc-2", "_source": {"content": "B"}}]}},
+ ]
+
+ empty = elasticsearch_core_instance.get_index_chunks("idx", page=2, page_size=10, path_or_url="/path")
+ assert empty == {"chunks": [], "total": 0, "page": 2, "page_size": 10}
+
+ paginated = elasticsearch_core_instance.get_index_chunks("idx", page=1, page_size=1)
+ assert paginated["chunks"] == [{"content": "A", "id": "doc-1"}]
+
+ scroll = elasticsearch_core_instance.get_index_chunks("idx")
+ assert scroll["chunks"] == [{"content": "B", "id": "doc-2"}]
+ elasticsearch_core_instance.client.clear_scroll.assert_not_called()
+
+ def test_get_index_chunks_exception_path(self, elasticsearch_core_instance):
+ elasticsearch_core_instance.client = MagicMock()
+ elasticsearch_core_instance.client.count.return_value = {"count": 1}
+ elasticsearch_core_instance.client.search.side_effect = RuntimeError("boom")
+
+ with pytest.raises(RuntimeError):
+ elasticsearch_core_instance.get_index_chunks("idx")
+
+ def test_check_index_exists_wrapper(self, elasticsearch_core_instance):
+ with patch.object(elasticsearch_core_instance.client.indices, "exists") as mock_exists:
+ mock_exists.return_value = True
+ assert elasticsearch_core_instance.check_index_exists("idx") is True
+
+ def test_search_and_multi_search_wrappers(self, elasticsearch_core_instance):
+ with patch.object(elasticsearch_core_instance.client, "search") as mock_search:
+ mock_search.return_value = {"hits": {"hits": []}}
+ assert elasticsearch_core_instance.search("idx", {"match_all": {}}) == {"hits": {"hits": []}}
+ mock_search.assert_called_once_with(index="idx", body={"match_all": {}})
+
+ with patch.object(elasticsearch_core_instance.client, "msearch") as mock_msearch:
+ mock_msearch.return_value = {"responses": []}
+ assert elasticsearch_core_instance.multi_search([{}], "idx") == {"responses": []}
+ mock_msearch.assert_called_once_with(body=[{}], index="idx")