Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions sme/adapters/oracle_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Oracle retrieval baseline — TREC-standard upper bound.

Reads ``expected_sources`` from the gold question set and returns
exactly those items. This is the ceiling — any system that matches
this has perfect retrieval (at the substring level the current
scorer operates on).
"""

from __future__ import annotations

from sme.adapters.base import Edge, Entity, QueryResult, SMEAdapter


class OracleRetrievalAdapter(SMEAdapter):
def __init__(self, *, questions: list[dict] | None = None):
self._corpus: list[dict] = []
self._questions = questions or []
# Build a lookup: question text -> expected_sources
self._gold: dict[str, list[str]] = {}
for q in self._questions:
text = q.get("text", "")
sources = q.get("expected_sources", [])
if text and sources:
self._gold[text] = sources

def ingest_corpus(self, corpus: list[dict]) -> dict:
self._corpus = list(corpus)
return {
"entities_created": len(corpus),
"edges_created": 0,
"errors": [],
"warnings": [],
}

def query(self, question: str, n_results: int | None = None) -> QueryResult:
# n_results is accepted for CLI/testkit parity but ignored: the
# oracle returns exactly the gold expected_sources, no more, no
# fewer — that is what makes it the substring-scorer ceiling.
sources = self._gold.get(question, [])
if not sources:
return QueryResult(
answer="", context_string="", error="NO_GOLD_ANSWER"
)
# Build context_string that contains all expected source substrings.
# The SME scorer uses substring matching, so including the expected
# substrings verbatim guarantees a perfect score.
context_parts: list[str] = []
entities: list[Entity] = []
for i, source in enumerate(sources):
context_parts.append(f"[{i+1}] oracle\n{source}")
entities.append(
Entity(
id=f"oracle:{source}",
name=source,
entity_type="oracle_source",
)
)
context_string = "\n\n".join(context_parts)
return QueryResult(
answer=context_string,
context_string=context_string,
retrieved_entities=entities,
)

def get_graph_snapshot(self) -> tuple[list[Entity], list[Edge]]:
return [], []
59 changes: 59 additions & 0 deletions sme/adapters/random_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Random retrieval baseline — TREC-standard lower bound.

Returns K uniformly random items from the ingested corpus, seeded
for reproducibility. Any memory system that can't beat this isn't
doing retrieval — it's doing random selection.
"""

from __future__ import annotations

import random as random_mod

from sme.adapters.base import Edge, Entity, QueryResult, SMEAdapter


class RandomRetrievalAdapter(SMEAdapter):
def __init__(self, *, seed: int = 42, n_results: int = 10):
self._seed = seed
self._rng = random_mod.Random(seed)
self._n_results = n_results
self._corpus: list[dict] = []

def ingest_corpus(self, corpus: list[dict]) -> dict:
self._corpus = list(corpus)
return {
"entities_created": len(corpus),
"edges_created": 0,
"errors": [],
"warnings": [],
}

def query(self, question: str, n_results: int | None = None) -> QueryResult:
if not self._corpus:
return QueryResult(answer="", context_string="", error="NO_CORPUS")
n = self._n_results if n_results is None else n_results
k = min(n, len(self._corpus))
selected = self._rng.sample(self._corpus, k)
context_parts: list[str] = []
entities: list[Entity] = []
for i, item in enumerate(selected):
item_id = item.get("id") or item.get("source_file") or f"item_{i}"
source = item.get("source_file", item.get("id", f"random_{i}"))
text = item.get("text", item.get("content", ""))
context_parts.append(f"[{i+1}] {source}\n{text}")
entities.append(
Entity(
id=f"random:{item_id}",
name=str(source),
entity_type="random_selection",
)
)
Comment on lines +39 to +50
context_string = "\n\n".join(context_parts)
return QueryResult(
answer=context_string,
context_string=context_string,
retrieved_entities=entities,
)

def get_graph_snapshot(self) -> tuple[list[Entity], list[Edge]]:
return [], []
22 changes: 22 additions & 0 deletions sme/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ def _karpathy_compiled_loader() -> type[SMEAdapter]:
return KarpathyCompiledAdapter


def _random_retrieval_loader() -> type[SMEAdapter]:
from sme.adapters.random_retrieval import RandomRetrievalAdapter

return RandomRetrievalAdapter


def _oracle_retrieval_loader() -> type[SMEAdapter]:
from sme.adapters.oracle_retrieval import OracleRetrievalAdapter

