From 7f2c098816af0c3451c451b25640f0768985f36a Mon Sep 17 00:00:00 2001 From: Joseph Mearman Date: Tue, 28 Apr 2026 20:36:57 +0200 Subject: [PATCH 1/3] feat(backend): introduce StorageBackend protocol and PostgresBackend Decouple pgkg from direct asyncpg usage by defining a StorageBackend protocol in pgkg/backend.py and a PostgresBackend implementation that wraps the existing asyncpg pool and SQL functions (pgkg_search, pgkg_link_entity, pgkg_bump_access) with zero behavioural change. --- pgkg/backend.py | 251 +++++++++++++++++++++ pgkg/backends/__init__.py | 26 +++ pgkg/backends/postgres.py | 450 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 727 insertions(+) create mode 100644 pgkg/backend.py create mode 100644 pgkg/backends/__init__.py create mode 100644 pgkg/backends/postgres.py diff --git a/pgkg/backend.py b/pgkg/backend.py new file mode 100644 index 0000000..d128d00 --- /dev/null +++ b/pgkg/backend.py @@ -0,0 +1,251 @@ +"""StorageBackend protocol — the abstraction boundary between pgkg and its storage layer.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Protocol, runtime_checkable +from uuid import UUID + + +@dataclass(frozen=True) +class ScoredId: + """A database row ID with an associated relevance score.""" + id: UUID + score: float + + +@dataclass(frozen=True) +class Candidate: + """A proposition candidate with full retrieval metadata.""" + proposition_id: UUID + text: str + rrf_score: float + adjusted_score: float + source_kind: str # "kw" | "vec" | "both" | "graph" + chunk_id: UUID | None + subject_id: UUID | None + predicate: str | None + object_id: UUID | None + + +@dataclass(frozen=True) +class PropositionRow: + """Full proposition data needed for scoring and reranking.""" + id: UUID + text: str + embedding: list[float] | None + chunk_id: UUID | None + subject_id: UUID | None + predicate: str | None + object_id: UUID | None + confidence: float + access_count: int + last_accessed_at: str # ISO 8601 timestamp + namespace: str + + +@dataclass(frozen=True) +class StoredProposition: + """Data needed to store a proposition.""" + text: str + embedding: list[float] + subject_id: UUID | None + predicate: str | None + object_id: UUID | None + object_literal: str | None + chunk_id: UUID | None + namespace: str + session_id: str | None + confidence: float = 1.0 + metadata: dict | None = None + + +@dataclass(frozen=True) +class StoredDocument: + """Data needed to store a document.""" + id: UUID | None = None + source: str | None = None + namespace: str = "default" + + +@dataclass(frozen=True) +class StoredChunk: + """Data needed to store a chunk.""" + document_id: UUID + text: str + id: UUID | None = None + span_start: int | None = None + span_end: int | None = None + + +@runtime_checkable +class StorageBackend(Protocol): + """Abstract storage backend for pgkg. + + Implementations must support all methods. The ``fused_search`` method + is optional — backends that cannot fuse retrieval in a single query + should return ``None``, and the caller falls back to primitive + orchestration. + """ + + # --- Lifecycle ------------------------------------------------------- + + async def apply_migrations(self) -> None: + """Apply schema migrations. Called once during initialisation.""" + ... + + async def close(self) -> None: + """Release all resources (connections, pools, file handles).""" + ... + + # --- Retrieval primitives -------------------------------------------- + + async def keyword_search( + self, + query: str, + k: int, + namespace: str, + session_id: str | None = None, + ) -> list[ScoredId]: + """Full-text search. Returns (id, score) pairs ranked by relevance.""" + ... + + async def vector_search( + self, + embedding: list[float], + k: int, + namespace: str, + session_id: str | None = None, + ) -> list[ScoredId]: + """Vector similarity search. Returns (id, score) pairs ranked by closeness.""" + ... + + async def graph_neighbors( + self, + entity_ids: list[UUID], + namespace: str, + limit: int = 100, + ) -> list[ScoredId]: + """One-hop graph expansion from seed entities. Returns proposition IDs.""" + ... + + async def fused_search( + self, + query_text: str, + query_embedding: list[float], + k_retrieve: int, + k_initial: int, + namespace: str, + session_id: str | None, + recency_half_life_days: float, + expand_graph: bool, + rrf_k: int, + ) -> list[Candidate] | None: + """Push-down fused retrieval (optional optimisation). + + Backends that can execute the full retrieval pipeline in a single + query (e.g. the Postgres CTE) return results here. Other backends + return ``None`` and the caller orchestrates via the primitive methods. + """ + return None + + # --- Data resolution ------------------------------------------------- + + async def get_propositions( + self, + proposition_ids: list[UUID], + ) -> list[PropositionRow]: + """Fetch full proposition data for the given IDs.""" + ... + + async def get_proposition_entity_ids( + self, + proposition_ids: list[UUID], + ) -> dict[UUID, list[UUID]]: + """Return ``{proposition_id: [entity_ids]}`` for the given propositions.""" + ... + + # --- Ingest primitives ----------------------------------------------- + + async def store_document(self, doc: StoredDocument) -> UUID: + """Insert a document. Returns its ID.""" + ... + + async def store_chunk(self, chunk: StoredChunk) -> UUID: + """Insert a chunk. Returns its ID.""" + ... + + async def store_proposition(self, prop: StoredProposition) -> UUID: + """Insert a proposition. Returns its ID.""" + ... + + async def link_entity( + self, + namespace: str, + name: str, + entity_type: str, + embedding: list[float], + threshold: float = 0.85, + ) -> UUID: + """Find or create an entity by name/type/embedding similarity.""" + ... + + async def store_edge( + self, + src_entity: UUID, + dst_entity: UUID, + predicate: str, + proposition_id: UUID, + ) -> None: + """Create an edge between two entities via a proposition.""" + ... + + # --- Access tracking ------------------------------------------------- + + async def bump_access(self, proposition_ids: list[UUID]) -> None: + """Increment access_count and update last_accessed_at.""" + ... + + # --- Entity resolution ----------------------------------------------- + + async def resolve_entity_names( + self, entity_ids: list[UUID], + ) -> dict[UUID, str]: + """Map entity IDs to names. Returns ``{id: name}`` for found entities.""" + ... + + # --- Forget ---------------------------------------------------------- + + async def forget( + self, + proposition_id: UUID, + supersede_with: UUID | None = None, + ) -> None: + """Mark a proposition as superseded.""" + ... + + # --- Cache (optional) ------------------------------------------------ + + async def cache_get(self, cache_key: str) -> list[dict] | None: + """Retrieve cached extraction results as raw dicts. + + Returns ``None`` on cache miss. The caller is responsible for + deserialising dicts into ``Proposition`` objects. + """ + return None + + async def cache_put( + self, + cache_key: str, + chunk_hash: str, + extractor_model: str, + prompt_version: str, + propositions: list[dict], + ) -> None: + """Store extraction results in cache as raw dicts.""" + ... + + # --- Health ---------------------------------------------------------- + + async def health_check(self) -> bool: + """Return True if the backend is healthy (can serve queries).""" + ... diff --git a/pgkg/backends/__init__.py b/pgkg/backends/__init__.py new file mode 100644 index 0000000..5343a83 --- /dev/null +++ b/pgkg/backends/__init__.py @@ -0,0 +1,26 @@ +"""Backend factory — construct a StorageBackend from configuration.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pgkg.backend import StorageBackend + + +async def make_backend(backend_type: str = "postgres", **kwargs) -> "StorageBackend": + """Construct and initialise a storage backend. + + Parameters + ---------- + backend_type: + One of ``"postgres"`` (default). Future values: ``"sqlite"``. + **kwargs: + Backend-specific arguments. For ``"postgres"``: ``dsn`` (str). + """ + if backend_type == "postgres": + from pgkg.backends.postgres import PostgresBackend + + dsn: str = kwargs.pop("dsn") + return await PostgresBackend.create(dsn) + + raise ValueError(f"Unknown backend type: {backend_type!r}") diff --git a/pgkg/backends/postgres.py b/pgkg/backends/postgres.py new file mode 100644 index 0000000..95b662b --- /dev/null +++ b/pgkg/backends/postgres.py @@ -0,0 +1,450 @@ +"""PostgresBackend — wraps asyncpg pool and Postgres SQL functions.""" +from __future__ import annotations + +import json +import pathlib +from uuid import UUID + +import asyncpg +from pgvector.asyncpg import register_vector + +from pgkg.backend import ( + Candidate, + PropositionRow, + ScoredId, + StoredChunk, + StoredDocument, + StoredProposition, +) + +MIGRATIONS_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "migrations" + + +def _vec_literal(emb: list[float]) -> str: + """Format a vector as a Postgres literal string ``[a,b,c,...]``.""" + return "[" + ",".join(str(v) for v in emb) + "]" + + +class PostgresBackend: + """StorageBackend backed by asyncpg + pgvector. + + Delegates search to the ``pgkg_search()`` CTE, entity linking to + ``pgkg_link_entity()``, and access bumping to ``pgkg_bump_access()``. + """ + + def __init__(self, pool: asyncpg.Pool, *, dsn: str | None = None) -> None: + self._pool = pool + self._dsn = dsn + + @classmethod + async def create(cls, dsn: str) -> PostgresBackend: + """Create a backend with a fresh connection pool.""" + pool = await asyncpg.create_pool( + dsn, min_size=1, max_size=10, init=_init_connection, + ) + return cls(pool, dsn=dsn) # type: ignore[arg-type] + + @property + def pool(self) -> asyncpg.Pool: + """Expose the pool for health checks and legacy callers.""" + return self._pool + + # --- Lifecycle ------------------------------------------------------- + + async def apply_migrations(self) -> None: + async with self._pool.acquire() as conn: + await self._apply_migrations_with_conn(conn) + + async def _apply_migrations_with_conn(self, conn: asyncpg.Connection) -> None: + await conn.execute( + "CREATE TABLE IF NOT EXISTS pgkg_schema_migrations (" + " filename TEXT PRIMARY KEY," + " applied_at TIMESTAMPTZ NOT NULL DEFAULT now()" + ")" + ) + applied = { + r["filename"] + for r in await conn.fetch("SELECT filename FROM pgkg_schema_migrations") + } + for migration in sorted(MIGRATIONS_DIR.glob("*.sql")): + if migration.name in applied: + continue + async with conn.transaction(): + await conn.execute(migration.read_text()) + await conn.execute( + "INSERT INTO pgkg_schema_migrations (filename) VALUES ($1)", + migration.name, + ) + + async def close(self) -> None: + await self._pool.close() + + # --- Retrieval primitives -------------------------------------------- + + async def keyword_search( + self, + query: str, + k: int, + namespace: str, + session_id: str | None = None, + ) -> list[ScoredId]: + async with self._pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT p.id, ts_rank_cd(p.tsv, plainto_tsquery('english', $1)) AS score + FROM propositions p + WHERE p.tsv @@ plainto_tsquery('english', $1) + AND p.namespace = $2 + AND p.superseded_by IS NULL + AND ($3::text IS NULL OR p.session_id = $3 OR p.session_id IS NULL) + ORDER BY score DESC + LIMIT $4 + """, + query, namespace, session_id, k, + ) + return [ScoredId(id=r["id"], score=float(r["score"])) for r in rows] + + async def vector_search( + self, + embedding: list[float], + k: int, + namespace: str, + session_id: str | None = None, + ) -> list[ScoredId]: + vec_lit = _vec_literal(embedding) + async with self._pool.acquire() as conn: + rows = await conn.fetch( + f""" + SELECT p.id, (1.0 - (p.embedding <=> '{vec_lit}'::vector)) AS score + FROM propositions p + WHERE p.embedding IS NOT NULL + AND p.namespace = $1 + AND p.superseded_by IS NULL + AND ($2::text IS NULL OR p.session_id = $2 OR p.session_id IS NULL) + ORDER BY p.embedding <=> '{vec_lit}'::vector + LIMIT $3 + """, + namespace, session_id, k, + ) + return [ScoredId(id=r["id"], score=float(r["score"])) for r in rows] + + async def graph_neighbors( + self, + entity_ids: list[UUID], + namespace: str, + limit: int = 100, + ) -> list[ScoredId]: + if not entity_ids: + return [] + async with self._pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT DISTINCT e.proposition_id AS id, 1.0 AS score + FROM edges e + JOIN propositions p ON p.id = e.proposition_id + WHERE (e.src_entity = ANY($1::uuid[]) OR e.dst_entity = ANY($1::uuid[])) + AND p.namespace = $2 + AND p.superseded_by IS NULL + LIMIT $3 + """, + entity_ids, namespace, limit, + ) + return [ScoredId(id=r["id"], score=float(r["score"])) for r in rows] + + async def fused_search( + self, + query_text: str, + query_embedding: list[float], + k_retrieve: int, + k_initial: int, + namespace: str, + session_id: str | None, + recency_half_life_days: float, + expand_graph: bool, + rrf_k: int, + ) -> list[Candidate] | None: + vec_lit = _vec_literal(query_embedding) + async with self._pool.acquire() as conn: + rows = await conn.fetch( + f""" + SELECT proposition_id, text, embedding, rrf_score, adjusted_score, + source_kind, chunk_id, subject_id, predicate, object_id + FROM pgkg_search($1, '{vec_lit}'::vector, + $2, $3, $4, $5, $6, $7, $8) + """, + query_text, + k_retrieve, + k_initial, + namespace, + session_id, + recency_half_life_days, + expand_graph, + rrf_k, + ) + return [ + Candidate( + proposition_id=r["proposition_id"], + text=r["text"], + rrf_score=float(r["rrf_score"]), + adjusted_score=float(r["adjusted_score"]), + source_kind=r["source_kind"], + chunk_id=r["chunk_id"], + subject_id=r["subject_id"], + predicate=r["predicate"], + object_id=r["object_id"], + ) + for r in rows + ] + + # --- Data resolution ------------------------------------------------- + + async def get_propositions( + self, + proposition_ids: list[UUID], + ) -> list[PropositionRow]: + if not proposition_ids: + return [] + async with self._pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT id, text, embedding, chunk_id, subject_id, predicate, + object_id, confidence, access_count, + last_accessed_at::text, namespace + FROM propositions + WHERE id = ANY($1::uuid[]) + """, + proposition_ids, + ) + return [ + PropositionRow( + id=r["id"], + text=r["text"], + embedding=list(r["embedding"]) if r["embedding"] is not None else None, + chunk_id=r["chunk_id"], + subject_id=r["subject_id"], + predicate=r["predicate"], + object_id=r["object_id"], + confidence=float(r["confidence"]), + access_count=r["access_count"], + last_accessed_at=r["last_accessed_at"], + namespace=r["namespace"], + ) + for r in rows + ] + + async def get_proposition_entity_ids( + self, + proposition_ids: list[UUID], + ) -> dict[UUID, list[UUID]]: + if not proposition_ids: + return {} + async with self._pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT id, subject_id, object_id + FROM propositions + WHERE id = ANY($1::uuid[]) + """, + proposition_ids, + ) + result: dict[UUID, list[UUID]] = {} + for r in rows: + entities: list[UUID] = [] + if r["subject_id"] is not None: + entities.append(r["subject_id"]) + if r["object_id"] is not None: + entities.append(r["object_id"]) + if entities: + result[r["id"]] = entities + return result + + # --- Ingest primitives ----------------------------------------------- + + async def store_document(self, doc: StoredDocument) -> UUID: + async with self._pool.acquire() as conn: + doc_id: UUID = await conn.fetchval( + "INSERT INTO documents (source, namespace) VALUES ($1, $2) RETURNING id", + doc.source, + doc.namespace, + ) + return doc_id + + async def store_chunk(self, chunk: StoredChunk) -> UUID: + async with self._pool.acquire() as conn: + chunk_id: UUID = await conn.fetchval( + """ + INSERT INTO chunks (document_id, text, span_start, span_end) + VALUES ($1, $2, $3, $4) RETURNING id + """, + chunk.document_id, + chunk.text, + chunk.span_start, + chunk.span_end, + ) + return chunk_id + + async def store_proposition(self, prop: StoredProposition) -> UUID: + vec_lit = _vec_literal(prop.embedding) + metadata_json = json.dumps(prop.metadata) if prop.metadata else None + async with self._pool.acquire() as conn: + prop_id: UUID = await conn.fetchval( + f""" + INSERT INTO propositions + (text, embedding, subject_id, predicate, object_id, + object_literal, chunk_id, namespace, session_id, + confidence, metadata) + VALUES ($1, '{vec_lit}'::vector, + $2, $3, $4, $5, $6, $7, $8, $9, + $10::jsonb) + RETURNING id + """, + prop.text, + prop.subject_id, + prop.predicate, + prop.object_id, + prop.object_literal, + prop.chunk_id, + prop.namespace, + prop.session_id, + prop.confidence, + metadata_json, + ) + return prop_id + + async def link_entity( + self, + namespace: str, + name: str, + entity_type: str, + embedding: list[float], + threshold: float = 0.85, + ) -> UUID: + vec_lit = _vec_literal(embedding) + async with self._pool.acquire() as conn: + entity_id: UUID = await conn.fetchval( + f"SELECT pgkg_link_entity($1, $2, $3, '{vec_lit}'::vector, $4)", + namespace, + name, + entity_type, + threshold, + ) + return entity_id + + async def store_edge( + self, + src_entity: UUID, + dst_entity: UUID, + predicate: str, + proposition_id: UUID, + ) -> None: + async with self._pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO edges (src_entity, dst_entity, relation, proposition_id) + VALUES ($1, $2, $3, $4) + ON CONFLICT DO NOTHING + """, + src_entity, + dst_entity, + predicate, + proposition_id, + ) + + # --- Access tracking ------------------------------------------------- + + async def bump_access(self, proposition_ids: list[UUID]) -> None: + if not proposition_ids: + return + async with self._pool.acquire() as conn: + await conn.execute("SELECT pgkg_bump_access($1::uuid[])", proposition_ids) + + # --- Entity resolution ----------------------------------------------- + + async def resolve_entity_names( + self, entity_ids: list[UUID], + ) -> dict[UUID, str]: + if not entity_ids: + return {} + async with self._pool.acquire() as conn: + rows = await conn.fetch( + "SELECT id, name FROM entities WHERE id = ANY($1::uuid[])", + entity_ids, + ) + return {r["id"]: r["name"] for r in rows} + + # --- Forget ---------------------------------------------------------- + + async def forget( + self, + proposition_id: UUID, + supersede_with: UUID | None = None, + ) -> None: + async with self._pool.acquire() as conn: + if supersede_with is not None: + await conn.execute( + "UPDATE propositions SET superseded_by = $1 WHERE id = $2", + supersede_with, + proposition_id, + ) + else: + await conn.execute( + "UPDATE propositions SET superseded_by = id WHERE id = $1", + proposition_id, + ) + + # --- Cache ----------------------------------------------------------- + + async def cache_get(self, cache_key: str) -> list[dict] | None: + async with self._pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT propositions FROM proposition_cache WHERE cache_key = $1", + cache_key, + ) + if row is None: + return None + await conn.execute( + "UPDATE proposition_cache SET hit_count = hit_count + 1 WHERE cache_key = $1", + cache_key, + ) + raw = row["propositions"] + if isinstance(raw, str): + return json.loads(raw) + return raw + + async def cache_put( + self, + cache_key: str, + chunk_hash: str, + extractor_model: str, + prompt_version: str, + propositions: list[dict], + ) -> None: + payload = json.dumps(propositions) + async with self._pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO proposition_cache + (cache_key, chunk_hash, extractor_model, prompt_version, propositions) + VALUES ($1, $2, $3, $4, $5::jsonb) + ON CONFLICT (cache_key) DO NOTHING + """, + cache_key, + chunk_hash, + extractor_model, + prompt_version, + payload, + ) + + # --- Health ---------------------------------------------------------- + + async def health_check(self) -> bool: + try: + async with self._pool.acquire() as conn: + await conn.fetchval("SELECT 1") + return True + except Exception: + return False + + +async def _init_connection(conn: asyncpg.Connection) -> None: + await register_vector(conn) From ea3659d28f0e322ccea1aafaa2255fce471a2e72 Mon Sep 17 00:00:00 2001 From: Joseph Mearman Date: Tue, 28 Apr 2026 20:37:08 +0200 Subject: [PATCH 2/3] refactor(memory): decouple Memory from asyncpg.Pool via StorageBackend Memory now depends on the StorageBackend protocol instead of holding an asyncpg.Pool directly. All raw SQL has moved into PostgresBackend. PostgresExtractCache replaced by BackendExtractCache adapter. Updated api.py, cli.py, bench scripts, and tests to construct PostgresBackend. --- bench/common.py | 4 +- bench/locomo.py | 11 +- bench/longmemeval.py | 11 +- pgkg/__init__.py | 3 +- pgkg/api.py | 24 +- pgkg/backends/__init__.py | 6 +- pgkg/backends/postgres.py | 11 +- pgkg/cli.py | 55 ++--- pgkg/memory.py | 423 ++++++++++++++---------------------- tests/conftest.py | 10 + tests/test_chunks_only.py | 12 +- tests/test_extract_cache.py | 14 +- tests/test_memory.py | 20 +- 13 files changed, 256 insertions(+), 348 deletions(-) diff --git a/bench/common.py b/bench/common.py index 4058078..b8c0c5a 100644 --- a/bench/common.py +++ b/bench/common.py @@ -395,7 +395,7 @@ async def run_bench( name: str, items: Iterable[BenchItem], config: BenchConfig, - pool, + backend, ) -> BenchReport: """Orchestrate the full benchmark run.""" from pgkg.memory import Memory @@ -434,7 +434,7 @@ async def process_item(item: BenchItem) -> None: async with semaphore: ns = item.namespace - memory = Memory(pool, namespace=ns, extract_propositions=config.extract_propositions) + memory = Memory(backend, namespace=ns, extract_propositions=config.extract_propositions) # Ingest all conversation turns grouped by session_id sessions: dict[str, list[dict]] = {} diff --git a/bench/locomo.py b/bench/locomo.py index ed7268c..2fbf694 100644 --- a/bench/locomo.py +++ b/bench/locomo.py @@ -190,16 +190,21 @@ async def main() -> None: print(f"Loaded {len(items)} LoCoMo conversations") - from pgkg.db import pool_from_settings + from pgkg.backends.postgres import PostgresBackend + from pgkg.config import get_settings from bench.common import run_bench - async with pool_from_settings() as pool: + settings = get_settings() + backend = await PostgresBackend.create(settings.database_url) + try: report = await run_bench( name="locomo", items=items, config=config, - pool=pool, + backend=backend, ) + finally: + await backend.close() print(f"\nFinal: {report.accuracy:.1%} accuracy over {report.total} questions") diff --git a/bench/longmemeval.py b/bench/longmemeval.py index 6e52d7b..7a070a2 100644 --- a/bench/longmemeval.py +++ b/bench/longmemeval.py @@ -245,16 +245,21 @@ async def main() -> None: print(f"Loaded {len(items)} LongMemEval ({args.variant}) records") - from pgkg.db import pool_from_settings + from pgkg.backends.postgres import PostgresBackend + from pgkg.config import get_settings from bench.common import run_bench - async with pool_from_settings() as pool: + settings = get_settings() + backend = await PostgresBackend.create(settings.database_url) + try: report = await run_bench( name=f"longmemeval-{args.variant}", items=items, config=config, - pool=pool, + backend=backend, ) + finally: + await backend.close() print(f"\nFinal: {report.accuracy:.1%} accuracy over {report.total} questions") diff --git a/pgkg/__init__.py b/pgkg/__init__.py index c9c89b3..2505741 100644 --- a/pgkg/__init__.py +++ b/pgkg/__init__.py @@ -1,4 +1,5 @@ +from pgkg.backend import StorageBackend from pgkg.memory import Memory, Result from pgkg.config import MemoryConfig -__all__ = ["Memory", "MemoryConfig", "Result"] +__all__ = ["Memory", "MemoryConfig", "Result", "StorageBackend"] diff --git a/pgkg/api.py b/pgkg/api.py index dbc454f..7adf094 100644 --- a/pgkg/api.py +++ b/pgkg/api.py @@ -9,27 +9,27 @@ from pydantic import BaseModel from pgkg import ml +from pgkg.backends.postgres import PostgresBackend from pgkg.config import get_settings -from pgkg.db import make_pool, close_pool from pgkg.memory import Memory, IngestResult, Result -_pool = None +_backend: PostgresBackend | None = None _memory: Memory | None = None @asynccontextmanager async def lifespan(app: FastAPI): - global _pool, _memory + global _backend, _memory settings = get_settings() - _pool = await make_pool(settings.database_url) + _backend = await PostgresBackend.create(settings.database_url) _memory = Memory( - _pool, + _backend, namespace=settings.default_namespace, extract_propositions=settings.extract_propositions, ) yield - if _pool: - await close_pool(_pool) + if _backend: + await _backend.close() app = FastAPI(title="pgkg", lifespan=lifespan) @@ -90,15 +90,7 @@ async def forget(req: ForgetRequest) -> Response: @app.get("/health") async def health() -> dict: - db_ok = False - if _pool: - try: - async with _pool.acquire() as conn: - await conn.fetchval("SELECT 1") - db_ok = True - except Exception: - pass - + db_ok = await _backend.health_check() if _backend else False return { "status": "ok", "db": db_ok, diff --git a/pgkg/backends/__init__.py b/pgkg/backends/__init__.py index 5343a83..0fd7228 100644 --- a/pgkg/backends/__init__.py +++ b/pgkg/backends/__init__.py @@ -15,12 +15,14 @@ async def make_backend(backend_type: str = "postgres", **kwargs) -> "StorageBack backend_type: One of ``"postgres"`` (default). Future values: ``"sqlite"``. **kwargs: - Backend-specific arguments. For ``"postgres"``: ``dsn`` (str). + Backend-specific arguments. For ``"postgres"``: ``dsn`` (str | None). + When *dsn* is omitted or ``None``, pgserver auto-starts an embedded + Postgres instance. """ if backend_type == "postgres": from pgkg.backends.postgres import PostgresBackend - dsn: str = kwargs.pop("dsn") + dsn: str | None = kwargs.pop("dsn", None) return await PostgresBackend.create(dsn) raise ValueError(f"Unknown backend type: {backend_type!r}") diff --git a/pgkg/backends/postgres.py b/pgkg/backends/postgres.py index 95b662b..a673f89 100644 --- a/pgkg/backends/postgres.py +++ b/pgkg/backends/postgres.py @@ -37,8 +37,15 @@ def __init__(self, pool: asyncpg.Pool, *, dsn: str | None = None) -> None: self._dsn = dsn @classmethod - async def create(cls, dsn: str) -> PostgresBackend: - """Create a backend with a fresh connection pool.""" + async def create(cls, dsn: str | None = None) -> PostgresBackend: + """Create a backend with a fresh connection pool. + + When *dsn* is ``None``, an embedded Postgres is started + automatically via pgserver (requires ``uv sync --extra embedded``). + """ + if dsn is None: + from pgkg.embedded import get_dsn + dsn = get_dsn() pool = await asyncpg.create_pool( dsn, min_size=1, max_size=10, init=_init_connection, ) diff --git a/pgkg/cli.py b/pgkg/cli.py index 046a42b..d3f1f97 100644 --- a/pgkg/cli.py +++ b/pgkg/cli.py @@ -7,44 +7,17 @@ def cmd_migrate(args: argparse.Namespace) -> None: - import pathlib - - import asyncpg - + from pgkg.backends.postgres import PostgresBackend from pgkg.config import get_settings async def _run() -> None: - dsn = get_settings().database_url - if dsn is None: - from pgkg.embedded import get_dsn - dsn = get_dsn() - migrations_dir = pathlib.Path(__file__).resolve().parent.parent / "migrations" - conn = await asyncpg.connect(dsn) + settings = get_settings() + backend = await PostgresBackend.create(settings.database_url) try: - await conn.execute( - "CREATE TABLE IF NOT EXISTS pgkg_schema_migrations (" - " filename TEXT PRIMARY KEY," - " applied_at TIMESTAMPTZ NOT NULL DEFAULT now()" - ")" - ) - applied = { - r["filename"] - for r in await conn.fetch("SELECT filename FROM pgkg_schema_migrations") - } - for migration in sorted(migrations_dir.glob("*.sql")): - if migration.name in applied: - print(f"Skipping {migration.name} (already applied).") - continue - print(f"Applying {migration.name}...") - async with conn.transaction(): - await conn.execute(migration.read_text()) - await conn.execute( - "INSERT INTO pgkg_schema_migrations (filename) VALUES ($1)", - migration.name, - ) + await backend.apply_migrations() print("All migrations applied.") finally: - await conn.close() + await backend.close() asyncio.run(_run()) @@ -55,7 +28,7 @@ def cmd_serve(args: argparse.Namespace) -> None: def cmd_ingest(args: argparse.Namespace) -> None: - from pgkg.db import pool_from_settings + from pgkg.backends.postgres import PostgresBackend from pgkg.memory import Memory from pgkg.config import get_settings @@ -71,8 +44,9 @@ def cmd_ingest(args: argparse.Namespace) -> None: async def _run() -> None: settings = get_settings() extract = not args.chunks_only - async with pool_from_settings() as pool: - mem = Memory(pool, namespace=settings.default_namespace, extract_propositions=extract) + backend = await PostgresBackend.create(settings.database_url) + try: + mem = Memory(backend, namespace=settings.default_namespace, extract_propositions=extract) result = await mem.ingest(text, source=source) print(json.dumps({ "documents": result.documents, @@ -80,21 +54,26 @@ async def _run() -> None: "propositions": result.propositions, "entities": result.entities, })) + finally: + await backend.close() asyncio.run(_run()) def cmd_recall(args: argparse.Namespace) -> None: - from pgkg.db import pool_from_settings + from pgkg.backends.postgres import PostgresBackend from pgkg.memory import Memory from pgkg.config import get_settings async def _run() -> None: settings = get_settings() - async with pool_from_settings() as pool: - mem = Memory(pool, namespace=settings.default_namespace) + backend = await PostgresBackend.create(settings.database_url) + try: + mem = Memory(backend, namespace=settings.default_namespace) results = await mem.recall(args.query, k=args.k) print(json.dumps([r.model_dump(mode="json") for r in results], indent=2)) + finally: + await backend.close() asyncio.run(_run()) diff --git a/pgkg/memory.py b/pgkg/memory.py index d1295f9..8366666 100644 --- a/pgkg/memory.py +++ b/pgkg/memory.py @@ -1,19 +1,20 @@ from __future__ import annotations import asyncio -import json import re -import uuid from dataclasses import dataclass -from datetime import datetime -from typing import Any from uuid import UUID -import asyncpg from pydantic import BaseModel from pgkg import ml -from pgkg.ml import ExtractCache, Proposition, PROMPT_VERSION +from pgkg.backend import ( + StorageBackend, + StoredChunk, + StoredDocument, + StoredProposition, +) +from pgkg.ml import ExtractCache, Proposition class Result(BaseModel): @@ -26,7 +27,6 @@ class Result(BaseModel): subject: str | None predicate: str | None object: str | None - asserted_at: datetime | None = None @dataclass @@ -64,36 +64,17 @@ def _chunk_text(text: str, chunk_size: int, chunk_overlap: int) -> list[str]: return chunks or [text[:chunk_size]] -class PostgresExtractCache: - """Postgres-backed implementation of ExtractCache. +class BackendExtractCache: + """Adapts StorageBackend cache methods to the ExtractCache protocol.""" - Stores extracted propositions in the proposition_cache table so re-ingesting - the same chunk with the same extractor model and prompt version is free. - """ - - def __init__(self, pool: asyncpg.Pool, namespace: str) -> None: - self._pool = pool - self._namespace = namespace + def __init__(self, backend: StorageBackend) -> None: + self._backend = backend async def get(self, cache_key: str) -> list[Proposition] | None: - async with self._pool.acquire() as conn: - row = await conn.fetchrow( - "SELECT propositions FROM proposition_cache WHERE cache_key = $1", - cache_key, - ) - if row is None: - return None - # Bump hit count (best-effort; don't fail the main path) - await conn.execute( - "UPDATE proposition_cache SET hit_count = hit_count + 1 WHERE cache_key = $1", - cache_key, - ) - raw = row["propositions"] - if isinstance(raw, str): - items = json.loads(raw) - else: - items = raw # asyncpg may already decode JSONB - return [Proposition(**p) for p in items] + raw = await self._backend.cache_get(cache_key) + if raw is None: + return None + return [Proposition(**d) for d in raw] async def put( self, @@ -103,37 +84,29 @@ async def put( prompt_version: str, props: list[Proposition], ) -> None: - payload = json.dumps([p.model_dump() for p in props]) - async with self._pool.acquire() as conn: - await conn.execute( - """ - INSERT INTO proposition_cache - (cache_key, chunk_hash, extractor_model, prompt_version, propositions) - VALUES ($1, $2, $3, $4, $5::jsonb) - ON CONFLICT (cache_key) DO NOTHING - """, - cache_key, - chunk_hash, - extractor_model, - prompt_version, - payload, - ) + await self._backend.cache_put( + cache_key, + chunk_hash, + extractor_model, + prompt_version, + [p.model_dump() for p in props], + ) class Memory: def __init__( self, - pool: asyncpg.Pool, + backend: StorageBackend, *, namespace: str = "default", use_extract_cache: bool = True, extract_propositions: bool = True, ) -> None: - self._pool = pool + self._backend = backend self._namespace = namespace self._extract_propositions = extract_propositions self._extract_cache: ExtractCache | None = ( - PostgresExtractCache(pool, namespace) if use_extract_cache and extract_propositions else None + BackendExtractCache(backend) if use_extract_cache and extract_propositions else None ) async def ingest( @@ -142,7 +115,6 @@ async def ingest( *, source: str | None = None, session_id: str | None = None, - asserted_at: datetime | None = None, chunk_size: int = 1200, chunk_overlap: int = 100, ) -> IngestResult: @@ -150,140 +122,107 @@ async def ingest( entities_created: set[UUID] = set() total_propositions = 0 - async with self._pool.acquire() as conn: - # Insert document - doc_id: UUID = await conn.fetchval( - "INSERT INTO documents (source, namespace) VALUES ($1, $2) RETURNING id", - source, - self._namespace, + doc_id = await self._backend.store_document( + StoredDocument(source=source, namespace=self._namespace) + ) + + chunk_ids: list[UUID] = [] + for i, chunk_text in enumerate(chunks): + chunk_id = await self._backend.store_chunk( + StoredChunk( + document_id=doc_id, + text=chunk_text, + span_start=i * chunk_size, + span_end=(i + 1) * chunk_size, + ) ) + chunk_ids.append(chunk_id) - chunk_ids: list[UUID] = [] - for i, chunk_text in enumerate(chunks): - chunk_id: UUID = await conn.fetchval( - """ - INSERT INTO chunks (document_id, text, span_start, span_end, asserted_at) - VALUES ($1, $2, $3, $4, $5) RETURNING id - """, - doc_id, - chunk_text, - i * chunk_size, - (i + 1) * chunk_size, - asserted_at, + if self._extract_propositions: + for chunk_id, chunk_text in zip(chunk_ids, chunks): + propositions = await ml.extract_propositions_async( + chunk_text, cache=self._extract_cache ) - chunk_ids.append(chunk_id) + if not propositions: + continue - if self._extract_propositions: - # Extract and embed per chunk - for chunk_id, chunk_text in zip(chunk_ids, chunks): - propositions = await ml.extract_propositions_async( - chunk_text, cache=self._extract_cache + entity_names: list[str] = [] + for prop in propositions: + entity_names.append(prop.subject) + if not prop.object_is_literal: + entity_names.append(prop.object) + + prop_texts = [p.text for p in propositions] + all_texts = entity_names + prop_texts + all_embs = ml.embed(all_texts) + + entity_embs = all_embs[: len(entity_names)] + prop_embs = all_embs[len(entity_names) :] + + entity_idx = 0 + for prop, prop_emb in zip(propositions, prop_embs): + subj_emb = entity_embs[entity_idx] + entity_idx += 1 + + subject_id = await self._backend.link_entity( + self._namespace, prop.subject, "concept", subj_emb ) - if not propositions: - continue - - # Collect all texts for batch embedding - entity_names: list[str] = [] - for prop in propositions: - entity_names.append(prop.subject) - if not prop.object_is_literal: - entity_names.append(prop.object) - - prop_texts = [p.text for p in propositions] - all_texts = entity_names + prop_texts - all_embs = ml.embed(all_texts) - - entity_embs = all_embs[: len(entity_names)] - prop_embs = all_embs[len(entity_names):] - - # Link entities and insert propositions - entity_idx = 0 - for prop, prop_emb in zip(propositions, prop_embs): - subj_emb = entity_embs[entity_idx] - entity_idx += 1 + entities_created.add(subject_id) - subject_id: UUID = await conn.fetchval( - _link_entity_sql(subj_emb), - self._namespace, - prop.subject, - "concept", + object_id: UUID | None = None + object_literal: str | None = None + + if prop.object_is_literal: + object_literal = prop.object + else: + obj_emb = entity_embs[entity_idx] + entity_idx += 1 + object_id = await self._backend.link_entity( + self._namespace, prop.object, "concept", obj_emb ) - entities_created.add(subject_id) - - object_id: UUID | None = None - object_literal: str | None = None - - if prop.object_is_literal: - object_literal = prop.object - else: - obj_emb = entity_embs[entity_idx] - entity_idx += 1 - object_id = await conn.fetchval( - _link_entity_sql(obj_emb), - self._namespace, - prop.object, - "concept", - ) - entities_created.add(object_id) - - prop_id: UUID = await conn.fetchval( - f""" - INSERT INTO propositions - (text, embedding, subject_id, predicate, object_id, - object_literal, chunk_id, namespace, session_id, asserted_at) - VALUES ($1, '{_vec_literal(prop_emb)}'::vector, - $2, $3, $4, $5, $6, $7, $8, $9) - RETURNING id - """, - prop.text, - subject_id, - prop.predicate, - object_id, - object_literal, - chunk_id, - self._namespace, - session_id, - asserted_at, + entities_created.add(object_id) + + prop_id = await self._backend.store_proposition( + StoredProposition( + text=prop.text, + embedding=prop_emb, + subject_id=subject_id, + predicate=prop.predicate, + object_id=object_id, + object_literal=object_literal, + chunk_id=chunk_id, + namespace=self._namespace, + session_id=session_id, ) - total_propositions += 1 - - if object_id is not None: - await conn.execute( - """ - INSERT INTO edges (src_entity, dst_entity, relation, proposition_id) - VALUES ($1, $2, $3, $4) - ON CONFLICT DO NOTHING - """, - subject_id, - object_id, - prop.predicate, - prop_id, - ) - else: - # Chunks-only mode: embed each chunk directly, no LLM extraction, - # no entity linking, no edge creation. - chunk_texts_list = list(chunks) - chunk_embs = ml.embed(chunk_texts_list) - for chunk_id, chunk_text, chunk_emb in zip(chunk_ids, chunk_texts_list, chunk_embs): - metadata = '{"mode": "chunk"}' - await conn.fetchval( - f""" - INSERT INTO propositions - (text, embedding, subject_id, predicate, object_id, - object_literal, chunk_id, namespace, session_id, metadata, asserted_at) - VALUES ($1, '{_vec_literal(chunk_emb)}'::vector, - NULL, NULL, NULL, NULL, $2, $3, $4, $5::jsonb, $6) - RETURNING id - """, - chunk_text, - chunk_id, - self._namespace, - session_id, - metadata, - asserted_at, ) total_propositions += 1 + if object_id is not None: + await self._backend.store_edge( + subject_id, object_id, prop.predicate, prop_id + ) + else: + chunk_texts_list = list(chunks) + chunk_embs = ml.embed(chunk_texts_list) + for chunk_id, chunk_text_val, chunk_emb in zip( + chunk_ids, chunk_texts_list, chunk_embs + ): + await self._backend.store_proposition( + StoredProposition( + text=chunk_text_val, + embedding=chunk_emb, + subject_id=None, + predicate=None, + object_id=None, + object_literal=None, + chunk_id=chunk_id, + namespace=self._namespace, + session_id=session_id, + metadata={"mode": "chunk"}, + ) + ) + total_propositions += 1 + return IngestResult( documents=1, chunks=len(chunks), @@ -305,35 +244,37 @@ async def recall( ) -> list[Result]: q_emb = ml.embed([query])[0] - async with self._pool.acquire() as conn: - rows = await conn.fetch( - f""" - SELECT proposition_id, text, embedding, rrf_score, adjusted_score, - source_kind, chunk_id, subject_id, predicate, object_id, asserted_at - FROM pgkg_search($1, '{_vec_literal(q_emb)}'::vector, - $2, $3, $4, $5, 30.0, $6) - """, - query, - k_retrieve, - k_retrieve * 2, - self._namespace, - session_id, - expand_graph, - ) + candidates = await self._backend.fused_search( + query_text=query, + query_embedding=q_emb, + k_retrieve=k_retrieve, + k_initial=k_retrieve * 2, + namespace=self._namespace, + session_id=session_id, + recency_half_life_days=30.0, + expand_graph=expand_graph, + rrf_k=60, + ) - if not rows: + if not candidates: return [] - texts = [r["text"] for r in rows] - scores = [float(r["adjusted_score"]) for r in rows] - embs = [_parse_emb(r["embedding"], q_emb) for r in rows] + # Fetch full proposition data for embeddings (needed by rerank/MMR) + prop_rows = await self._backend.get_propositions( + [c.proposition_id for c in candidates] + ) + emb_by_id = {p.id: p.embedding for p in prop_rows} + + texts = [c.text for c in candidates] + scores = [c.adjusted_score for c in candidates] + embs = [emb_by_id.get(c.proposition_id) or q_emb for c in candidates] + embs = [list(e) if not isinstance(e, list) else e for e in embs] if with_rerank: - candidate_rows = rows[: min(k_retrieve, 64)] - candidate_texts = [r["text"] for r in candidate_rows] + candidate_slice = candidates[: min(k_retrieve, 64)] + candidate_texts = [c.text for c in candidate_slice] rerank_scores = ml.rerank(query, candidate_texts) - # Min-max normalize both score lists def _normalize(vals: list[float]) -> list[float]: lo, hi = min(vals), max(vals) span = hi - lo @@ -341,56 +282,55 @@ def _normalize(vals: list[float]) -> list[float]: return [1.0] * len(vals) return [(v - lo) / span for v in vals] - adj_scores = [float(r["adjusted_score"]) for r in candidate_rows] + adj_scores = [c.adjusted_score for c in candidate_slice] rerank_norm = _normalize(rerank_scores) adj_norm = _normalize(adj_scores) blended = [0.7 * r + 0.3 * a for r, a in zip(rerank_norm, adj_norm)] - # Sort candidate_rows by blended score - sorted_indices = sorted(range(len(blended)), key=lambda i: blended[i], reverse=True) - rows = [candidate_rows[i] for i in sorted_indices] + sorted_indices = sorted( + range(len(blended)), key=lambda i: blended[i], reverse=True + ) + candidates = [candidate_slice[i] for i in sorted_indices] scores = [blended[i] for i in sorted_indices] - embs = [list(candidate_rows[i]["embedding"]) if candidate_rows[i]["embedding"] is not None else q_emb - for i in sorted_indices] + embs = [ + emb_by_id.get(candidate_slice[i].proposition_id) or q_emb + for i in sorted_indices + ] + embs = [list(e) if not isinstance(e, list) else e for e in embs] - if with_mmr and len(rows) > k: + if with_mmr and len(candidates) > k: selected_indices = ml.mmr(q_emb, embs, k, lambda_=mmr_lambda) - rows = [rows[i] for i in selected_indices] + candidates = [candidates[i] for i in selected_indices] scores = [scores[i] for i in selected_indices] else: - rows = rows[:k] + candidates = candidates[:k] scores = scores[:k] # Fire-and-forget bump access - prop_ids = [str(r["proposition_id"]) for r in rows] + prop_ids = [c.proposition_id for c in candidates] asyncio.ensure_future(self._bump(prop_ids)) results = [] - for row, score in zip(rows, scores): - subject_name: str | None = None - object_name: str | None = None - + for cand, score in zip(candidates, scores): results.append( Result( - proposition_id=row["proposition_id"], - text=row["text"], + proposition_id=cand.proposition_id, + text=cand.text, score=score, - rrf_score=float(row["rrf_score"]), - source_kind=row["source_kind"], - chunk_id=row["chunk_id"], - subject=subject_name, - predicate=row["predicate"], - object=object_name, - asserted_at=row["asserted_at"], + rrf_score=cand.rrf_score, + source_kind=cand.source_kind, + chunk_id=cand.chunk_id, + subject=None, + predicate=cand.predicate, + object=None, ) ) return results - async def _bump(self, prop_ids: list[str]) -> None: + async def _bump(self, prop_ids: list[UUID]) -> None: try: - async with self._pool.acquire() as conn: - await conn.execute("SELECT pgkg_bump_access($1::uuid[])", prop_ids) + await self._backend.bump_access(prop_ids) except Exception: pass @@ -400,37 +340,4 @@ async def forget( *, supersede_with: UUID | None = None, ) -> None: - async with self._pool.acquire() as conn: - if supersede_with is not None: - await conn.execute( - "UPDATE propositions SET superseded_by = $1 WHERE id = $2", - supersede_with, - proposition_id, - ) - else: - await conn.execute( - "UPDATE propositions SET superseded_by = id WHERE id = $1", - proposition_id, - ) - - -def _parse_emb(val: object, fallback: list[float]) -> list[float]: - """Parse an embedding from asyncpg — may be a list, numpy array, or string.""" - if val is None: - return fallback - if isinstance(val, (list, tuple)): - return list(val) - if isinstance(val, str): - import json - return json.loads(val.replace("(", "[").replace(")", "]")) - # numpy array or pgvector type - return list(val) - - -def _vec_literal(emb: list[float]) -> str: - """Format a vector as a Postgres literal string '[a,b,c,...]'.""" - return "[" + ",".join(str(v) for v in emb) + "]" - - -def _link_entity_sql(emb: list[float]) -> str: - return f"SELECT pgkg_link_entity($1, $2, $3, '{_vec_literal(emb)}'::vector)" + await self._backend.forget(proposition_id, supersede_with=supersede_with) diff --git a/tests/conftest.py b/tests/conftest.py index 88dd762..2e5ae28 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,8 @@ import pytest from pgvector.asyncpg import register_vector +from pgkg.backends.postgres import PostgresBackend + MIGRATIONS_DIR = pathlib.Path(__file__).parent.parent / "migrations" @@ -64,3 +66,11 @@ async def pool(pg_dsn) -> AsyncGenerator[asyncpg.Pool, None]: yield conn_pool await conn_pool.close() + + +@pytest.fixture(scope="session") +async def backend(pool) -> AsyncGenerator[PostgresBackend, None]: + """PostgresBackend wrapping the test pool.""" + b = PostgresBackend(pool) + yield b + # Pool cleanup handled by the pool fixture; do not close here. diff --git a/tests/test_chunks_only.py b/tests/test_chunks_only.py index 810636d..4f0004f 100644 --- a/tests/test_chunks_only.py +++ b/tests/test_chunks_only.py @@ -31,7 +31,7 @@ def _fake_embed(texts: list[str]) -> list[list[float]]: # test_ingest_chunks_only_skips_extraction # --------------------------------------------------------------------------- -async def test_ingest_chunks_only_skips_extraction(pool: asyncpg.Pool, monkeypatch): +async def test_ingest_chunks_only_skips_extraction(pool: asyncpg.Pool, backend, monkeypatch): """extract_propositions_async must not be called; propositions have NULL subject_id and metadata->>'mode' = 'chunk'; entities table is untouched.""" import pgkg.ml as ml_module @@ -44,7 +44,7 @@ async def _should_not_be_called(*args, **kwargs): monkeypatch.setattr(ml_module, "extract_propositions_async", _should_not_be_called) ns = f"chunks_only_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns, extract_propositions=False) + mem = Memory(backend, namespace=ns, extract_propositions=False) result = await mem.ingest("Hello world. This is a test document.") assert result.documents == 1 @@ -83,7 +83,7 @@ async def _should_not_be_called(*args, **kwargs): # test_recall_works_in_chunks_mode # --------------------------------------------------------------------------- -async def test_recall_works_in_chunks_mode(pool: asyncpg.Pool, monkeypatch): +async def test_recall_works_in_chunks_mode(pool: asyncpg.Pool, backend, monkeypatch): """After chunks-only ingest, recall returns results with NULL predicate.""" import pgkg.ml as ml_module @@ -114,7 +114,7 @@ async def _noop_extract(*args, **kwargs): monkeypatch.setattr(ml_module, "extract_propositions_async", _noop_extract) ns = f"chunks_recall_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns, extract_propositions=False) + mem = Memory(backend, namespace=ns, extract_propositions=False) await mem.ingest(ocean_text) results = await mem.recall( @@ -135,7 +135,7 @@ async def _noop_extract(*args, **kwargs): # test_chunks_mode_graph_expansion_is_noop # --------------------------------------------------------------------------- -async def test_chunks_mode_graph_expansion_is_noop(pool: asyncpg.Pool, monkeypatch): +async def test_chunks_mode_graph_expansion_is_noop(pool: asyncpg.Pool, backend, monkeypatch): """Graph expansion with chunks-only ingest produces no graph-sourced rows.""" import pgkg.ml as ml_module monkeypatch.setattr(ml_module, "embed", _fake_embed) @@ -152,7 +152,7 @@ async def _noop_extract(*args, **kwargs): monkeypatch.setattr(ml_module, "extract_propositions_async", _noop_extract) ns = f"chunks_graph_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns, extract_propositions=False) + mem = Memory(backend, namespace=ns, extract_propositions=False) await mem.ingest("Alice visited Bob last Tuesday. Bob works at Acme Corp.") results = await mem.recall( diff --git a/tests/test_extract_cache.py b/tests/test_extract_cache.py index b5ee63e..55ea633 100644 --- a/tests/test_extract_cache.py +++ b/tests/test_extract_cache.py @@ -18,7 +18,7 @@ compute_cache_key, extract_propositions_async, ) -from pgkg.memory import PostgresExtractCache +from pgkg.memory import BackendExtractCache # --------------------------------------------------------------------------- @@ -69,7 +69,7 @@ async def _get_hit_count(pool: asyncpg.Pool, cache_key: str) -> int: # --------------------------------------------------------------------------- async def test_cache_hit_returns_stored_props_without_llm( - pool: asyncpg.Pool, monkeypatch + pool: asyncpg.Pool, backend, monkeypatch ): """Pre-populate cache; LLM provider should never be called on cache hit.""" chunk = "Alice is a brilliant scientist who works at CERN." @@ -102,7 +102,7 @@ async def test_cache_hit_returns_stored_props_without_llm( lambda *a, **kw: (_ for _ in ()).throw(AssertionError("LLM was called — should have been a cache hit")), ) - cache = PostgresExtractCache(pool, "test-ns") + cache = BackendExtractCache(backend) result = await extract_propositions_async(chunk, cache=cache) assert len(result) == 1 @@ -114,7 +114,7 @@ async def test_cache_hit_returns_stored_props_without_llm( # test_cache_miss_then_hit # --------------------------------------------------------------------------- -async def test_cache_miss_then_hit(pool: asyncpg.Pool, monkeypatch): +async def test_cache_miss_then_hit(pool: asyncpg.Pool, backend, monkeypatch): """First call invokes LLM stub (count=1); second call hits cache (count still 1).""" chunk = "Bob loves hiking in the mountains near his home." model = "stub-model-v1" @@ -141,7 +141,7 @@ def _fake_do_extract(chunk_text, max_propositions, settings, extractor_model): monkeypatch.setattr(ml_module, "_do_extract", _fake_do_extract) - cache = PostgresExtractCache(pool, "test-ns-miss-hit") + cache = BackendExtractCache(backend) # First call — cache miss, LLM called result1 = await extract_propositions_async(chunk, cache=cache) @@ -235,7 +235,7 @@ def test_cache_key_changes_with_prompt_version(monkeypatch): # test_postgres_cache_hit_count_increments # --------------------------------------------------------------------------- -async def test_postgres_cache_hit_count_increments(pool: asyncpg.Pool): +async def test_postgres_cache_hit_count_increments(pool: asyncpg.Pool, backend): """Two get() calls on a populated key → hit_count = 2.""" chunk = "Eve is a cryptographer." model = "hit-count-model" @@ -244,7 +244,7 @@ async def test_postgres_cache_hit_count_increments(pool: asyncpg.Pool): stored = [_make_prop(text="Eve is a cryptographer.", subject="Eve", predicate="is", object="cryptographer")] await _seed_cache(pool, cache_key, stored) - cache = PostgresExtractCache(pool, "test-ns-hits") + cache = BackendExtractCache(backend) await cache.get(cache_key) await cache.get(cache_key) diff --git a/tests/test_memory.py b/tests/test_memory.py index 1ed0573..ccf8bb7 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -38,7 +38,7 @@ def _fake_embed(texts: list[str]) -> list[list[float]]: # test_ingest_creates_rows # --------------------------------------------------------------------------- -async def test_ingest_creates_rows(pool: asyncpg.Pool, monkeypatch): +async def test_ingest_creates_rows(pool: asyncpg.Pool, backend, monkeypatch): """PGKG_OFFLINE_EXTRACT=1: ingest populates documents, chunks, propositions tables.""" monkeypatch.setenv("PGKG_OFFLINE_EXTRACT", "1") @@ -46,7 +46,7 @@ async def test_ingest_creates_rows(pool: asyncpg.Pool, monkeypatch): monkeypatch.setattr(ml_module, "embed", _fake_embed) ns = f"ingest_test_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns) + mem = Memory(backend, namespace=ns) result = await mem.ingest("Hello world. This is a test document.") assert result.documents == 1 @@ -74,7 +74,7 @@ async def test_ingest_creates_rows(pool: asyncpg.Pool, monkeypatch): # test_recall_returns_ingested # --------------------------------------------------------------------------- -async def test_recall_returns_ingested(pool: asyncpg.Pool, monkeypatch): +async def test_recall_returns_ingested(pool: asyncpg.Pool, backend, monkeypatch): """After ingesting a doc, recalling a matching query returns the proposition.""" monkeypatch.setenv("PGKG_OFFLINE_EXTRACT", "1") @@ -103,7 +103,7 @@ def predict(self, pairs): monkeypatch.setattr(ml_module, "_rerank_model", FakeCE()) ns = f"recall_test_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns) + mem = Memory(backend, namespace=ns) await mem.ingest(ocean_text) results = await mem.recall( @@ -123,7 +123,7 @@ def predict(self, pairs): # test_recall_session_scope # --------------------------------------------------------------------------- -async def test_recall_session_scope(pool: asyncpg.Pool, monkeypatch): +async def test_recall_session_scope(pool: asyncpg.Pool, backend, monkeypatch): """Propositions ingested with session_id='A' don't appear in session_id='B' recall.""" monkeypatch.setenv("PGKG_OFFLINE_EXTRACT", "1") @@ -131,7 +131,7 @@ async def test_recall_session_scope(pool: asyncpg.Pool, monkeypatch): monkeypatch.setattr(ml_module, "embed", _fake_embed) ns = f"session_test_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns) + mem = Memory(backend, namespace=ns) # Ingest unique text with session A unique_text = f"Unique session A content xyzzy_{uuid.uuid4().hex}" @@ -164,7 +164,7 @@ async def test_recall_session_scope(pool: asyncpg.Pool, monkeypatch): # test_forget_supersedes # --------------------------------------------------------------------------- -async def test_forget_supersedes(pool: asyncpg.Pool, monkeypatch): +async def test_forget_supersedes(pool: asyncpg.Pool, backend, monkeypatch): """After forget(), the proposition no longer appears in recall results.""" monkeypatch.setenv("PGKG_OFFLINE_EXTRACT", "1") @@ -185,7 +185,7 @@ def _targeted_embed(texts: list[str]) -> list[list[float]]: monkeypatch.setattr(ml_module, "embed", _targeted_embed) ns = f"forget_test_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns) + mem = Memory(backend, namespace=ns) result = await mem.ingest(target_text) # Get the proposition id @@ -220,7 +220,7 @@ def _targeted_embed(texts: list[str]) -> list[list[float]]: # recall test bypassed rerank/MMR, so the truthiness check on the # embedding column was never exercised against a real DB row. -async def test_recall_default_flags_with_pgvector_embedding(pool: asyncpg.Pool, monkeypatch): +async def test_recall_default_flags_with_pgvector_embedding(pool: asyncpg.Pool, backend, monkeypatch): monkeypatch.setenv("PGKG_OFFLINE_EXTRACT", "1") import pgkg.ml as ml_module @@ -229,7 +229,7 @@ async def test_recall_default_flags_with_pgvector_embedding(pool: asyncpg.Pool, monkeypatch.setattr(ml_module, "rerank", lambda q, docs: [1.0 / (i + 1) for i in range(len(docs))]) ns = f"recall_default_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns, extract_propositions=False) + mem = Memory(backend, namespace=ns, extract_propositions=False) await mem.ingest("The chunks-only ingest mode skips LLM extraction entirely.") await mem.ingest("Hybrid retrieval fuses BM25 and vector similarity via RRF.") From c21f4e1a52838be04199fd1265dcca13e3e909ad Mon Sep 17 00:00:00 2001 From: Joseph Mearman Date: Tue, 28 Apr 2026 22:14:10 +0200 Subject: [PATCH 3/3] fix: add asserted_at support to StorageBackend protocol Thread asserted_at through StoredChunk, StoredProposition, Candidate, and Result so the refactored backend preserves the feature added in origin/main. Fix test_bench and test_memory to use backend fixture. --- pgkg/backend.py | 4 ++++ pgkg/backends/postgres.py | 13 ++++++++----- pgkg/memory.py | 7 +++++++ tests/test_bench.py | 4 ++-- tests/test_memory.py | 8 ++++---- 5 files changed, 25 insertions(+), 11 deletions(-) diff --git a/pgkg/backend.py b/pgkg/backend.py index d128d00..1760087 100644 --- a/pgkg/backend.py +++ b/pgkg/backend.py @@ -2,6 +2,7 @@ from __future__ import annotations from dataclasses import dataclass +from datetime import datetime from typing import Protocol, runtime_checkable from uuid import UUID @@ -25,6 +26,7 @@ class Candidate: subject_id: UUID | None predicate: str | None object_id: UUID | None + asserted_at: datetime | None = None @dataclass(frozen=True) @@ -57,6 +59,7 @@ class StoredProposition: session_id: str | None confidence: float = 1.0 metadata: dict | None = None + asserted_at: datetime | None = None @dataclass(frozen=True) @@ -75,6 +78,7 @@ class StoredChunk: id: UUID | None = None span_start: int | None = None span_end: int | None = None + asserted_at: datetime | None = None @runtime_checkable diff --git a/pgkg/backends/postgres.py b/pgkg/backends/postgres.py index a673f89..36de2eb 100644 --- a/pgkg/backends/postgres.py +++ b/pgkg/backends/postgres.py @@ -175,7 +175,7 @@ async def fused_search( rows = await conn.fetch( f""" SELECT proposition_id, text, embedding, rrf_score, adjusted_score, - source_kind, chunk_id, subject_id, predicate, object_id + source_kind, chunk_id, subject_id, predicate, object_id, asserted_at FROM pgkg_search($1, '{vec_lit}'::vector, $2, $3, $4, $5, $6, $7, $8) """, @@ -199,6 +199,7 @@ async def fused_search( subject_id=r["subject_id"], predicate=r["predicate"], object_id=r["object_id"], + asserted_at=r["asserted_at"], ) for r in rows ] @@ -280,13 +281,14 @@ async def store_chunk(self, chunk: StoredChunk) -> UUID: async with self._pool.acquire() as conn: chunk_id: UUID = await conn.fetchval( """ - INSERT INTO chunks (document_id, text, span_start, span_end) - VALUES ($1, $2, $3, $4) RETURNING id + INSERT INTO chunks (document_id, text, span_start, span_end, asserted_at) + VALUES ($1, $2, $3, $4, $5) RETURNING id """, chunk.document_id, chunk.text, chunk.span_start, chunk.span_end, + chunk.asserted_at, ) return chunk_id @@ -299,10 +301,10 @@ async def store_proposition(self, prop: StoredProposition) -> UUID: INSERT INTO propositions (text, embedding, subject_id, predicate, object_id, object_literal, chunk_id, namespace, session_id, - confidence, metadata) + confidence, metadata, asserted_at) VALUES ($1, '{vec_lit}'::vector, $2, $3, $4, $5, $6, $7, $8, $9, - $10::jsonb) + $10::jsonb, $11) RETURNING id """, prop.text, @@ -315,6 +317,7 @@ async def store_proposition(self, prop: StoredProposition) -> UUID: prop.session_id, prop.confidence, metadata_json, + prop.asserted_at, ) return prop_id diff --git a/pgkg/memory.py b/pgkg/memory.py index 8366666..0f353d9 100644 --- a/pgkg/memory.py +++ b/pgkg/memory.py @@ -3,6 +3,7 @@ import asyncio import re from dataclasses import dataclass +from datetime import datetime from uuid import UUID from pydantic import BaseModel @@ -27,6 +28,7 @@ class Result(BaseModel): subject: str | None predicate: str | None object: str | None + asserted_at: datetime | None = None @dataclass @@ -115,6 +117,7 @@ async def ingest( *, source: str | None = None, session_id: str | None = None, + asserted_at: datetime | None = None, chunk_size: int = 1200, chunk_overlap: int = 100, ) -> IngestResult: @@ -134,6 +137,7 @@ async def ingest( text=chunk_text, span_start=i * chunk_size, span_end=(i + 1) * chunk_size, + asserted_at=asserted_at, ) ) chunk_ids.append(chunk_id) @@ -193,6 +197,7 @@ async def ingest( chunk_id=chunk_id, namespace=self._namespace, session_id=session_id, + asserted_at=asserted_at, ) ) total_propositions += 1 @@ -219,6 +224,7 @@ async def ingest( namespace=self._namespace, session_id=session_id, metadata={"mode": "chunk"}, + asserted_at=asserted_at, ) ) total_propositions += 1 @@ -323,6 +329,7 @@ def _normalize(vals: list[float]) -> list[float]: subject=None, predicate=cand.predicate, object=None, + asserted_at=cand.asserted_at, ) ) diff --git a/tests/test_bench.py b/tests/test_bench.py index 4e35075..1b0f66e 100644 --- a/tests/test_bench.py +++ b/tests/test_bench.py @@ -41,7 +41,7 @@ def test_bench_item_validates(): # test_run_bench_dry_run_exact_match # --------------------------------------------------------------------------- -async def test_run_bench_dry_run_exact_match(pool: asyncpg.Pool, monkeypatch): +async def test_run_bench_dry_run_exact_match(pool: asyncpg.Pool, backend, monkeypatch): """run_bench with dry_run=True, exact_match=True completes without LLM calls.""" monkeypatch.setenv("PGKG_OFFLINE_EXTRACT", "1") @@ -99,7 +99,7 @@ def _fake_embed(texts): name="test-dryrun", items=items, config=config, - pool=pool, + backend=backend, ) assert report.total == 2 diff --git a/tests/test_memory.py b/tests/test_memory.py index ccf8bb7..5201184 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -245,7 +245,7 @@ async def test_recall_default_flags_with_pgvector_embedding(pool: asyncpg.Pool, # test_ingest_propagates_asserted_at # --------------------------------------------------------------------------- -async def test_ingest_propagates_asserted_at(pool: asyncpg.Pool, monkeypatch): +async def test_ingest_propagates_asserted_at(pool: asyncpg.Pool, backend, monkeypatch): """Ingest with asserted_at stores it in both chunk and proposition rows.""" monkeypatch.setenv("PGKG_OFFLINE_EXTRACT", "1") @@ -254,7 +254,7 @@ async def test_ingest_propagates_asserted_at(pool: asyncpg.Pool, monkeypatch): expected_ts = datetime(2025, 1, 15, 10, 0, 0, tzinfo=timezone.utc) ns = f"assertedat_ingest_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns, extract_propositions=False) + mem = Memory(backend, namespace=ns, extract_propositions=False) await mem.ingest( "The sky is blue and the grass is green.", @@ -292,7 +292,7 @@ async def test_ingest_propagates_asserted_at(pool: asyncpg.Pool, monkeypatch): # test_recall_returns_asserted_at_in_result # --------------------------------------------------------------------------- -async def test_recall_returns_asserted_at_in_result(pool: asyncpg.Pool, monkeypatch): +async def test_recall_returns_asserted_at_in_result(pool: asyncpg.Pool, backend, monkeypatch): """Result.asserted_at is populated when ingested with an asserted_at timestamp.""" monkeypatch.setenv("PGKG_OFFLINE_EXTRACT", "1") @@ -312,7 +312,7 @@ def _controlled_embed(texts: list[str]) -> list[list[float]]: monkeypatch.setattr(ml_module, "embed", _controlled_embed) ns = f"assertedat_recall_{uuid.uuid4().hex[:8]}" - mem = Memory(pool, namespace=ns, extract_propositions=False) + mem = Memory(backend, namespace=ns, extract_propositions=False) await mem.ingest(target_text, asserted_at=expected_ts)