From 0d6279c70e4178102b63702bd3f3d74eb025ed28 Mon Sep 17 00:00:00 2001 From: Varun Patel Date: Fri, 12 Jun 2026 12:42:56 -0700 Subject: [PATCH 1/4] Optimize ingestion load performance --- .env.example | 3 + client/src/pages/UploadPage.tsx | 47 ++++++++++++++++ docker-compose.yml | 17 ++++-- scripts/ingestion_load_test.py | 52 ++++++++++++++++-- server/.env.example | 3 + server/app/api/documents.py | 30 ++++++++-- server/app/config.py | 3 + server/app/core/ingestion_runs.py | 35 ++++++++---- server/app/db/session.py | 18 +++++- worker/jobs/ingest_index.py | 64 ++++++++++++---------- worker/tests/test_ingestion_run_refresh.py | 14 +++-- 11 files changed, 225 insertions(+), 61 deletions(-) diff --git a/.env.example b/.env.example index 58821a8..3b34032 100644 --- a/.env.example +++ b/.env.example @@ -14,6 +14,9 @@ OPENAI_API_KEY=sk-... # ===== Database (Local Dev) ===== # Production uses Supabase Postgres DATABASE_URL=postgresql://postgres:postgres@localhost:5432/enterprise_rag +DB_POOL_SIZE=1 +DB_MAX_OVERFLOW=0 +DB_POOL_TIMEOUT_SECONDS=30 # ===== Redis (Required for queue + cache) ===== REDIS_URL=redis://localhost:6379/0 diff --git a/client/src/pages/UploadPage.tsx b/client/src/pages/UploadPage.tsx index 0c69894..26da807 100644 --- a/client/src/pages/UploadPage.tsx +++ b/client/src/pages/UploadPage.tsx @@ -231,6 +231,7 @@ export default function UploadPage() { setActiveDocumentId, } = useAppShellContext(); const [deletingId, setDeletingId] = useState(null); + const [deleteAllProgress, setDeleteAllProgress] = useState<{ done: number; total: number } | null>(null); const [retryingIds, setRetryingIds] = useState>(new Set()); const [activeRunId, setActiveRunId] = useState(null); const [activeRun, setActiveRun] = useState(null); @@ -353,6 +354,42 @@ export default function UploadPage() { } }; + const handleDeleteAll = async () => { + if (documents.length === 0) { + return; + } + + const confirmed = window.confirm( + `Delete all ${documents.length} documents in this workspace? This cannot be undone.`, + ); + if (!confirmed) { + return; + } + + setDeleteAllProgress({ done: 0, total: documents.length }); + const failures: string[] = []; + + for (const [index, document] of documents.entries()) { + try { + await apiDeleteDocument(accessToken, document.id); + if (activeDocument?.id === document.id) { + setActiveDocumentId(null); + } + } catch (error) { + failures.push(`${document.filename}: ${error instanceof Error ? error.message : "Delete failed"}`); + } finally { + setDeleteAllProgress({ done: index + 1, total: documents.length }); + } + } + + setDeleteAllProgress(null); + await Promise.all([refreshDocuments(), refreshWorkspace(), refreshIngestionState()]); + + if (failures.length > 0) { + window.alert(`Some documents could not be deleted:\n${failures.join("\n")}`); + } + }; + const retryDocuments = async (targets: DocumentRecord[]) => { const retryableTargets = targets.filter((doc) => isRetryableFailure(doc.error_message)); if (retryableTargets.length === 0) { @@ -463,6 +500,16 @@ export default function UploadPage() { + +
diff --git a/docker-compose.yml b/docker-compose.yml index de76f0f..3da29d6 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -22,6 +22,9 @@ services: container_name: rag-server environment: - DATABASE_URL=${DATABASE_URL} + - DB_POOL_SIZE=${DB_POOL_SIZE:-1} + - DB_MAX_OVERFLOW=${DB_MAX_OVERFLOW:-0} + - DB_POOL_TIMEOUT_SECONDS=${DB_POOL_TIMEOUT_SECONDS:-30} - REDIS_URL=redis://redis:6379/0 - SUPABASE_URL=${SUPABASE_URL} - SUPABASE_SERVICE_ROLE_KEY=${SUPABASE_SERVICE_ROLE_KEY} @@ -53,6 +56,9 @@ services: dockerfile: Dockerfile environment: - DATABASE_URL=${DATABASE_URL} + - DB_POOL_SIZE=${DB_POOL_SIZE:-1} + - DB_MAX_OVERFLOW=${DB_MAX_OVERFLOW:-0} + - DB_POOL_TIMEOUT_SECONDS=${DB_POOL_TIMEOUT_SECONDS:-30} - REDIS_URL=redis://redis:6379/0 - SUPABASE_URL=${SUPABASE_URL} - SUPABASE_SERVICE_ROLE_KEY=${SUPABASE_SERVICE_ROLE_KEY} @@ -63,14 +69,14 @@ services: - INGEST_INDEX_JOB_TIMEOUT_SECONDS=${INGEST_INDEX_JOB_TIMEOUT_SECONDS:-1800} - EMBEDDING_BATCH_SIZE=${EMBEDDING_BATCH_SIZE:-32} - QUEUE_NAME=ingest_extract - - WORKER_COUNT=5 + - WORKER_COUNT=4 volumes: - ./worker:/app - ./server/app:/app/shared # Share code depends_on: - redis deploy: - replicas: 5 + replicas: 4 command: python worker.py ingest_extract # RQ Workers (Index) @@ -80,6 +86,9 @@ services: dockerfile: Dockerfile environment: - DATABASE_URL=${DATABASE_URL} + - DB_POOL_SIZE=${DB_POOL_SIZE:-1} + - DB_MAX_OVERFLOW=${DB_MAX_OVERFLOW:-0} + - DB_POOL_TIMEOUT_SECONDS=${DB_POOL_TIMEOUT_SECONDS:-30} - REDIS_URL=redis://redis:6379/0 - SUPABASE_URL=${SUPABASE_URL} - SUPABASE_SERVICE_ROLE_KEY=${SUPABASE_SERVICE_ROLE_KEY} @@ -90,14 +99,14 @@ services: - INGEST_INDEX_JOB_TIMEOUT_SECONDS=${INGEST_INDEX_JOB_TIMEOUT_SECONDS:-1800} - EMBEDDING_BATCH_SIZE=${EMBEDDING_BATCH_SIZE:-32} - QUEUE_NAME=ingest_index - - WORKER_COUNT=3 + - WORKER_COUNT=4 volumes: - ./worker:/app - ./server/app:/app/shared depends_on: - redis deploy: - replicas: 3 + replicas: 4 command: python worker.py ingest_index # RQ Dashboard (monitoring) diff --git a/scripts/ingestion_load_test.py b/scripts/ingestion_load_test.py index 26020c8..5799e5c 100644 --- a/scripts/ingestion_load_test.py +++ b/scripts/ingestion_load_test.py @@ -7,6 +7,7 @@ import urllib.error import urllib.parse import urllib.request +from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timezone from pathlib import Path from typing import Any @@ -46,7 +47,7 @@ def _json_request( return json.loads(response_body) if response_body else {} -def _put_file(url: str, path: Path) -> None: +def _put_file(url: str, path: Path, timeout_seconds: float) -> None: _validate_http_url(url) request = urllib.request.Request( url, @@ -55,7 +56,7 @@ def _put_file(url: str, path: Path) -> None: headers={"Content-Type": "application/pdf"}, ) try: - with urllib.request.urlopen(request, timeout=120) as response: # nosec B310 + with urllib.request.urlopen(request, timeout=timeout_seconds) as response: # nosec B310 response.read() except urllib.error.HTTPError as exc: detail = exc.read().decode("utf-8", errors="replace") @@ -169,27 +170,63 @@ def run_load_test(args: argparse.Namespace) -> dict[str, Any]: request_timeout_seconds=args.request_timeout_seconds, payload=prepare_payload, ) - upload_items = [] + prepared_uploads = [] file_by_name = {path.name: path for path in files} for item in prepare.get("items", []): if item.get("status") not in {"prepared", "accepted"}: continue path = file_by_name[str(item["filename"])] - _put_file(str(item["upload_url"]), path) - upload_items.append( + prepared_uploads.append( { + "index": int(item["index"]), "document_id": item["document_id"], "bucket": item["bucket"], "storage_path": item["storage_path"], + "upload_url": item["upload_url"], + "path": path, } ) + upload_items = [] + max_workers = max(1, min(args.upload_concurrency, len(prepared_uploads) or 1)) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_upload = { + executor.submit( + _put_file, + str(upload["upload_url"]), + upload["path"], + args.upload_timeout_seconds, + ): upload + for upload in prepared_uploads + } + for future in as_completed(future_to_upload): + upload = future_to_upload[future] + future.result() + upload_items.append( + { + "index": upload["index"], + "document_id": upload["document_id"], + "bucket": upload["bucket"], + "storage_path": upload["storage_path"], + } + ) + + upload_items.sort(key=lambda item: int(item["index"])) + complete_files = [ + { + "document_id": item["document_id"], + "bucket": item["bucket"], + "storage_path": item["storage_path"], + } + for item in upload_items + ] + complete = _json_request( method="POST", url=f"{api_base}/documents/upload-complete-batch", token=token, request_timeout_seconds=args.request_timeout_seconds, - payload={"ingestion_run_id": prepare.get("ingestion_run_id"), "files": upload_items}, + payload={"ingestion_run_id": prepare.get("ingestion_run_id"), "files": complete_files}, ) prepare_upload_complete_seconds = round(time.perf_counter() - prepare_started, 3) @@ -244,6 +281,7 @@ def run_load_test(args: argparse.Namespace) -> dict[str, Any]: "prepare_upload_complete": prepare_upload_complete_seconds, "total": total_seconds, }, + "upload_concurrency": max_workers, } artifact_dir = Path(args.artifact_dir) @@ -267,6 +305,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--artifact-dir", default="artifacts/ingestion-load") parser.add_argument("--poll-seconds", type=float, default=5) parser.add_argument("--request-timeout-seconds", type=float, default=180) + parser.add_argument("--upload-timeout-seconds", type=float, default=120) + parser.add_argument("--upload-concurrency", type=int, default=8) parser.add_argument("--timeout-seconds", type=int, default=1800) return parser.parse_args() diff --git a/server/.env.example b/server/.env.example index 6160540..f75283f 100644 --- a/server/.env.example +++ b/server/.env.example @@ -2,6 +2,9 @@ SUPABASE_URL=https://your-project.supabase.co SUPABASE_SERVICE_ROLE_KEY=your-service-role-key DATABASE_URL=postgresql://postgres:postgres@localhost:5432/enterprise_rag +DB_POOL_SIZE=1 +DB_MAX_OVERFLOW=0 +DB_POOL_TIMEOUT_SECONDS=30 REDIS_URL=redis://localhost:6379/0 ENVIRONMENT=development API_HOST=0.0.0.0 diff --git a/server/app/api/documents.py b/server/app/api/documents.py index 596be56..cf71c21 100644 --- a/server/app/api/documents.py +++ b/server/app/api/documents.py @@ -16,7 +16,7 @@ ScheduledJobRegistry, StartedJobRegistry, ) -from sqlalchemy import func, inspect, select, text +from sqlalchemy import func, select, text from sqlalchemy.orm import Session from app.api.deps import get_workspace_id @@ -167,8 +167,17 @@ def _sanitize_filename(filename: str) -> str: def _document_columns(db: Session) -> set[str]: - bind = db.get_bind() - return {col["name"] for col in inspect(bind).get_columns("documents")} + rows = db.execute( + text( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = current_schema() + AND table_name = 'documents' + """ + ) + ).scalars() + return {str(column_name) for column_name in rows} def _timing_select_fields(columns: set[str]) -> list[str]: @@ -230,7 +239,20 @@ def _failure_summary(error_messages: list[str | None]) -> IngestionHealthFailure def _table_exists(db: Session, table_name: str) -> bool: - return inspect(db.get_bind()).has_table(table_name) + return bool( + db.execute( + text( + """ + SELECT 1 + FROM information_schema.tables + WHERE table_schema = current_schema() + AND table_name = :table_name + LIMIT 1 + """ + ), + {"table_name": table_name}, + ).first() + ) def _require_ingestion_run_schema(db: Session) -> None: diff --git a/server/app/config.py b/server/app/config.py index 551fbc0..4f35ecd 100644 --- a/server/app/config.py +++ b/server/app/config.py @@ -10,6 +10,9 @@ class Settings(BaseSettings): api_port: int = 8000 DATABASE_URL: str = "postgresql://postgres:postgres@localhost:5432/enterprise_rag" + DB_POOL_SIZE: int = 1 + DB_MAX_OVERFLOW: int = 0 + DB_POOL_TIMEOUT_SECONDS: int = 30 REDIS_URL: str = "redis://localhost:6379/0" SUPABASE_URL: str = "" diff --git a/server/app/core/ingestion_runs.py b/server/app/core/ingestion_runs.py index 7b3e88a..cd51005 100644 --- a/server/app/core/ingestion_runs.py +++ b/server/app/core/ingestion_runs.py @@ -96,30 +96,43 @@ def refresh_ingestion_run_status( run_id: uuid.UUID, updated_at: datetime | None = None, ) -> str | None: - run_row = ( + rows = ( db.execute( text( """ - SELECT accepted_documents, rejected_documents - FROM ingestion_runs - WHERE id = :run_id - AND workspace_id = :workspace_id - LIMIT 1 + SELECT + r.accepted_documents, + r.rejected_documents, + d.status, + COUNT(d.id) AS count + FROM ingestion_runs r + LEFT JOIN documents d + ON d.workspace_id = r.workspace_id + AND d.ingestion_run_id = r.id + WHERE r.id = :run_id + AND r.workspace_id = :workspace_id + GROUP BY r.accepted_documents, r.rejected_documents, d.status """ ), {"workspace_id": workspace_id, "run_id": run_id}, ) .mappings() - .first() + .all() ) - if run_row is None: + if not rows: return None - status_counts = document_status_counts_for_run(db=db, workspace_id=workspace_id, run_id=run_id) + first_row = rows[0] + status_counts = empty_document_status_counts() + for row in rows: + status_name = row["status"] + if status_name is not None: + status_counts[str(status_name)] = int(row["count"] or 0) + status_counts["total"] = sum(status_counts.values()) run_status = derive_ingestion_run_status( status_counts=status_counts, - accepted_documents=int(run_row["accepted_documents"] or 0), - rejected_documents=int(run_row["rejected_documents"] or 0), + accepted_documents=int(first_row["accepted_documents"] or 0), + rejected_documents=int(first_row["rejected_documents"] or 0), ) db.execute( text( diff --git a/server/app/db/session.py b/server/app/db/session.py index e3241c1..051910a 100644 --- a/server/app/db/session.py +++ b/server/app/db/session.py @@ -5,13 +5,27 @@ from app.config import settings -engine = create_engine(settings.DATABASE_URL, pool_pre_ping=True) + +def _engine_kwargs(database_url: str) -> dict[str, object]: + kwargs: dict[str, object] = {"pool_pre_ping": True} + if not database_url.startswith("sqlite"): + kwargs.update( + { + "pool_size": settings.DB_POOL_SIZE, + "max_overflow": settings.DB_MAX_OVERFLOW, + "pool_timeout": settings.DB_POOL_TIMEOUT_SECONDS, + } + ) + return kwargs + + +engine = create_engine(settings.DATABASE_URL, **_engine_kwargs(settings.DATABASE_URL)) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() def create_session_local(database_url: str): - worker_engine = create_engine(database_url, pool_pre_ping=True) + worker_engine = create_engine(database_url, **_engine_kwargs(database_url)) return sessionmaker(autocommit=False, autoflush=False, bind=worker_engine) diff --git a/worker/jobs/ingest_index.py b/worker/jobs/ingest_index.py index b489de5..5b13773 100644 --- a/worker/jobs/ingest_index.py +++ b/worker/jobs/ingest_index.py @@ -404,20 +404,23 @@ def ingest_index( embedding_batch_size = max(1, int(settings.EMBEDDING_BATCH_SIZE)) embedding_count = 0 total_embedding_tokens = 0 + total_reserved_tokens = sum(int(row["token_count"]) for row in chunk_rows) + with SessionLocal() as db: + reserve_tokens( + db=db, + workspace_id=workspace_uuid, + amount=total_reserved_tokens, + usage_date_utc=usage_date, + reservation_ttl_seconds=settings.RESERVATION_TTL_SECONDS, + ) + db.commit() + outstanding_reservations.append(total_reserved_tokens) + + embedding_rows: list[dict[str, object]] = [] for batch_rows in _batched(chunk_rows, embedding_batch_size): batch_started_at = time.perf_counter() estimated_tokens = sum(int(row["token_count"]) for row in batch_rows) - with SessionLocal() as db: - reserve_tokens( - db=db, - workspace_id=workspace_uuid, - amount=estimated_tokens, - usage_date_utc=usage_date, - reservation_ttl_seconds=settings.RESERVATION_TTL_SECONDS, - ) - db.commit() - outstanding_reservations.append(estimated_tokens) response = client.embeddings.create( model=model, @@ -434,7 +437,6 @@ def ingest_index( if all(getattr(item, "index", None) is not None for item in response_data): response_data.sort(key=lambda item: int(item.index)) - embedding_rows: list[dict[str, object]] = [] for row, embedding_data in zip(batch_rows, response_data, strict=True): embedding_rows.append( { @@ -448,25 +450,6 @@ def ingest_index( } ) - with SessionLocal() as db: - db.execute( - text( - """ - INSERT INTO chunk_embeddings (chunk_id, workspace_id, document_id, embedding, embedding_model) - VALUES (:chunk_id, :workspace_id, :document_id, CAST(:embedding AS vector), :embedding_model) - """ - ), - embedding_rows, - ) - commit_usage( - db=db, - workspace_id=workspace_uuid, - amount=estimated_tokens, - usage_date_utc=usage_date, - ) - db.commit() - outstanding_reservations.pop() - embedding_count += len(batch_rows) total_embedding_tokens += estimated_tokens logger.info( @@ -481,6 +464,27 @@ def ingest_index( }, ) + with SessionLocal() as db: + db.execute( + text( + """ + INSERT INTO chunk_embeddings (chunk_id, workspace_id, document_id, embedding, embedding_model) + VALUES (:chunk_id, :workspace_id, :document_id, CAST(:embedding AS vector), :embedding_model) + """ + ), + embedding_rows, + ) + db.commit() + with SessionLocal() as db: + commit_usage( + db=db, + workspace_id=workspace_uuid, + amount=total_reserved_tokens, + usage_date_utc=usage_date, + ) + db.commit() + outstanding_reservations.pop() + final_status = "indexed" _mark_document_indexed( workspace_id=workspace_uuid, diff --git a/worker/tests/test_ingestion_run_refresh.py b/worker/tests/test_ingestion_run_refresh.py index 8eaf5f9..e12439d 100644 --- a/worker/tests/test_ingestion_run_refresh.py +++ b/worker/tests/test_ingestion_run_refresh.py @@ -54,12 +54,18 @@ def execute(self, stmt, params=None): return FakeResult(scalar_values=list(self.columns)) if "SELECT status, error_message" in sql: return FakeResult(first_row=self.document_row) - if "SELECT accepted_documents, rejected_documents" in sql: + if "r.accepted_documents" in sql and "COUNT(d.id) AS count" in sql: return FakeResult( - first_row={"accepted_documents": 1, "rejected_documents": 0} + all_rows=[ + { + "accepted_documents": 1, + "rejected_documents": 0, + "status": row["status"], + "count": row["count"], + } + for row in self.status_count_rows + ] ) - if "SELECT status, COUNT(*) AS count" in sql: - return FakeResult(all_rows=self.status_count_rows) return FakeResult() def commit(self) -> None: From f3e22ba38fd4b76b4136afe7f0746a78ce8370b5 Mon Sep 17 00:00:00 2001 From: Varun Patel Date: Fri, 12 Jun 2026 12:51:16 -0700 Subject: [PATCH 2/4] Fix Trivy dependency findings --- server/requirements.txt | 10 +++++----- worker/requirements.txt | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/server/requirements.txt b/server/requirements.txt index c537ee5..73491b9 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -24,14 +24,14 @@ storage3==0.7.4 openai==1.10.0 # PDF Processing -unstructured==0.12.0 +unstructured==0.18.18 pypdf==4.2.0 pdf2image==1.17.0 -pillow==10.2.0 +pillow==10.3.0 # Auth -python-jose[cryptography]==3.3.0 -python-multipart==0.0.6 +python-jose[cryptography]==3.4.0 +python-multipart==0.0.27 # Text Processing tiktoken==0.5.2 @@ -48,6 +48,6 @@ prometheus-client==0.19.0 pytest==7.4.4 pytest-asyncio==0.23.3 pytest-cov==4.1.0 -black==24.1.1 +black==26.3.1 ruff==0.1.14 mypy==1.8.0 diff --git a/worker/requirements.txt b/worker/requirements.txt index 37d8b34..5a582f8 100644 --- a/worker/requirements.txt +++ b/worker/requirements.txt @@ -24,13 +24,13 @@ storage3==0.7.4 openai==1.10.0 # PDF Processing -unstructured==0.12.0 +unstructured==0.18.18 pdf2image==1.17.0 -pillow==10.2.0 +pillow==10.3.0 # Auth -python-jose[cryptography]==3.3.0 -python-multipart==0.0.6 +python-jose[cryptography]==3.4.0 +python-multipart==0.0.27 # Text Processing tiktoken==0.5.2 @@ -47,6 +47,6 @@ prometheus-client==0.19.0 pytest==7.4.4 pytest-asyncio==0.23.3 pytest-cov==4.1.0 -black==24.1.1 +black==26.3.1 ruff==0.1.14 mypy==1.8.0 From 91a474fa79f2d4cf7ef033a1f837c001d9c491ad Mon Sep 17 00:00:00 2001 From: Varun Patel Date: Fri, 12 Jun 2026 13:04:00 -0700 Subject: [PATCH 3/4] Fix Trivy scan findings --- .github/trivy/Dockerfile | 7 +++++++ .github/workflows/deploy.yml | 1 + client/Dockerfile | 7 +++++-- client/nginx.conf | 2 +- server/Dockerfile | 2 +- server/requirements.txt | 2 +- worker/Dockerfile | 2 +- worker/requirements.txt | 2 +- 8 files changed, 18 insertions(+), 7 deletions(-) diff --git a/.github/trivy/Dockerfile b/.github/trivy/Dockerfile index 2028c59..51c4af7 100644 --- a/.github/trivy/Dockerfile +++ b/.github/trivy/Dockerfile @@ -1 +1,8 @@ FROM aquasec/trivy:0.70.0 + +RUN adduser -D trivyuser \ + && mkdir -p /home/trivyuser/.cache/trivy \ + && chown -R trivyuser:trivyuser /home/trivyuser + +ENV HOME=/home/trivyuser +USER trivyuser diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index ac51b8b..01c3a5d 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -35,6 +35,7 @@ jobs: - name: Run Trivy FS Scan run: | docker run --rm \ + --user root \ -v "${{ github.workspace }}:/workspace" \ -w /workspace \ local-trivy fs . \ diff --git a/client/Dockerfile b/client/Dockerfile index 96396e1..46d786f 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -23,7 +23,10 @@ RUN npm run build # Production stage FROM nginx:alpine AS production -COPY --from=build /app/dist /usr/share/nginx/html +COPY --from=build --chown=nginx:nginx /app/dist /usr/share/nginx/html COPY nginx.conf /etc/nginx/conf.d/default.conf -EXPOSE 80 +RUN touch /var/run/nginx.pid \ + && chown -R nginx:nginx /var/cache/nginx /var/run/nginx.pid /usr/share/nginx/html +USER nginx +EXPOSE 8080 CMD ["nginx", "-g", "daemon off;"] diff --git a/client/nginx.conf b/client/nginx.conf index 84fa1c0..d4f2f14 100644 --- a/client/nginx.conf +++ b/client/nginx.conf @@ -1,5 +1,5 @@ server { - listen 80; + listen 8080; server_name _; root /usr/share/nginx/html; index index.html; diff --git a/server/Dockerfile b/server/Dockerfile index 2952281..396c9e0 100644 --- a/server/Dockerfile +++ b/server/Dockerfile @@ -3,7 +3,7 @@ FROM python:3.11-slim WORKDIR /app # Install system dependencies -RUN apt-get update && apt-get install -y \ +RUN apt-get update && apt-get install -y --no-install-recommends \ gcc \ g++ \ libpq-dev \ diff --git a/server/requirements.txt b/server/requirements.txt index 73491b9..05feaf4 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -27,7 +27,7 @@ openai==1.10.0 unstructured==0.18.18 pypdf==4.2.0 pdf2image==1.17.0 -pillow==10.3.0 +pillow==12.2.0 # Auth python-jose[cryptography]==3.4.0 diff --git a/worker/Dockerfile b/worker/Dockerfile index 425d61a..cbe9a1e 100644 --- a/worker/Dockerfile +++ b/worker/Dockerfile @@ -3,7 +3,7 @@ FROM python:3.11-slim WORKDIR /app # Install system dependencies (same as server) -RUN apt-get update && apt-get install -y \ +RUN apt-get update && apt-get install -y --no-install-recommends \ gcc \ g++ \ libpq-dev \ diff --git a/worker/requirements.txt b/worker/requirements.txt index 5a582f8..4116499 100644 --- a/worker/requirements.txt +++ b/worker/requirements.txt @@ -26,7 +26,7 @@ openai==1.10.0 # PDF Processing unstructured==0.18.18 pdf2image==1.17.0 -pillow==10.3.0 +pillow==12.2.0 # Auth python-jose[cryptography]==3.4.0 From d7efa94237cb202fa04ecb0f5588402dc31e2510 Mon Sep 17 00:00:00 2001 From: Varun Patel Date: Fri, 12 Jun 2026 18:28:02 -0700 Subject: [PATCH 4/4] Add multi-document chat querying --- .env.example | 6 +- .../src/components/chat/ChatSessionList.tsx | 6 +- client/src/components/layout/AppShell.tsx | 53 ++- .../components/sidebar/DocumentSidebar.tsx | 39 +- client/src/lib/api.ts | 7 +- client/src/pages/ChatPage.tsx | 130 +++++-- docker-compose.yml | 12 +- scripts/schema.local.sql | 11 + scripts/schema.supabase.sql | 11 + server/.env.example | 4 +- server/app/api/chats.py | 355 +++++++++++------- server/app/config.py | 4 +- server/app/core/retrieval.py | 65 ++-- server/app/db/models.py | 4 + server/app/schemas/chat.py | 6 + ...0612_0002_add_chat_session_document_ids.py | 36 ++ 16 files changed, 550 insertions(+), 199 deletions(-) create mode 100644 server/migrations/versions/20260612_0002_add_chat_session_document_ids.py diff --git a/.env.example b/.env.example index 3b34032..a45ab05 100644 --- a/.env.example +++ b/.env.example @@ -14,8 +14,10 @@ OPENAI_API_KEY=sk-... # ===== Database (Local Dev) ===== # Production uses Supabase Postgres DATABASE_URL=postgresql://postgres:postgres@localhost:5432/enterprise_rag -DB_POOL_SIZE=1 -DB_MAX_OVERFLOW=0 +SERVER_DB_POOL_SIZE=3 +SERVER_DB_MAX_OVERFLOW=2 +WORKER_DB_POOL_SIZE=1 +WORKER_DB_MAX_OVERFLOW=0 DB_POOL_TIMEOUT_SECONDS=30 # ===== Redis (Required for queue + cache) ===== diff --git a/client/src/components/chat/ChatSessionList.tsx b/client/src/components/chat/ChatSessionList.tsx index 095c67f..9fb957f 100644 --- a/client/src/components/chat/ChatSessionList.tsx +++ b/client/src/components/chat/ChatSessionList.tsx @@ -38,9 +38,11 @@ export default function ChatSessionList({ {!loading && items.length === 0 ?

