Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ async def run_bench(
name: str,
items: Iterable[BenchItem],
config: BenchConfig,
pool,
backend,
) -> BenchReport:
"""Orchestrate the full benchmark run."""
from pgkg.memory import Memory
Expand Down Expand Up @@ -434,7 +434,7 @@ async def process_item(item: BenchItem) -> None:

async with semaphore:
ns = item.namespace
memory = Memory(pool, namespace=ns, extract_propositions=config.extract_propositions)
memory = Memory(backend, namespace=ns, extract_propositions=config.extract_propositions)

# Ingest all conversation turns grouped by session_id
sessions: dict[str, list[dict]] = {}
Expand Down
11 changes: 8 additions & 3 deletions bench/locomo.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,21 @@ async def main() -> None:

print(f"Loaded {len(items)} LoCoMo conversations")

from pgkg.db import pool_from_settings
from pgkg.backends.postgres import PostgresBackend
from pgkg.config import get_settings
from bench.common import run_bench

async with pool_from_settings() as pool:
settings = get_settings()
backend = await PostgresBackend.create(settings.database_url)
try:
report = await run_bench(
name="locomo",
items=items,
config=config,
pool=pool,
backend=backend,
)
finally:
await backend.close()

print(f"\nFinal: {report.accuracy:.1%} accuracy over {report.total} questions")

Expand Down
11 changes: 8 additions & 3 deletions bench/longmemeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,21 @@ async def main() -> None:

print(f"Loaded {len(items)} LongMemEval ({args.variant}) records")

from pgkg.db import pool_from_settings
from pgkg.backends.postgres import PostgresBackend
from pgkg.config import get_settings
from bench.common import run_bench

async with pool_from_settings() as pool:
settings = get_settings()
backend = await PostgresBackend.create(settings.database_url)
try:
report = await run_bench(
name=f"longmemeval-{args.variant}",
items=items,
config=config,
pool=pool,
backend=backend,
)
finally:
await backend.close()

print(f"\nFinal: {report.accuracy:.1%} accuracy over {report.total} questions")

Expand Down
3 changes: 2 additions & 1 deletion pgkg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pgkg.backend import StorageBackend
from pgkg.memory import Memory, Result
from pgkg.config import MemoryConfig

__all__ = ["Memory", "MemoryConfig", "Result"]
__all__ = ["Memory", "MemoryConfig", "Result", "StorageBackend"]
24 changes: 8 additions & 16 deletions pgkg/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,27 @@
from pydantic import BaseModel

from pgkg import ml
from pgkg.backends.postgres import PostgresBackend
from pgkg.config import get_settings
from pgkg.db import make_pool, close_pool
from pgkg.memory import Memory, IngestResult, Result

_pool = None
_backend: PostgresBackend | None = None
_memory: Memory | None = None


@asynccontextmanager
async def lifespan(app: FastAPI):
global _pool, _memory
global _backend, _memory
settings = get_settings()
_pool = await make_pool(settings.database_url)
_backend = await PostgresBackend.create(settings.database_url)
_memory = Memory(
_pool,
_backend,
namespace=settings.default_namespace,
extract_propositions=settings.extract_propositions,
)
yield
if _pool:
await close_pool(_pool)
if _backend:
await _backend.close()


app = FastAPI(title="pgkg", lifespan=lifespan)
Expand Down Expand Up @@ -90,15 +90,7 @@ async def forget(req: ForgetRequest) -> Response:

@app.get("/health")
async def health() -> dict:
db_ok = False
if _pool:
try:
async with _pool.acquire() as conn:
await conn.fetchval("SELECT 1")
db_ok = True
except Exception:
pass

