diff --git a/bench/common.py b/bench/common.py index 4058078..b8c0c5a 100644 --- a/bench/common.py +++ b/bench/common.py @@ -395,7 +395,7 @@ async def run_bench( name: str, items: Iterable[BenchItem], config: BenchConfig, - pool, + backend, ) -> BenchReport: """Orchestrate the full benchmark run.""" from pgkg.memory import Memory @@ -434,7 +434,7 @@ async def process_item(item: BenchItem) -> None: async with semaphore: ns = item.namespace - memory = Memory(pool, namespace=ns, extract_propositions=config.extract_propositions) + memory = Memory(backend, namespace=ns, extract_propositions=config.extract_propositions) # Ingest all conversation turns grouped by session_id sessions: dict[str, list[dict]] = {} diff --git a/bench/locomo.py b/bench/locomo.py index ed7268c..2fbf694 100644 --- a/bench/locomo.py +++ b/bench/locomo.py @@ -190,16 +190,21 @@ async def main() -> None: print(f"Loaded {len(items)} LoCoMo conversations") - from pgkg.db import pool_from_settings + from pgkg.backends.postgres import PostgresBackend + from pgkg.config import get_settings from bench.common import run_bench - async with pool_from_settings() as pool: + settings = get_settings() + backend = await PostgresBackend.create(settings.database_url) + try: report = await run_bench( name="locomo", items=items, config=config, - pool=pool, + backend=backend, ) + finally: + await backend.close() print(f"\nFinal: {report.accuracy:.1%} accuracy over {report.total} questions") diff --git a/bench/longmemeval.py b/bench/longmemeval.py index 6e52d7b..7a070a2 100644 --- a/bench/longmemeval.py +++ b/bench/longmemeval.py @@ -245,16 +245,21 @@ async def main() -> None: print(f"Loaded {len(items)} LongMemEval ({args.variant}) records") - from pgkg.db import pool_from_settings + from pgkg.backends.postgres import PostgresBackend + from pgkg.config import get_settings from bench.common import run_bench - async with pool_from_settings() as pool: + settings = get_settings() + backend = await PostgresBackend.create(settings.database_url) + try: report = await run_bench( name=f"longmemeval-{args.variant}", items=items, config=config, - pool=pool, + backend=backend, ) + finally: + await backend.close() print(f"\nFinal: {report.accuracy:.1%} accuracy over {report.total} questions") diff --git a/pgkg/__init__.py b/pgkg/__init__.py index c9c89b3..2505741 100644 --- a/pgkg/__init__.py +++ b/pgkg/__init__.py @@ -1,4 +1,5 @@ +from pgkg.backend import StorageBackend from pgkg.memory import Memory, Result from pgkg.config import MemoryConfig -__all__ = ["Memory", "MemoryConfig", "Result"] +__all__ = ["Memory", "MemoryConfig", "Result", "StorageBackend"] diff --git a/pgkg/api.py b/pgkg/api.py index dbc454f..7adf094 100644 --- a/pgkg/api.py +++ b/pgkg/api.py @@ -9,27 +9,27 @@ from pydantic import BaseModel from pgkg import ml +from pgkg.backends.postgres import PostgresBackend from pgkg.config import get_settings -from pgkg.db import make_pool, close_pool from pgkg.memory import Memory, IngestResult, Result -_pool = None +_backend: PostgresBackend | None = None _memory: Memory | None = None @asynccontextmanager async def lifespan(app: FastAPI): - global _pool, _memory + global _backend, _memory settings = get_settings() - _pool = await make_pool(settings.database_url) + _backend = await PostgresBackend.create(settings.database_url) _memory = Memory( - _pool, + _backend, namespace=settings.default_namespace, extract_propositions=settings.extract_propositions, ) yield - if _pool: - await close_pool(_pool) + if _backend: + await _backend.close() app = FastAPI(title="pgkg", lifespan=lifespan) @@ -90,15 +90,7 @@ async def forget(req: ForgetRequest) -> Response: @app.get("/health") async def health() -> dict: - db_ok = False - if _pool: - try: - async with _pool.acquire() as conn: - await conn.fetchval("SELECT 1") - db_ok = True - except Exception: - pass - + db_ok = await _backend.health_check() if _backend else False return { "status": "ok", "db": db_ok, diff --git a/pgkg/backend.py b/pgkg/backend.py new file mode 100644 index 0000000..1760087 --- /dev/null +++ b/pgkg/backend.py @@ -0,0 +1,255 @@ +"""StorageBackend protocol — the abstraction boundary between pgkg and its storage layer.""" +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime +from typing import Protocol, runtime_checkable +from uuid import UUID + + +@dataclass(frozen=True) +class ScoredId: + """A database row ID with an associated relevance score.""" + id: UUID + score: float + + +@dataclass(frozen=True) +class Candidate: + """A proposition candidate with full retrieval metadata.""" + proposition_id: UUID + text: str + rrf_score: float + adjusted_score: float + source_kind: str # "kw" | "vec" | "both" | "graph" + chunk_id: UUID | None + subject_id: UUID | None + predicate: str | None + object_id: UUID | None + asserted_at: datetime | None = None + + +@dataclass(frozen=True) +class PropositionRow: + """Full proposition data needed for scoring and reranking.""" + id: UUID + text: str + embedding: list[float] | None + chunk_id: UUID | None + subject_id: UUID | None + predicate: str | None + object_id: UUID | None + confidence: float + access_count: int + last_accessed_at: str # ISO 8601 timestamp + namespace: str + + +@dataclass(frozen=True) +class StoredProposition: + """Data needed to store a proposition.""" + text: str + embedding: list[float] + subject_id: UUID | None + predicate: str | None + object_id: UUID | None + object_literal: str | None + chunk_id: UUID | None + namespace: str + session_id: str | None + confidence: float = 1.0 + metadata: dict | None = None + asserted_at: datetime | None = None + + +@dataclass(frozen=True) +class StoredDocument: + """Data needed to store a document.""" + id: UUID | None = None + source: str | None = None + namespace: str = "default" + + +@dataclass(frozen=True) +class StoredChunk: + """Data needed to store a chunk.""" + document_id: UUID + text: str + id: UUID | None = None + span_start: int | None = None + span_end: int | None = None + asserted_at: datetime | None = None + + +@runtime_checkable +class StorageBackend(Protocol): + """Abstract storage backend for pgkg. + + Implementations must support all methods. The ``fused_search`` method + is optional — backends that cannot fuse retrieval in a single query + should return ``None``, and the caller falls back to primitive + orchestration. + """ + + # --- Lifecycle ------------------------------------------------------- + + async def apply_migrations(self) -> None: + """Apply schema migrations. Called once during initialisation.""" + ... + + async def close(self) -> None: + """Release all resources (connections, pools, file handles).""" + ... + + # --- Retrieval primitives -------------------------------------------- + + async def keyword_search( + self, + query: str, + k: int, + namespace: str, + session_id: str | None = None, + ) -> list[ScoredId]: + """Full-text search. Returns (id, score) pairs ranked by relevance.""" + ... + + async def vector_search( + self, + embedding: list[float], + k: int, + namespace: str, + session_id: str | None = None, + ) -> list[ScoredId]: + """Vector similarity search. Returns (id, score) pairs ranked by closeness.""" + ... + + async def graph_neighbors( + self, + entity_ids: list[UUID], + namespace: str, + limit: int = 100, + ) -> list[ScoredId]: + """One-hop graph expansion from seed entities. Returns proposition IDs.""" + ... + + async def fused_search( + self, + query_text: str, + query_embedding: list[float], + k_retrieve: int, + k_initial: int, + namespace: str, + session_id: str | None, + recency_half_life_days: float, + expand_graph: bool, + rrf_k: int, + ) -> list[Candidate] | None: + """Push-down fused retrieval (optional optimisation). + + Backends that can execute the full retrieval pipeline in a single + query (e.g. the Postgres CTE) return results here. Other backends + return ``None`` and the caller orchestrates via the primitive methods. + """ + return None + + # --- Data resolution ------------------------------------------------- + + async def get_propositions( + self, + proposition_ids: list[UUID], + ) -> list[PropositionRow]: + """Fetch full proposition data for the given IDs.""" + ... + + async def get_proposition_entity_ids( + self, + proposition_ids: list[UUID], + ) -> dict[UUID, list[UUID]]: + """Return ``{proposition_id: [entity_ids]}`` for the given propositions.""" + ... + + # --- Ingest primitives ----------------------------------------------- + + async def store_document(self, doc: StoredDocument) -> UUID: + """Insert a document. Returns its ID.""" + ... + + async def store_chunk(self, chunk: StoredChunk) -> UUID: + """Insert a chunk. Returns its ID.""" + ... + + async def store_proposition(self, prop: StoredProposition) -> UUID: + """Insert a proposition. Returns its ID.""" + ... + + async def link_entity( + self, + namespace: str, + name: str, + entity_type: str, + embedding: list[float], + threshold: float = 0.85, + ) -> UUID: + """Find or create an entity by name/type/embedding similarity.""" + ... + + async def store_edge( + self, + src_entity: UUID, + dst_entity: UUID, + predicate: str, + proposition_id: UUID, + ) -> None: + """Create an edge between two entities via a proposition.""" + ... + + # --- Access tracking ------------------------------------------------- + + async def bump_access(self, proposition_ids: list[UUID]) -> None: + """Increment access_count and update last_accessed_at.""" + ... + + # --- Entity resolution ----------------------------------------------- + + async def resolve_entity_names( + self, entity_ids: list[UUID], + ) -> dict[UUID, str]: + """Map entity IDs to names. Returns ``{id: name}`` for found entities.""" + ... + + # --- Forget ---------------------------------------------------------- + + async def forget( + self, + proposition_id: UUID, + supersede_with: UUID | None = None, + ) -> None: + """Mark a proposition as superseded.""" + ... + + # --- Cache (optional) ------------------------------------------------ + + async def cache_get(self, cache_key: str) -> list[dict] | None: + """Retrieve cached extraction results as raw dicts. + + Returns ``None`` on cache miss. The caller is responsible for + deserialising dicts into ``Proposition`` objects. + """ + return None + + async def cache_put( + self, + cache_key: str, + chunk_hash: str, + extractor_model: str, + prompt_version: str, + propositions: list[dict], + ) -> None: + """Store extraction results in cache as raw dicts.""" + ... + + # --- Health ---------------------------------------------------------- + + async def health_check(self) -> bool: + """Return True if the backend is healthy (can serve queries).""" + ... diff --git a/pgkg/backends/__init__.py b/pgkg/backends/__init__.py new file mode 100644 index 0000000..0fd7228 --- /dev/null +++ b/pgkg/backends/__init__.py @@ -0,0 +1,28 @@ +"""Backend factory — construct a StorageBackend from configuration.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pgkg.backend import StorageBackend + + +async def make_backend(backend_type: str = "postgres", **kwargs) -> "StorageBackend": + """Construct and initialise a storage backend. + + Parameters + ---------- + backend_type: + One of ``"postgres"`` (default). Future values: ``"sqlite"``. + **kwargs: + Backend-specific arguments. For ``"postgres"``: ``dsn`` (str | None). + When *dsn* is omitted or ``None``, pgserver auto-starts an embedded + Postgres instance. + """ + if backend_type == "postgres": + from pgkg.backends.postgres import PostgresBackend + + dsn: str | None = kwargs.pop("dsn", None) + return await PostgresBackend.create(dsn) + + raise ValueError(f"Unknown backend type: {backend_type!r}") diff --git a/pgkg/backends/postgres.py b/pgkg/backends/postgres.py new file mode 100644 index 0000000..36de2eb --- /dev/null +++ b/pgkg/backends/postgres.py @@ -0,0 +1,460 @@ +"""PostgresBackend — wraps asyncpg pool and Postgres SQL functions.""" +from __future__ import annotations + +import json +import pathlib +from uuid import UUID + +import asyncpg +from pgvector.asyncpg import register_vector + +from pgkg.backend import ( + Candidate, + PropositionRow, + ScoredId, + StoredChunk, + StoredDocument, + StoredProposition, +) + +MIGRATIONS_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "migrations" + + +def _vec_literal(emb: list[float]) -> str: + """Format a vector as a Postgres literal string ``[a,b,c,...]``.""" + return "[" + ",".join(str(v) for v in emb) + "]" + + +class PostgresBackend: + """StorageBackend backed by asyncpg + pgvector. + + Delegates search to the ``pgkg_search()`` CTE, entity linking to + ``pgkg_link_entity()``, and access bumping to ``pgkg_bump_access()``. + """ + + def __init__(self, pool: asyncpg.Pool, *, dsn: str | None = None) -> None: + self._pool = pool + self._dsn = dsn + + @classmethod + async def create(cls, dsn: str | None = None) -> PostgresBackend: + """Create a backend with a fresh connection pool. + + When *dsn* is ``None``, an embedded Postgres is started + automatically via pgserver (requires ``uv sync --extra embedded``). + """ + if dsn is None: + from pgkg.embedded import get_dsn + dsn = get_dsn() + pool = await asyncpg.create_pool( + dsn, min_size=1, max_size=10, init=_init_connection, + ) + return cls(pool, dsn=dsn) # type: ignore[arg-type] + + @property + def pool(self) -> asyncpg.Pool: + """Expose the pool for health checks and legacy callers.""" + return self._pool + + # --- Lifecycle ------------------------------------------------------- + + async def apply_migrations(self) -> None: + async with self._pool.acquire() as conn: + await self._apply_migrations_with_conn(conn) + + async def _apply_migrations_with_conn(self, conn: asyncpg.Connection) -> None: + await conn.execute( + "CREATE TABLE IF NOT EXISTS pgkg_schema_migrations (" + " filename TEXT PRIMARY KEY," + " applied_at TIMESTAMPTZ NOT NULL DEFAULT now()" + ")" + ) + applied = { + r["filename"] + for r in await conn.fetch("SELECT filename FROM pgkg_schema_migrations") + } + for migration in sorted(MIGRATIONS_DIR.glob("*.sql")): + if migration.name in applied: + continue + async with conn.transaction(): + await conn.execute(migration.read_text()) + await conn.execute( + "INSERT INTO pgkg_schema_migrations (filename) VALUES ($1)", + migration.name, + ) + + async def close(self) -> None: + await self._pool.close() + + # --- Retrieval primitives -------------------------------------------- + + async def keyword_search( + self, + query: str, + k: int, + namespace: str, + session_id: str | None = None, + ) -> list[ScoredId]: + async with self._pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT p.id, ts_rank_cd(p.tsv, plainto_tsquery('english', $1)) AS score + FROM propositions p + WHERE p.tsv @@ plainto_tsquery('english', $1) + AND p.namespace = $2 + AND p.superseded_by IS NULL + AND ($3::text IS NULL OR p.session_id = $3 OR p.session_id IS NULL) + ORDER BY score DESC + LIMIT $4 + """, + query, namespace, session_id, k, + ) + return [ScoredId(id=r["id"], score=float(r["score"])) for r in rows] + + async def vector_search( + self, + embedding: list[float], + k: int, + namespace: str, + session_id: str | None = None, + ) -> list[ScoredId]: + vec_lit = _vec_literal(embedding) + async with self._pool.acquire() as conn: + rows = await conn.fetch( + f""" + SELECT p.id, (1.0 - (p.embedding <=> '{vec_lit}'::vector)) AS score + FROM propositions p + WHERE p.embedding IS NOT NULL + AND p.namespace = $1 + AND p.superseded_by IS NULL + AND ($2::text IS NULL OR p.session_id = $2 OR p.session_id IS NULL) + ORDER BY p.embedding <=> '{vec_lit}'::vector + LIMIT $3 + """, + namespace, session_id, k, + ) + return [ScoredId(id=r["id"], score=float(r["score"])) for r in rows] + + async def graph_neighbors( + self, + entity_ids: list[UUID], + namespace: str, + limit: int = 100, + ) -> list[ScoredId]: + if not entity_ids: + return [] + async with self._pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT DISTINCT e.proposition_id AS id, 1.0 AS score + FROM edges e + JOIN propositions p ON p.id = e.proposition_id + WHERE (e.src_entity = ANY($1::uuid[]) OR e.dst_entity = ANY($1::uuid[])) + AND p.namespace = $2 + AND p.superseded_by IS NULL + LIMIT $3 + """, + entity_ids, namespace, limit, + ) + return [ScoredId(id=r["id"], score=float(r["score"])) for r in rows] + + async def fused_search( + self, + query_text: str, + query_embedding: list[float], + k_retrieve: int, + k_initial: int, + namespace: str, + session_id: str | None, + recency_half_life_days: float, + expand_graph: bool, + rrf_k: int, + ) -> list[Candidate] | None: + vec_lit = _vec_literal(query_embedding) + async with self._pool.acquire() as conn: + rows = await conn.fetch( + f""" + SELECT proposition_id, text, embedding, rrf_score, adjusted_score, + source_kind, chunk_id, subject_id, predicate, object_id, asserted_at + FROM pgkg_search($1, '{vec_lit}'::vector, + $2, $3, $4, $5, $6, $7, $8) + """, + query_text, + k_retrieve, + k_initial, + namespace, + session_id, + recency_half_life_days, + expand_graph, + rrf_k, + ) + return [ + Candidate( + proposition_id=r["proposition_id"], + text=r["text"], + rrf_score=float(r["rrf_score"]), + adjusted_score=float(r["adjusted_score"]), + source_kind=r["source_kind"], + chunk_id=r["chunk_id"], + subject_id=r["subject_id"], + predicate=r["predicate"], + object_id=r["object_id"], + asserted_at=r["asserted_at"], + ) + for r in rows + ] + + # --- Data resolution ------------------------------------------------- + + async def get_propositions( + self, + proposition_ids: list[UUID], + ) -> list[PropositionRow]: + if not proposition_ids: + return [] + async with self._pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT id, text, embedding, chunk_id, subject_id, predicate, + object_id, confidence, access_count, + last_accessed_at::text, namespace + FROM propositions + WHERE id = ANY($1::uuid[]) + """, + proposition_ids, + ) + return [ + PropositionRow( + id=r["id"], + text=r["text"], + embedding=list(r["embedding"]) if r["embedding"] is not None else None, + chunk_id=r["chunk_id"], + subject_id=r["subject_id"], + predicate=r["predicate"], + object_id=r["object_id"], + confidence=float(r["confidence"]), + access_count=r["access_count"], + last_accessed_at=r["last_accessed_at"], + namespace=r["namespace"], + ) + for r in rows + ] + + async def get_proposition_entity_ids( + self, + proposition_ids: list[UUID], + ) -> dict[UUID, list[UUID]]: + if not proposition_ids: + return {} + async with self._pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT id, subject_id, object_id + FROM propositions + WHERE id = ANY($1::uuid[]) + """, + proposition_ids, + ) + result: dict[UUID, list[UUID]] = {} + for r in rows: + entities: list[UUID] = [] + if r["subject_id"] is not None: + entities.append(r["subject_id"]) + if r["object_id"] is not None: + entities.append(r["object_id"]) + if entities: + result[r["id"]] = entities + return result + + # --- Ingest primitives ----------------------------------------------- + + async def store_document(self, doc: StoredDocument) -> UUID: + async with self._pool.acquire() as conn: + doc_id: UUID = await conn.fetchval( + "INSERT INTO documents (source, namespace) VALUES ($1, $2) RETURNING id", + doc.source, + doc.namespace, + ) + return doc_id + + async def store_chunk(self, chunk: StoredChunk) -> UUID: + async with self._pool.acquire() as conn: + chunk_id: UUID = await conn.fetchval( + """ + INSERT INTO chunks (document_id, text, span_start, span_end, asserted_at) + VALUES ($1, $2, $3, $4, $5) RETURNING id + """, + chunk.document_id, + chunk.text, + chunk.span_start, + chunk.span_end, + chunk.asserted_at, + ) + return chunk_id + + async def store_proposition(self, prop: StoredProposition) -> UUID: + vec_lit = _vec_literal(prop.embedding) + metadata_json = json.dumps(prop.metadata) if prop.metadata else None + async with self._pool.acquire() as conn: + prop_id: UUID = await conn.fetchval( + f""" + INSERT INTO propositions + (text, embedding, subject_id, predicate, object_id, + object_literal, chunk_id, namespace, session_id, + confidence, metadata, asserted_at) + VALUES ($1, '{vec_lit}'::vector, + $2, $3, $4, $5, $6, $7, $8, $9, + $10::jsonb, $11) + RETURNING id + """, + prop.text, + prop.subject_id, + prop.predicate, + prop.object_id, + prop.object_literal, + prop.chunk_id, + prop.namespace, + prop.session_id, + prop.confidence, + metadata_json, + prop.asserted_at, + ) + return prop_id + + async def link_entity( + self, + namespace: str, + name: str, + entity_type: str, + embedding: list[float], + threshold: float = 0.85, + ) -> UUID: + vec_lit = _vec_literal(embedding) + async with self._pool.acquire() as conn: + entity_id: UUID = await conn.fetchval( + f"SELECT pgkg_link_entity($1, $2, $3, '{vec_lit}'::vector, $4)", + namespace, + name, + entity_type, + threshold, + ) + return entity_id + + async def store_edge( + self, + src_entity: UUID, + dst_entity: UUID, + predicate: str, + proposition_id: UUID, + ) -> None: + async with self._pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO edges (src_entity, dst_entity, relation, proposition_id) + VALUES ($1, $2, $3, $4) + ON CONFLICT DO NOTHING + """, + src_entity, + dst_entity, + predicate, + proposition_id, + ) + + # --- Access tracking ------------------------------------------------- + + async def bump_access(self, proposition_ids: list[UUID]) -> None: + if not proposition_ids: + return + async with self._pool.acquire() as conn: + await conn.execute("SELECT pgkg_bump_access($1::uuid[])", proposition_ids) + + # --- Entity resolution ----------------------------------------------- + + async def resolve_entity_names( + self, entity_ids: list[UUID], + ) -> dict[UUID, str]: + if not entity_ids: + return {} + async with self._pool.acquire() as conn: + rows = await conn.fetch( + "SELECT id, name FROM entities WHERE id = ANY($1::uuid[])", + entity_ids, + ) + return {r["id"]: r["name"] for r in rows} + + # --- Forget ---------------------------------------------------------- + + async def forget( + self, + proposition_id: UUID, + supersede_with: UUID | None = None, + ) -> None: + async with self._pool.acquire() as conn: + if supersede_with is not None: + await conn.execute( + "UPDATE propositions SET superseded_by = $1 WHERE id = $2", + supersede_with, + proposition_id, + ) + else: + await conn.execute( + "UPDATE propositions SET superseded_by = id WHERE id = $1", + proposition_id, + ) + + # --- Cache ----------------------------------------------------------- + + async def cache_get(self, cache_key: str) -> list[dict] | None: + async with self._pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT propositions FROM proposition_cache WHERE cache_key = $1", + cache_key, + ) + if row is None: + return None + await conn.execute( + "UPDATE proposition_cache SET hit_count = hit_count + 1 WHERE cache_key = $1", + cache_key, + ) + raw = row["propositions"] + if isinstance(raw, str): + return json.loads(raw) + return raw + + async def cache_put( + self, + cache_key: str, + chunk_hash: str, + extractor_model: str, + prompt_version: str, + propositions: list[dict], + ) -> None: + payload = json.dumps(propositions) + async with self._pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO proposition_cache + (cache_key, chunk_hash, extractor_model, prompt_version, propositions) + VALUES ($1, $2, $3, $4, $5::jsonb) + ON CONFLICT (cache_key) DO NOTHING + """, + cache_key, + chunk_hash, + extractor_model, + prompt_version, + payload, + ) + + # --- Health ---------------------------------------------------------- + + async def health_check(self) -> bool: + try: + async with self._pool.acquire() as conn: + await conn.fetchval("SELECT 1") + return True + except Exception: + return False + + +async def _init_connection(conn: asyncpg.Connection) -> None: + await register_vector(conn) diff --git a/pgkg/cli.py b/pgkg/cli.py index 046a42b..d3f1f97 100644 --- a/pgkg/cli.py +++ b/pgkg/cli.py @@ -7,44 +7,17 @@ def cmd_migrate(args: argparse.Namespace) -> None: - import pathlib - - import asyncpg - + from pgkg.backends.postgres import PostgresBackend from pgkg.config import get_settings async def _run() -> None: - dsn = get_settings().database_url - if dsn is None: - from pgkg.embedded import get_dsn - dsn = get_dsn() - migrations_dir = pathlib.Path(__file__).resolve().parent.parent / "migrations" - conn = await asyncpg.connect(dsn) + settings = get_settings() + backend = await PostgresBackend.create(settings.database_url) try: - await conn.execute( - "CREATE TABLE IF NOT EXISTS pgkg_schema_migrations (" - " filename TEXT PRIMARY KEY," - " applied_at TIMESTAMPTZ NOT NULL DEFAULT now()" - ")" - ) - applied = { - r["filename"] - for r in await conn.fetch("SELECT filename FROM pgkg_schema_migrations") - } - for migration in sorted(migrations_dir.glob("*.sql")): - if migration.name in applied: - print(f"Skipping {migration.name} (already applied).") - continue - print(f"Applying {migration.name}...") - async with conn.transaction(): - await conn.execute(migration.read_text()) - await conn.execute( - "INSERT INTO pgkg_schema_migrations (filename) VALUES ($1)", - migration.name, - ) + await backend.apply_migrations() print("All migrations applied.") finally: - await conn.close() + await backend.close() asyncio.run(_run()) @@ -55,7 +28,7 @@ def cmd_serve(args: argparse.Namespace) -> None: def cmd_ingest(args: argparse.Namespace) -> None: - from pgkg.db import pool_from_settings + from pgkg.backends.postgres import PostgresBackend from pgkg.memory import Memory from pgkg.config import get_settings @@ -71,8 +44,9 @@ def cmd_ingest(args: argparse.Namespace) -> None: async def _run() -> None: settings = get_settings() extract = not args.chunks_only - async with pool_from_settings() as pool: - mem = Memory(pool, namespace=settings.default_namespace, extract_propositions=extract) + backend = await PostgresBackend.create(settings.database_url) + try: + mem = Memory(backend, namespace=settings.default_namespace, extract_propositions=extract) result = await mem.ingest(text, source=source) print(json.dumps({ "documents": result.documents, @@ -80,21 +54,26 @@ async def _run() -> None: "propositions": result.propositions, "entities": result.entities, })) + finally: + await backend.close() asyncio.run(_run()) def cmd_recall(args: argparse.Namespace) -> None: - from pgkg.db import pool_from_settings + from pgkg.backends.postgres import PostgresBackend from pgkg.memory import Memory from pgkg.config import get_settings async def _run() -> None: settings = get_settings() - async with pool_from_settings() as pool: - mem = Memory(pool, namespace=settings.default_namespace) + backend = await PostgresBackend.create(settings.database_url) + try: + mem = Memory(backend, namespace=settings.default_namespace) results = await mem.recall(args.query, k=args.k) print(json.dumps([r.model_dump(mode="json") for r in results], indent=2)) + finally: + await backend.close() asyncio.run(_run()) diff --git a/pgkg/memory.py b/pgkg/memory.py index d1295f9..0f353d9 100644 --- a/pgkg/memory.py +++ b/pgkg/memory.py @@ -1,19 +1,21 @@ from __future__ import annotations import asyncio -import json import re -import uuid from dataclasses import dataclass from datetime import datetime -from typing import Any from uuid import UUID -import asyncpg from pydantic import BaseModel from pgkg import ml -from pgkg.ml import ExtractCache, Proposition, PROMPT_VERSION +from pgkg.backend import ( + StorageBackend, + StoredChunk, + StoredDocument, + StoredProposition, +) +from pgkg.ml import ExtractCache, Proposition class Result(BaseModel): @@ -64,36 +66,17 @@ def _chunk_text(text: str, chunk_size: int, chunk_overlap: int) -> list[str]: return chunks or [text[:chunk_size]] -class PostgresExtractCache: - """Postgres-backed implementation of ExtractCache. +class BackendExtractCache: + """Adapts StorageBackend cache methods to the ExtractCache protocol.""" - Stores extracted propositions in the proposition_cache table so re-ingesting - the same chunk with the same extractor model and prompt version is free. - """ - - def __init__(self, pool: asyncpg.Pool, namespace: str) -> None: - self._pool = pool - self._namespace = namespace + def __init__(self, backend: StorageBackend) -> None: + self._backend = backend async def get(self, cache_key: str) -> list[Proposition] | None: - async with self._pool.acquire() as conn: - row = await conn.fetchrow( - "SELECT propositions FROM proposition_cache WHERE cache_key = $1", - cache_key, - ) - if row is None: - return None - # Bump hit count (best-effort; don't fail the main path) - await conn.execute( - "UPDATE proposition_cache SET hit_count = hit_count + 1 WHERE cache_key = $1", - cache_key, - ) - raw = row["propositions"] - if isinstance(raw, str): - items = json.loads(raw) - else: - items = raw # asyncpg may already decode JSONB - return [Proposition(**p) for p in items] + raw = await self._backend.cache_get(cache_key) + if raw is None: + return None + return [Proposition(**d) for d in raw] async def put( self, @@ -103,37 +86,29 @@ async def put( prompt_version: str, props: list[Proposition], ) -> None: - payload = json.dumps([p.model_dump() for p in props]) - async with self._pool.acquire() as conn: - await conn.execute( - """ - INSERT INTO proposition_cache - (cache_key, chunk_hash, extractor_model, prompt_version, propositions) - VALUES ($1, $2, $3, $4, $5::jsonb) - ON CONFLICT (cache_key) DO NOTHING - """, - cache_key, - chunk_hash, - extractor_model, - prompt_version, - payload, - ) + await self._backend.cache_put( + cache_key, + chunk_hash, + extractor_model, + prompt_version, + [p.model_dump() for p in props], + ) class Memory: def __init__( self, - pool: asyncpg.Pool, + backend: StorageBackend, *, namespace: str = "default", use_extract_cache: bool = True, extract_propositions: bool = True, ) -> None: - self._pool = pool + self._backend = backend self._namespace = namespace self._extract_propositions = extract_propositions self._extract_cache: ExtractCache | None = ( - PostgresExtractCache(pool, namespace) if use_extract_cache and extract_propositions else None + BackendExtractCache(backend) if use_extract_cache and extract_propositions else None ) async def ingest( @@ -150,140 +125,110 @@ async def ingest( entities_created: set[UUID] = set() total_propositions = 0 - async with self._pool.acquire() as conn: - # Insert document - doc_id: UUID = await conn.fetchval( - "INSERT INTO documents (source, namespace) VALUES ($1, $2) RETURNING id", - source, - self._namespace, + doc_id = await self._backend.store_document( + StoredDocument(source=source, namespace=self._namespace) + ) + + chunk_ids: list[UUID] = [] + for i, chunk_text in enumerate(chunks): + chunk_id = await self._backend.store_chunk( + StoredChunk( + document_id=doc_id, + text=chunk_text, + span_start=i * chunk_size, + span_end=(i + 1) * chunk_size, + asserted_at=asserted_at, + ) ) + chunk_ids.append(chunk_id) - chunk_ids: list[UUID] = [] - for i, chunk_text in enumerate(chunks): - chunk_id: UUID = await conn.fetchval( - """ - INSERT INTO chunks (document_id, text, span_start, span_end, asserted_at) - VALUES ($1, $2, $3, $4, $5) RETURNING id - """, - doc_id, - chunk_text, - i * chunk_size, - (i + 1) * chunk_size, - asserted_at, + if self._extract_propositions: + for chunk_id, chunk_text in zip(chunk_ids, chunks): + propositions = await ml.extract_propositions_async( + chunk_text, cache=self._extract_cache ) - chunk_ids.append(chunk_id) + if not propositions: + continue - if self._extract_propositions: - # Extract and embed per chunk - for chunk_id, chunk_text in zip(chunk_ids, chunks): - propositions = await ml.extract_propositions_async( - chunk_text, cache=self._extract_cache + entity_names: list[str] = [] + for prop in propositions: + entity_names.append(prop.subject) + if not prop.object_is_literal: + entity_names.append(prop.object) + + prop_texts = [p.text for p in propositions] + all_texts = entity_names + prop_texts + all_embs = ml.embed(all_texts) + + entity_embs = all_embs[: len(entity_names)] + prop_embs = all_embs[len(entity_names) :] + + entity_idx = 0 + for prop, prop_emb in zip(propositions, prop_embs): + subj_emb = entity_embs[entity_idx] + entity_idx += 1 + + subject_id = await self._backend.link_entity( + self._namespace, prop.subject, "concept", subj_emb ) - if not propositions: - continue - - # Collect all texts for batch embedding - entity_names: list[str] = [] - for prop in propositions: - entity_names.append(prop.subject) - if not prop.object_is_literal: - entity_names.append(prop.object) - - prop_texts = [p.text for p in propositions] - all_texts = entity_names + prop_texts - all_embs = ml.embed(all_texts) - - entity_embs = all_embs[: len(entity_names)] - prop_embs = all_embs[len(entity_names):] - - # Link entities and insert propositions - entity_idx = 0 - for prop, prop_emb in zip(propositions, prop_embs): - subj_emb = entity_embs[entity_idx] - entity_idx += 1 + entities_created.add(subject_id) - subject_id: UUID = await conn.fetchval( - _link_entity_sql(subj_emb), - self._namespace, - prop.subject, - "concept", + object_id: UUID | None = None + object_literal: str | None = None + + if prop.object_is_literal: + object_literal = prop.object + else: + obj_emb = entity_embs[entity_idx] + entity_idx += 1 + object_id = await self._backend.link_entity( + self._namespace, prop.object, "concept", obj_emb ) - entities_created.add(subject_id) - - object_id: UUID | None = None - object_literal: str | None = None - - if prop.object_is_literal: - object_literal = prop.object - else: - obj_emb = entity_embs[entity_idx] - entity_idx += 1 - object_id = await conn.fetchval( - _link_entity_sql(obj_emb), - self._namespace, - prop.object, - "concept", - ) - entities_created.add(object_id) - - prop_id: UUID = await conn.fetchval( - f""" - INSERT INTO propositions - (text, embedding, subject_id, predicate, object_id, - object_literal, chunk_id, namespace, session_id, asserted_at) - VALUES ($1, '{_vec_literal(prop_emb)}'::vector, - $2, $3, $4, $5, $6, $7, $8, $9) - RETURNING id - """, - prop.text, - subject_id, - prop.predicate, - object_id, - object_literal, - chunk_id, - self._namespace, - session_id, - asserted_at, + entities_created.add(object_id) + + prop_id = await self._backend.store_proposition( + StoredProposition( + text=prop.text, + embedding=prop_emb, + subject_id=subject_id, + predicate=prop.predicate, + object_id=object_id, + object_literal=object_literal, + chunk_id=chunk_id, + namespace=self._namespace, + session_id=session_id, + asserted_at=asserted_at, ) - total_propositions += 1 - - if object_id is not None: - await conn.execute( - """ - INSERT INTO edges (src_entity, dst_entity, relation, proposition_id) - VALUES ($1, $2, $3, $4) - ON CONFLICT DO NOTHING - """, - subject_id, - object_id, - prop.predicate, - prop_id, - ) - else: - # Chunks-only mode: embed each chunk directly, no LLM extraction, - # no entity linking, no edge creation. - chunk_texts_list = list(chunks) - chunk_embs = ml.embed(chunk_texts_list) - for chunk_id, chunk_text, chunk_emb in zip(chunk_ids, chunk_texts_list, chunk_embs): - metadata = '{"mode": "chunk"}' - await conn.fetchval( - f""" - INSERT INTO propositions - (text, embedding, subject_id, predicate, object_id, - object_literal, chunk_id, namespace, session_id, metadata, asserted_at) - VALUES ($1, '{_vec_literal(chunk_emb)}'::vector, - NULL, NULL, NULL, NULL, $2, $3, $4, $5::jsonb, $6) - RETURNING id - """, - chunk_text, - chunk_id, - self._namespace, - session_id, - metadata, - asserted_at, ) total_propositions += 1 + if object_id is not None: + await self._backend.store_edge( + subject_id, object_id, prop.predicate, prop_id + ) + else: + chunk_texts_list = list(chunks) + chunk_embs = ml.embed(chunk_texts_list) + for chunk_id, chunk_text_val, chunk_emb in zip( + chunk_ids, chunk_texts_list, chunk_embs + ): + await self._backend.store_proposition( + StoredProposition( + text=chunk_text_val, + embedding=chunk_emb, + subject_id=None, + predicate=None, + object_id=None, + object_literal=None, + chunk_id=chunk_id, + namespace=self._namespace, + session_id=session_id, + metadata={"mode": "chunk"}, + asserted_at=asserted_at, + ) + ) + total_propositions += 1 + return IngestResult( documents=1, chunks=len(chunks), @@ -305,35 +250,37 @@ async def recall( ) -> list[Result]: q_emb = ml.embed([query])[0] - async with self._pool.acquire() as conn: - rows = await conn.fetch( - f""" - SELECT proposition_id, text, embedding, rrf_score, adjusted_score, - source_kind, chunk_id, subject_id, predicate, object_id, asserted_at - FROM pgkg_search($1, '{_vec_literal(q_emb)}'::vector, - $2, $3, $4, $5, 30.0, $6) - """, - query, - k_retrieve, - k_retrieve * 2, - self._namespace, - session_id, - expand_graph, - ) + candidates = await self._backend.fused_search( + query_text=query, + query_embedding=q_emb, + k_retrieve=k_retrieve, + k_initial=k_retrieve * 2, + namespace=self._namespace, + session_id=session_id, + recency_half_life_days=30.0, + expand_graph=expand_graph, + rrf_k=60, + ) - if not rows: + if not candidates: return [] - texts = [r["text"] for r in rows] - scores = [float(r["adjusted_score"]) for r in rows] - embs = [_parse_emb(r["embedding"], q_emb) for r in rows] + # Fetch full proposition data for embeddings (needed by rerank/MMR) + prop_rows = await self._backend.get_propositions( + [c.proposition_id for c in candidates] + ) + emb_by_id = {p.id: p.embedding for p in prop_rows} + + texts = [c.text for c in candidates] + scores = [c.adjusted_score for c in candidates] + embs = [emb_by_id.get(c.proposition_id) or q_emb for c in candidates] + embs = [list(e) if not isinstance(e, list) else e for e in embs] if with_rerank: - candidate_rows = rows[: min(k_retrieve, 64)] - candidate_texts = [r["text"] for r in candidate_rows] + candidate_slice = candidates[: min(k_retrieve, 64)] + candidate_texts = [c.text for c in candidate_slice] rerank_scores = ml.rerank(query, candidate_texts) - # Min-max normalize both score lists def _normalize(vals: list[float]) -> list[float]: lo, hi = min(vals), max(vals) span = hi - lo @@ -341,56 +288,56 @@ def _normalize(vals: list[float]) -> list[float]: return [1.0] * len(vals) return [(v - lo) / span for v in vals] - adj_scores = [float(r["adjusted_score"]) for r in candidate_rows] + adj_scores = [c.adjusted_score for c in candidate_slice] rerank_norm = _normalize(rerank_scores) adj_norm = _normalize(adj_scores) blended = [0.7 * r + 0.3 * a for r, a in zip(rerank_norm, adj_norm)] - # Sort candidate_rows by blended score - sorted_indices = sorted(range(len(blended)), key=lambda i: blended[i], reverse=True) - rows = [candidate_rows[i] for i in sorted_indices] + sorted_indices = sorted( + range(len(blended)), key=lambda i: blended[i], reverse=True + ) + candidates = [candidate_slice[i] for i in sorted_indices] scores = [blended[i] for i in sorted_indices] - embs = [list(candidate_rows[i]["embedding"]) if candidate_rows[i]["embedding"] is not None else q_emb - for i in sorted_indices] + embs = [ + emb_by_id.get(candidate_slice[i].proposition_id) or q_emb + for i in sorted_indices + ] + embs = [list(e) if not isinstance(e, list) else e for e in embs] - if with_mmr and len(rows) > k: + if with_mmr and len(candidates) > k: selected_indices = ml.mmr(q_emb, embs, k, lambda_=mmr_lambda) - rows = [rows[i] for i in selected_indices] + candidates = [candidates[i] for i in selected_indices] scores = [scores[i] for i in selected_indices] else: - rows = rows[:k] + candidates = candidates[:k] scores = scores[:k] # Fire-and-forget bump access - prop_ids = [str(r["proposition_id"]) for r in rows] + prop_ids = [c.proposition_id for c in candidates] asyncio.ensure_future(self._bump(prop_ids)) results = [] - for row, score in zip(rows, scores): - subject_name: str | None = None - object_name: str | None = None - + for cand, score in zip(candidates, scores): results.append( Result( - proposition_id=row["proposition_id"], - text=row["text"], + proposition_id=cand.proposition_id, + text=cand.text, score=score, - rrf_score=float(row["rrf_score"]), - source_kind=row["source_kind"], - chunk_id=row["chunk_id"], - subject=subject_name, - predicate=row["predicate"], - object=object_name, - asserted_at=row["asserted_at"], + rrf_score=cand.rrf_score, + source_kind=cand.source_kind, + chunk_id=cand.chunk_id, + subject=None, + predicate=cand.predicate, + object=None, + asserted_at=cand.asserted_at, ) ) return results - async def _bump(self, prop_ids: list[str]) -> None: + async def _bump(self, prop_ids: list[UUID]) -> None: try: - async with self._pool.acquire() as conn: - await conn.execute("SELECT pgkg_bump_access($1::uuid[])", prop_ids) + await self._backend.bump_access(prop_ids) except Exception: pass @@ -400,37 +347,4 @@ async def forget( *, supersede_with: UUID | None = None, ) -> None: - async with self._pool.acquire() as conn: - if supersede_with is not None: - await conn.execute( - "UPDATE propositions SET superseded_by = $1 WHERE id = $2", - supersede_with, - proposition_id, - ) - else: - await conn.execute( - "UPDATE propositions SET superseded_by = id WHERE id = $1", - proposition_id, - ) - - -def _parse_emb(val: object, fallback: list[float]) -> list[float]: - """Parse an embedding from asyncpg — may be a list, numpy array, or string.""" - if val is None: - return fallback - if isinstance(val, (list, tuple)): - return list(val) - if isinstance(val, str): - import json - return json.loads(val.replace("(", "[").replace(")", "]")) - # numpy array or pgvector type - return list(val) - - -def _vec_literal(emb: list[float]) -> str: - """Format a vector as a Postgres literal string '[a,b,c,...]'.""" - return "[" + ",".join(str(v) for v in emb) + "]" - - -def _link_entity_sql(emb: list[float]) -> str: - return f"SELECT pgkg_link_entity($1, $2, $3, '{_vec_literal(emb)}'::vector)" + await self._backend.forget(proposition_id, supersede_with=supersede_with) diff --git a/tests/conftest.py b/tests/conftest.py index 88dd762..2e5ae28 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,8 @@ import pytest from pgvector.asyncpg import register_vector +from pgkg.backends.postgres import PostgresBackend + MIGRATIONS_DIR = pathlib.Path(__file__).parent.parent / "migrations" @@ -64,3 +66,11 @@ async def pool(pg_dsn) -> AsyncGenerator[asyncpg.Pool, None]: yield conn_pool await conn_pool.close() + + +@pytest.fixture(scope="session") +async def backend(pool) -> AsyncGenerator[PostgresBackend, None]: + """PostgresBackend wrapping the test pool.""" + b = PostgresBackend(pool) + yield b + # Pool cleanup handled by the pool fixture; do not close here. diff --git a/tests/test_bench.py b/tests/test_bench.py index 4e35075..1b0f66e 100644 --- a/tests/test_bench.py +++ b/tests/test_bench.py @@ -41,7 +41,7 @@ def test_bench_item_validates(): # test_run_bench_dry_run_exact_match # --------------------------------------------------------------------------- -async def test_run_bench_dry_run_exact_match(pool: asyncpg.Pool, monkeypatch): +async def test_run_bench_dry_run_exact_match(pool: asyncpg.Pool, backend, monkeypatch): """run_bench with dry_run=True, exact_match=True completes without LLM calls.""" monkeypatch.setenv("PGKG_OFFLINE_EXTRACT", "1") @@ -99,7 +99,7 @@ def _fake_embed(texts): name="test-dryrun", items=items, config=config, - pool=pool, + backend=backend, ) assert report.total == 2 diff --git a/tests/test_chunks_only.py b/tests/test_chunks_only.py index 810636d..4f0004f 100644 --- a/tests/test_chunks_only.py +++ b/tests/test_chunks_only.py @@ -31,7 +31,7 @@ def _fake_embed(texts: list[str]) -> list[list[float]]: # test_ingest_chunks_only_skips_extraction # --------------------------------------------------------------------------- -async def test_ingest_chunks_only_skips_extraction(pool: asyncpg.Pool, monkeypatch): +async def test_ingest_chunks_only_skips_extraction(pool: asyncpg.Pool, backend, monkeypatch): """extract_propositions_async must not be called; propositions have NULL subject_id and metadata->>'mode' = 'chunk'; entities table is untouched.""" import pgkg.ml as ml_module @@ -44,7 +44,7 @@ async def _should_not_be_called(*args, **kwargs): monkeypatch.setattr(ml_module, "extract_propositions_async", _should_not_be_called) ns = f"chunks_only_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns, extract_propositions=False) + mem = Memory(backend, namespace=ns, extract_propositions=False) result = await mem.ingest("Hello world. This is a test document.") assert result.documents == 1 @@ -83,7 +83,7 @@ async def _should_not_be_called(*args, **kwargs): # test_recall_works_in_chunks_mode # --------------------------------------------------------------------------- -async def test_recall_works_in_chunks_mode(pool: asyncpg.Pool, monkeypatch): +async def test_recall_works_in_chunks_mode(pool: asyncpg.Pool, backend, monkeypatch): """After chunks-only ingest, recall returns results with NULL predicate.""" import pgkg.ml as ml_module @@ -114,7 +114,7 @@ async def _noop_extract(*args, **kwargs): monkeypatch.setattr(ml_module, "extract_propositions_async", _noop_extract) ns = f"chunks_recall_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns, extract_propositions=False) + mem = Memory(backend, namespace=ns, extract_propositions=False) await mem.ingest(ocean_text) results = await mem.recall( @@ -135,7 +135,7 @@ async def _noop_extract(*args, **kwargs): # test_chunks_mode_graph_expansion_is_noop # --------------------------------------------------------------------------- -async def test_chunks_mode_graph_expansion_is_noop(pool: asyncpg.Pool, monkeypatch): +async def test_chunks_mode_graph_expansion_is_noop(pool: asyncpg.Pool, backend, monkeypatch): """Graph expansion with chunks-only ingest produces no graph-sourced rows.""" import pgkg.ml as ml_module monkeypatch.setattr(ml_module, "embed", _fake_embed) @@ -152,7 +152,7 @@ async def _noop_extract(*args, **kwargs): monkeypatch.setattr(ml_module, "extract_propositions_async", _noop_extract) ns = f"chunks_graph_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns, extract_propositions=False) + mem = Memory(backend, namespace=ns, extract_propositions=False) await mem.ingest("Alice visited Bob last Tuesday. Bob works at Acme Corp.") results = await mem.recall( diff --git a/tests/test_extract_cache.py b/tests/test_extract_cache.py index b5ee63e..55ea633 100644 --- a/tests/test_extract_cache.py +++ b/tests/test_extract_cache.py @@ -18,7 +18,7 @@ compute_cache_key, extract_propositions_async, ) -from pgkg.memory import PostgresExtractCache +from pgkg.memory import BackendExtractCache # --------------------------------------------------------------------------- @@ -69,7 +69,7 @@ async def _get_hit_count(pool: asyncpg.Pool, cache_key: str) -> int: # --------------------------------------------------------------------------- async def test_cache_hit_returns_stored_props_without_llm( - pool: asyncpg.Pool, monkeypatch + pool: asyncpg.Pool, backend, monkeypatch ): """Pre-populate cache; LLM provider should never be called on cache hit.""" chunk = "Alice is a brilliant scientist who works at CERN." @@ -102,7 +102,7 @@ async def test_cache_hit_returns_stored_props_without_llm( lambda *a, **kw: (_ for _ in ()).throw(AssertionError("LLM was called — should have been a cache hit")), ) - cache = PostgresExtractCache(pool, "test-ns") + cache = BackendExtractCache(backend) result = await extract_propositions_async(chunk, cache=cache) assert len(result) == 1 @@ -114,7 +114,7 @@ async def test_cache_hit_returns_stored_props_without_llm( # test_cache_miss_then_hit # --------------------------------------------------------------------------- -async def test_cache_miss_then_hit(pool: asyncpg.Pool, monkeypatch): +async def test_cache_miss_then_hit(pool: asyncpg.Pool, backend, monkeypatch): """First call invokes LLM stub (count=1); second call hits cache (count still 1).""" chunk = "Bob loves hiking in the mountains near his home." model = "stub-model-v1" @@ -141,7 +141,7 @@ def _fake_do_extract(chunk_text, max_propositions, settings, extractor_model): monkeypatch.setattr(ml_module, "_do_extract", _fake_do_extract) - cache = PostgresExtractCache(pool, "test-ns-miss-hit") + cache = BackendExtractCache(backend) # First call — cache miss, LLM called result1 = await extract_propositions_async(chunk, cache=cache) @@ -235,7 +235,7 @@ def test_cache_key_changes_with_prompt_version(monkeypatch): # test_postgres_cache_hit_count_increments # --------------------------------------------------------------------------- -async def test_postgres_cache_hit_count_increments(pool: asyncpg.Pool): +async def test_postgres_cache_hit_count_increments(pool: asyncpg.Pool, backend): """Two get() calls on a populated key → hit_count = 2.""" chunk = "Eve is a cryptographer." model = "hit-count-model" @@ -244,7 +244,7 @@ async def test_postgres_cache_hit_count_increments(pool: asyncpg.Pool): stored = [_make_prop(text="Eve is a cryptographer.", subject="Eve", predicate="is", object="cryptographer")] await _seed_cache(pool, cache_key, stored) - cache = PostgresExtractCache(pool, "test-ns-hits") + cache = BackendExtractCache(backend) await cache.get(cache_key) await cache.get(cache_key) diff --git a/tests/test_memory.py b/tests/test_memory.py index 1ed0573..5201184 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -38,7 +38,7 @@ def _fake_embed(texts: list[str]) -> list[list[float]]: # test_ingest_creates_rows # --------------------------------------------------------------------------- -async def test_ingest_creates_rows(pool: asyncpg.Pool, monkeypatch): +async def test_ingest_creates_rows(pool: asyncpg.Pool, backend, monkeypatch): """PGKG_OFFLINE_EXTRACT=1: ingest populates documents, chunks, propositions tables.""" monkeypatch.setenv("PGKG_OFFLINE_EXTRACT", "1") @@ -46,7 +46,7 @@ async def test_ingest_creates_rows(pool: asyncpg.Pool, monkeypatch): monkeypatch.setattr(ml_module, "embed", _fake_embed) ns = f"ingest_test_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns) + mem = Memory(backend, namespace=ns) result = await mem.ingest("Hello world. This is a test document.") assert result.documents == 1 @@ -74,7 +74,7 @@ async def test_ingest_creates_rows(pool: asyncpg.Pool, monkeypatch): # test_recall_returns_ingested # --------------------------------------------------------------------------- -async def test_recall_returns_ingested(pool: asyncpg.Pool, monkeypatch): +async def test_recall_returns_ingested(pool: asyncpg.Pool, backend, monkeypatch): """After ingesting a doc, recalling a matching query returns the proposition.""" monkeypatch.setenv("PGKG_OFFLINE_EXTRACT", "1") @@ -103,7 +103,7 @@ def predict(self, pairs): monkeypatch.setattr(ml_module, "_rerank_model", FakeCE()) ns = f"recall_test_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns) + mem = Memory(backend, namespace=ns) await mem.ingest(ocean_text) results = await mem.recall( @@ -123,7 +123,7 @@ def predict(self, pairs): # test_recall_session_scope # --------------------------------------------------------------------------- -async def test_recall_session_scope(pool: asyncpg.Pool, monkeypatch): +async def test_recall_session_scope(pool: asyncpg.Pool, backend, monkeypatch): """Propositions ingested with session_id='A' don't appear in session_id='B' recall.""" monkeypatch.setenv("PGKG_OFFLINE_EXTRACT", "1") @@ -131,7 +131,7 @@ async def test_recall_session_scope(pool: asyncpg.Pool, monkeypatch): monkeypatch.setattr(ml_module, "embed", _fake_embed) ns = f"session_test_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns) + mem = Memory(backend, namespace=ns) # Ingest unique text with session A unique_text = f"Unique session A content xyzzy_{uuid.uuid4().hex}" @@ -164,7 +164,7 @@ async def test_recall_session_scope(pool: asyncpg.Pool, monkeypatch): # test_forget_supersedes # --------------------------------------------------------------------------- -async def test_forget_supersedes(pool: asyncpg.Pool, monkeypatch): +async def test_forget_supersedes(pool: asyncpg.Pool, backend, monkeypatch): """After forget(), the proposition no longer appears in recall results.""" monkeypatch.setenv("PGKG_OFFLINE_EXTRACT", "1") @@ -185,7 +185,7 @@ def _targeted_embed(texts: list[str]) -> list[list[float]]: monkeypatch.setattr(ml_module, "embed", _targeted_embed) ns = f"forget_test_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns) + mem = Memory(backend, namespace=ns) result = await mem.ingest(target_text) # Get the proposition id @@ -220,7 +220,7 @@ def _targeted_embed(texts: list[str]) -> list[list[float]]: # recall test bypassed rerank/MMR, so the truthiness check on the # embedding column was never exercised against a real DB row. -async def test_recall_default_flags_with_pgvector_embedding(pool: asyncpg.Pool, monkeypatch): +async def test_recall_default_flags_with_pgvector_embedding(pool: asyncpg.Pool, backend, monkeypatch): monkeypatch.setenv("PGKG_OFFLINE_EXTRACT", "1") import pgkg.ml as ml_module @@ -229,7 +229,7 @@ async def test_recall_default_flags_with_pgvector_embedding(pool: asyncpg.Pool, monkeypatch.setattr(ml_module, "rerank", lambda q, docs: [1.0 / (i + 1) for i in range(len(docs))]) ns = f"recall_default_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns, extract_propositions=False) + mem = Memory(backend, namespace=ns, extract_propositions=False) await mem.ingest("The chunks-only ingest mode skips LLM extraction entirely.") await mem.ingest("Hybrid retrieval fuses BM25 and vector similarity via RRF.") @@ -245,7 +245,7 @@ async def test_recall_default_flags_with_pgvector_embedding(pool: asyncpg.Pool, # test_ingest_propagates_asserted_at # --------------------------------------------------------------------------- -async def test_ingest_propagates_asserted_at(pool: asyncpg.Pool, monkeypatch): +async def test_ingest_propagates_asserted_at(pool: asyncpg.Pool, backend, monkeypatch): """Ingest with asserted_at stores it in both chunk and proposition rows.""" monkeypatch.setenv("PGKG_OFFLINE_EXTRACT", "1") @@ -254,7 +254,7 @@ async def test_ingest_propagates_asserted_at(pool: asyncpg.Pool, monkeypatch): expected_ts = datetime(2025, 1, 15, 10, 0, 0, tzinfo=timezone.utc) ns = f"assertedat_ingest_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns, extract_propositions=False) + mem = Memory(backend, namespace=ns, extract_propositions=False) await mem.ingest( "The sky is blue and the grass is green.", @@ -292,7 +292,7 @@ async def test_ingest_propagates_asserted_at(pool: asyncpg.Pool, monkeypatch): # test_recall_returns_asserted_at_in_result # --------------------------------------------------------------------------- -async def test_recall_returns_asserted_at_in_result(pool: asyncpg.Pool, monkeypatch): +async def test_recall_returns_asserted_at_in_result(pool: asyncpg.Pool, backend, monkeypatch): """Result.asserted_at is populated when ingested with an asserted_at timestamp.""" monkeypatch.setenv("PGKG_OFFLINE_EXTRACT", "1") @@ -312,7 +312,7 @@ def _controlled_embed(texts: list[str]) -> list[list[float]]: monkeypatch.setattr(ml_module, "embed", _controlled_embed) ns = f"assertedat_recall_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns, extract_propositions=False) + mem = Memory(backend, namespace=ns, extract_propositions=False) await mem.ingest(target_text, asserted_at=expected_ts)