No saved chats yet.

: null}
{items.map((item) => { - const docName = item.document_id - ? documents.find((doc) => doc.id === item.document_id)?.filename ?? "Document" + const documentIds = item.document_ids?.length ? item.document_ids : item.document_id ? [item.document_id] : []; + const firstDocName = documentIds[0] + ? documents.find((doc) => doc.id === documentIds[0])?.filename ?? "Document" : "All documents"; + const docName = documentIds.length > 1 ? `${firstDocName} + ${documentIds.length - 1} more` : firstDocName; return ( + + + {queryDocumentIds.length}/{maxQueryDocuments} docs, {queryPageCount}/{maxQueryPages} pages + +
+ {queryCapError ?

{queryCapError}

: null}
diff --git a/docker-compose.yml b/docker-compose.yml index 3da29d6..55d9531 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -22,8 +22,8 @@ services: container_name: rag-server environment: - DATABASE_URL=${DATABASE_URL} - - DB_POOL_SIZE=${DB_POOL_SIZE:-1} - - DB_MAX_OVERFLOW=${DB_MAX_OVERFLOW:-0} + - DB_POOL_SIZE=${SERVER_DB_POOL_SIZE:-3} + - DB_MAX_OVERFLOW=${SERVER_DB_MAX_OVERFLOW:-2} - DB_POOL_TIMEOUT_SECONDS=${DB_POOL_TIMEOUT_SECONDS:-30} - REDIS_URL=redis://redis:6379/0 - SUPABASE_URL=${SUPABASE_URL} @@ -56,8 +56,8 @@ services: dockerfile: Dockerfile environment: - DATABASE_URL=${DATABASE_URL} - - DB_POOL_SIZE=${DB_POOL_SIZE:-1} - - DB_MAX_OVERFLOW=${DB_MAX_OVERFLOW:-0} + - DB_POOL_SIZE=${WORKER_DB_POOL_SIZE:-1} + - DB_MAX_OVERFLOW=${WORKER_DB_MAX_OVERFLOW:-0} - DB_POOL_TIMEOUT_SECONDS=${DB_POOL_TIMEOUT_SECONDS:-30} - REDIS_URL=redis://redis:6379/0 - SUPABASE_URL=${SUPABASE_URL} @@ -86,8 +86,8 @@ services: dockerfile: Dockerfile environment: - DATABASE_URL=${DATABASE_URL} - - DB_POOL_SIZE=${DB_POOL_SIZE:-1} - - DB_MAX_OVERFLOW=${DB_MAX_OVERFLOW:-0} + - DB_POOL_SIZE=${WORKER_DB_POOL_SIZE:-1} + - DB_MAX_OVERFLOW=${WORKER_DB_MAX_OVERFLOW:-0} - DB_POOL_TIMEOUT_SECONDS=${DB_POOL_TIMEOUT_SECONDS:-30} - REDIS_URL=redis://redis:6379/0 - SUPABASE_URL=${SUPABASE_URL} diff --git a/scripts/schema.local.sql b/scripts/schema.local.sql index d3611f6..cb8f3ba 100644 --- a/scripts/schema.local.sql +++ b/scripts/schema.local.sql @@ -199,6 +199,7 @@ CREATE TABLE IF NOT EXISTS chat_sessions ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), workspace_id UUID NOT NULL REFERENCES workspaces(id) ON DELETE CASCADE, document_id UUID NULL REFERENCES documents(id) ON DELETE SET NULL, + document_ids UUID[] NOT NULL DEFAULT '{}', title TEXT NOT NULL DEFAULT '', messages JSONB NOT NULL DEFAULT '[]'::jsonb, started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), @@ -207,9 +208,19 @@ CREATE TABLE IF NOT EXISTS chat_sessions ( updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); +ALTER TABLE chat_sessions + ADD COLUMN IF NOT EXISTS document_ids UUID[] NOT NULL DEFAULT '{}'; + +UPDATE chat_sessions +SET document_ids = ARRAY[document_id]::UUID[] +WHERE document_id IS NOT NULL + AND COALESCE(array_length(document_ids, 1), 0) = 0; + CREATE INDEX IF NOT EXISTS idx_chat_sessions_workspace_updated ON chat_sessions(workspace_id, updated_at DESC); CREATE INDEX IF NOT EXISTS idx_chat_sessions_workspace_document ON chat_sessions(workspace_id, document_id); +CREATE INDEX IF NOT EXISTS idx_chat_sessions_document_ids + ON chat_sessions USING GIN (document_ids); COMMIT; diff --git a/scripts/schema.supabase.sql b/scripts/schema.supabase.sql index d3611f6..cb8f3ba 100644 --- a/scripts/schema.supabase.sql +++ b/scripts/schema.supabase.sql @@ -199,6 +199,7 @@ CREATE TABLE IF NOT EXISTS chat_sessions ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), workspace_id UUID NOT NULL REFERENCES workspaces(id) ON DELETE CASCADE, document_id UUID NULL REFERENCES documents(id) ON DELETE SET NULL, + document_ids UUID[] NOT NULL DEFAULT '{}', title TEXT NOT NULL DEFAULT '', messages JSONB NOT NULL DEFAULT '[]'::jsonb, started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), @@ -207,9 +208,19 @@ CREATE TABLE IF NOT EXISTS chat_sessions ( updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); +ALTER TABLE chat_sessions + ADD COLUMN IF NOT EXISTS document_ids UUID[] NOT NULL DEFAULT '{}'; + +UPDATE chat_sessions +SET document_ids = ARRAY[document_id]::UUID[] +WHERE document_id IS NOT NULL + AND COALESCE(array_length(document_ids, 1), 0) = 0; + CREATE INDEX IF NOT EXISTS idx_chat_sessions_workspace_updated ON chat_sessions(workspace_id, updated_at DESC); CREATE INDEX IF NOT EXISTS idx_chat_sessions_workspace_document ON chat_sessions(workspace_id, document_id); +CREATE INDEX IF NOT EXISTS idx_chat_sessions_document_ids + ON chat_sessions USING GIN (document_ids); COMMIT; diff --git a/server/.env.example b/server/.env.example index f75283f..ab44ec8 100644 --- a/server/.env.example +++ b/server/.env.example @@ -2,8 +2,8 @@ SUPABASE_URL=https://your-project.supabase.co SUPABASE_SERVICE_ROLE_KEY=your-service-role-key DATABASE_URL=postgresql://postgres:postgres@localhost:5432/enterprise_rag -DB_POOL_SIZE=1 -DB_MAX_OVERFLOW=0 +DB_POOL_SIZE=3 +DB_MAX_OVERFLOW=2 DB_POOL_TIMEOUT_SECONDS=30 REDIS_URL=redis://localhost:6379/0 ENVIRONMENT=development diff --git a/server/app/api/chats.py b/server/app/api/chats.py index 9f27fa1..12c3d01 100644 --- a/server/app/api/chats.py +++ b/server/app/api/chats.py @@ -38,26 +38,49 @@ def _normalize_title(title: str | None, messages: list[dict[str, object]]) -> st return "Untitled chat" -def _ensure_document_in_workspace(db: Session, workspace_id: uuid.UUID, document_id: uuid.UUID | None) -> None: - if document_id is None: +def _selected_document_ids( + document_id: uuid.UUID | None, + document_ids: list[uuid.UUID] | None, +) -> list[uuid.UUID]: + selected: list[uuid.UUID] = [] + if document_ids: + selected.extend(document_ids) + if document_id is not None and document_id not in selected: + selected.insert(0, document_id) + return selected + + +def _ensure_documents_in_workspace( + db: Session, workspace_id: uuid.UUID, document_ids: list[uuid.UUID] +) -> None: + if not document_ids: return - exists = db.execute( - text( - """ - SELECT 1 + rows = ( + db.execute( + text(""" + SELECT id FROM documents - WHERE id = :document_id + WHERE id IN :document_ids AND workspace_id = :workspace_id - LIMIT 1 - """ - ), - {"document_id": document_id, "workspace_id": workspace_id}, - ).scalar_one_or_none() - if exists is None: + """).bindparams(bindparam("document_ids", expanding=True)), + {"document_ids": document_ids, "workspace_id": workspace_id}, + ) + .scalars() + .all() + ) + found = set(rows) + missing = [document_id for document_id in document_ids if document_id not in found] + if missing: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Document not found") +def _ensure_document_in_workspace( + db: Session, workspace_id: uuid.UUID, document_id: uuid.UUID | None +) -> None: + _ensure_documents_in_workspace(db, workspace_id, [document_id] if document_id else []) + + def _chat_sessions_table_exists(db: Session) -> bool: exists = db.execute(text("SELECT to_regclass('public.chat_sessions')")).scalar_one_or_none() return exists is not None @@ -78,7 +101,9 @@ def _build_payload( ) -def _parse_payload(raw: str | None, fallback_started_at: datetime) -> tuple[list[dict[str, object]], datetime, datetime | None]: +def _parse_payload( + raw: str | None, fallback_started_at: datetime +) -> tuple[list[dict[str, object]], datetime, datetime | None]: if not raw: return [], fallback_started_at, None @@ -122,15 +147,17 @@ def create_chat_session( user: AuthenticatedUser = Depends(get_current_user), db: Session = Depends(get_db), ) -> ChatSessionMetadata: - _ensure_document_in_workspace(db, workspace_id, payload.document_id) + document_ids = _selected_document_ids(payload.document_id, payload.document_ids) + primary_document_id = document_ids[0] if document_ids else None + _ensure_documents_in_workspace(db, workspace_id, document_ids) messages = [message.model_dump(mode="json") for message in payload.messages] title = _normalize_title(payload.title, messages) now_utc = datetime.now(UTC) if not _chat_sessions_table_exists(db): - row = db.execute( - text( - """ + row = ( + db.execute( + text(""" INSERT INTO query_logs ( workspace_id, user_id, @@ -167,58 +194,68 @@ def create_chat_session( NOW() ) RETURNING id, query_text AS title, created_at - """ - ), - { - "workspace_id": workspace_id, - "user_id": user.user_id, - "query_text": title, - "documents_searched": [payload.document_id] if payload.document_id else [], - "retrieved_chunk_ids": [], - "chunk_scores": [], - "answer_text": _build_payload(messages=messages, started_at=now_utc, ended_at=None), - "error_message": QUERY_LOG_CHAT_MARKER, - }, - ).mappings().one() + """), + { + "workspace_id": workspace_id, + "user_id": user.user_id, + "query_text": title, + "documents_searched": document_ids, + "retrieved_chunk_ids": [], + "chunk_scores": [], + "answer_text": _build_payload( + messages=messages, started_at=now_utc, ended_at=None + ), + "error_message": QUERY_LOG_CHAT_MARKER, + }, + ) + .mappings() + .one() + ) db.commit() return ChatSessionMetadata( id=row["id"], title=row["title"], - document_id=payload.document_id, + document_id=primary_document_id, + document_ids=document_ids, created_at=row["created_at"], updated_at=row["created_at"], ended_at=None, ) - row = db.execute( - text( - """ + row = ( + db.execute( + text(""" INSERT INTO chat_sessions ( workspace_id, document_id, + document_ids, title, messages, started_at, created_at, updated_at ) - VALUES (:workspace_id, :document_id, :title, :messages, NOW(), NOW(), NOW()) - RETURNING id, title, document_id, created_at, updated_at, ended_at - """ - ).bindparams(bindparam("messages", type_=JSONB)), - { - "workspace_id": workspace_id, - "document_id": payload.document_id, - "title": title, - "messages": messages, - }, - ).mappings().one() + VALUES (:workspace_id, :document_id, :document_ids, :title, :messages, NOW(), NOW(), NOW()) + RETURNING id, title, document_id, document_ids, created_at, updated_at, ended_at + """).bindparams(bindparam("messages", type_=JSONB)), + { + "workspace_id": workspace_id, + "document_id": primary_document_id, + "document_ids": document_ids, + "title": title, + "messages": messages, + }, + ) + .mappings() + .one() + ) db.commit() return ChatSessionMetadata( id=row["id"], title=row["title"], document_id=row["document_id"], + document_ids=list(row["document_ids"] or []), created_at=row["created_at"], updated_at=row["updated_at"], ended_at=row["ended_at"], @@ -233,51 +270,72 @@ def update_chat_session( db: Session = Depends(get_db), ) -> ChatSessionMetadata: if not _chat_sessions_table_exists(db): - existing = db.execute( - text( - """ + existing = ( + db.execute( + text(""" SELECT id, query_text, documents_searched, answer_text, created_at FROM query_logs WHERE id = :session_id AND workspace_id = :workspace_id AND error_message = :marker LIMIT 1 - """ - ), - { - "session_id": session_id, - "workspace_id": workspace_id, - "marker": QUERY_LOG_CHAT_MARKER, - }, - ).mappings().first() + """), + { + "session_id": session_id, + "workspace_id": workspace_id, + "marker": QUERY_LOG_CHAT_MARKER, + }, + ) + .mappings() + .first() + ) if existing is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat session not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Chat session not found" + ) - old_messages, started_at, old_ended_at = _parse_payload(existing["answer_text"], existing["created_at"]) - messages = [message.model_dump(mode="json") for message in payload.messages] if payload.messages is not None else old_messages + old_messages, started_at, old_ended_at = _parse_payload( + existing["answer_text"], existing["created_at"] + ) + messages = ( + [message.model_dump(mode="json") for message in payload.messages] + if payload.messages is not None + else old_messages + ) title = _normalize_title(payload.title, messages) ended_at = datetime.now(UTC) if payload.ended else old_ended_at + old_document_ids = list(existing["documents_searched"] or []) + document_ids = ( + _selected_document_ids(payload.document_id, payload.document_ids) or old_document_ids + ) + _ensure_documents_in_workspace(db, workspace_id, document_ids) - row = db.execute( - text( - """ + row = ( + db.execute( + text(""" UPDATE query_logs SET query_text = :query_text, + documents_searched = :documents_searched, answer_text = :answer_text WHERE id = :session_id AND workspace_id = :workspace_id AND error_message = :marker RETURNING id, query_text AS title, documents_searched, created_at - """ - ), - { - "session_id": session_id, - "workspace_id": workspace_id, - "marker": QUERY_LOG_CHAT_MARKER, - "query_text": title, - "answer_text": _build_payload(messages=messages, started_at=started_at, ended_at=ended_at), - }, - ).mappings().one() + """), + { + "session_id": session_id, + "workspace_id": workspace_id, + "marker": QUERY_LOG_CHAT_MARKER, + "query_text": title, + "documents_searched": document_ids, + "answer_text": _build_payload( + messages=messages, started_at=started_at, ended_at=ended_at + ), + }, + ) + .mappings() + .one() + ) db.commit() docs = list(row["documents_searched"] or []) @@ -285,23 +343,26 @@ def update_chat_session( id=row["id"], title=row["title"], document_id=docs[0] if docs else None, + document_ids=docs, created_at=row["created_at"], updated_at=datetime.now(UTC), ended_at=ended_at, ) - existing = db.execute( - text( - """ - SELECT id, document_id, messages + existing = ( + db.execute( + text(""" + SELECT id, document_id, document_ids, messages FROM chat_sessions WHERE id = :session_id AND workspace_id = :workspace_id LIMIT 1 - """ - ), - {"session_id": session_id, "workspace_id": workspace_id}, - ).mappings().first() + """), + {"session_id": session_id, "workspace_id": workspace_id}, + ) + .mappings() + .first() + ) if existing is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat session not found") @@ -311,34 +372,48 @@ def update_chat_session( else existing["messages"] ) title = _normalize_title(payload.title, messages) + document_ids = _selected_document_ids(payload.document_id, payload.document_ids) or list( + existing["document_ids"] or [] + ) + if not document_ids and existing["document_id"]: + document_ids = [existing["document_id"]] + primary_document_id = document_ids[0] if document_ids else None + _ensure_documents_in_workspace(db, workspace_id, document_ids) - row = db.execute( - text( - """ + row = ( + db.execute( + text(""" UPDATE chat_sessions SET title = :title, + document_id = :document_id, + document_ids = :document_ids, messages = :messages, ended_at = CASE WHEN :ended THEN NOW() ELSE ended_at END, updated_at = NOW() WHERE id = :session_id AND workspace_id = :workspace_id - RETURNING id, title, document_id, created_at, updated_at, ended_at - """ - ).bindparams(bindparam("messages", type_=JSONB)), - { - "session_id": session_id, - "workspace_id": workspace_id, - "title": title, - "messages": messages, - "ended": payload.ended, - }, - ).mappings().one() + RETURNING id, title, document_id, document_ids, created_at, updated_at, ended_at + """).bindparams(bindparam("messages", type_=JSONB)), + { + "session_id": session_id, + "workspace_id": workspace_id, + "title": title, + "document_id": primary_document_id, + "document_ids": document_ids, + "messages": messages, + "ended": payload.ended, + }, + ) + .mappings() + .one() + ) db.commit() return ChatSessionMetadata( id=row["id"], title=row["title"], document_id=row["document_id"], + document_ids=list(row["document_ids"] or []), created_at=row["created_at"], updated_at=row["updated_at"], ended_at=row["ended_at"], @@ -382,18 +457,20 @@ def list_chat_sessions( ).scalar_one() or 0 ) - rows = db.execute( - text( - f""" + rows = ( + db.execute( + text(f""" SELECT id, query_text AS title, documents_searched, answer_text, created_at FROM query_logs WHERE {where_sql} ORDER BY created_at DESC LIMIT :limit OFFSET :offset - """ - ), - params, - ).mappings().all() + """), + params, + ) + .mappings() + .all() + ) items: list[ChatSessionListItem] = [] for row in rows: @@ -404,6 +481,7 @@ def list_chat_sessions( id=row["id"], title=row["title"], document_id=docs[0] if docs else None, + document_ids=docs, updated_at=row["created_at"], ended_at=ended_at, ) @@ -419,7 +497,7 @@ def list_chat_sessions( "offset": offset, } if document_id is not None: - where_sql += " AND document_id = :document_id" + where_sql += " AND (:document_id = ANY(document_ids) OR document_id = :document_id)" total = int( db.execute( @@ -428,18 +506,20 @@ def list_chat_sessions( ).scalar_one() or 0 ) - rows = db.execute( - text( - f""" - SELECT id, title, document_id, updated_at, ended_at + rows = ( + db.execute( + text(f""" + SELECT id, title, document_id, document_ids, updated_at, ended_at FROM chat_sessions WHERE {where_sql} ORDER BY updated_at DESC LIMIT :limit OFFSET :offset - """ - ), - params, - ).mappings().all() + """), + params, + ) + .mappings() + .all() + ) return ChatSessionListResponse( items=[ @@ -447,6 +527,9 @@ def list_chat_sessions( id=row["id"], title=row["title"], document_id=row["document_id"], + document_ids=list( + row["document_ids"] or ([row["document_id"]] if row["document_id"] else []) + ), updated_at=row["updated_at"], ended_at=row["ended_at"], ) @@ -463,25 +546,29 @@ def get_chat_session( db: Session = Depends(get_db), ) -> ChatSessionDetailResponse: if not _chat_sessions_table_exists(db): - row = db.execute( - text( - """ + row = ( + db.execute( + text(""" SELECT id, query_text AS title, documents_searched, answer_text, created_at FROM query_logs WHERE id = :session_id AND workspace_id = :workspace_id AND error_message = :marker LIMIT 1 - """ - ), - { - "session_id": session_id, - "workspace_id": workspace_id, - "marker": QUERY_LOG_CHAT_MARKER, - }, - ).mappings().first() + """), + { + "session_id": session_id, + "workspace_id": workspace_id, + "marker": QUERY_LOG_CHAT_MARKER, + }, + ) + .mappings() + .first() + ) if row is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat session not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Chat session not found" + ) docs = list(row["documents_searched"] or []) messages, started_at, ended_at = _parse_payload(row["answer_text"], row["created_at"]) @@ -489,23 +576,26 @@ def get_chat_session( id=row["id"], title=row["title"], document_id=docs[0] if docs else None, + document_ids=docs, messages=messages, started_at=started_at, ended_at=ended_at, ) - row = db.execute( - text( - """ - SELECT id, title, document_id, messages, started_at, ended_at + row = ( + db.execute( + text(""" + SELECT id, title, document_id, document_ids, messages, started_at, ended_at FROM chat_sessions WHERE id = :session_id AND workspace_id = :workspace_id LIMIT 1 - """ - ), - {"session_id": session_id, "workspace_id": workspace_id}, - ).mappings().first() + """), + {"session_id": session_id, "workspace_id": workspace_id}, + ) + .mappings() + .first() + ) if row is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat session not found") @@ -513,6 +603,9 @@ def get_chat_session( id=row["id"], title=row["title"], document_id=row["document_id"], + document_ids=list( + row["document_ids"] or ([row["document_id"]] if row["document_id"] else []) + ), messages=row["messages"] or [], started_at=row["started_at"], ended_at=row["ended_at"], diff --git a/server/app/config.py b/server/app/config.py index 4f35ecd..28eb8e4 100644 --- a/server/app/config.py +++ b/server/app/config.py @@ -10,8 +10,8 @@ class Settings(BaseSettings): api_port: int = 8000 DATABASE_URL: str = "postgresql://postgres:postgres@localhost:5432/enterprise_rag" - DB_POOL_SIZE: int = 1 - DB_MAX_OVERFLOW: int = 0 + DB_POOL_SIZE: int = 3 + DB_MAX_OVERFLOW: int = 2 DB_POOL_TIMEOUT_SECONDS: int = 30 REDIS_URL: str = "redis://localhost:6379/0" diff --git a/server/app/core/retrieval.py b/server/app/core/retrieval.py index 53940cb..4f2449d 100644 --- a/server/app/core/retrieval.py +++ b/server/app/core/retrieval.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +import math from typing import Any import uuid @@ -52,30 +53,49 @@ def retrieve_top_k_chunks_for_documents( if not document_ids: return [] vector_literal = _embedding_to_vector_literal(query_embedding) - sql = text( - """ + per_document_k = ( + top_k + if len(document_ids) == 1 + else max(1, min(top_k, math.ceil(top_k / len(document_ids)) + 2)) + ) + sql = text(""" + WITH ranked_chunks AS ( + SELECT + c.id AS chunk_id, + c.document_id AS document_id, + c.page_start AS page_number, + c.content AS chunk_text, + COALESCE(dp.content, c.content) AS page_text, + c.token_count AS token_count, + (ce.embedding <=> CAST(:query_embedding AS vector)) AS score, + ROW_NUMBER() OVER ( + PARTITION BY c.document_id + ORDER BY ce.embedding <=> CAST(:query_embedding AS vector) + ) AS document_rank + FROM chunk_embeddings ce + JOIN chunks c ON c.id = ce.chunk_id + LEFT JOIN document_pages dp + ON dp.workspace_id = :workspace_id + AND dp.document_id = c.document_id + AND dp.page_number = c.page_start + WHERE ce.workspace_id = :workspace_id + AND ce.document_id IN :document_ids + AND c.workspace_id = :workspace_id + AND c.document_id IN :document_ids + ) SELECT - c.id AS chunk_id, - c.document_id AS document_id, - c.page_start AS page_number, - c.content AS chunk_text, - COALESCE(dp.content, c.content) AS page_text, - c.token_count AS token_count, - (ce.embedding <=> CAST(:query_embedding AS vector)) AS score - FROM chunk_embeddings ce - JOIN chunks c ON c.id = ce.chunk_id - LEFT JOIN document_pages dp - ON dp.workspace_id = :workspace_id - AND dp.document_id = c.document_id - AND dp.page_number = c.page_start - WHERE ce.workspace_id = :workspace_id - AND ce.document_id IN :document_ids - AND c.workspace_id = :workspace_id - AND c.document_id IN :document_ids - ORDER BY ce.embedding <=> CAST(:query_embedding AS vector) + chunk_id, + document_id, + page_number, + chunk_text, + page_text, + token_count, + score + FROM ranked_chunks + WHERE document_rank <= :per_document_k + ORDER BY score LIMIT :top_k - """ - ).bindparams(bindparam("document_ids", expanding=True)) + """).bindparams(bindparam("document_ids", expanding=True)) rows: list[dict[str, Any]] = ( db.execute( sql, @@ -83,6 +103,7 @@ def retrieve_top_k_chunks_for_documents( "workspace_id": workspace_id, "document_ids": document_ids, "query_embedding": vector_literal, + "per_document_k": per_document_k, "top_k": top_k, }, ) diff --git a/server/app/db/models.py b/server/app/db/models.py index 4c84c81..e246598 100644 --- a/server/app/db/models.py +++ b/server/app/db/models.py @@ -13,6 +13,7 @@ func, ) from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.orm import Mapped, mapped_column from app.db.session import Base @@ -126,6 +127,9 @@ class ChatSession(Base): document_id: Mapped[uuid.UUID | None] = mapped_column( UUID(as_uuid=True), ForeignKey("documents.id", ondelete="SET NULL"), nullable=True ) + document_ids: Mapped[list[uuid.UUID]] = mapped_column( + ARRAY(UUID(as_uuid=True)), nullable=False, default=list + ) title: Mapped[str] = mapped_column(Text, nullable=False, default="") messages: Mapped[list[dict[str, object]]] = mapped_column(JSONB, nullable=False, default=list) started_at: Mapped[datetime] = mapped_column( diff --git a/server/app/schemas/chat.py b/server/app/schemas/chat.py index 10b0dba..1a8a4c2 100644 --- a/server/app/schemas/chat.py +++ b/server/app/schemas/chat.py @@ -16,11 +16,14 @@ class ChatMessage(BaseModel): class ChatSessionCreateRequest(BaseModel): document_id: uuid.UUID | None = None + document_ids: list[uuid.UUID] | None = None title: str | None = None messages: list[ChatMessage] = Field(default_factory=list) class ChatSessionUpdateRequest(BaseModel): + document_id: uuid.UUID | None = None + document_ids: list[uuid.UUID] | None = None title: str | None = None messages: list[ChatMessage] | None = None ended: bool = False @@ -30,6 +33,7 @@ class ChatSessionMetadata(BaseModel): id: uuid.UUID title: str document_id: uuid.UUID | None = None + document_ids: list[uuid.UUID] = Field(default_factory=list) created_at: datetime updated_at: datetime ended_at: datetime | None = None @@ -39,6 +43,7 @@ class ChatSessionListItem(BaseModel): id: uuid.UUID title: str document_id: uuid.UUID | None = None + document_ids: list[uuid.UUID] = Field(default_factory=list) updated_at: datetime ended_at: datetime | None = None @@ -52,6 +57,7 @@ class ChatSessionDetailResponse(BaseModel): id: uuid.UUID title: str document_id: uuid.UUID | None = None + document_ids: list[uuid.UUID] = Field(default_factory=list) messages: list[ChatMessage] started_at: datetime ended_at: datetime | None = None diff --git a/server/migrations/versions/20260612_0002_add_chat_session_document_ids.py b/server/migrations/versions/20260612_0002_add_chat_session_document_ids.py new file mode 100644 index 0000000..421b736 --- /dev/null +++ b/server/migrations/versions/20260612_0002_add_chat_session_document_ids.py @@ -0,0 +1,36 @@ +"""add multi-document chat session selection + +Revision ID: 20260612_0002 +Revises: 20260217_0001 +Create Date: 2026-06-12 +""" + +from alembic import op + + +revision = "20260612_0002" +down_revision = "20260217_0001" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.execute( + """ + ALTER TABLE chat_sessions + ADD COLUMN IF NOT EXISTS document_ids UUID[] NOT NULL DEFAULT '{}'; + + UPDATE chat_sessions + SET document_ids = ARRAY[document_id]::UUID[] + WHERE document_id IS NOT NULL + AND COALESCE(array_length(document_ids, 1), 0) = 0; + + CREATE INDEX IF NOT EXISTS idx_chat_sessions_document_ids + ON chat_sessions USING GIN (document_ids); + """ + ) + + +def downgrade() -> None: + op.execute("DROP INDEX IF EXISTS idx_chat_sessions_document_ids") + op.execute("ALTER TABLE chat_sessions DROP COLUMN IF EXISTS document_ids")