return OracleRetrievalAdapter


_ADAPTER_REGISTRY: tuple[_AdapterSpec, ...] = (
_AdapterSpec(
aliases=("ladybugdb", "ladybug"),
Expand Down Expand Up @@ -141,6 +153,16 @@ def _karpathy_compiled_loader() -> type[SMEAdapter]:
accepts=frozenset({"compiled_dir", "include_wiki"}),
rename={"db_path": "compiled_dir"},
),
_AdapterSpec(
aliases=("random", "random-retrieval", "random_retrieval"),
loader=_random_retrieval_loader,
accepts=frozenset({"seed", "n_results"}),
),
_AdapterSpec(
aliases=("oracle", "oracle-retrieval", "oracle_retrieval"),
loader=_oracle_retrieval_loader,
accepts=frozenset({"questions"}),
),
)


Expand Down
16 changes: 16 additions & 0 deletions tests/test_adapter_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,27 @@ def _full_context_factory(tmp_path: Path) -> SMEAdapter:
return FullContextAdapter(vault)


def _random_retrieval_factory(tmp_path: Path) -> SMEAdapter:
"""RandomRetrievalAdapter (TREC lower bound) — pure in-memory, no env."""
from sme.adapters.random_retrieval import RandomRetrievalAdapter

return RandomRetrievalAdapter()


def _oracle_retrieval_factory(tmp_path: Path) -> SMEAdapter:
"""OracleRetrievalAdapter (TREC upper bound) — pure in-memory, no env."""
from sme.adapters.oracle_retrieval import OracleRetrievalAdapter

return OracleRetrievalAdapter()


# Register adapters here. Keep IDs stable — they show in pytest output.
ADAPTER_FACTORIES: dict[str, AdapterFactory] = {
"mock": _mock_factory,
"flat_baseline": _flat_baseline_factory,
"full_context": _full_context_factory,
"random_retrieval": _random_retrieval_factory,
"oracle_retrieval": _oracle_retrieval_factory,
}


Expand Down
103 changes: 103 additions & 0 deletions tests/test_oracle_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Tests for the oracle retrieval baseline adapter (TREC upper bound).

See ``sme/adapters/oracle_retrieval.py``.
"""

from __future__ import annotations

from sme.adapters.oracle_retrieval import OracleRetrievalAdapter


def _make_questions() -> list[dict]:
return [
{
"text": "what is the capital of france?",
"expected_sources": ["Paris is the capital of France."],
},
{
"text": "who wrote hamlet?",
"expected_sources": [
"Hamlet was written by William Shakespeare.",
"Shakespeare's tragedies include Hamlet, Macbeth, and Othello.",
],
},
]


def test_ingest_corpus_stores_items() -> None:
adapter = OracleRetrievalAdapter(questions=_make_questions())
corpus = [{"id": "a", "text": "first"}, {"id": "b", "text": "second"}]
result = adapter.ingest_corpus(corpus)
assert result["entities_created"] == 2
assert result["edges_created"] == 0
assert result["errors"] == []


def test_query_returns_expected_sources_for_known_question() -> None:
adapter = OracleRetrievalAdapter(questions=_make_questions())
adapter.ingest_corpus([])
result = adapter.query("what is the capital of france?")
assert result.error is None
assert len(result.retrieved_entities) == 1
assert result.retrieved_entities[0].name == "Paris is the capital of France."
assert result.retrieved_entities[0].entity_type == "oracle_source"
# Entity ID tracks the gold source string, not a loop index, so
# per-ID downstream analysis (Cat 4/Cat 5) can map back to the
# source.
assert (
result.retrieved_entities[0].id
== "oracle:Paris is the capital of France."
)


def test_query_returns_all_expected_sources() -> None:
adapter = OracleRetrievalAdapter(questions=_make_questions())
result = adapter.query("who wrote hamlet?")
assert result.error is None
assert len(result.retrieved_entities) == 2


def test_expected_sources_appear_in_context_string() -> None:
adapter = OracleRetrievalAdapter(questions=_make_questions())
result = adapter.query("who wrote hamlet?")
# The scorer uses substring matching, so each expected source must
# appear verbatim in context_string.
assert "Hamlet was written by William Shakespeare." in result.context_string
assert (
"Shakespeare's tragedies include Hamlet, Macbeth, and Othello."
in result.context_string
)


def test_unknown_question_returns_error() -> None:
adapter = OracleRetrievalAdapter(questions=_make_questions())
result = adapter.query("a question we have never seen")
assert result.error == "NO_GOLD_ANSWER"
assert result.context_string == ""
assert result.retrieved_entities == []


def test_empty_questions_list_treats_every_query_as_unknown() -> None:
adapter = OracleRetrievalAdapter()
result = adapter.query("anything")
assert result.error == "NO_GOLD_ANSWER"


def test_questions_without_text_or_sources_are_skipped() -> None:
questions = [
{"text": "", "expected_sources": ["x"]},
{"text": "q", "expected_sources": []},
{"text": "valid", "expected_sources": ["answer"]},
]
adapter = OracleRetrievalAdapter(questions=questions)
# Only the third entry made it into the gold lookup.
assert adapter.query("").error == "NO_GOLD_ANSWER"
assert adapter.query("q").error == "NO_GOLD_ANSWER"
assert adapter.query("valid").error is None


def test_graph_snapshot_is_empty() -> None:
adapter = OracleRetrievalAdapter(questions=_make_questions())
entities, edges = adapter.get_graph_snapshot()
assert entities == []
assert edges == []
91 changes: 91 additions & 0 deletions tests/test_random_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Tests for the random retrieval baseline adapter (TREC lower bound).

See ``sme/adapters/random_retrieval.py``.
"""