db_ok = await _backend.health_check() if _backend else False
return {
"status": "ok",
"db": db_ok,
Expand Down
255 changes: 255 additions & 0 deletions pgkg/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
"""StorageBackend protocol — the abstraction boundary between pgkg and its storage layer."""
from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime
from typing import Protocol, runtime_checkable
from uuid import UUID


@dataclass(frozen=True)
class ScoredId:
"""A database row ID with an associated relevance score."""
id: UUID
score: float


@dataclass(frozen=True)
class Candidate:
"""A proposition candidate with full retrieval metadata."""
proposition_id: UUID
text: str
rrf_score: float
adjusted_score: float
source_kind: str # "kw" | "vec" | "both" | "graph"
chunk_id: UUID | None
subject_id: UUID | None
predicate: str | None
object_id: UUID | None
asserted_at: datetime | None = None


@dataclass(frozen=True)
class PropositionRow:
"""Full proposition data needed for scoring and reranking."""
id: UUID
text: str
embedding: list[float] | None
chunk_id: UUID | None
subject_id: UUID | None
predicate: str | None
object_id: UUID | None
confidence: float
access_count: int
last_accessed_at: str # ISO 8601 timestamp
namespace: str


@dataclass(frozen=True)
class StoredProposition:
"""Data needed to store a proposition."""
text: str
embedding: list[float]
subject_id: UUID | None
predicate: str | None
object_id: UUID | None
object_literal: str | None
chunk_id: UUID | None
namespace: str
session_id: str | None
confidence: float = 1.0
metadata: dict | None = None
asserted_at: datetime | None = None


@dataclass(frozen=True)
class StoredDocument:
"""Data needed to store a document."""
id: UUID | None = None
source: str | None = None
namespace: str = "default"


@dataclass(frozen=True)
class StoredChunk:
"""Data needed to store a chunk."""
document_id: UUID
text: str
id: UUID | None = None
span_start: int | None = None
span_end: int | None = None
asserted_at: datetime | None = None


@runtime_checkable
class StorageBackend(Protocol):
"""Abstract storage backend for pgkg.

Implementations must support all methods. The ``fused_search`` method
is optional — backends that cannot fuse retrieval in a single query
should return ``None``, and the caller falls back to primitive
orchestration.
"""

# --- Lifecycle -------------------------------------------------------

async def apply_migrations(self) -> None:
"""Apply schema migrations. Called once during initialisation."""
...

async def close(self) -> None:
"""Release all resources (connections, pools, file handles)."""
...

# --- Retrieval primitives --------------------------------------------

async def keyword_search(
self,
query: str,
k: int,
namespace: str,
session_id: str | None = None,
) -> list[ScoredId]:
"""Full-text search. Returns (id, score) pairs ranked by relevance."""
...

async def vector_search(
self,
embedding: list[float],
k: int,
namespace: str,
session_id: str | None = None,
) -> list[ScoredId]:
"""Vector similarity search. Returns (id, score) pairs ranked by closeness."""
...

async def graph_neighbors(
self,
entity_ids: list[UUID],
namespace: str,
limit: int = 100,
) -> list[ScoredId]:
"""One-hop graph expansion from seed entities. Returns proposition IDs."""
...

async def fused_search(
self,
query_text: str,
query_embedding: list[float],
k_retrieve: int,
k_initial: int,
namespace: str,
session_id: str | None,
recency_half_life_days: float,
expand_graph: bool,
rrf_k: int,
) -> list[Candidate] | None:
"""Push-down fused retrieval (optional optimisation).

Backends that can execute the full retrieval pipeline in a single
query (e.g. the Postgres CTE) return results here. Other backends
return ``None`` and the caller orchestrates via the primitive methods.
"""
return None

# --- Data resolution -------------------------------------------------

async def get_propositions(
self,
proposition_ids: list[UUID],
) -> list[PropositionRow]:
"""Fetch full proposition data for the given IDs."""
...

async def get_proposition_entity_ids(
self,
proposition_ids: list[UUID],
) -> dict[UUID, list[UUID]]:
"""Return ``{proposition_id: [entity_ids]}`` for the given propositions."""
...

# --- Ingest primitives -----------------------------------------------

async def store_document(self, doc: StoredDocument) -> UUID:
"""Insert a document. Returns its ID."""
...

async def store_chunk(self, chunk: StoredChunk) -> UUID:
"""Insert a chunk. Returns its ID."""
...

async def store_proposition(self, prop: StoredProposition) -> UUID:
"""Insert a proposition. Returns its ID."""
...

async def link_entity(
self,
namespace: str,
name: str,
entity_type: str,
embedding: list[float],
threshold: float = 0.85,
) -> UUID:
"""Find or create an entity by name/type/embedding similarity."""
...

async def store_edge(
self,
src_entity: UUID,
dst_entity: UUID,
predicate: str,
proposition_id: UUID,
) -> None:
"""Create an edge between two entities via a proposition."""
...

# --- Access tracking -------------------------------------------------

async def bump_access(self, proposition_ids: list[UUID]) -> None:
"""Increment access_count and update last_accessed_at."""
...

# --- Entity resolution -----------------------------------------------

async def resolve_entity_names(
self, entity_ids: list[UUID],
) -> dict[UUID, str]:
"""Map entity IDs to names. Returns ``{id: name}`` for found entities."""
...

# --- Forget ----------------------------------------------------------

async def forget(
self,
proposition_id: UUID,
supersede_with: UUID | None = None,
) -> None:
"""Mark a proposition as superseded."""
...

# --- Cache (optional) ------------------------------------------------

async def cache_get(self, cache_key: str) -> list[dict] | None:
"""Retrieve cached extraction results as raw dicts.

Returns ``None`` on cache miss. The caller is responsible for
deserialising dicts into ``Proposition`` objects.
"""
return None

async def cache_put(
self,
cache_key: str,
chunk_hash: str,
extractor_model: str,
prompt_version: str,
propositions: list[dict],
) -> None:
"""Store extraction results in cache as raw dicts."""
...

# --- Health ----------------------------------------------------------

async def health_check(self) -> bool:
"""Return True if the backend is healthy (can serve queries)."""
...
Loading