diff --git a/docs/chinese-llm-setup.md b/docs/chinese-llm-setup.md index 37042aa2f..1fb0ce2a1 100644 --- a/docs/chinese-llm-setup.md +++ b/docs/chinese-llm-setup.md @@ -24,7 +24,7 @@ SurfSense 现已支持以下国产 LLM: 1. 登录 SurfSense Dashboard 2. 进入 **Settings** → **API Keys** (或 **LLM Configurations**) -3. 点击 **Add New Configuration** +3. 点击 **Add LLM Model** 4. 从 **Provider** 下拉菜单中选择你的国产 LLM 提供商 5. 填写必填字段(见下方各提供商详细配置) 6. 点击 **Save** diff --git a/surfsense_backend/alembic/versions/111_add_prompts_table.py b/surfsense_backend/alembic/versions/111_add_prompts_table.py index 7d4d69fd2..f61c4e298 100644 --- a/surfsense_backend/alembic/versions/111_add_prompts_table.py +++ b/surfsense_backend/alembic/versions/111_add_prompts_table.py @@ -42,7 +42,9 @@ def upgrade() -> None: ) """) op.execute("CREATE INDEX ix_prompts_user_id ON prompts (user_id)") - op.execute("CREATE INDEX ix_prompts_search_space_id ON prompts (search_space_id)") + op.execute( + "CREATE INDEX ix_prompts_search_space_id ON prompts (search_space_id)" + ) def downgrade() -> None: diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py index a712c9a45..8dffb18dd 100644 --- a/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/create_file.py @@ -81,7 +81,8 @@ async def create_onedrive_file( select(SearchSourceConnector).filter( SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, ) ) connectors = result.scalars().all() @@ -95,12 +96,14 @@ async def create_onedrive_file( accounts = [] for c in connectors: cfg = c.config or {} - accounts.append({ - "id": c.id, - "name": c.name, - "user_email": cfg.get("user_email"), - "auth_expired": cfg.get("auth_expired", False), - }) + accounts.append( + { + "id": c.id, + "name": c.name, + "user_email": cfg.get("user_email"), + "auth_expired": cfg.get("auth_expired", False), + } + ) if all(a.get("auth_expired") for a in accounts): return { @@ -119,16 +122,22 @@ async def create_onedrive_file( client = OneDriveClient(session=db_session, connector_id=cid) items, err = await client.list_children("root") if err: - logger.warning("Failed to list folders for connector %s: %s", cid, err) + logger.warning( + "Failed to list folders for connector %s: %s", cid, err + ) parent_folders[cid] = [] else: parent_folders[cid] = [ {"folder_id": item["id"], "name": item["name"]} for item in items - if item.get("folder") is not None and item.get("id") and item.get("name") + if item.get("folder") is not None + and item.get("id") + and item.get("name") ] except Exception: - logger.warning("Error fetching folders for connector %s", cid, exc_info=True) + logger.warning( + "Error fetching folders for connector %s", cid, exc_info=True + ) parent_folders[cid] = [] context: dict[str, Any] = { @@ -152,8 +161,12 @@ async def create_onedrive_file( } ) - decisions_raw = approval.get("decisions", []) if isinstance(approval, dict) else [] - decisions = decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] + decisions_raw = ( + approval.get("decisions", []) if isinstance(approval, dict) else [] + ) + decisions = ( + decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] + ) decisions = [d for d in decisions if isinstance(d, dict)] if not decisions: return {"status": "error", "message": "No approval decision received"} @@ -192,7 +205,8 @@ async def create_onedrive_file( SearchSourceConnector.id == final_connector_id, SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, ) ) connector = result.scalars().first() @@ -200,7 +214,10 @@ async def create_onedrive_file( connector = connectors[0] if not connector: - return {"status": "error", "message": "Selected OneDrive connector is invalid."} + return { + "status": "error", + "message": "Selected OneDrive connector is invalid.", + } docx_bytes = _markdown_to_docx(final_content or "") @@ -212,7 +229,9 @@ async def create_onedrive_file( mime_type=DOCX_MIME, ) - logger.info(f"OneDrive file created: id={created.get('id')}, name={created.get('name')}") + logger.info( + f"OneDrive file created: id={created.get('id')}, name={created.get('name')}" + ) kb_message_suffix = "" try: diff --git a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py index ae7c5e306..79d8222fd 100644 --- a/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/onedrive/trash_file.py @@ -52,10 +52,15 @@ async def delete_onedrive_file( - If status is "not_found", relay the exact message to the user and ask them to verify the file name or check if it has been indexed. """ - logger.info(f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}") + logger.info( + f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" + ) if db_session is None or search_space_id is None or user_id is None: - return {"status": "error", "message": "OneDrive tool not properly configured."} + return { + "status": "error", + "message": "OneDrive tool not properly configured.", + } try: doc_result = await db_session.execute( @@ -89,8 +94,12 @@ async def delete_onedrive_file( Document.search_space_id == search_space_id, Document.document_type == DocumentType.ONEDRIVE_FILE, func.lower( - cast(Document.document_metadata["onedrive_file_name"], String) - ) == func.lower(file_name), + cast( + Document.document_metadata["onedrive_file_name"], + String, + ) + ) + == func.lower(file_name), SearchSourceConnector.user_id == user_id, ) ) @@ -110,14 +119,20 @@ async def delete_onedrive_file( } if not document.connector_id: - return {"status": "error", "message": "Document has no associated connector."} + return { + "status": "error", + "message": "Document has no associated connector.", + } meta = document.document_metadata or {} file_id = meta.get("onedrive_file_id") document_id = document.id if not file_id: - return {"status": "error", "message": "File ID is missing. Please re-index the file."} + return { + "status": "error", + "message": "File ID is missing. Please re-index the file.", + } conn_result = await db_session.execute( select(SearchSourceConnector).filter( @@ -125,13 +140,17 @@ async def delete_onedrive_file( SearchSourceConnector.id == document.connector_id, SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, ) ) ) connector = conn_result.scalars().first() if not connector: - return {"status": "error", "message": "OneDrive connector not found or access denied."} + return { + "status": "error", + "message": "OneDrive connector not found or access denied.", + } cfg = connector.config or {} if cfg.get("auth_expired"): @@ -170,8 +189,12 @@ async def delete_onedrive_file( } ) - decisions_raw = approval.get("decisions", []) if isinstance(approval, dict) else [] - decisions = decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] + decisions_raw = ( + approval.get("decisions", []) if isinstance(approval, dict) else [] + ) + decisions = ( + decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] + ) decisions = [d for d in decisions if isinstance(d, dict)] if not decisions: return {"status": "error", "message": "No approval decision received"} @@ -206,7 +229,8 @@ async def delete_onedrive_file( SearchSourceConnector.id == final_connector_id, SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.user_id == user_id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, ) ) ) @@ -224,10 +248,14 @@ async def delete_onedrive_file( f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}" ) - client = OneDriveClient(session=db_session, connector_id=actual_connector_id) + client = OneDriveClient( + session=db_session, connector_id=actual_connector_id + ) await client.trash_file(final_file_id) - logger.info(f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}") + logger.info( + f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}" + ) trash_result: dict[str, Any] = { "status": "success", @@ -272,6 +300,9 @@ async def delete_onedrive_file( if isinstance(e, GraphInterrupt): raise logger.error(f"Error deleting OneDrive file: {e}", exc_info=True) - return {"status": "error", "message": "Something went wrong while trashing the file. Please try again."} + return { + "status": "error", + "message": "Something went wrong while trashing the file. Please try again.", + } return delete_onedrive_file diff --git a/surfsense_backend/app/connectors/onedrive/client.py b/surfsense_backend/app/connectors/onedrive/client.py index cc118c0c9..37c5823a3 100644 --- a/surfsense_backend/app/connectors/onedrive/client.py +++ b/surfsense_backend/app/connectors/onedrive/client.py @@ -39,7 +39,9 @@ async def _get_valid_token(self) -> str: cfg = connector.config or {} is_encrypted = cfg.get("_token_encrypted", False) - token_encryption = TokenEncryption(config.SECRET_KEY) if config.SECRET_KEY else None + token_encryption = ( + TokenEncryption(config.SECRET_KEY) if config.SECRET_KEY else None + ) access_token = cfg.get("access_token", "") refresh_token = cfg.get("refresh_token") @@ -206,18 +208,20 @@ async def download_file(self, item_id: str) -> tuple[bytes | None, str | None]: async def download_file_to_disk(self, item_id: str, dest_path: str) -> str | None: """Stream file content to disk. Returns error message on failure.""" token = await self._get_valid_token() - async with httpx.AsyncClient(follow_redirects=True) as client: - async with client.stream( + async with ( + httpx.AsyncClient(follow_redirects=True) as client, + client.stream( "GET", f"{GRAPH_API_BASE}/me/drive/items/{item_id}/content", headers={"Authorization": f"Bearer {token}"}, timeout=120.0, - ) as resp: - if resp.status_code != 200: - return f"Download failed: {resp.status_code}" - with open(dest_path, "wb") as f: - async for chunk in resp.aiter_bytes(chunk_size=5 * 1024 * 1024): - f.write(chunk) + ) as resp, + ): + if resp.status_code != 200: + return f"Download failed: {resp.status_code}" + with open(dest_path, "wb") as f: + async for chunk in resp.aiter_bytes(chunk_size=5 * 1024 * 1024): + f.write(chunk) return None async def create_file( diff --git a/surfsense_backend/app/connectors/onedrive/content_extractor.py b/surfsense_backend/app/connectors/onedrive/content_extractor.py index 109a8cb15..8917ba1fd 100644 --- a/surfsense_backend/app/connectors/onedrive/content_extractor.py +++ b/surfsense_backend/app/connectors/onedrive/content_extractor.py @@ -5,6 +5,7 @@ """ import asyncio +import contextlib import logging import os import tempfile @@ -60,7 +61,9 @@ async def download_and_extract_content( temp_file_path = None try: - extension = Path(file_name).suffix or get_extension_from_mime(mime_type) or ".bin" + extension = ( + Path(file_name).suffix or get_extension_from_mime(mime_type) or ".bin" + ) with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp: temp_file_path = tmp.name @@ -76,10 +79,8 @@ async def download_and_extract_content( return None, metadata, str(e) finally: if temp_file_path and os.path.exists(temp_file_path): - try: + with contextlib.suppress(Exception): os.unlink(temp_file_path) - except Exception: - pass async def _parse_file_to_markdown(file_path: str, filename: str) -> str: @@ -94,9 +95,10 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str: return f.read() if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")): - from app.config import config as app_config from litellm import atranscription + from app.config import config as app_config + stt_service_type = ( "local" if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/") @@ -106,9 +108,13 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str: from app.services.stt_service import stt_service t0 = time.monotonic() - logger.info(f"[local-stt] START file={filename} thread={threading.current_thread().name}") + logger.info( + f"[local-stt] START file={filename} thread={threading.current_thread().name}" + ) result = await asyncio.to_thread(stt_service.transcribe_file, file_path) - logger.info(f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s") + logger.info( + f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s" + ) text = result.get("text", "") else: with open(file_path, "rb") as audio_file: @@ -150,7 +156,9 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str: parse_with_llamacloud_retry, ) - result = await parse_with_llamacloud_retry(file_path=file_path, estimated_pages=50) + result = await parse_with_llamacloud_retry( + file_path=file_path, estimated_pages=50 + ) markdown_documents = await result.aget_markdown_documents(split_by_page=False) if not markdown_documents: raise RuntimeError(f"LlamaCloud returned no documents for {filename}") @@ -161,9 +169,13 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str: converter = DocumentConverter() t0 = time.monotonic() - logger.info(f"[docling] START file={filename} thread={threading.current_thread().name}") + logger.info( + f"[docling] START file={filename} thread={threading.current_thread().name}" + ) result = await asyncio.to_thread(converter.convert, file_path) - logger.info(f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s") + logger.info( + f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s" + ) return result.document.export_to_markdown() raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}") diff --git a/surfsense_backend/app/connectors/onedrive/folder_manager.py b/surfsense_backend/app/connectors/onedrive/folder_manager.py index 7f286453c..6fa725ca1 100644 --- a/surfsense_backend/app/connectors/onedrive/folder_manager.py +++ b/surfsense_backend/app/connectors/onedrive/folder_manager.py @@ -27,7 +27,10 @@ async def list_folder_contents( if item["isFolder"]: item.setdefault("mimeType", "application/vnd.ms-folder") else: - item.setdefault("mimeType", item.get("file", {}).get("mimeType", "application/octet-stream")) + item.setdefault( + "mimeType", + item.get("file", {}).get("mimeType", "application/octet-stream"), + ) items.sort(key=lambda x: (not x["isFolder"], x.get("name", "").lower())) @@ -63,7 +66,9 @@ async def get_files_in_folder( client, item["id"], include_subfolders=True ) if sub_error: - logger.warning(f"Error recursing into folder {item.get('name')}: {sub_error}") + logger.warning( + f"Error recursing into folder {item.get('name')}: {sub_error}" + ) continue files.extend(sub_files) elif not should_skip_file(item): diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index a2b7a154a..644ab07dc 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -33,9 +33,10 @@ from .notes_routes import router as notes_router from .notifications_routes import router as notifications_router from .notion_add_connector_route import router as notion_add_connector_router +from .onedrive_add_connector_route import router as onedrive_add_connector_router from .podcasts_routes import router as podcasts_router -from .public_chat_routes import router as public_chat_router from .prompts_routes import router as prompts_router +from .public_chat_routes import router as public_chat_router from .rbac_routes import router as rbac_router from .reports_routes import router as reports_router from .sandbox_routes import router as sandbox_router @@ -44,7 +45,6 @@ from .slack_add_connector_route import router as slack_add_connector_router from .surfsense_docs_routes import router as surfsense_docs_router from .teams_add_connector_route import router as teams_add_connector_router -from .onedrive_add_connector_route import router as onedrive_add_connector_router from .video_presentations_routes import router as video_presentations_router from .youtube_routes import router as youtube_router diff --git a/surfsense_backend/app/routes/onedrive_add_connector_route.py b/surfsense_backend/app/routes/onedrive_add_connector_route.py index 19bcbe6ff..2f41efca7 100644 --- a/surfsense_backend/app/routes/onedrive_add_connector_route.py +++ b/surfsense_backend/app/routes/onedrive_add_connector_route.py @@ -79,9 +79,13 @@ async def connect_onedrive(space_id: int, user: User = Depends(current_active_us if not space_id: raise HTTPException(status_code=400, detail="space_id is required") if not config.MICROSOFT_CLIENT_ID: - raise HTTPException(status_code=500, detail="Microsoft OneDrive OAuth not configured.") + raise HTTPException( + status_code=500, detail="Microsoft OneDrive OAuth not configured." + ) if not config.SECRET_KEY: - raise HTTPException(status_code=500, detail="SECRET_KEY not configured for OAuth security.") + raise HTTPException( + status_code=500, detail="SECRET_KEY not configured for OAuth security." + ) state_manager = get_state_manager() state_encoded = state_manager.generate_secure_state(space_id, user.id) @@ -96,14 +100,18 @@ async def connect_onedrive(space_id: int, user: User = Depends(current_active_us } auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}" - logger.info("Generated OneDrive OAuth URL for user %s, space %s", user.id, space_id) + logger.info( + "Generated OneDrive OAuth URL for user %s, space %s", user.id, space_id + ) return {"auth_url": auth_url} except HTTPException: raise except Exception as e: logger.error("Failed to initiate OneDrive OAuth: %s", str(e), exc_info=True) - raise HTTPException(status_code=500, detail=f"Failed to initiate OneDrive OAuth: {e!s}") from e + raise HTTPException( + status_code=500, detail=f"Failed to initiate OneDrive OAuth: {e!s}" + ) from e @router.get("/auth/onedrive/connector/reauth") @@ -121,15 +129,20 @@ async def reauth_onedrive( SearchSourceConnector.id == connector_id, SearchSourceConnector.user_id == user.id, SearchSourceConnector.search_space_id == space_id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, ) ) connector = result.scalars().first() if not connector: - raise HTTPException(status_code=404, detail="OneDrive connector not found or access denied") + raise HTTPException( + status_code=404, detail="OneDrive connector not found or access denied" + ) if not config.SECRET_KEY: - raise HTTPException(status_code=500, detail="SECRET_KEY not configured for OAuth security.") + raise HTTPException( + status_code=500, detail="SECRET_KEY not configured for OAuth security." + ) state_manager = get_state_manager() extra: dict = {"connector_id": connector_id} @@ -148,14 +161,20 @@ async def reauth_onedrive( } auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}" - logger.info("Initiating OneDrive re-auth for user %s, connector %s", user.id, connector_id) + logger.info( + "Initiating OneDrive re-auth for user %s, connector %s", + user.id, + connector_id, + ) return {"auth_url": auth_url} except HTTPException: raise except Exception as e: logger.error("Failed to initiate OneDrive re-auth: %s", str(e), exc_info=True) - raise HTTPException(status_code=500, detail=f"Failed to initiate OneDrive re-auth: {e!s}") from e + raise HTTPException( + status_code=500, detail=f"Failed to initiate OneDrive re-auth: {e!s}" + ) from e @router.get("/auth/onedrive/connector/callback") @@ -182,10 +201,14 @@ async def onedrive_callback( return RedirectResponse( url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=onedrive_oauth_denied" ) - return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=onedrive_oauth_denied") + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=onedrive_oauth_denied" + ) if not code or not state: - raise HTTPException(status_code=400, detail="Missing required OAuth parameters") + raise HTTPException( + status_code=400, detail="Missing required OAuth parameters" + ) state_manager = get_state_manager() try: @@ -194,7 +217,9 @@ async def onedrive_callback( user_id = UUID(data["user_id"]) except (HTTPException, ValueError, KeyError) as e: logger.error("Invalid OAuth state: %s", str(e)) - return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=invalid_state") + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=invalid_state" + ) reauth_connector_id = data.get("connector_id") reauth_return_url = data.get("return_url") @@ -222,20 +247,26 @@ async def onedrive_callback( error_detail = error_json.get("error_description", error_detail) except Exception: pass - raise HTTPException(status_code=400, detail=f"Token exchange failed: {error_detail}") + raise HTTPException( + status_code=400, detail=f"Token exchange failed: {error_detail}" + ) token_json = token_response.json() access_token = token_json.get("access_token") refresh_token = token_json.get("refresh_token") if not access_token: - raise HTTPException(status_code=400, detail="No access token received from Microsoft") + raise HTTPException( + status_code=400, detail="No access token received from Microsoft" + ) token_encryption = get_token_encryption() expires_at = None if token_json.get("expires_in"): - expires_at = datetime.now(UTC) + timedelta(seconds=int(token_json["expires_in"])) + expires_at = datetime.now(UTC) + timedelta( + seconds=int(token_json["expires_in"]) + ) user_info: dict = {} try: @@ -248,7 +279,8 @@ async def onedrive_callback( if user_response.status_code == 200: user_data = user_response.json() user_info = { - "user_email": user_data.get("mail") or user_data.get("userPrincipalName"), + "user_email": user_data.get("mail") + or user_data.get("userPrincipalName"), "user_name": user_data.get("displayName"), } except Exception as e: @@ -256,7 +288,9 @@ async def onedrive_callback( connector_config = { "access_token": token_encryption.encrypt_token(access_token), - "refresh_token": token_encryption.encrypt_token(refresh_token) if refresh_token else None, + "refresh_token": token_encryption.encrypt_token(refresh_token) + if refresh_token + else None, "token_type": token_json.get("token_type", "Bearer"), "expires_in": token_json.get("expires_in"), "expires_at": expires_at.isoformat() if expires_at else None, @@ -273,22 +307,36 @@ async def onedrive_callback( SearchSourceConnector.id == reauth_connector_id, SearchSourceConnector.user_id == user_id, SearchSourceConnector.search_space_id == space_id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, ) ) db_connector = result.scalars().first() if not db_connector: - raise HTTPException(status_code=404, detail="Connector not found or access denied during re-auth") + raise HTTPException( + status_code=404, + detail="Connector not found or access denied during re-auth", + ) existing_delta_link = db_connector.config.get("delta_link") - db_connector.config = {**connector_config, "delta_link": existing_delta_link, "auth_expired": False} + db_connector.config = { + **connector_config, + "delta_link": existing_delta_link, + "auth_expired": False, + } flag_modified(db_connector, "config") await session.commit() await session.refresh(db_connector) - logger.info("Re-authenticated OneDrive connector %s for user %s", db_connector.id, user_id) + logger.info( + "Re-authenticated OneDrive connector %s for user %s", + db_connector.id, + user_id, + ) if reauth_return_url and reauth_return_url.startswith("/"): - return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}") + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}" + ) return RedirectResponse( url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=ONEDRIVE_CONNECTOR&connectorId={db_connector.id}" ) @@ -298,16 +346,26 @@ async def onedrive_callback( SearchSourceConnectorType.ONEDRIVE_CONNECTOR, connector_config ) is_duplicate = await check_duplicate_connector( - session, SearchSourceConnectorType.ONEDRIVE_CONNECTOR, space_id, user_id, connector_identifier, + session, + SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + space_id, + user_id, + connector_identifier, ) if is_duplicate: - logger.warning("Duplicate OneDrive connector for user %s, space %s", user_id, space_id) + logger.warning( + "Duplicate OneDrive connector for user %s, space %s", user_id, space_id + ) return RedirectResponse( url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=ONEDRIVE_CONNECTOR" ) connector_name = await generate_unique_connector_name( - session, SearchSourceConnectorType.ONEDRIVE_CONNECTOR, space_id, user_id, connector_identifier, + session, + SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + space_id, + user_id, + connector_identifier, ) new_connector = SearchSourceConnector( @@ -323,20 +381,30 @@ async def onedrive_callback( session.add(new_connector) await session.commit() await session.refresh(new_connector) - logger.info("Successfully created OneDrive connector %s for user %s", new_connector.id, user_id) + logger.info( + "Successfully created OneDrive connector %s for user %s", + new_connector.id, + user_id, + ) return RedirectResponse( url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=ONEDRIVE_CONNECTOR&connectorId={new_connector.id}" ) except IntegrityError as e: await session.rollback() - logger.error("Database integrity error creating OneDrive connector: %s", str(e)) - return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=connector_creation_failed") + logger.error( + "Database integrity error creating OneDrive connector: %s", str(e) + ) + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=connector_creation_failed" + ) except HTTPException: raise except (IntegrityError, ValueError) as e: logger.error("OneDrive OAuth callback error: %s", str(e), exc_info=True) - return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=onedrive_auth_error") + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=onedrive_auth_error" + ) @router.get("/connectors/{connector_id}/onedrive/folders") @@ -353,28 +421,44 @@ async def list_onedrive_folders( select(SearchSourceConnector).filter( SearchSourceConnector.id == connector_id, SearchSourceConnector.user_id == user.id, - SearchSourceConnector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, ) ) connector = result.scalars().first() if not connector: - raise HTTPException(status_code=404, detail="OneDrive connector not found or access denied") + raise HTTPException( + status_code=404, detail="OneDrive connector not found or access denied" + ) onedrive_client = OneDriveClient(session, connector_id) items, error = await list_folder_contents(onedrive_client, parent_id=parent_id) if error: error_lower = error.lower() - if "401" in error or "authentication expired" in error_lower or "invalid_grant" in error_lower: + if ( + "401" in error + or "authentication expired" in error_lower + or "invalid_grant" in error_lower + ): try: if connector and not connector.config.get("auth_expired"): connector.config = {**connector.config, "auth_expired": True} flag_modified(connector, "config") await session.commit() except Exception: - logger.warning("Failed to persist auth_expired for connector %s", connector_id, exc_info=True) - raise HTTPException(status_code=400, detail="OneDrive authentication expired. Please re-authenticate.") - raise HTTPException(status_code=500, detail=f"Failed to list folder contents: {error}") + logger.warning( + "Failed to persist auth_expired for connector %s", + connector_id, + exc_info=True, + ) + raise HTTPException( + status_code=400, + detail="OneDrive authentication expired. Please re-authenticate.", + ) + raise HTTPException( + status_code=500, detail=f"Failed to list folder contents: {error}" + ) return {"items": items} @@ -391,8 +475,13 @@ async def list_onedrive_folders( await session.commit() except Exception: pass - raise HTTPException(status_code=400, detail="OneDrive authentication expired. Please re-authenticate.") from e - raise HTTPException(status_code=500, detail=f"Failed to list OneDrive contents: {e!s}") from e + raise HTTPException( + status_code=400, + detail="OneDrive authentication expired. Please re-authenticate.", + ) from e + raise HTTPException( + status_code=500, detail=f"Failed to list OneDrive contents: {e!s}" + ) from e async def refresh_onedrive_token( @@ -410,10 +499,15 @@ async def refresh_onedrive_token( refresh_token = token_encryption.decrypt_token(refresh_token) except Exception as e: logger.error("Failed to decrypt refresh token: %s", str(e)) - raise HTTPException(status_code=500, detail="Failed to decrypt stored refresh token") from e + raise HTTPException( + status_code=500, detail="Failed to decrypt stored refresh token" + ) from e if not refresh_token: - raise HTTPException(status_code=400, detail=f"No refresh token available for connector {connector.id}") + raise HTTPException( + status_code=400, + detail=f"No refresh token available for connector {connector.id}", + ) refresh_data = { "client_id": config.MICROSOFT_CLIENT_ID, @@ -425,8 +519,10 @@ async def refresh_onedrive_token( async with httpx.AsyncClient() as client: token_response = await client.post( - TOKEN_URL, data=refresh_data, - headers={"Content-Type": "application/x-www-form-urlencoded"}, timeout=30.0, + TOKEN_URL, + data=refresh_data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30.0, ) if token_response.status_code != 200: @@ -439,16 +535,27 @@ async def refresh_onedrive_token( except Exception: pass error_lower = (error_detail + error_code).lower() - if "invalid_grant" in error_lower or "expired" in error_lower or "revoked" in error_lower: - raise HTTPException(status_code=401, detail="OneDrive authentication failed. Please re-authenticate.") - raise HTTPException(status_code=400, detail=f"Token refresh failed: {error_detail}") + if ( + "invalid_grant" in error_lower + or "expired" in error_lower + or "revoked" in error_lower + ): + raise HTTPException( + status_code=401, + detail="OneDrive authentication failed. Please re-authenticate.", + ) + raise HTTPException( + status_code=400, detail=f"Token refresh failed: {error_detail}" + ) token_json = token_response.json() access_token = token_json.get("access_token") new_refresh_token = token_json.get("refresh_token") if not access_token: - raise HTTPException(status_code=400, detail="No access token received from Microsoft refresh") + raise HTTPException( + status_code=400, detail="No access token received from Microsoft refresh" + ) expires_at = None expires_in = token_json.get("expires_in") diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 7e9ac1e59..d12fa3745 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -2567,8 +2567,12 @@ async def run_onedrive_indexing( search_space_id=search_space_id, folder_count=len(items_dict.get("folders", [])), file_count=len(items_dict.get("files", [])), - folder_names=[f.get("name", "Unknown") for f in items_dict.get("folders", [])], - file_names=[f.get("name", "Unknown") for f in items_dict.get("files", [])], + folder_names=[ + f.get("name", "Unknown") for f in items_dict.get("folders", []) + ], + file_names=[ + f.get("name", "Unknown") for f in items_dict.get("files", []) + ], ) if notification: @@ -2593,7 +2597,9 @@ async def run_onedrive_indexing( ) if _is_auth_error(error_message): await _persist_auth_expired(session, connector_id) - error_message = "OneDrive authentication expired. Please re-authenticate." + error_message = ( + "OneDrive authentication expired. Please re-authenticate." + ) else: if notification: await session.refresh(notification) diff --git a/surfsense_backend/app/services/onedrive/kb_sync_service.py b/surfsense_backend/app/services/onedrive/kb_sync_service.py index 5e82950a5..962c19fc9 100644 --- a/surfsense_backend/app/services/onedrive/kb_sync_service.py +++ b/surfsense_backend/app/services/onedrive/kb_sync_service.py @@ -56,9 +56,7 @@ async def sync_after_create( indexable_content = (content or "").strip() if not indexable_content: - indexable_content = ( - f"OneDrive file: {file_name} (type: {mime_type})" - ) + indexable_content = f"OneDrive file: {file_name} (type: {mime_type})" content_hash = generate_content_hash(indexable_content, search_space_id) @@ -95,9 +93,7 @@ async def sync_after_create( ) else: logger.warning("No LLM configured — using fallback summary") - summary_content = ( - f"OneDrive File: {file_name}\n\n{indexable_content}" - ) + summary_content = f"OneDrive File: {file_name}\n\n{indexable_content}" summary_embedding = embed_text(summary_content) chunks = await create_document_chunks(indexable_content) diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index c1ca089d0..4b37cb69e 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1075,6 +1075,37 @@ def complete_current_step() -> str | None: "thread_id": thread_id_str, }, ) + elif tool_name == "web_search": + xml = ( + tool_output.get("result", str(tool_output)) + if isinstance(tool_output, dict) + else str(tool_output) + ) + citations: dict[str, dict[str, str]] = {} + for m in re.finditer( + r"<!\[CDATA\[(.*?)\]\]>\s*", + xml, + ): + title, url = m.group(1).strip(), m.group(2).strip() + if url.startswith("http") and url not in citations: + citations[url] = {"title": title} + for m in re.finditer( + r"", + xml, + ): + chunk_url, content = m.group(1).strip(), m.group(2).strip() + if ( + chunk_url.startswith("http") + and chunk_url in citations + and content + ): + citations[chunk_url]["snippet"] = ( + content[:200] + "…" if len(content) > 200 else content + ) + yield streaming_service.format_tool_output_available( + tool_call_id, + {"status": "completed", "citations": citations}, + ) else: yield streaming_service.format_tool_output_available( tool_call_id, diff --git a/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py index e565f6a6a..748cb0988 100644 --- a/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py @@ -45,6 +45,7 @@ # Helpers # --------------------------------------------------------------------------- + async def _should_skip_file( session: AsyncSession, file: dict, @@ -186,9 +187,13 @@ async def _download_one(file: dict) -> ConnectorDocument | None: logger.warning(f"Download/ETL failed for {file_name}: {reason}") return None doc = _build_connector_doc( - file, markdown, od_metadata, - connector_id=connector_id, search_space_id=search_space_id, - user_id=user_id, enable_summary=enable_summary, + file, + markdown, + od_metadata, + connector_id=connector_id, + search_space_id=search_space_id, + user_id=user_id, + enable_summary=enable_summary, ) async with hb_lock: completed_count += 1 @@ -204,9 +209,7 @@ async def _download_one(file: dict) -> ConnectorDocument | None: failed = 0 for outcome in outcomes: - if isinstance(outcome, Exception): - failed += 1 - elif outcome is None: + if isinstance(outcome, Exception) or outcome is None: failed += 1 else: results.append(outcome) @@ -227,9 +230,12 @@ async def _download_and_index( ) -> tuple[int, int]: """Parallel download then parallel indexing. Returns (batch_indexed, total_failed).""" connector_docs, download_failed = await _download_files_parallel( - onedrive_client, files, - connector_id=connector_id, search_space_id=search_space_id, - user_id=user_id, enable_summary=enable_summary, + onedrive_client, + files, + connector_id=connector_id, + search_space_id=search_space_id, + user_id=user_id, + enable_summary=enable_summary, on_heartbeat=on_heartbeat, ) @@ -242,7 +248,9 @@ async def _get_llm(s): return await get_user_long_context_llm(s, user_id, search_space_id) _, batch_indexed, batch_failed = await pipeline.index_batch_parallel( - connector_docs, _get_llm, max_concurrency=3, + connector_docs, + _get_llm, + max_concurrency=3, on_heartbeat=on_heartbeat, ) @@ -305,10 +313,14 @@ async def _index_selected_files( files_to_download.append(file) - batch_indexed, failed = await _download_and_index( - onedrive_client, session, files_to_download, - connector_id=connector_id, search_space_id=search_space_id, - user_id=user_id, enable_summary=enable_summary, + batch_indexed, _failed = await _download_and_index( + onedrive_client, + session, + files_to_download, + connector_id=connector_id, + search_space_id=search_space_id, + user_id=user_id, + enable_summary=enable_summary, on_heartbeat=on_heartbeat, ) @@ -319,6 +331,7 @@ async def _index_selected_files( # Scan strategies # --------------------------------------------------------------------------- + async def _index_full_scan( onedrive_client: OneDriveClient, session: AsyncSession, @@ -338,7 +351,11 @@ async def _index_full_scan( await task_logger.log_task_progress( log_entry, f"Starting full scan of folder: {folder_name}", - {"stage": "full_scan", "folder_id": folder_id, "include_subfolders": include_subfolders}, + { + "stage": "full_scan", + "folder_id": folder_id, + "include_subfolders": include_subfolders, + }, ) renamed_count = 0 @@ -346,12 +363,16 @@ async def _index_full_scan( files_to_download: list[dict] = [] all_files, error = await get_files_in_folder( - onedrive_client, folder_id, include_subfolders=include_subfolders, + onedrive_client, + folder_id, + include_subfolders=include_subfolders, ) if error: err_lower = error.lower() if "401" in error or "authentication expired" in err_lower: - raise Exception(f"OneDrive authentication failed. Please re-authenticate. (Error: {error})") + raise Exception( + f"OneDrive authentication failed. Please re-authenticate. (Error: {error})" + ) raise Exception(f"Failed to list OneDrive files: {error}") for file in all_files[:max_files]: @@ -365,14 +386,20 @@ async def _index_full_scan( files_to_download.append(file) batch_indexed, failed = await _download_and_index( - onedrive_client, session, files_to_download, - connector_id=connector_id, search_space_id=search_space_id, - user_id=user_id, enable_summary=enable_summary, + onedrive_client, + session, + files_to_download, + connector_id=connector_id, + search_space_id=search_space_id, + user_id=user_id, + enable_summary=enable_summary, on_heartbeat=on_heartbeat_callback, ) indexed = renamed_count + batch_indexed - logger.info(f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed") + logger.info( + f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed" + ) return indexed, skipped @@ -392,7 +419,8 @@ async def _index_with_delta_sync( ) -> tuple[int, int, str | None]: """Delta sync using OneDrive change tracking. Returns (indexed, skipped, new_delta_link).""" await task_logger.log_task_progress( - log_entry, "Starting delta sync", + log_entry, + "Starting delta sync", {"stage": "delta_sync"}, ) @@ -402,7 +430,9 @@ async def _index_with_delta_sync( if error: err_lower = error.lower() if "401" in error or "authentication expired" in err_lower: - raise Exception(f"OneDrive authentication failed. Please re-authenticate. (Error: {error})") + raise Exception( + f"OneDrive authentication failed. Please re-authenticate. (Error: {error})" + ) raise Exception(f"Failed to fetch OneDrive changes: {error}") if not changes: @@ -444,14 +474,20 @@ async def _index_with_delta_sync( files_to_download.append(change) batch_indexed, failed = await _download_and_index( - onedrive_client, session, files_to_download, - connector_id=connector_id, search_space_id=search_space_id, - user_id=user_id, enable_summary=enable_summary, + onedrive_client, + session, + files_to_download, + connector_id=connector_id, + search_space_id=search_space_id, + user_id=user_id, + enable_summary=enable_summary, on_heartbeat=on_heartbeat_callback, ) indexed = renamed_count + batch_indexed - logger.info(f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed") + logger.info( + f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed" + ) return indexed, skipped, new_delta_link @@ -459,6 +495,7 @@ async def _index_with_delta_sync( # Public entry point # --------------------------------------------------------------------------- + async def index_onedrive_files( session: AsyncSession, connector_id: int, @@ -489,13 +526,20 @@ async def index_onedrive_files( ) if not connector: error_msg = f"OneDrive connector with ID {connector_id} not found" - await task_logger.log_task_failure(log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}) + await task_logger.log_task_failure( + log_entry, error_msg, None, {"error_type": "ConnectorNotFound"} + ) return 0, 0, error_msg token_encrypted = connector.config.get("_token_encrypted", False) if token_encrypted and not config.SECRET_KEY: error_msg = "SECRET_KEY not configured but credentials are encrypted" - await task_logger.log_task_failure(log_entry, error_msg, "Missing SECRET_KEY", {"error_type": "MissingSecretKey"}) + await task_logger.log_task_failure( + log_entry, + error_msg, + "Missing SECRET_KEY", + {"error_type": "MissingSecretKey"}, + ) return 0, 0, error_msg connector_enable_summary = getattr(connector, "enable_summary", True) @@ -513,10 +557,14 @@ async def index_onedrive_files( selected_files = items_dict.get("files", []) if selected_files: file_tuples = [(f["id"], f.get("name")) for f in selected_files] - indexed, skipped, errors = await _index_selected_files( - onedrive_client, session, file_tuples, - connector_id=connector_id, search_space_id=search_space_id, - user_id=user_id, enable_summary=connector_enable_summary, + indexed, skipped, _errors = await _index_selected_files( + onedrive_client, + session, + file_tuples, + connector_id=connector_id, + search_space_id=search_space_id, + user_id=user_id, + enable_summary=connector_enable_summary, ) total_indexed += indexed total_skipped += skipped @@ -534,8 +582,16 @@ async def index_onedrive_files( if can_use_delta: logger.info(f"Using delta sync for folder {folder_name}") indexed, skipped, new_delta_link = await _index_with_delta_sync( - onedrive_client, session, connector_id, search_space_id, user_id, - folder_id, delta_link, task_logger, log_entry, max_files, + onedrive_client, + session, + connector_id, + search_space_id, + user_id, + folder_id, + delta_link, + task_logger, + log_entry, + max_files, enable_summary=connector_enable_summary, ) total_indexed += indexed @@ -550,18 +606,36 @@ async def index_onedrive_files( # Reconciliation full scan ri, rs = await _index_full_scan( - onedrive_client, session, connector_id, search_space_id, user_id, - folder_id, folder_name, task_logger, log_entry, max_files, - include_subfolders, enable_summary=connector_enable_summary, + onedrive_client, + session, + connector_id, + search_space_id, + user_id, + folder_id, + folder_name, + task_logger, + log_entry, + max_files, + include_subfolders, + enable_summary=connector_enable_summary, ) total_indexed += ri total_skipped += rs else: logger.info(f"Using full scan for folder {folder_name}") indexed, skipped = await _index_full_scan( - onedrive_client, session, connector_id, search_space_id, user_id, - folder_id, folder_name, task_logger, log_entry, max_files, - include_subfolders, enable_summary=connector_enable_summary, + onedrive_client, + session, + connector_id, + search_space_id, + user_id, + folder_id, + folder_name, + task_logger, + log_entry, + max_files, + include_subfolders, + enable_summary=connector_enable_summary, ) total_indexed += indexed total_skipped += skipped @@ -585,22 +659,28 @@ async def index_onedrive_files( f"Successfully completed OneDrive indexing for connector {connector_id}", {"files_processed": total_indexed, "files_skipped": total_skipped}, ) - logger.info(f"OneDrive indexing completed: {total_indexed} indexed, {total_skipped} skipped") + logger.info( + f"OneDrive indexing completed: {total_indexed} indexed, {total_skipped} skipped" + ) return total_indexed, total_skipped, None except SQLAlchemyError as db_error: await session.rollback() await task_logger.log_task_failure( - log_entry, f"Database error during OneDrive indexing for connector {connector_id}", - str(db_error), {"error_type": "SQLAlchemyError"}, + log_entry, + f"Database error during OneDrive indexing for connector {connector_id}", + str(db_error), + {"error_type": "SQLAlchemyError"}, ) logger.error(f"Database error: {db_error!s}", exc_info=True) return 0, 0, f"Database error: {db_error!s}" except Exception as e: await session.rollback() await task_logger.log_task_failure( - log_entry, f"Failed to index OneDrive files for connector {connector_id}", - str(e), {"error_type": type(e).__name__}, + log_entry, + f"Failed to index OneDrive files for connector {connector_id}", + str(e), + {"error_type": type(e).__name__}, ) logger.error(f"Failed to index OneDrive files: {e!s}", exc_info=True) return 0, 0, f"Failed to index OneDrive files: {e!s}" diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_onedrive_pipeline.py b/surfsense_backend/tests/integration/indexing_pipeline/test_onedrive_pipeline.py index ee83795a5..541e3a38e 100644 --- a/surfsense_backend/tests/integration/indexing_pipeline/test_onedrive_pipeline.py +++ b/surfsense_backend/tests/integration/indexing_pipeline/test_onedrive_pipeline.py @@ -13,7 +13,9 @@ pytestmark = pytest.mark.integration -def _onedrive_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument: +def _onedrive_doc( + *, unique_id: str, search_space_id: int, connector_id: int, user_id: str +) -> ConnectorDocument: return ConnectorDocument( title=f"File {unique_id}.docx", source_markdown=f"## Document\n\nContent from {unique_id}", @@ -32,7 +34,9 @@ def _onedrive_doc(*, unique_id: str, search_space_id: int, connector_id: int, us ) -@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize", "patched_embed_texts", "patched_chunk_text" +) async def test_onedrive_pipeline_creates_ready_document( db_session, db_search_space, db_connector, db_user, mocker ): @@ -61,7 +65,9 @@ async def test_onedrive_pipeline_creates_ready_document( assert DocumentStatus.is_state(row.status, DocumentStatus.READY) -@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text") +@pytest.mark.usefixtures( + "patched_summarize", "patched_embed_texts", "patched_chunk_text" +) async def test_onedrive_duplicate_content_skipped( db_session, db_search_space, db_connector, db_user, mocker ): @@ -87,8 +93,6 @@ async def test_onedrive_duplicate_content_skipped( ) first_doc = result.scalars().first() assert first_doc is not None - first_id = first_doc.id - doc2 = _onedrive_doc( unique_id="od-dup-file", search_space_id=space_id, @@ -97,4 +101,6 @@ async def test_onedrive_duplicate_content_skipped( ) prepared2 = await service.prepare_for_indexing([doc2]) - assert len(prepared2) == 0 or (len(prepared2) == 1 and prepared2[0].existing_document is not None) + assert len(prepared2) == 0 or ( + len(prepared2) == 1 and prepared2[0].existing_document is not None + ) diff --git a/surfsense_backend/tests/unit/connector_indexers/test_onedrive_parallel.py b/surfsense_backend/tests/unit/connector_indexers/test_onedrive_parallel.py index b5c774c6f..12a912b03 100644 --- a/surfsense_backend/tests/unit/connector_indexers/test_onedrive_parallel.py +++ b/surfsense_backend/tests/unit/connector_indexers/test_onedrive_parallel.py @@ -48,12 +48,14 @@ def _patch(side_effect=None, return_value=None): mock, ) return mock + return _patch # Slice 1: Tracer bullet async def test_single_file_returns_one_connector_document( - mock_onedrive_client, patch_extract, + mock_onedrive_client, + patch_extract, ): patch_extract(return_value=_mock_extract_ok("f1", "test.txt")) @@ -75,7 +77,8 @@ async def test_single_file_returns_one_connector_document( # Slice 2: Multiple files all produce documents async def test_multiple_files_all_produce_documents( - mock_onedrive_client, patch_extract, + mock_onedrive_client, + patch_extract, ): files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)] patch_extract( @@ -98,7 +101,8 @@ async def test_multiple_files_all_produce_documents( # Slice 3: Error isolation async def test_one_download_exception_does_not_block_others( - mock_onedrive_client, patch_extract, + mock_onedrive_client, + patch_extract, ): files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)] patch_extract( @@ -125,7 +129,8 @@ async def test_one_download_exception_does_not_block_others( # Slice 4: ETL error counts as download failure async def test_etl_error_counts_as_download_failure( - mock_onedrive_client, patch_extract, + mock_onedrive_client, + patch_extract, ): files = [_make_file_dict("f0", "good.txt"), _make_file_dict("f1", "bad.txt")] patch_extract( @@ -150,7 +155,8 @@ async def test_etl_error_counts_as_download_failure( # Slice 5: Semaphore bound async def test_concurrency_bounded_by_semaphore( - mock_onedrive_client, monkeypatch, + mock_onedrive_client, + monkeypatch, ): lock = asyncio.Lock() active = 0 @@ -190,7 +196,8 @@ async def _slow_extract(client, file): # Slice 6: Heartbeat fires async def test_heartbeat_fires_during_parallel_downloads( - mock_onedrive_client, monkeypatch, + mock_onedrive_client, + monkeypatch, ): import app.tasks.connector_indexers.onedrive_indexer as _mod diff --git a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx index 25e4e990b..1715e525f 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx @@ -183,6 +183,10 @@ export function DashboardClientLayout({ ); } + if (isOnboardingPage) { + return <>{children}; + } + return ( diff --git a/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell.tsx b/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell.tsx index 68d971fc4..92ced6e47 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell.tsx @@ -744,7 +744,11 @@ export function DocumentsTableShell({ - onOpenInTab ? onOpenInTab(doc) : handleViewDocument(doc)}> + + onOpenInTab ? onOpenInTab(doc) : handleViewDocument(doc) + } + > Open @@ -986,9 +990,10 @@ export function DocumentsTableShell({ handleDeleteFromMenu(); }} disabled={isDeleting} - className="bg-destructive text-destructive-foreground hover:bg-destructive/90" + className="relative bg-destructive text-destructive-foreground hover:bg-destructive/90" > - {isDeleting ? : "Delete"} + Delete + {isDeleting && } @@ -1104,9 +1109,10 @@ export function DocumentsTableShell({ handleBulkDelete(); }} disabled={isBulkDeleting} - className="bg-destructive text-destructive-foreground hover:bg-destructive/90" + className="relative bg-destructive text-destructive-foreground hover:bg-destructive/90" > - {isBulkDeleting ? : "Delete"} + Delete + {isBulkDeleting && } diff --git a/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/RowActions.tsx b/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/RowActions.tsx index a8b85e20b..5b7451c61 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/RowActions.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/RowActions.tsx @@ -24,7 +24,7 @@ import { } from "@/components/ui/dropdown-menu"; import type { Document } from "./types"; -const EDITABLE_DOCUMENT_TYPES = ["NOTE"] as const; +const EDITABLE_DOCUMENT_TYPES = ["FILE", "NOTE"] as const; // SURFSENSE_DOCS are system-managed and cannot be deleted const NON_DELETABLE_DOCUMENT_TYPES = ["SURFSENSE_DOCS"] as const; diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 9809c9b2e..8928974d9 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -33,7 +33,7 @@ import { closeReportPanelAtom } from "@/atoms/chat/report-panel.atom"; import { type AgentCreatedDocument, agentCreatedDocumentsAtom } from "@/atoms/documents/ui.atoms"; import { closeEditorPanelAtom } from "@/atoms/editor/editor-panel.atom"; import { membersAtom } from "@/atoms/members/members-query.atoms"; -import { updateChatTabTitleAtom } from "@/atoms/tabs/tabs.atom"; +import { removeChatTabAtom, updateChatTabTitleAtom } from "@/atoms/tabs/tabs.atom"; import { currentUserAtom } from "@/atoms/user/user-query.atoms"; import { ThinkingStepsDataUI } from "@/components/assistant-ui/thinking-steps"; import { Thread } from "@/components/assistant-ui/thread"; @@ -70,6 +70,7 @@ import { getThreadMessages, type ThreadRecord, } from "@/lib/chat/thread-persistence"; +import { NotFoundError } from "@/lib/error"; import { trackChatCreated, trackChatError, @@ -131,6 +132,7 @@ function extractMentionedDocuments(content: unknown): MentionedDocumentInfo[] { * Tools that should render custom UI in the chat. */ const TOOLS_WITH_UI = new Set([ + "web_search", "generate_podcast", "generate_report", "generate_video_presentation", @@ -194,6 +196,7 @@ export default function NewChatPage() { const closeReportPanel = useSetAtom(closeReportPanelAtom); const closeEditorPanel = useSetAtom(closeEditorPanelAtom); const updateChatTabTitle = useSetAtom(updateChatTabTitleAtom); + const removeChatTab = useSetAtom(removeChatTabAtom); const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom); // Get current user for author info in shared chats @@ -323,6 +326,14 @@ export default function NewChatPage() { // This improves UX (instant load) and avoids orphan threads } catch (error) { console.error("[NewChatPage] Failed to initialize thread:", error); + if (urlChatId > 0 && error instanceof NotFoundError) { + removeChatTab(urlChatId); + if (typeof window !== "undefined") { + window.history.replaceState(null, "", `/dashboard/${searchSpaceId}/new-chat`); + } + toast.error("This chat was deleted."); + return; + } // Keep threadId as null - don't use Date.now() as it creates an invalid ID // that will cause 404 errors on subsequent API calls setThreadId(null); @@ -338,12 +349,14 @@ export default function NewChatPage() { setSidebarDocuments, closeReportPanel, closeEditorPanel, + removeChatTab, + searchSpaceId, ]); // Initialize on mount, and re-init when switching search spaces (even if urlChatId is the same) useEffect(() => { initializeThread(); - }, [initializeThread, searchSpaceId]); + }, [initializeThread]); // Prefetch document titles for @ mention picker // Runs when user lands on page so data is ready when they type @ @@ -483,18 +496,17 @@ export default function NewChatPage() { // Add user message to state const userMsgId = `msg-user-${Date.now()}`; - // Include author metadata for shared chats - const authorMetadata = - currentThread?.visibility === "SEARCH_SPACE" && currentUser - ? { - custom: { - author: { - displayName: currentUser.display_name ?? null, - avatarUrl: currentUser.avatar_url ?? null, - }, + // Always include author metadata so the UI layer can decide visibility + const authorMetadata = currentUser + ? { + custom: { + author: { + displayName: currentUser.display_name ?? null, + avatarUrl: currentUser.avatar_url ?? null, }, - } - : undefined; + }, + } + : undefined; const userMessage: ThreadMessageLike = { id: userMsgId, @@ -654,62 +666,62 @@ export default function NewChatPage() { const scheduleFlush = () => batcher.schedule(flushMessages); for await (const parsed of readSSEStream(response)) { - switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - break; + switch (parsed.type) { + case "text-delta": + appendText(contentPartsState, parsed.delta); + scheduleFlush(); + break; - case "tool-input-start": - addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {}); - batcher.flush(); - break; + case "tool-input-start": + addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {}); + batcher.flush(); + break; - case "tool-input-available": { - if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} }); - } else { - addToolCall( - contentPartsState, - TOOLS_WITH_UI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {} - ); + case "tool-input-available": { + if (toolCallIndices.has(parsed.toolCallId)) { + updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} }); + } else { + addToolCall( + contentPartsState, + TOOLS_WITH_UI, + parsed.toolCallId, + parsed.toolName, + parsed.input || {} + ); + } + batcher.flush(); + break; } - batcher.flush(); - break; - } - case "tool-output-available": { - updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output }); - markInterruptsCompleted(contentParts); - if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { - const idx = toolCallIndices.get(parsed.toolCallId); - if (idx !== undefined) { - const part = contentParts[idx]; - if (part?.type === "tool-call" && part.toolName === "generate_podcast") { - setActivePodcastTaskId(String(parsed.output.podcast_id)); + case "tool-output-available": { + updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output }); + markInterruptsCompleted(contentParts); + if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { + const idx = toolCallIndices.get(parsed.toolCallId); + if (idx !== undefined) { + const part = contentParts[idx]; + if (part?.type === "tool-call" && part.toolName === "generate_podcast") { + setActivePodcastTaskId(String(parsed.output.podcast_id)); + } } } + batcher.flush(); + break; } - batcher.flush(); - break; - } - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); + case "data-thinking-step": { + const stepData = parsed.data as ThinkingStepData; + if (stepData?.id) { + currentThinkingSteps.set(stepData.id, stepData); + const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); + if (didUpdate) { + scheduleFlush(); + } } + break; } - break; - } - case "data-thread-title-update": { + case "data-thread-title-update": { const titleData = parsed.data as { threadId: number; title: string }; if (titleData?.title && titleData?.threadId === currentThreadId) { setCurrentThread((prev) => (prev ? { ...prev, title: titleData.title } : prev)); @@ -882,7 +894,6 @@ export default function NewChatPage() { setMessageDocumentsMap, setAgentCreatedDocuments, queryClient, - currentThread, currentUser, disabledTools, updateChatTabTitle, @@ -1001,7 +1012,7 @@ export default function NewChatPage() { throw new Error(`Backend error: ${response.status}`); } - const flushMessages = () => { + const flushMessages = () => { setMessages((prev) => prev.map((m) => m.id === assistantMsgId @@ -1013,55 +1024,55 @@ export default function NewChatPage() { const scheduleFlush = () => batcher.schedule(flushMessages); for await (const parsed of readSSEStream(response)) { - switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - break; + switch (parsed.type) { + case "text-delta": + appendText(contentPartsState, parsed.delta); + scheduleFlush(); + break; - case "tool-input-start": - addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {}); - batcher.flush(); - break; + case "tool-input-start": + addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {}); + batcher.flush(); + break; + + case "tool-input-available": + if (toolCallIndices.has(parsed.toolCallId)) { + updateToolCall(contentPartsState, parsed.toolCallId, { + args: parsed.input || {}, + }); + } else { + addToolCall( + contentPartsState, + TOOLS_WITH_UI, + parsed.toolCallId, + parsed.toolName, + parsed.input || {} + ); + } + batcher.flush(); + break; - case "tool-input-available": - if (toolCallIndices.has(parsed.toolCallId)) { + case "tool-output-available": updateToolCall(contentPartsState, parsed.toolCallId, { - args: parsed.input || {}, + result: parsed.output, }); - } else { - addToolCall( - contentPartsState, - TOOLS_WITH_UI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {} - ); - } - batcher.flush(); - break; - - case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { - result: parsed.output, - }); - markInterruptsCompleted(contentParts); - batcher.flush(); - break; + markInterruptsCompleted(contentParts); + batcher.flush(); + break; - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); + case "data-thinking-step": { + const stepData = parsed.data as ThinkingStepData; + if (stepData?.id) { + currentThinkingSteps.set(stepData.id, stepData); + const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); + if (didUpdate) { + scheduleFlush(); + } } + break; } - break; - } - case "data-interrupt-request": { + case "data-interrupt-request": { const interruptData = parsed.data as Record; const actionRequests = (interruptData.action_requests ?? []) as Array<{ name: string; @@ -1319,7 +1330,7 @@ export default function NewChatPage() { throw new Error(`Backend error: ${response.status}`); } - const flushMessages = () => { + const flushMessages = () => { setMessages((prev) => prev.map((m) => m.id === assistantMsgId @@ -1331,63 +1342,63 @@ export default function NewChatPage() { const scheduleFlush = () => batcher.schedule(flushMessages); for await (const parsed of readSSEStream(response)) { - switch (parsed.type) { - case "text-delta": - appendText(contentPartsState, parsed.delta); - scheduleFlush(); - break; + switch (parsed.type) { + case "text-delta": + appendText(contentPartsState, parsed.delta); + scheduleFlush(); + break; - case "tool-input-start": - addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {}); - batcher.flush(); - break; + case "tool-input-start": + addToolCall(contentPartsState, TOOLS_WITH_UI, parsed.toolCallId, parsed.toolName, {}); + batcher.flush(); + break; - case "tool-input-available": - if (toolCallIndices.has(parsed.toolCallId)) { - updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} }); - } else { - addToolCall( - contentPartsState, - TOOLS_WITH_UI, - parsed.toolCallId, - parsed.toolName, - parsed.input || {} - ); - } - batcher.flush(); - break; + case "tool-input-available": + if (toolCallIndices.has(parsed.toolCallId)) { + updateToolCall(contentPartsState, parsed.toolCallId, { args: parsed.input || {} }); + } else { + addToolCall( + contentPartsState, + TOOLS_WITH_UI, + parsed.toolCallId, + parsed.toolName, + parsed.input || {} + ); + } + batcher.flush(); + break; - case "tool-output-available": - updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output }); - markInterruptsCompleted(contentParts); - if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { - const idx = toolCallIndices.get(parsed.toolCallId); - if (idx !== undefined) { - const part = contentParts[idx]; - if (part?.type === "tool-call" && part.toolName === "generate_podcast") { - setActivePodcastTaskId(String(parsed.output.podcast_id)); + case "tool-output-available": + updateToolCall(contentPartsState, parsed.toolCallId, { result: parsed.output }); + markInterruptsCompleted(contentParts); + if (parsed.output?.status === "pending" && parsed.output?.podcast_id) { + const idx = toolCallIndices.get(parsed.toolCallId); + if (idx !== undefined) { + const part = contentParts[idx]; + if (part?.type === "tool-call" && part.toolName === "generate_podcast") { + setActivePodcastTaskId(String(parsed.output.podcast_id)); + } } } - } - batcher.flush(); - break; + batcher.flush(); + break; - case "data-thinking-step": { - const stepData = parsed.data as ThinkingStepData; - if (stepData?.id) { - currentThinkingSteps.set(stepData.id, stepData); - const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); - if (didUpdate) { - scheduleFlush(); + case "data-thinking-step": { + const stepData = parsed.data as ThinkingStepData; + if (stepData?.id) { + currentThinkingSteps.set(stepData.id, stepData); + const didUpdate = updateThinkingSteps(contentPartsState, currentThinkingSteps); + if (didUpdate) { + scheduleFlush(); + } } + break; } - break; - } - case "error": - throw new Error(parsed.errorText || "Server error"); + case "error": + throw new Error(parsed.errorText || "Server error"); + } } - } batcher.flush(); @@ -1536,4 +1547,4 @@ export default function NewChatPage() { ); -} \ No newline at end of file +} diff --git a/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx index b188d7c8f..4dba3bbb6 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx @@ -1,7 +1,6 @@ "use client"; -import { useAtomValue, useSetAtom } from "jotai"; -import { motion } from "motion/react"; +import { useAtomValue } from "jotai"; import { useParams, useRouter } from "next/navigation"; import { useEffect, useRef, useState } from "react"; import { toast } from "sonner"; @@ -13,19 +12,17 @@ import { globalNewLLMConfigsAtom, llmPreferencesAtom, } from "@/atoms/new-llm-config/new-llm-config-query.atoms"; -import { searchSpaceSettingsDialogAtom } from "@/atoms/settings/settings-dialog.atoms"; import { Logo } from "@/components/Logo"; import { LLMConfigForm, type LLMConfigFormData } from "@/components/shared/llm-config-form"; -import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; +import { Button } from "@/components/ui/button"; import { Spinner } from "@/components/ui/spinner"; +import { useGlobalLoadingEffect } from "@/hooks/use-global-loading"; import { getBearerToken, redirectToLogin } from "@/lib/auth-utils"; export default function OnboardPage() { const router = useRouter(); const params = useParams(); const searchSpaceId = Number(params.search_space_id); - const setSearchSpaceSettingsDialog = useSetAtom(searchSpaceSettingsDialogAtom); - // Queries const { data: globalConfigs = [], @@ -62,14 +59,12 @@ export default function OnboardPage() { preferences.document_summary_llm_id !== null && preferences.document_summary_llm_id !== undefined; - // If onboarding is already complete, redirect immediately useEffect(() => { if (!preferencesLoading && isOnboardingComplete) { router.push(`/dashboard/${searchSpaceId}/new-chat`); } }, [preferencesLoading, isOnboardingComplete, router, searchSpaceId]); - // Auto-configure if global configs are available useEffect(() => { const autoConfigureWithGlobal = async () => { if (hasAttemptedAutoConfig.current) return; @@ -77,7 +72,6 @@ export default function OnboardPage() { if (!globalConfigsLoaded) return; if (isOnboardingComplete) return; - // Only auto-configure if we have global configs if (globalConfigs.length > 0) { hasAttemptedAutoConfig.current = true; setIsAutoConfiguring(true); @@ -97,7 +91,6 @@ export default function OnboardPage() { description: `Using ${firstGlobalConfig.name}. You can customize this later in Settings.`, }); - // Redirect to new-chat router.push(`/dashboard/${searchSpaceId}/new-chat`); } catch (error) { console.error("Auto-configuration failed:", error); @@ -119,13 +112,10 @@ export default function OnboardPage() { router, ]); - // Handle form submission const handleSubmit = async (formData: LLMConfigFormData) => { try { - // Create the config const newConfig = await createConfig(formData); - // Auto-assign to all roles await updatePreferences({ search_space_id: searchSpaceId, data: { @@ -138,7 +128,6 @@ export default function OnboardPage() { description: "Redirecting to chat...", }); - // Redirect to new-chat router.push(`/dashboard/${searchSpaceId}/new-chat`); } catch (error) { console.error("Failed to create config:", error); @@ -150,124 +139,59 @@ export default function OnboardPage() { const isSubmitting = isCreating || isUpdatingPreferences; - // Loading state - if (globalConfigsLoading || preferencesLoading || isAutoConfiguring) { - return ( -
- -
-
-
- -
-
-
-

- {isAutoConfiguring ? "Setting up your AI..." : "Loading..."} -

-

- {isAutoConfiguring - ? "Auto-configuring with available settings" - : "Please wait while we check your configuration"} -

-
-
- {[0, 1, 2].map((i) => ( - - ))} -
- -
- ); + const isLoading = globalConfigsLoading || preferencesLoading || isAutoConfiguring; + useGlobalLoadingEffect(isLoading); + + if (isLoading) { + return null; } - // If global configs exist but auto-config failed, show simple message if (globalConfigs.length > 0 && !isAutoConfiguring) { - return null; // Will redirect via useEffect + return null; } - // No global configs - show the config form return ( -
-
- - {/* Header */} -
- - - - -
-

Configure Your AI

-

- Add your LLM provider to get started with SurfSense -

-
+
+
+ {/* Header */} +
+ +
+

Configure Your AI

+

+ Add your LLM provider to get started with SurfSense +

- - {/* Config Form */} - - - - LLM Configuration - - - - - - - - {/* Footer note */} - + + {/* Form card */} +
+ +
+ + {/* Footer */} +
+ - - + Start Using SurfSense + {isSubmitting && } + +

You can add more configurations later

+
); diff --git a/surfsense_web/app/dashboard/[search_space_id]/team/team-content.tsx b/surfsense_web/app/dashboard/[search_space_id]/team/team-content.tsx index b6f008887..d9ca9efb3 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/team/team-content.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/team/team-content.tsx @@ -308,7 +308,8 @@ export function TeamContent({ searchSpaceId }: TeamContentProps) { {invitesLoading ? ( ) : ( - canInvite && activeInvites.length > 0 && ( + canInvite && + activeInvites.length > 0 && ( ) )} diff --git a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PromptsContent.tsx b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PromptsContent.tsx index 38ccafa94..c2d2c01de 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PromptsContent.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/user-settings/components/PromptsContent.tsx @@ -3,11 +3,11 @@ import { PenLine, Plus, Sparkles, Trash2 } from "lucide-react"; import { useCallback, useEffect, useState } from "react"; import { toast } from "sonner"; -import type { PromptRead } from "@/contracts/types/prompts.types"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { Spinner } from "@/components/ui/spinner"; +import type { PromptRead } from "@/contracts/types/prompts.types"; import { promptsApiService } from "@/lib/apis/prompts-api.service"; interface PromptFormData { @@ -99,7 +99,9 @@ export function PromptsContent() {

- Create prompt templates triggered with / in the chat composer. + Create prompt templates triggered with{" "} + / in the + chat composer.

{!showForm && (
@@ -153,7 +159,9 @@ export function PromptsContent() { setFormData((p) => ({ ...p, name: e.target.value }))} - /> -
- -
- - - setFormData((p) => ({ ...p, description: e.target.value })) - } - /> -
- - - -
- - -
- -
- - {suggestedModels.length > 0 ? ( - - - - - - - - setFormData((p) => ({ ...p, model_name: val })) - } - /> - - - - Type a custom model name - - - - {suggestedModels.map((m) => ( - { - setFormData((p) => ({ ...p, model_name: m.value })); - setModelComboboxOpen(false); - }} - > - - {m.value} - - {m.label} - - - ))} - - - - - - ) : ( - - setFormData((p) => ({ ...p, model_name: e.target.value })) - } - /> - )} -
- -
- - setFormData((p) => ({ ...p, api_key: e.target.value }))} - /> -
- -
- - setFormData((p) => ({ ...p, api_base: e.target.value }))} - /> -
- - {formData.provider === "AZURE_OPENAI" && ( -
- - - setFormData((p) => ({ ...p, api_version: e.target.value })) - } - /> -
- )} -
- )} -
- - {/* Fixed footer */} -
- - {mode === "create" || (mode === "edit" && !isGlobal) ? ( - - ) : isAutoMode ? ( - - ) : isGlobal && config ? ( - - ) : null} -
-
- - - )} - - ); - - return typeof document !== "undefined" ? createPortal(dialogContent, document.body) : null; -} diff --git a/surfsense_web/components/new-chat/model-config-dialog.tsx b/surfsense_web/components/new-chat/model-config-dialog.tsx deleted file mode 100644 index 06ec3b9b5..000000000 --- a/surfsense_web/components/new-chat/model-config-dialog.tsx +++ /dev/null @@ -1,489 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { AlertCircle, X, Zap } from "lucide-react"; -import { AnimatePresence, motion } from "motion/react"; -import { useCallback, useEffect, useRef, useState } from "react"; -import { createPortal } from "react-dom"; -import { toast } from "sonner"; -import { - createNewLLMConfigMutationAtom, - updateLLMPreferencesMutationAtom, - updateNewLLMConfigMutationAtom, -} from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; -import { LLMConfigForm, type LLMConfigFormData } from "@/components/shared/llm-config-form"; -import { Alert, AlertDescription } from "@/components/ui/alert"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Spinner } from "@/components/ui/spinner"; -import type { - GlobalNewLLMConfig, - LiteLLMProvider, - NewLLMConfigPublic, -} from "@/contracts/types/new-llm-config.types"; -import { cn } from "@/lib/utils"; - -interface ModelConfigDialogProps { - open: boolean; - onOpenChange: (open: boolean) => void; - config: NewLLMConfigPublic | GlobalNewLLMConfig | null; - isGlobal: boolean; - searchSpaceId: number; - mode: "create" | "edit" | "view"; -} - -export function ModelConfigDialog({ - open, - onOpenChange, - config, - isGlobal, - searchSpaceId, - mode, -}: ModelConfigDialogProps) { - const [isSubmitting, setIsSubmitting] = useState(false); - const [mounted, setMounted] = useState(false); - const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); - const scrollRef = useRef(null); - - useEffect(() => { - setMounted(true); - }, []); - - const handleScroll = useCallback((e: React.UIEvent) => { - const el = e.currentTarget; - const atTop = el.scrollTop <= 2; - const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; - setScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); - }, []); - - const { mutateAsync: createConfig } = useAtomValue(createNewLLMConfigMutationAtom); - const { mutateAsync: updateConfig } = useAtomValue(updateNewLLMConfigMutationAtom); - const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); - - useEffect(() => { - const handleEscape = (e: KeyboardEvent) => { - if (e.key === "Escape" && open) { - onOpenChange(false); - } - }; - window.addEventListener("keydown", handleEscape); - return () => window.removeEventListener("keydown", handleEscape); - }, [open, onOpenChange]); - - const isAutoMode = config && "is_auto_mode" in config && config.is_auto_mode; - - const getTitle = () => { - if (mode === "create") return "Add New Configuration"; - if (isAutoMode) return "Auto Mode (Fastest)"; - if (isGlobal) return "View Global Configuration"; - return "Edit Configuration"; - }; - - const getSubtitle = () => { - if (mode === "create") return "Set up a new LLM provider for this search space"; - if (isAutoMode) return "Automatically routes requests across providers"; - if (isGlobal) return "Read-only global configuration"; - return "Update your configuration settings"; - }; - - const handleSubmit = useCallback( - async (data: LLMConfigFormData) => { - setIsSubmitting(true); - try { - if (mode === "create") { - const result = await createConfig({ - ...data, - search_space_id: searchSpaceId, - }); - - if (result?.id) { - await updatePreferences({ - search_space_id: searchSpaceId, - data: { - agent_llm_id: result.id, - }, - }); - } - - toast.success("Configuration created and assigned!"); - onOpenChange(false); - } else if (!isGlobal && config) { - await updateConfig({ - id: config.id, - data: { - name: data.name, - description: data.description, - provider: data.provider, - custom_provider: data.custom_provider, - model_name: data.model_name, - api_key: data.api_key, - api_base: data.api_base, - litellm_params: data.litellm_params, - system_instructions: data.system_instructions, - use_default_system_instructions: data.use_default_system_instructions, - citations_enabled: data.citations_enabled, - }, - }); - toast.success("Configuration updated!"); - onOpenChange(false); - } - } catch (error) { - console.error("Failed to save configuration:", error); - toast.error("Failed to save configuration"); - } finally { - setIsSubmitting(false); - } - }, - [ - mode, - isGlobal, - config, - searchSpaceId, - createConfig, - updateConfig, - updatePreferences, - onOpenChange, - ] - ); - - const handleUseGlobalConfig = useCallback(async () => { - if (!config || !isGlobal) return; - setIsSubmitting(true); - try { - await updatePreferences({ - search_space_id: searchSpaceId, - data: { - agent_llm_id: config.id, - }, - }); - toast.success(`Now using ${config.name}`); - onOpenChange(false); - } catch (error) { - console.error("Failed to set model:", error); - toast.error("Failed to set model"); - } finally { - setIsSubmitting(false); - } - }, [config, isGlobal, searchSpaceId, updatePreferences, onOpenChange]); - - if (!mounted) return null; - - const dialogContent = ( - - {open && ( - <> - {/* Backdrop */} - onOpenChange(false)} - /> - - {/* Dialog */} - -
e.stopPropagation()} - onKeyDown={(e) => { - if (e.key === "Escape") onOpenChange(false); - }} - > - {/* Header */} -
-
-
-

{getTitle()}

- {isAutoMode && ( - - Recommended - - )} - {isGlobal && !isAutoMode && mode !== "create" && ( - - Global - - )} - {!isGlobal && mode !== "create" && !isAutoMode && ( - - Custom - - )} -
-

{getSubtitle()}

- {config && !isAutoMode && mode !== "create" && ( -

- {config.model_name} -

- )} -
- -
- - {/* Scrollable content */} -
- {isAutoMode && ( - - - Auto mode automatically distributes requests across all available LLM - providers to optimize performance and avoid rate limits. - - - )} - - {isGlobal && !isAutoMode && mode !== "create" && ( - - - - Global configurations are read-only. To customize settings, create a new - configuration based on this template. - - - )} - - {mode === "create" ? ( - - ) : isAutoMode && config ? ( -
-
-
-
- How It Works -
-

{config.description}

-
- -
- -
-
- Key Benefits -
-
-
- -
-

- Automatic (Fastest) -

-

- Distributes requests across all configured LLM providers -

-
-
-
- -
-

- Rate Limit Protection -

-

- Automatically handles rate limits with cooldowns and retries -

-
-
-
- -
-

- Automatic Failover -

-

- Falls back to other providers if one becomes unavailable -

-
-
-
-
-
-
- ) : isGlobal && config ? ( -
-
-
-
-
- Configuration Name -
-

{config.name}

-
- {config.description && ( -
-
- Description -
-

{config.description}

-
- )} -
- -
- -
-
-
- Provider -
-

{config.provider}

-
-
-
- Model -
-

{config.model_name}

-
-
- -
- -
-
-
- Citations -
- - {config.citations_enabled ? "Enabled" : "Disabled"} - -
-
- - {config.system_instructions && ( - <> -
-
-
- System Instructions -
-
-

- {config.system_instructions} -

-
-
- - )} -
-
- ) : config ? ( - - ) : null} -
- - {/* Fixed footer */} -
- - {mode === "create" || (!isGlobal && !isAutoMode && config) ? ( - - ) : isAutoMode ? ( - - ) : isGlobal && config ? ( - - ) : null} -
-
- - - )} - - ); - - return typeof document !== "undefined" ? createPortal(dialogContent, document.body) : null; -} diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 62e666001..7a2a471ba 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -498,7 +498,7 @@ export function ModelSelector({ }} > - Add New Configuration + Add LLM Model
diff --git a/surfsense_web/components/new-chat/prompt-picker.tsx b/surfsense_web/components/new-chat/prompt-picker.tsx index dee3eae32..7f0dab8a4 100644 --- a/surfsense_web/components/new-chat/prompt-picker.tsx +++ b/surfsense_web/components/new-chat/prompt-picker.tsx @@ -1,5 +1,6 @@ "use client"; +import { useSetAtom } from "jotai"; import { BookOpen, Check, @@ -8,11 +9,10 @@ import { List, Minimize2, PenLine, + Plus, Search, Zap, - Plus, } from "lucide-react"; -import { useSetAtom } from "jotai"; import { forwardRef, useCallback, @@ -53,113 +53,192 @@ const ICONS: Record = { zap: , }; -const DEFAULT_ACTIONS: { name: string; prompt: string; mode: "transform" | "explore"; icon: string }[] = [ - { name: "Fix grammar", prompt: "Fix the grammar and spelling in the following text. Return only the corrected text, nothing else.\n\n{selection}", mode: "transform", icon: "check" }, - { name: "Make shorter", prompt: "Make the following text more concise while preserving its meaning. Return only the shortened text, nothing else.\n\n{selection}", mode: "transform", icon: "minimize" }, - { name: "Translate", prompt: "Translate the following text to English. If it is already in English, translate it to French. Return only the translation, nothing else.\n\n{selection}", mode: "transform", icon: "languages" }, - { name: "Rewrite", prompt: "Rewrite the following text to improve clarity and readability. Return only the rewritten text, nothing else.\n\n{selection}", mode: "transform", icon: "pen-line" }, - { name: "Summarize", prompt: "Summarize the following text concisely. Return only the summary, nothing else.\n\n{selection}", mode: "transform", icon: "list" }, - { name: "Explain", prompt: "Explain the following text in simple terms:\n\n{selection}", mode: "explore", icon: "book-open" }, - { name: "Ask my knowledge base", prompt: "Search my knowledge base for information related to:\n\n{selection}", mode: "explore", icon: "search" }, - { name: "Look up on the web", prompt: "Search the web for information about:\n\n{selection}", mode: "explore", icon: "globe" }, +const DEFAULT_ACTIONS: { + name: string; + prompt: string; + mode: "transform" | "explore"; + icon: string; +}[] = [ + { + name: "Fix grammar", + prompt: + "Fix the grammar and spelling in the following text. Return only the corrected text, nothing else.\n\n{selection}", + mode: "transform", + icon: "check", + }, + { + name: "Make shorter", + prompt: + "Make the following text more concise while preserving its meaning. Return only the shortened text, nothing else.\n\n{selection}", + mode: "transform", + icon: "minimize", + }, + { + name: "Translate", + prompt: + "Translate the following text to English. If it is already in English, translate it to French. Return only the translation, nothing else.\n\n{selection}", + mode: "transform", + icon: "languages", + }, + { + name: "Rewrite", + prompt: + "Rewrite the following text to improve clarity and readability. Return only the rewritten text, nothing else.\n\n{selection}", + mode: "transform", + icon: "pen-line", + }, + { + name: "Summarize", + prompt: + "Summarize the following text concisely. Return only the summary, nothing else.\n\n{selection}", + mode: "transform", + icon: "list", + }, + { + name: "Explain", + prompt: "Explain the following text in simple terms:\n\n{selection}", + mode: "explore", + icon: "book-open", + }, + { + name: "Ask my knowledge base", + prompt: "Search my knowledge base for information related to:\n\n{selection}", + mode: "explore", + icon: "search", + }, + { + name: "Look up on the web", + prompt: "Search the web for information about:\n\n{selection}", + mode: "explore", + icon: "globe", + }, ]; -export const PromptPicker = forwardRef( - function PromptPicker({ onSelect, onDone, externalSearch = "", containerStyle }, ref) { - const setUserSettingsDialog = useSetAtom(userSettingsDialogAtom); - const [highlightedIndex, setHighlightedIndex] = useState(0); - const [customPrompts, setCustomPrompts] = useState([]); - const scrollContainerRef = useRef(null); - const shouldScrollRef = useRef(false); - const itemRefs = useRef>(new Map()); - - useEffect(() => { - promptsApiService.list().then(setCustomPrompts).catch(() => {}); - }, []); - - const allActions = useMemo(() => { - const customs = customPrompts.map((a) => ({ - name: a.name, - prompt: a.prompt, - mode: a.mode as "transform" | "explore", - icon: a.icon || "zap", - })); - return [...DEFAULT_ACTIONS, ...customs]; - }, [customPrompts]); - - const filtered = useMemo(() => { - if (!externalSearch) return allActions; - return allActions.filter((a) => - a.name.toLowerCase().includes(externalSearch.toLowerCase()) - ); - }, [allActions, externalSearch]); - - // Reset highlight when results change - const prevSearchRef = useRef(externalSearch); - if (prevSearchRef.current !== externalSearch) { - prevSearchRef.current = externalSearch; - if (highlightedIndex !== 0) { - setHighlightedIndex(0); - } +export const PromptPicker = forwardRef(function PromptPicker( + { onSelect, onDone, externalSearch = "", containerStyle }, + ref +) { + const setUserSettingsDialog = useSetAtom(userSettingsDialogAtom); + const [highlightedIndex, setHighlightedIndex] = useState(0); + const [customPrompts, setCustomPrompts] = useState([]); + const scrollContainerRef = useRef(null); + const shouldScrollRef = useRef(false); + const itemRefs = useRef>(new Map()); + + useEffect(() => { + promptsApiService + .list() + .then(setCustomPrompts) + .catch(() => {}); + }, []); + + const allActions = useMemo(() => { + const customs = customPrompts.map((a) => ({ + name: a.name, + prompt: a.prompt, + mode: a.mode as "transform" | "explore", + icon: a.icon || "zap", + })); + return [...DEFAULT_ACTIONS, ...customs]; + }, [customPrompts]); + + const filtered = useMemo(() => { + if (!externalSearch) return allActions; + return allActions.filter((a) => a.name.toLowerCase().includes(externalSearch.toLowerCase())); + }, [allActions, externalSearch]); + + // Reset highlight when results change + const prevSearchRef = useRef(externalSearch); + if (prevSearchRef.current !== externalSearch) { + prevSearchRef.current = externalSearch; + if (highlightedIndex !== 0) { + setHighlightedIndex(0); } + } - const handleSelect = useCallback( - (index: number) => { - const action = filtered[index]; - if (!action) return; - onSelect({ name: action.name, prompt: action.prompt, mode: action.mode }); - }, - [filtered, onSelect] - ); - - // Auto-scroll highlighted item into view - useEffect(() => { - if (!shouldScrollRef.current) return; - shouldScrollRef.current = false; - - const rafId = requestAnimationFrame(() => { - const item = itemRefs.current.get(highlightedIndex); - const container = scrollContainerRef.current; - if (item && container) { - const itemRect = item.getBoundingClientRect(); - const containerRect = container.getBoundingClientRect(); - if (itemRect.top < containerRect.top || itemRect.bottom > containerRect.bottom) { - item.scrollIntoView({ block: "nearest" }); - } + const handleSelect = useCallback( + (index: number) => { + const action = filtered[index]; + if (!action) return; + onSelect({ name: action.name, prompt: action.prompt, mode: action.mode }); + }, + [filtered, onSelect] + ); + + // Auto-scroll highlighted item into view + useEffect(() => { + if (!shouldScrollRef.current) return; + shouldScrollRef.current = false; + + const rafId = requestAnimationFrame(() => { + const item = itemRefs.current.get(highlightedIndex); + const container = scrollContainerRef.current; + if (item && container) { + const itemRect = item.getBoundingClientRect(); + const containerRect = container.getBoundingClientRect(); + if (itemRect.top < containerRect.top || itemRect.bottom > containerRect.bottom) { + item.scrollIntoView({ block: "nearest" }); } - }); - - return () => cancelAnimationFrame(rafId); - }, [highlightedIndex]); - - useImperativeHandle( - ref, - () => ({ - selectHighlighted: () => handleSelect(highlightedIndex), - moveUp: () => { - shouldScrollRef.current = true; - setHighlightedIndex((prev) => (prev > 0 ? prev - 1 : filtered.length - 1)); - }, - moveDown: () => { - shouldScrollRef.current = true; - setHighlightedIndex((prev) => (prev < filtered.length - 1 ? prev + 1 : 0)); - }, - }), - [filtered.length, highlightedIndex, handleSelect] - ); - - if (filtered.length === 0) return null; - - const defaultFiltered = filtered.filter((_, i) => i < DEFAULT_ACTIONS.length); - const customFiltered = filtered.filter((_, i) => i >= DEFAULT_ACTIONS.length); - - return ( -
-
- {defaultFiltered.map((action, index) => ( + } + }); + + return () => cancelAnimationFrame(rafId); + }, [highlightedIndex]); + + useImperativeHandle( + ref, + () => ({ + selectHighlighted: () => handleSelect(highlightedIndex), + moveUp: () => { + shouldScrollRef.current = true; + setHighlightedIndex((prev) => (prev > 0 ? prev - 1 : filtered.length - 1)); + }, + moveDown: () => { + shouldScrollRef.current = true; + setHighlightedIndex((prev) => (prev < filtered.length - 1 ? prev + 1 : 0)); + }, + }), + [filtered.length, highlightedIndex, handleSelect] + ); + + if (filtered.length === 0) return null; + + const defaultFiltered = filtered.filter((_, i) => i < DEFAULT_ACTIONS.length); + const customFiltered = filtered.filter((_, i) => i >= DEFAULT_ACTIONS.length); + + return ( +
+
+ {defaultFiltered.map((action, index) => ( + + ))} + + {customFiltered.length > 0 &&
} + + {customFiltered.map((action, i) => { + const index = defaultFiltered.length + i; + return ( - ))} - - {customFiltered.length > 0 && ( -
- )} - - {customFiltered.map((action, i) => { - const index = defaultFiltered.length + i; - return ( - - ); - })} - -
- -
+ ); + })} + +
+
- ); - } -); +
+ ); +}); diff --git a/surfsense_web/components/public-chat/public-thread.tsx b/surfsense_web/components/public-chat/public-thread.tsx index bee0496f6..8678cef52 100644 --- a/surfsense_web/components/public-chat/public-thread.tsx +++ b/surfsense_web/components/public-chat/public-thread.tsx @@ -10,6 +10,7 @@ import { import { CheckIcon, CopyIcon } from "lucide-react"; import Image from "next/image"; import { type FC, type ReactNode, useState } from "react"; +import { CitationMetadataProvider } from "@/components/assistant-ui/citation-metadata-context"; import { MarkdownText } from "@/components/assistant-ui/markdown-text"; import { ToolFallback } from "@/components/assistant-ui/tool-fallback"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; @@ -142,30 +143,33 @@ const PublicAssistantMessage: FC = () => { className="aui-assistant-message-root group fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150" data-role="assistant" > -
- null, - multi_link_preview: () => null, - scrape_webpage: () => null, + +
+ null, + link_preview: () => null, + multi_link_preview: () => null, + scrape_webpage: () => null, + }, + Fallback: ToolFallback, }, - Fallback: ToolFallback, - }, - }} - /> -
+ }} + /> +
-
- -
+
+ +
+ ); }; diff --git a/surfsense_web/components/settings/image-model-manager.tsx b/surfsense_web/components/settings/image-model-manager.tsx index 877fa991c..8f08b7db3 100644 --- a/surfsense_web/components/settings/image-model-manager.tsx +++ b/surfsense_web/components/settings/image-model-manager.tsx @@ -1,31 +1,15 @@ "use client"; import { useAtomValue } from "jotai"; -import { - AlertCircle, - Check, - ChevronsUpDown, - Edit3, - Info, - Key, - Plus, - RefreshCw, - Trash2, - Wand2, -} from "lucide-react"; -import { useCallback, useMemo, useState } from "react"; -import { toast } from "sonner"; -import { - createImageGenConfigMutationAtom, - deleteImageGenConfigMutationAtom, - updateImageGenConfigMutationAtom, -} from "@/atoms/image-gen-config/image-gen-config-mutation.atoms"; +import { AlertCircle, Edit3, Info, Plus, RefreshCw, Trash2, Wand2 } from "lucide-react"; +import { useMemo, useState } from "react"; +import { deleteImageGenConfigMutationAtom } from "@/atoms/image-gen-config/image-gen-config-mutation.atoms"; import { globalImageGenConfigsAtom, imageGenConfigsAtom, } from "@/atoms/image-gen-config/image-gen-config-query.atoms"; import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; -import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { ImageConfigDialog } from "@/components/shared/image-config-dialog"; import { Alert, AlertDescription } from "@/components/ui/alert"; import { AlertDialog, @@ -40,39 +24,9 @@ import { import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; import { Button } from "@/components/ui/button"; import { Card, CardContent } from "@/components/ui/card"; -import { - Command, - CommandEmpty, - CommandGroup, - CommandInput, - CommandItem, - CommandList, -} from "@/components/ui/command"; -import { - Dialog, - DialogContent, - DialogDescription, - DialogHeader, - DialogTitle, -} from "@/components/ui/dialog"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { Separator } from "@/components/ui/separator"; import { Skeleton } from "@/components/ui/skeleton"; import { Spinner } from "@/components/ui/spinner"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; -import { - getImageGenModelsByProvider, - IMAGE_GEN_PROVIDERS, -} from "@/contracts/enums/image-gen-providers"; import type { ImageGenerationConfig } from "@/contracts/types/new-llm-config.types"; import { useMediaQuery } from "@/hooks/use-media-query"; import { getProviderIcon } from "@/lib/provider-icons"; @@ -92,23 +46,12 @@ function getInitials(name: string): string { export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { const isDesktop = useMediaQuery("(min-width: 768px)"); - // Image gen config atoms - const { - mutateAsync: createConfig, - isPending: isCreating, - error: createError, - } = useAtomValue(createImageGenConfigMutationAtom); - const { - mutateAsync: updateConfig, - isPending: isUpdating, - error: updateError, - } = useAtomValue(updateImageGenConfigMutationAtom); + const { mutateAsync: deleteConfig, isPending: isDeleting, error: deleteError, } = useAtomValue(deleteImageGenConfigMutationAtom); - const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); const { data: userConfigs, @@ -119,7 +62,6 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { const { data: globalConfigs = [], isFetching: globalLoading } = useAtomValue(globalImageGenConfigsAtom); - // Members for user resolution const { data: members } = useAtomValue(membersAtom); const memberMap = useMemo(() => { const map = new Map(); @@ -135,7 +77,6 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { return map; }, [members]); - // Permissions const { data: access } = useAtomValue(myAccessAtom); const canCreate = useMemo(() => { if (!access) return false; @@ -147,127 +88,36 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { if (access.is_owner) return true; return access.permissions?.includes("image_generations:delete") ?? false; }, [access]); - // Backend uses image_generations:create for update as well const canUpdate = canCreate; const isReadOnly = !canCreate && !canDelete; - // Local state const [isDialogOpen, setIsDialogOpen] = useState(false); const [editingConfig, setEditingConfig] = useState(null); const [configToDelete, setConfigToDelete] = useState(null); - const isSubmitting = isCreating || isUpdating; const isLoading = configsLoading || globalLoading; - const errors = [createError, updateError, deleteError, fetchError].filter(Boolean) as Error[]; - - // Form state for create/edit dialog - const [formData, setFormData] = useState({ - name: "", - description: "", - provider: "", - custom_provider: "", - model_name: "", - api_key: "", - api_base: "", - api_version: "", - }); - const [modelComboboxOpen, setModelComboboxOpen] = useState(false); + const errors = [deleteError, fetchError].filter(Boolean) as Error[]; - const resetForm = () => { - setFormData({ - name: "", - description: "", - provider: "", - custom_provider: "", - model_name: "", - api_key: "", - api_base: "", - api_version: "", - }); + const openEditDialog = (config: ImageGenerationConfig) => { + setEditingConfig(config); + setIsDialogOpen(true); }; - const handleFormSubmit = useCallback(async () => { - if (!formData.name || !formData.provider || !formData.model_name || !formData.api_key) { - toast.error("Please fill in all required fields"); - return; - } - try { - if (editingConfig) { - await updateConfig({ - id: editingConfig.id, - data: { - name: formData.name, - description: formData.description || undefined, - provider: formData.provider as any, - custom_provider: formData.custom_provider || undefined, - model_name: formData.model_name, - api_key: formData.api_key, - api_base: formData.api_base || undefined, - api_version: formData.api_version || undefined, - }, - }); - } else { - const result = await createConfig({ - name: formData.name, - description: formData.description || undefined, - provider: formData.provider as any, - custom_provider: formData.custom_provider || undefined, - model_name: formData.model_name, - api_key: formData.api_key, - api_base: formData.api_base || undefined, - api_version: formData.api_version || undefined, - search_space_id: searchSpaceId, - }); - // Auto-assign newly created config - if (result?.id) { - await updatePreferences({ - search_space_id: searchSpaceId, - data: { image_generation_config_id: result.id }, - }); - } - } - setIsDialogOpen(false); - setEditingConfig(null); - resetForm(); - } catch { - // Error handled by mutation - } - }, [editingConfig, formData, searchSpaceId, createConfig, updateConfig, updatePreferences]); + const openNewDialog = () => { + setEditingConfig(null); + setIsDialogOpen(true); + }; const handleDelete = async () => { if (!configToDelete) return; try { - await deleteConfig(configToDelete.id); + await deleteConfig({ id: configToDelete.id, name: configToDelete.name }); setConfigToDelete(null); } catch { // Error handled by mutation } }; - const openEditDialog = (config: ImageGenerationConfig) => { - setEditingConfig(config); - setFormData({ - name: config.name, - description: config.description || "", - provider: config.provider, - custom_provider: config.custom_provider || "", - model_name: config.model_name, - api_key: config.api_key, - api_base: config.api_base || "", - api_version: config.api_version || "", - }); - setIsDialogOpen(true); - }; - - const openNewDialog = () => { - setEditingConfig(null); - resetForm(); - setIsDialogOpen(true); - }; - - const selectedProvider = IMAGE_GEN_PROVIDERS.find((p) => p.value === formData.provider); - const suggestedModels = getImageGenModelsByProvider(formData.provider); - return (
{/* Header */} @@ -336,11 +186,16 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { - - {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length} global - image model(s) - {" "} - available from your administrator. +

+ + {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length}{" "} + global image{" "} + {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length === 1 + ? "model" + : "models"} + {" "} + available from your administrator. Use the model selector to view and select them. +

)} @@ -348,31 +203,26 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { {/* Loading Skeleton */} {isLoading && (
- {/* Your Image Models Section Skeleton */}
- {/* Cards Grid Skeleton */}
{["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( - {/* Header */}
- {/* Provider + Model */}
- {/* Footer */}
@@ -529,216 +379,27 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
)} - {/* Create/Edit Dialog */} - { - if (!open) { - setIsDialogOpen(false); - setEditingConfig(null); - resetForm(); - } + setIsDialogOpen(open); + if (!open) setEditingConfig(null); }} - > - e.preventDefault()} - > - - {editingConfig ? "Edit Image Model" : "Add Image Model"} - - {editingConfig - ? "Update your image generation model" - : "Configure a new image generation model (DALL-E 3, GPT Image 1, etc.)"} - - - -
- {/* Name */} -
- - setFormData((p) => ({ ...p, name: e.target.value }))} - /> -
- - {/* Description */} -
- - setFormData((p) => ({ ...p, description: e.target.value }))} - /> -
- - - - {/* Provider */} -
- - -
- - {/* Model Name */} -
- - {suggestedModels.length > 0 ? ( - - - - - - - setFormData((p) => ({ ...p, model_name: val }))} - /> - - - - Type a custom model name - - - - {suggestedModels.map((m) => ( - { - setFormData((p) => ({ ...p, model_name: m.value })); - setModelComboboxOpen(false); - }} - > - - {m.value} - {m.label} - - ))} - - - - - - ) : ( - setFormData((p) => ({ ...p, model_name: e.target.value }))} - /> - )} -
- - {/* API Key */} -
- - setFormData((p) => ({ ...p, api_key: e.target.value }))} - /> -
- - {/* API Base (optional) */} -
- - setFormData((p) => ({ ...p, api_base: e.target.value }))} - /> -
- - {/* API Version (Azure) */} - {formData.provider === "AZURE_OPENAI" && ( -
- - setFormData((p) => ({ ...p, api_version: e.target.value }))} - /> -
- )} - - {/* Actions */} -
- - -
-
-
-
+ config={editingConfig} + isGlobal={false} + searchSpaceId={searchSpaceId} + mode={editingConfig ? "edit" : "create"} + /> {/* Delete Confirmation */} !open && setConfigToDelete(null)} > - + - - - Delete Image Model - + Delete Image Model Are you sure you want to delete{" "} {configToDelete?.name}? @@ -749,19 +410,10 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { - {isDeleting ? ( - <> - - Deleting - - ) : ( - <> - - Delete - - )} + Delete + {isDeleting && } diff --git a/surfsense_web/components/settings/model-config-manager.tsx b/surfsense_web/components/settings/model-config-manager.tsx index 80bfd8e31..046288a96 100644 --- a/surfsense_web/components/settings/model-config-manager.tsx +++ b/surfsense_web/components/settings/model-config-manager.tsx @@ -12,18 +12,14 @@ import { Trash2, Wand2, } from "lucide-react"; -import { useCallback, useMemo, useState } from "react"; +import { useMemo, useState } from "react"; import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; -import { - createNewLLMConfigMutationAtom, - deleteNewLLMConfigMutationAtom, - updateNewLLMConfigMutationAtom, -} from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { deleteNewLLMConfigMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; import { globalNewLLMConfigsAtom, newLLMConfigsAtom, } from "@/atoms/new-llm-config/new-llm-config-query.atoms"; -import { LLMConfigForm, type LLMConfigFormData } from "@/components/shared/llm-config-form"; +import { ModelConfigDialog } from "@/components/shared/model-config-dialog"; import { Alert, AlertDescription } from "@/components/ui/alert"; import { AlertDialog, @@ -39,13 +35,6 @@ import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Card, CardContent } from "@/components/ui/card"; -import { - Dialog, - DialogContent, - DialogDescription, - DialogHeader, - DialogTitle, -} from "@/components/ui/dialog"; import { Skeleton } from "@/components/ui/skeleton"; import { Spinner } from "@/components/ui/spinner"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; @@ -69,12 +58,6 @@ function getInitials(name: string): string { export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) { const isDesktop = useMediaQuery("(min-width: 768px)"); // Mutations - const { mutateAsync: createConfig, isPending: isCreating } = useAtomValue( - createNewLLMConfigMutationAtom - ); - const { mutateAsync: updateConfig, isPending: isUpdating } = useAtomValue( - updateNewLLMConfigMutationAtom - ); const { mutateAsync: deleteConfig, isPending: isDeleting } = useAtomValue( deleteNewLLMConfigMutationAtom ); @@ -128,33 +111,10 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) { const [editingConfig, setEditingConfig] = useState(null); const [configToDelete, setConfigToDelete] = useState(null); - const isSubmitting = isCreating || isUpdating; - - const handleFormSubmit = useCallback( - async (formData: LLMConfigFormData) => { - try { - if (editingConfig) { - const { search_space_id, ...updateData } = formData; - await updateConfig({ - id: editingConfig.id, - data: updateData, - }); - } else { - await createConfig(formData); - } - setIsDialogOpen(false); - setEditingConfig(null); - } catch { - // Error is displayed inside the dialog by the form - } - }, - [editingConfig, createConfig, updateConfig] - ); - const handleDelete = async () => { if (!configToDelete) return; try { - await deleteConfig({ id: configToDelete.id }); + await deleteConfig({ id: configToDelete.id, name: configToDelete.name }); setConfigToDelete(null); } catch { // Error handled by mutation state @@ -171,11 +131,6 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) { setIsDialogOpen(true); }; - const closeDialog = () => { - setIsDialogOpen(false); - setEditingConfig(null); - }; - return (
{/* Header actions */} @@ -196,7 +151,7 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) { onClick={openNewDialog} className="gap-2 bg-white text-black hover:bg-neutral-100 dark:bg-white dark:text-black dark:hover:bg-neutral-200" > - Add Configuration + Add LLM Model )}
@@ -243,18 +198,17 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) { {/* Global Configs Info */} {globalConfigs.length > 0 && ( -
- - - - {globalConfigs.length} global configuration(s){" "} - available from your administrator. These are pre-configured and ready to use.{" "} - - Global configs: {globalConfigs.map((g) => g.name).join(", ")} - - - -
+ + + +

+ + {globalConfigs.length} global {globalConfigs.length === 1 ? "model" : "models"} + {" "} + available from your administrator. Use the model selector to view and select them. +

+
+
)} {/* Loading Skeleton */} @@ -463,66 +417,26 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) { )} {/* Add/Edit Configuration Dialog */} - !open && closeDialog()}> - e.preventDefault()} - > - - - {editingConfig ? "Edit Configuration" : "Create New Configuration"} - - - {editingConfig - ? "Update your AI model and prompt configuration" - : "Set up a new AI model with custom prompts and citation settings"} - - - - - - + { + setIsDialogOpen(open); + if (!open) setEditingConfig(null); + }} + config={editingConfig} + isGlobal={false} + searchSpaceId={searchSpaceId} + mode={editingConfig ? "edit" : "create"} + /> {/* Delete Confirmation Dialog */} !open && setConfigToDelete(null)} > - + - - - Delete Configuration - + Delete LLM Model Are you sure you want to delete{" "} {configToDelete?.name}? This @@ -542,10 +456,7 @@ export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) { Deleting ) : ( - <> - - Delete - + "Delete" )} diff --git a/surfsense_web/components/shared/image-config-dialog.tsx b/surfsense_web/components/shared/image-config-dialog.tsx new file mode 100644 index 000000000..1cfbf8842 --- /dev/null +++ b/surfsense_web/components/shared/image-config-dialog.tsx @@ -0,0 +1,454 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { AlertCircle, Check, ChevronsUpDown } from "lucide-react"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { toast } from "sonner"; +import { + createImageGenConfigMutationAtom, + updateImageGenConfigMutationAtom, +} from "@/atoms/image-gen-config/image-gen-config-mutation.atoms"; +import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { Alert, AlertDescription } from "@/components/ui/alert"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from "@/components/ui/command"; +import { Dialog, DialogContent, DialogTitle } from "@/components/ui/dialog"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Separator } from "@/components/ui/separator"; +import { Spinner } from "@/components/ui/spinner"; +import { IMAGE_GEN_MODELS, IMAGE_GEN_PROVIDERS } from "@/contracts/enums/image-gen-providers"; +import type { + GlobalImageGenConfig, + ImageGenerationConfig, + ImageGenProvider, +} from "@/contracts/types/new-llm-config.types"; +import { cn } from "@/lib/utils"; + +interface ImageConfigDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + config: ImageGenerationConfig | GlobalImageGenConfig | null; + isGlobal: boolean; + searchSpaceId: number; + mode: "create" | "edit" | "view"; +} + +const INITIAL_FORM = { + name: "", + description: "", + provider: "", + model_name: "", + api_key: "", + api_base: "", + api_version: "", +}; + +export function ImageConfigDialog({ + open, + onOpenChange, + config, + isGlobal, + searchSpaceId, + mode, +}: ImageConfigDialogProps) { + const [isSubmitting, setIsSubmitting] = useState(false); + const [formData, setFormData] = useState(INITIAL_FORM); + const [modelComboboxOpen, setModelComboboxOpen] = useState(false); + const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); + const scrollRef = useRef(null); + + useEffect(() => { + if (open) { + if (mode === "edit" && config && !isGlobal) { + setFormData({ + name: config.name || "", + description: config.description || "", + provider: config.provider || "", + model_name: config.model_name || "", + api_key: (config as ImageGenerationConfig).api_key || "", + api_base: config.api_base || "", + api_version: config.api_version || "", + }); + } else if (mode === "create") { + setFormData(INITIAL_FORM); + } + setScrollPos("top"); + } + }, [open, mode, config, isGlobal]); + + const { mutateAsync: createConfig } = useAtomValue(createImageGenConfigMutationAtom); + const { mutateAsync: updateConfig } = useAtomValue(updateImageGenConfigMutationAtom); + const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); + + const handleScroll = useCallback((e: React.UIEvent) => { + const el = e.currentTarget; + const atTop = el.scrollTop <= 2; + const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; + setScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); + }, []); + + const suggestedModels = useMemo(() => { + if (!formData.provider) return []; + return IMAGE_GEN_MODELS.filter((m) => m.provider === formData.provider); + }, [formData.provider]); + + const getTitle = () => { + if (mode === "create") return "Add Image Model"; + if (isGlobal) return "View Global Image Model"; + return "Edit Image Model"; + }; + + const getSubtitle = () => { + if (mode === "create") return "Set up a new image generation provider"; + if (isGlobal) return "Read-only global configuration"; + return "Update your image model settings"; + }; + + const handleSubmit = useCallback(async () => { + setIsSubmitting(true); + try { + if (mode === "create") { + const result = await createConfig({ + name: formData.name, + provider: formData.provider as ImageGenProvider, + model_name: formData.model_name, + api_key: formData.api_key, + api_base: formData.api_base || undefined, + api_version: formData.api_version || undefined, + description: formData.description || undefined, + search_space_id: searchSpaceId, + }); + if (result?.id) { + await updatePreferences({ + search_space_id: searchSpaceId, + data: { image_generation_config_id: result.id }, + }); + } + onOpenChange(false); + } else if (!isGlobal && config) { + await updateConfig({ + id: config.id, + data: { + name: formData.name, + description: formData.description || undefined, + provider: formData.provider as ImageGenProvider, + model_name: formData.model_name, + api_key: formData.api_key, + api_base: formData.api_base || undefined, + api_version: formData.api_version || undefined, + }, + }); + onOpenChange(false); + } + } catch (error) { + console.error("Failed to save image config:", error); + toast.error("Failed to save image model"); + } finally { + setIsSubmitting(false); + } + }, [ + mode, + isGlobal, + config, + formData, + searchSpaceId, + createConfig, + updateConfig, + updatePreferences, + onOpenChange, + ]); + + const handleUseGlobalConfig = useCallback(async () => { + if (!config || !isGlobal) return; + setIsSubmitting(true); + try { + await updatePreferences({ + search_space_id: searchSpaceId, + data: { image_generation_config_id: config.id }, + }); + toast.success(`Now using ${config.name}`); + onOpenChange(false); + } catch (error) { + console.error("Failed to set image model:", error); + toast.error("Failed to set image model"); + } finally { + setIsSubmitting(false); + } + }, [config, isGlobal, searchSpaceId, updatePreferences, onOpenChange]); + + const isFormValid = formData.name && formData.provider && formData.model_name && formData.api_key; + const selectedProvider = IMAGE_GEN_PROVIDERS.find((p) => p.value === formData.provider); + + return ( + + e.preventDefault()} + > + {getTitle()} + + {/* Header */} +
+
+
+

{getTitle()}

+ {isGlobal && mode !== "create" && ( + + Global + + )} +
+

{getSubtitle()}

+ {config && mode !== "create" && ( +

{config.model_name}

+ )} +
+
+ + {/* Scrollable content */} +
+ {isGlobal && config && ( + <> + + + + Global configurations are read-only. To customize, create a new model. + + +
+
+
+
+ Name +
+

{config.name}

+
+ {config.description && ( +
+
+ Description +
+

{config.description}

+
+ )} +
+ +
+
+
+ Provider +
+

{config.provider}

+
+
+
+ Model +
+

{config.model_name}

+
+
+
+ + )} + + {(mode === "create" || (mode === "edit" && !isGlobal)) && ( +
+
+ + setFormData((p) => ({ ...p, name: e.target.value }))} + /> +
+ +
+ + setFormData((p) => ({ ...p, description: e.target.value }))} + /> +
+ + + +
+ + +
+ +
+ + {suggestedModels.length > 0 ? ( + + + + + + + setFormData((p) => ({ ...p, model_name: val }))} + /> + + + + Type a custom model name + + + + {suggestedModels.map((m) => ( + { + setFormData((p) => ({ ...p, model_name: m.value })); + setModelComboboxOpen(false); + }} + > + + {m.value} + + {m.label} + + + ))} + + + + + + ) : ( + setFormData((p) => ({ ...p, model_name: e.target.value }))} + /> + )} +
+ +
+ + setFormData((p) => ({ ...p, api_key: e.target.value }))} + /> +
+ +
+ + setFormData((p) => ({ ...p, api_base: e.target.value }))} + /> +
+ + {formData.provider === "AZURE_OPENAI" && ( +
+ + setFormData((p) => ({ ...p, api_version: e.target.value }))} + /> +
+ )} +
+ )} +
+ + {/* Fixed footer */} +
+ + {mode === "create" || (mode === "edit" && !isGlobal) ? ( + + ) : isGlobal && config ? ( + + ) : null} +
+
+
+ ); +} diff --git a/surfsense_web/components/shared/llm-config-form.tsx b/surfsense_web/components/shared/llm-config-form.tsx index 38c67cfa6..732bf971e 100644 --- a/surfsense_web/components/shared/llm-config-form.tsx +++ b/surfsense_web/components/shared/llm-config-form.tsx @@ -3,9 +3,8 @@ import { zodResolver } from "@hookform/resolvers/zod"; import { useAtomValue } from "jotai"; import { Check, ChevronDown, ChevronsUpDown } from "lucide-react"; -import { AnimatePresence, motion } from "motion/react"; import { useEffect, useMemo, useState } from "react"; -import { useForm } from "react-hook-form"; +import { type Resolver, useForm } from "react-hook-form"; import { z } from "zod"; import { defaultSystemInstructionsAtom, @@ -41,7 +40,6 @@ import { SelectValue, } from "@/components/ui/select"; import { Separator } from "@/components/ui/separator"; -import { Spinner } from "@/components/ui/spinner"; import { Switch } from "@/components/ui/switch"; import { Textarea } from "@/components/ui/textarea"; import { LLM_PROVIDERS } from "@/contracts/enums/llm-providers"; @@ -73,28 +71,18 @@ interface LLMConfigFormProps { initialData?: Partial; searchSpaceId: number; onSubmit: (data: LLMConfigFormData) => Promise; - onCancel?: () => void; - isSubmitting?: boolean; mode?: "create" | "edit"; - submitLabel?: string; showAdvanced?: boolean; - compact?: boolean; formId?: string; - hideActions?: boolean; } export function LLMConfigForm({ initialData, searchSpaceId, onSubmit, - onCancel, - isSubmitting = false, mode = "create", - submitLabel, showAdvanced = true, - compact = false, formId, - hideActions = false, }: LLMConfigFormProps) { const { data: defaultInstructions, isSuccess: defaultInstructionsLoaded } = useAtomValue( defaultSystemInstructionsAtom @@ -105,8 +93,7 @@ export function LLMConfigForm({ const [systemInstructionsOpen, setSystemInstructionsOpen] = useState(false); const form = useForm({ - // eslint-disable-next-line @typescript-eslint/no-explicit-any - resolver: zodResolver(formSchema) as any, + resolver: zodResolver(formSchema) as Resolver, defaultValues: { name: initialData?.name ?? "", description: initialData?.description ?? "", @@ -233,33 +220,21 @@ export function LLMConfigForm({ /> {/* Custom Provider (conditional) */} - - {watchProvider === "CUSTOM" && ( - - ( - - Custom Provider Name - - - - - - )} - /> - - )} - + {watchProvider === "CUSTOM" && ( + ( + + Custom Provider Name + + + + + + )} + /> + )} {/* Model Name with Combobox */} {/* Ollama Quick Actions */} - - {watchProvider === "OLLAMA" && ( - + - - - )} - + localhost:11434 + + +
+ )}
{/* Advanced Parameters */} @@ -554,44 +522,6 @@ export function LLMConfigForm({ /> - - {!hideActions && ( -
- {onCancel && ( - - )} - -
- )} ); diff --git a/surfsense_web/components/shared/model-config-dialog.tsx b/surfsense_web/components/shared/model-config-dialog.tsx new file mode 100644 index 000000000..84ba821fc --- /dev/null +++ b/surfsense_web/components/shared/model-config-dialog.tsx @@ -0,0 +1,333 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { AlertCircle } from "lucide-react"; +import { useCallback, useRef, useState } from "react"; +import { toast } from "sonner"; +import { + createNewLLMConfigMutationAtom, + updateLLMPreferencesMutationAtom, + updateNewLLMConfigMutationAtom, +} from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { LLMConfigForm, type LLMConfigFormData } from "@/components/shared/llm-config-form"; +import { Alert, AlertDescription } from "@/components/ui/alert"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Dialog, DialogContent, DialogTitle } from "@/components/ui/dialog"; +import { Spinner } from "@/components/ui/spinner"; +import type { + GlobalNewLLMConfig, + LiteLLMProvider, + NewLLMConfigPublic, +} from "@/contracts/types/new-llm-config.types"; + +interface ModelConfigDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + config: NewLLMConfigPublic | GlobalNewLLMConfig | null; + isGlobal: boolean; + searchSpaceId: number; + mode: "create" | "edit" | "view"; +} + +export function ModelConfigDialog({ + open, + onOpenChange, + config, + isGlobal, + searchSpaceId, + mode, +}: ModelConfigDialogProps) { + const [isSubmitting, setIsSubmitting] = useState(false); + const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); + const scrollRef = useRef(null); + + const handleScroll = useCallback((e: React.UIEvent) => { + const el = e.currentTarget; + const atTop = el.scrollTop <= 2; + const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; + setScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); + }, []); + + const { mutateAsync: createConfig } = useAtomValue(createNewLLMConfigMutationAtom); + const { mutateAsync: updateConfig } = useAtomValue(updateNewLLMConfigMutationAtom); + const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); + + const getTitle = () => { + if (mode === "create") return "Add New Configuration"; + if (isGlobal) return "View Global Configuration"; + return "Edit Configuration"; + }; + + const getSubtitle = () => { + if (mode === "create") return "Set up a new LLM provider for this search space"; + if (isGlobal) return "Read-only global configuration"; + return "Update your configuration settings"; + }; + + const handleSubmit = useCallback( + async (data: LLMConfigFormData) => { + setIsSubmitting(true); + try { + if (mode === "create") { + const result = await createConfig({ + ...data, + search_space_id: searchSpaceId, + }); + + if (result?.id) { + await updatePreferences({ + search_space_id: searchSpaceId, + data: { + agent_llm_id: result.id, + }, + }); + } + + onOpenChange(false); + } else if (!isGlobal && config) { + await updateConfig({ + id: config.id, + data: { + name: data.name, + description: data.description, + provider: data.provider, + custom_provider: data.custom_provider, + model_name: data.model_name, + api_key: data.api_key, + api_base: data.api_base, + litellm_params: data.litellm_params, + system_instructions: data.system_instructions, + use_default_system_instructions: data.use_default_system_instructions, + citations_enabled: data.citations_enabled, + }, + }); + onOpenChange(false); + } + } catch (error) { + console.error("Failed to save configuration:", error); + } finally { + setIsSubmitting(false); + } + }, + [ + mode, + isGlobal, + config, + searchSpaceId, + createConfig, + updateConfig, + updatePreferences, + onOpenChange, + ] + ); + + const handleUseGlobalConfig = useCallback(async () => { + if (!config || !isGlobal) return; + setIsSubmitting(true); + try { + await updatePreferences({ + search_space_id: searchSpaceId, + data: { + agent_llm_id: config.id, + }, + }); + toast.success(`Now using ${config.name}`); + onOpenChange(false); + } catch (error) { + console.error("Failed to set model:", error); + } finally { + setIsSubmitting(false); + } + }, [config, isGlobal, searchSpaceId, updatePreferences, onOpenChange]); + + return ( + + e.preventDefault()} + > + {getTitle()} + + {/* Header */} +
+
+
+

{getTitle()}

+ {isGlobal && mode !== "create" && ( + + Global + + )} + {!isGlobal && mode !== "create" && ( + + Custom + + )} +
+

{getSubtitle()}

+ {config && mode !== "create" && ( +

{config.model_name}

+ )} +
+
+ + {/* Scrollable content */} +
+ {isGlobal && mode !== "create" && ( + + + + Global configurations are read-only. To customize settings, create a new + configuration based on this template. + + + )} + + {mode === "create" ? ( + + ) : isGlobal && config ? ( +
+
+
+
+
+ Configuration Name +
+

{config.name}

+
+ {config.description && ( +
+
+ Description +
+

{config.description}

+
+ )} +
+ +
+ +
+
+
+ Provider +
+

{config.provider}

+
+
+
+ Model +
+

{config.model_name}

+
+
+ +
+ +
+
+
+ Citations +
+ + {config.citations_enabled ? "Enabled" : "Disabled"} + +
+
+ + {config.system_instructions && ( + <> +
+
+
+ System Instructions +
+
+

+ {config.system_instructions} +

+
+
+ + )} +
+
+ ) : config ? ( + + ) : null} +
+ + {/* Fixed footer */} +
+ + {mode === "create" || (!isGlobal && config) ? ( + + ) : isGlobal && config ? ( + + ) : null} +
+ +
+ ); +} diff --git a/surfsense_web/components/tool-ui/citation/_adapter.tsx b/surfsense_web/components/tool-ui/citation/_adapter.tsx new file mode 100644 index 000000000..ba8ea5080 --- /dev/null +++ b/surfsense_web/components/tool-ui/citation/_adapter.tsx @@ -0,0 +1,8 @@ +"use client"; + +export { + Popover, + PopoverContent, + PopoverTrigger, +} from "@/components/ui/popover"; +export { cn } from "@/lib/utils"; diff --git a/surfsense_web/components/tool-ui/citation/citation-list.tsx b/surfsense_web/components/tool-ui/citation/citation-list.tsx new file mode 100644 index 000000000..3151917b6 --- /dev/null +++ b/surfsense_web/components/tool-ui/citation/citation-list.tsx @@ -0,0 +1,395 @@ +"use client"; + +import type { LucideIcon } from "lucide-react"; +import { Code2, Database, ExternalLink, File, FileText, Globe, Newspaper } from "lucide-react"; +import * as React from "react"; +import { openSafeNavigationHref, resolveSafeNavigationHref } from "../shared/media"; +import { cn, Popover, PopoverContent, PopoverTrigger } from "./_adapter"; +import { Citation } from "./citation"; +import type { CitationType, CitationVariant, SerializableCitation } from "./schema"; + +const TYPE_ICONS: Record = { + webpage: Globe, + document: FileText, + article: Newspaper, + api: Database, + code: Code2, + other: File, +}; + +function useHoverPopover(delay = 100) { + const [open, setOpen] = React.useState(false); + const timeoutRef = React.useRef | null>(null); + const containerRef = React.useRef(null); + + const handleMouseEnter = React.useCallback(() => { + if (timeoutRef.current) clearTimeout(timeoutRef.current); + timeoutRef.current = setTimeout(() => setOpen(true), delay); + }, [delay]); + + const handleMouseLeave = React.useCallback(() => { + if (timeoutRef.current) clearTimeout(timeoutRef.current); + timeoutRef.current = setTimeout(() => setOpen(false), delay); + }, [delay]); + + const handleFocus = React.useCallback(() => { + if (timeoutRef.current) clearTimeout(timeoutRef.current); + setOpen(true); + }, []); + + const handleBlur = React.useCallback( + (e: React.FocusEvent) => { + const relatedTarget = e.relatedTarget as HTMLElement | null; + if (containerRef.current?.contains(relatedTarget)) { + return; + } + if (relatedTarget?.closest("[data-radix-popper-content-wrapper]")) { + return; + } + if (timeoutRef.current) clearTimeout(timeoutRef.current); + timeoutRef.current = setTimeout(() => setOpen(false), delay); + }, + [delay] + ); + + React.useEffect(() => { + return () => { + if (timeoutRef.current) clearTimeout(timeoutRef.current); + }; + }, []); + + return { + open, + setOpen, + containerRef, + handleMouseEnter, + handleMouseLeave, + handleFocus, + handleBlur, + }; +} + +export interface CitationListProps { + id: string; + citations: SerializableCitation[]; + variant?: CitationVariant; + maxVisible?: number; + className?: string; + onNavigate?: (href: string, citation: SerializableCitation) => void; +} + +export function CitationList(props: CitationListProps) { + const { id, citations, variant = "default", maxVisible, className, onNavigate } = props; + + const shouldTruncate = maxVisible !== undefined && citations.length > maxVisible; + const visibleCitations = shouldTruncate ? citations.slice(0, maxVisible) : citations; + const overflowCitations = shouldTruncate ? citations.slice(maxVisible) : []; + const overflowCount = overflowCitations.length; + + const wrapperClass = + variant === "inline" ? "flex flex-wrap items-center gap-1.5" : "flex flex-col gap-2"; + + // Stacked variant: overlapping favicons with popover + if (variant === "stacked") { + return ( + + ); + } + + if (variant === "default") { + return ( +
+ {visibleCitations.map((citation) => ( + + ))} + {shouldTruncate && ( + + )} +
+ ); + } + + return ( +
+ {visibleCitations.map((citation) => ( + + ))} + {shouldTruncate && ( + + )} +
+ ); +} + +interface OverflowIndicatorProps { + citations: SerializableCitation[]; + count: number; + variant: CitationVariant; + onNavigate?: (href: string, citation: SerializableCitation) => void; +} + +function OverflowIndicator({ citations, count, variant, onNavigate }: OverflowIndicatorProps) { + const { open, handleMouseEnter, handleMouseLeave } = useHoverPopover(); + + const handleClick = (citation: SerializableCitation) => { + const href = resolveSafeNavigationHref(citation.href); + if (!href) return; + if (onNavigate) { + onNavigate(href, citation); + } else { + openSafeNavigationHref(href); + } + }; + + const popoverContent = ( +
+ {citations.map((citation) => ( + handleClick(citation)} /> + ))} +
+ ); + + if (variant === "inline") { + return ( + + + + + e.preventDefault()} + > + {popoverContent} + + + ); + } + + // Default variant + return ( + + + + + e.preventDefault()} + > + {popoverContent} + + + ); +} + +interface OverflowItemProps { + citation: SerializableCitation; + onClick: () => void; +} + +function OverflowItem({ citation, onClick }: OverflowItemProps) { + const TypeIcon = TYPE_ICONS[citation.type ?? "webpage"] ?? Globe; + + return ( + + ); +} + +interface StackedCitationsProps { + id: string; + citations: SerializableCitation[]; + className?: string; + onNavigate?: (href: string, citation: SerializableCitation) => void; +} + +function StackedCitations({ id, citations, className, onNavigate }: StackedCitationsProps) { + const { open, setOpen, containerRef, handleMouseEnter, handleMouseLeave, handleBlur } = + useHoverPopover(); + const maxIcons = 4; + const visibleCitations = citations.slice(0, maxIcons); + const remainingCount = Math.max(0, citations.length - maxIcons); + + const handleClick = (citation: SerializableCitation) => { + const href = resolveSafeNavigationHref(citation.href); + if (!href) return; + if (onNavigate) { + onNavigate(href, citation); + } else { + openSafeNavigationHref(href); + } + }; + + return ( + // biome-ignore lint/a11y/noStaticElementInteractions: blur boundary for popover focus management +
+ + + + + setOpen(false)} + > +
+ {citations.map((citation) => ( + handleClick(citation)} + /> + ))} +
+
+
+
+ ); +} diff --git a/surfsense_web/components/tool-ui/citation/citation.tsx b/surfsense_web/components/tool-ui/citation/citation.tsx new file mode 100644 index 000000000..523169f49 --- /dev/null +++ b/surfsense_web/components/tool-ui/citation/citation.tsx @@ -0,0 +1,248 @@ +"use client"; + +import type { LucideIcon } from "lucide-react"; +import { Code2, Database, ExternalLink, File, FileText, Globe, Newspaper } from "lucide-react"; +import * as React from "react"; +import { openSafeNavigationHref, sanitizeHref } from "../shared/media"; +import { cn, Popover, PopoverContent, PopoverTrigger } from "./_adapter"; +import type { CitationType, CitationVariant, SerializableCitation } from "./schema"; + +const FALLBACK_LOCALE = "en-US"; + +const TYPE_ICONS: Record = { + webpage: Globe, + document: FileText, + article: Newspaper, + api: Database, + code: Code2, + other: File, +}; + +function extractDomain(url: string): string | undefined { + try { + const urlObj = new URL(url); + return urlObj.hostname.replace(/^www\./, ""); + } catch { + return undefined; + } +} + +function formatDate(isoString: string, locale: string): string { + try { + const date = new Date(isoString); + return date.toLocaleDateString(locale, { + year: "numeric", + month: "short", + }); + } catch { + return isoString; + } +} + +function useHoverPopover(delay = 100) { + const [open, setOpen] = React.useState(false); + const timeoutRef = React.useRef | null>(null); + + const handleMouseEnter = React.useCallback(() => { + if (timeoutRef.current) clearTimeout(timeoutRef.current); + timeoutRef.current = setTimeout(() => setOpen(true), delay); + }, [delay]); + + const handleMouseLeave = React.useCallback(() => { + if (timeoutRef.current) clearTimeout(timeoutRef.current); + timeoutRef.current = setTimeout(() => setOpen(false), delay); + }, [delay]); + + React.useEffect(() => { + return () => { + if (timeoutRef.current) clearTimeout(timeoutRef.current); + }; + }, []); + + return { open, setOpen, handleMouseEnter, handleMouseLeave }; +} + +export interface CitationProps extends SerializableCitation { + variant?: CitationVariant; + className?: string; + onNavigate?: (href: string, citation: SerializableCitation) => void; +} + +export function Citation(props: CitationProps) { + const { variant = "default", className, onNavigate, ...serializable } = props; + + const { + id, + href: rawHref, + title, + snippet, + domain: providedDomain, + favicon, + author, + publishedAt, + type = "webpage", + locale: providedLocale, + } = serializable; + + const locale = providedLocale ?? FALLBACK_LOCALE; + const sanitizedHref = sanitizeHref(rawHref); + const domain = providedDomain ?? extractDomain(rawHref); + + const citationData: SerializableCitation = { + ...serializable, + href: sanitizedHref ?? rawHref, + domain, + locale, + }; + + const TypeIcon = TYPE_ICONS[type] ?? Globe; + + const handleClick = () => { + if (!sanitizedHref) return; + if (onNavigate) { + onNavigate(sanitizedHref, citationData); + } else { + openSafeNavigationHref(sanitizedHref); + } + }; + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (sanitizedHref && (e.key === "Enter" || e.key === " ")) { + e.preventDefault(); + handleClick(); + } + }; + + const iconElement = favicon ? ( + // biome-ignore lint/performance/noImgElement: external favicon from arbitrary domain — next/image requires remotePatterns config + + ) : ( +