Skip to content

Commit 55b17c1

Browse files
committed
make ingest/delete idempotent and scale celery workers with configurable concurrency
1 parent 638c47a commit 55b17c1

File tree

16 files changed

+477
-21
lines changed

16 files changed

+477
-21
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@ The UI provides a query page for asking questions about your documents and an up
8484
- **Redis** — Celery broker, rate limiting, cache payloads
8585
- **GCS** — document upload storage
8686

87+
### Scaling the worker
88+
89+
The Celery worker uses the prefork pool with process-level concurrency. Each worker process has its own KG service, Neo4j driver, postgres engine, and LlamaIndex indexes — process isolation gives crash isolation and avoids races in LlamaIndex internals.
90+
91+
- **Vertical** — tune the `CELERY_WORKER_CONCURRENCY` env var (default `4`) to change the number of worker processes per container. `WORKER_MAX_TASKS_PER_CHILD` (default `100`) recycles processes to bound memory leaks.
92+
- **Horizontal** — run multiple worker containers behind the same Redis broker.
93+
94+
Tasks use `task_acks_late` + `task_reject_on_worker_lost`, so a task is redelivered to another worker if its worker crashes mid-execution.
95+
8796
## Project Structure
8897

8998
```

services/query-engine/.env.example

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,5 @@ RATE_LIMIT_INGEST=10/minute
4242

4343
# Celery
4444
CELERY_BROKER_URL=redis://localhost:6379/0
45+
CELERY_WORKER_CONCURRENCY=4
46+
WORKER_MAX_TASKS_PER_CHILD=100

services/query-engine/app/api/v1/knowledge_graph.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from fastapi import APIRouter, Depends, Query, Request, UploadFile
55

66
from app.core.config import settings
7-
from app.core.errors import ServiceUnavailableError
7+
from app.core.errors import NotFoundError, ServiceUnavailableError
88
from app.core.rate_limit import limiter
99
from app.dependencies import get_kg_service, get_upload_service
1010
from app.models.knowledge_graph import (
@@ -52,13 +52,19 @@ async def list_documents(
5252
async def delete_document(
5353
request: Request,
5454
doc_id: str,
55+
service: KnowledgeGraphService = Depends(get_kg_service),
5556
) -> TaskAcceptedResponse:
5657
"""
5758
Submit a document deletion job for background processing.
5859
59-
Deletes the document from all storage layers (Neo4j, pgvector, docstore).
60-
Returns a task ID for polling. Retries automatically on partial failure.
60+
Validates that the document exists synchronously so a typoed doc_id
61+
returns 404 immediately. The actual deletion runs as a Celery task
62+
that deletes from all storage layers (Neo4j, pgvector, docstore).
63+
Returns a task ID for polling.
6164
"""
65+
if not await service.document_exists(doc_id):
66+
raise NotFoundError(detail=f"Document {doc_id} not found")
67+
6268
try:
6369
result = delete_document_task.delay(doc_id=doc_id)
6470
except Exception as exc:

services/query-engine/app/connectors/gcs.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Google Cloud Storage document connector."""
22

3+
import os
34
import tempfile
45
from collections.abc import Iterator
56
from pathlib import Path
@@ -77,6 +78,20 @@ def load_documents(self, config: dict[str, Any]) -> Iterator[Document]:
7778

7879
documents = reader.load_data()
7980

81+
# Attach a stable, fully qualified source path so downstream
82+
# consumers can derive deterministic doc_ids that survive task
83+
# retries and process restarts. SimpleDirectoryReader sets
84+
# metadata["file_path"] to the absolute temp-dir path, which
85+
# is not stable across runs — strip the temp prefix to get
86+
# the GCS-relative blob name.
87+
for doc in documents:
88+
abs_path = doc.metadata.get("file_path", "") if doc.metadata else ""
89+
if abs_path:
90+
rel_path = os.path.relpath(abs_path, str(tmp_path))
91+
doc.metadata["source_path"] = f"gs://{bucket_name}/{rel_path}"
92+
else:
93+
doc.metadata["source_path"] = f"gs://{bucket_name}/{prefix}"
94+
8095
logger.info(
8196
"gcs_documents_loaded",
8297
bucket=bucket_name,

services/query-engine/app/core/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class Settings(BaseSettings):
5959

6060
# Celery
6161
celery_broker_url: str = ""
62+
celery_worker_concurrency: int = 4
63+
worker_max_tasks_per_child: int = 100
6264

6365

6466
settings = Settings()

services/query-engine/app/core/gcs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
"""Singleton GCS client factory."""
22

33
import json
4+
from typing import Any
45

56
from google.cloud import storage as gcs_storage # type: ignore[import-untyped]
67

78
from app.core.config import Settings
89

9-
# NOTE: This lazy singleton is safe because the API server is single-threaded
10-
# async and the Celery worker runs with concurrency=1. If worker concurrency
11-
# is ever increased, this must be replaced with thread-safe init.
12-
_client: gcs_storage.Client | None = None
10+
# Lazy per-process singleton. Safe under FastAPI (single-threaded async) and
11+
# Celery's prefork pool (each worker process has its own copy of this global).
12+
_client: Any = None
1313

1414

1515
def get_gcs_client(config: Settings) -> gcs_storage.Client:

services/query-engine/app/core/postgres.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33
from sqlalchemy import Engine, create_engine
44

5-
# NOTE: This lazy singleton is safe because the API server is single-threaded
6-
# async and the Celery worker runs with concurrency=1. If worker concurrency
7-
# is ever increased, this must be replaced with thread-safe init.
5+
# Lazy per-process singleton. Safe under FastAPI (single-threaded async) and
6+
# Celery's prefork pool (each worker process has its own copy of this global).
87
_engine: Engine | None = None
98

109

services/query-engine/app/services/ingestion_pipeline.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Ingestion pipeline: connects document connectors to the KG service."""
22

33
import asyncio
4+
import hashlib
45
from typing import Any
56

67
import structlog
@@ -49,9 +50,27 @@ async def run(
4950

5051
for doc in documents:
5152
try:
53+
content = doc.get_content()
54+
# Derive a stable source identifier so the KG service can
55+
# produce a deterministic doc_id. This makes the task safe
56+
# to retry: a Celery redelivery (worker crash, OOM, time
57+
# limit) re-runs with the same source_id, which hashes to
58+
# the same doc_id, and the storage layers replace prior
59+
# state instead of creating duplicates.
60+
# The content hash is included so an in-place file
61+
# replacement (same path, new content) is treated as a
62+
# different document, not a stale duplicate.
63+
source_path = (
64+
doc.metadata.get("source_path")
65+
or doc.metadata.get("file_name")
66+
or "unknown"
67+
)
68+
content_hash = hashlib.sha256(content.encode()).hexdigest()
69+
source_id = f"{source_type.value}:{source_path}:{content_hash}"
5270
_doc_id, triplets = await self._kg_service.ingest(
53-
text=doc.get_content(),
71+
text=content,
5472
metadata=doc.metadata,
73+
source_id=source_id,
5574
)
5675
total_triplets += triplets
5776
ingested_count += 1

services/query-engine/app/services/knowledge_graph.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import asyncio
44
import hashlib
55
import time
6-
import uuid
76
from datetime import UTC, datetime
87
from functools import partial
98
from typing import Any
@@ -344,6 +343,22 @@ async def check_cache_health(self) -> dict[str, str] | None:
344343
kg_cache_up.set(1 if result.get("status") == "ok" else 0)
345344
return result
346345

346+
async def document_exists(self, doc_id: str) -> bool:
347+
"""Return True if a document with this doc_id exists in the docstore.
348+
349+
Cheap key lookup (single docstore query) used by the delete endpoint
350+
to validate input synchronously, so a typoed doc_id returns 404
351+
immediately instead of dispatching a background task that would
352+
eventually report failure.
353+
"""
354+
loop = asyncio.get_running_loop()
355+
356+
def _check() -> bool:
357+
info = self._storage_context.docstore.get_ref_doc_info(doc_id)
358+
return info is not None
359+
360+
return await loop.run_in_executor(None, _check)
361+
347362
async def list_documents(
348363
self,
349364
limit: int = 20,
@@ -536,14 +551,22 @@ def _delete_sync() -> list[str]:
536551
async def ingest(
537552
self,
538553
text: str,
554+
source_id: str,
539555
metadata: dict[str, Any] | None = None,
540556
) -> tuple[str, int]:
541557
"""
542558
Ingest a document into both KG and vector indexes.
543559
560+
`source_id` must be a stable identifier for this document (e.g.
561+
derived from the source path and content hash). The resulting
562+
`doc_id` is `sha256(source_id)`, which makes the call idempotent —
563+
re-running with the same source_id replaces any prior vector-store
564+
state for that document instead of creating duplicates. This makes
565+
the ingest path safe to retry after a Celery worker crash.
566+
544567
Returns a tuple of (document_id, triplet_count).
545568
"""
546-
doc_id = str(uuid.uuid4())
569+
doc_id = hashlib.sha256(source_id.encode()).hexdigest()
547570
# Store all metadata in the docstore (for grouping, display, etc.)
548571
# but exclude everything from LLM triplet extraction so it doesn't
549572
# pollute the knowledge graph.
@@ -581,6 +604,23 @@ def _stable_id(i: int, doc: Document) -> str:
581604
for node in nodes:
582605
node.excluded_llm_metadata_keys = list(node.metadata.keys())
583606

607+
# Idempotency guard: PGVectorStore.add() does NOT enforce
608+
# uniqueness on node_id, so a Celery retry of an ingest
609+
# would otherwise accumulate duplicate vector rows.
610+
# Explicitly purge any prior rows for this doc_id before
611+
# re-inserting. Safe no-op on the first run. Neo4j MERGE and
612+
# the postgres docstore (`allow_update=True`) handle their
613+
# own dedupe; the vector store is the only layer that needs
614+
# this.
615+
try:
616+
self._vector_index.vector_store.delete(ref_doc_id=doc_id)
617+
except Exception as exc:
618+
logger.warning(
619+
"vector_store_predelete_failed",
620+
doc_id=doc_id,
621+
error=str(exc),
622+
)
623+
584624
# Vector-first: embedding/pgvector write is more likely to fail
585625
# (external API call). If it fails, Neo4j is untouched.
586626
# If it succeeds and KG insert fails, we have embeddings without

services/query-engine/app/worker/celery_app.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,16 @@
1717
accept_content=["json"],
1818
result_expires=86400, # 24 hours
1919
task_track_started=True,
20-
worker_concurrency=1,
21-
worker_prefetch_multiplier=1,
20+
# Prefork pool (Celery default on Unix). Each worker process has its own
21+
# KG service, Neo4j driver, postgres engine, GCS client, and LlamaIndex
22+
# indexes — process isolation makes the lazy singletons safe and avoids
23+
# races in LlamaIndex internals. Scale horizontally by running more worker
24+
# containers.
25+
worker_concurrency=settings.celery_worker_concurrency,
26+
worker_prefetch_multiplier=1, # fair distribution for long-running tasks
27+
worker_max_tasks_per_child=settings.worker_max_tasks_per_child, # recycle to bound memory leaks
28+
task_acks_late=True, # redeliver task if worker crashes mid-execution
29+
task_reject_on_worker_lost=True, # pairs with acks_late for crash recovery
2230
task_soft_time_limit=180, # 3 minutes
2331
task_time_limit=240, # 4 minutes
2432
beat_schedule={

0 commit comments

Comments
 (0)