from __future__ import annotations

from sme.adapters.random_retrieval import RandomRetrievalAdapter


def _make_corpus(n: int) -> list[dict]:
return [
{"id": f"item_{i}", "source_file": f"file_{i}.md", "text": f"body {i}"}
for i in range(n)
]


def test_ingest_corpus_stores_items() -> None:
adapter = RandomRetrievalAdapter()
corpus = _make_corpus(5)
result = adapter.ingest_corpus(corpus)
assert result["entities_created"] == 5
assert result["edges_created"] == 0
assert result["errors"] == []
assert result["warnings"] == []


def test_query_returns_n_results_items() -> None:
adapter = RandomRetrievalAdapter(n_results=3)
adapter.ingest_corpus(_make_corpus(10))
result = adapter.query("anything")
assert result.error is None
assert len(result.retrieved_entities) == 3
# All selected items should reference real corpus sources.
for ent in result.retrieved_entities:
assert ent.name.startswith("file_")
assert ent.entity_type == "random_selection"
# Entity IDs track the corpus item's intrinsic id, not a loop
# index, so per-ID downstream analysis (Cat 4/Cat 5) can map a
# returned entity back to the source corpus item.
assert ent.id.startswith("random:item_")
# Context string is non-empty and contains source labels.
assert result.context_string
assert "[1]" in result.context_string


def test_query_caps_at_corpus_size() -> None:
adapter = RandomRetrievalAdapter(n_results=100)
adapter.ingest_corpus(_make_corpus(4))
result = adapter.query("anything")
assert len(result.retrieved_entities) == 4


def test_seeded_queries_are_reproducible() -> None:
corpus = _make_corpus(20)
a = RandomRetrievalAdapter(seed=42, n_results=5)
b = RandomRetrievalAdapter(seed=42, n_results=5)
a.ingest_corpus(corpus)
b.ingest_corpus(corpus)
res_a = a.query("q")
res_b = b.query("q")
names_a = [e.name for e in res_a.retrieved_entities]
names_b = [e.name for e in res_b.retrieved_entities]
assert names_a == names_b


def test_different_seeds_diverge() -> None:
corpus = _make_corpus(50)
a = RandomRetrievalAdapter(seed=1, n_results=10)
b = RandomRetrievalAdapter(seed=2, n_results=10)
a.ingest_corpus(corpus)
b.ingest_corpus(corpus)
names_a = [e.name for e in a.query("q").retrieved_entities]
names_b = [e.name for e in b.query("q").retrieved_entities]
assert names_a != names_b


Comment on lines +67 to +77
def test_empty_corpus_returns_error() -> None:
adapter = RandomRetrievalAdapter()
result = adapter.query("anything")
assert result.error == "NO_CORPUS"
assert result.context_string == ""
assert result.retrieved_entities == []


def test_graph_snapshot_is_empty() -> None:
adapter = RandomRetrievalAdapter()
adapter.ingest_corpus(_make_corpus(3))
entities, edges = adapter.get_graph_snapshot()
assert entities == []
assert edges == []
Loading