From 23753d91d5abda3c6597f710c019bcad462bb28b Mon Sep 17 00:00:00 2001 From: jp Date: Sun, 24 May 2026 17:54:15 -0700 Subject: [PATCH] feat(adapters): random + oracle retrieval baselines (TREC bounds) (#23) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two reference adapters that establish the standard lower/upper bounds for retrieval evaluation per TREC methodology: - RandomRetrievalAdapter: returns K uniformly random items from the ingested corpus, seeded for reproducibility. Any system that can't beat this isn't doing retrieval. - OracleRetrievalAdapter: returns the gold expected_sources verbatim in context_string. Substring-scorer ceiling — no system can do better than perfect retrieval. Both wired into sme/cli.py _load_adapter() under the names random/random-retrieval/random_retrieval and oracle/oracle-retrieval/ oracle_retrieval. Adapter-specific kwargs from other adapters are silently dropped for CLI parity. Per the issue brief, baselines are intentionally NOT added to the adapter harness manifest contract — they're reference bounds, not adapters under test. Closes M0nkeyFl0wer/multipass-structural-memory-eval#23 Co-Authored-By: Claude Opus 4.6 --- sme/adapters/oracle_retrieval.py | 66 ++++++++++++++++++++ sme/adapters/random_retrieval.py | 59 ++++++++++++++++++ sme/cli.py | 22 +++++++ tests/test_adapter_contract.py | 16 +++++ tests/test_oracle_retrieval.py | 103 +++++++++++++++++++++++++++++++ tests/test_random_retrieval.py | 91 +++++++++++++++++++++++++++ 6 files changed, 357 insertions(+) create mode 100644 sme/adapters/oracle_retrieval.py create mode 100644 sme/adapters/random_retrieval.py create mode 100644 tests/test_oracle_retrieval.py create mode 100644 tests/test_random_retrieval.py diff --git a/sme/adapters/oracle_retrieval.py b/sme/adapters/oracle_retrieval.py new file mode 100644 index 0000000..2c43ba3 --- /dev/null +++ b/sme/adapters/oracle_retrieval.py @@ -0,0 +1,66 @@ +"""Oracle retrieval baseline — TREC-standard upper bound. + +Reads ``expected_sources`` from the gold question set and returns +exactly those items. This is the ceiling — any system that matches +this has perfect retrieval (at the substring level the current +scorer operates on). +""" + +from __future__ import annotations + +from sme.adapters.base import Edge, Entity, QueryResult, SMEAdapter + + +class OracleRetrievalAdapter(SMEAdapter): + def __init__(self, *, questions: list[dict] | None = None): + self._corpus: list[dict] = [] + self._questions = questions or [] + # Build a lookup: question text -> expected_sources + self._gold: dict[str, list[str]] = {} + for q in self._questions: + text = q.get("text", "") + sources = q.get("expected_sources", []) + if text and sources: + self._gold[text] = sources + + def ingest_corpus(self, corpus: list[dict]) -> dict: + self._corpus = list(corpus) + return { + "entities_created": len(corpus), + "edges_created": 0, + "errors": [], + "warnings": [], + } + + def query(self, question: str, n_results: int | None = None) -> QueryResult: + # n_results is accepted for CLI/testkit parity but ignored: the + # oracle returns exactly the gold expected_sources, no more, no + # fewer — that is what makes it the substring-scorer ceiling. + sources = self._gold.get(question, []) + if not sources: + return QueryResult( + answer="", context_string="", error="NO_GOLD_ANSWER" + ) + # Build context_string that contains all expected source substrings. + # The SME scorer uses substring matching, so including the expected + # substrings verbatim guarantees a perfect score. + context_parts: list[str] = [] + entities: list[Entity] = [] + for i, source in enumerate(sources): + context_parts.append(f"[{i+1}] oracle\n{source}") + entities.append( + Entity( + id=f"oracle:{source}", + name=source, + entity_type="oracle_source", + ) + ) + context_string = "\n\n".join(context_parts) + return QueryResult( + answer=context_string, + context_string=context_string, + retrieved_entities=entities, + ) + + def get_graph_snapshot(self) -> tuple[list[Entity], list[Edge]]: + return [], [] diff --git a/sme/adapters/random_retrieval.py b/sme/adapters/random_retrieval.py new file mode 100644 index 0000000..ff4dcba --- /dev/null +++ b/sme/adapters/random_retrieval.py @@ -0,0 +1,59 @@ +"""Random retrieval baseline — TREC-standard lower bound. + +Returns K uniformly random items from the ingested corpus, seeded +for reproducibility. Any memory system that can't beat this isn't +doing retrieval — it's doing random selection. +""" + +from __future__ import annotations + +import random as random_mod + +from sme.adapters.base import Edge, Entity, QueryResult, SMEAdapter + + +class RandomRetrievalAdapter(SMEAdapter): + def __init__(self, *, seed: int = 42, n_results: int = 10): + self._seed = seed + self._rng = random_mod.Random(seed) + self._n_results = n_results + self._corpus: list[dict] = [] + + def ingest_corpus(self, corpus: list[dict]) -> dict: + self._corpus = list(corpus) + return { + "entities_created": len(corpus), + "edges_created": 0, + "errors": [], + "warnings": [], + } + + def query(self, question: str, n_results: int | None = None) -> QueryResult: + if not self._corpus: + return QueryResult(answer="", context_string="", error="NO_CORPUS") + n = self._n_results if n_results is None else n_results + k = min(n, len(self._corpus)) + selected = self._rng.sample(self._corpus, k) + context_parts: list[str] = [] + entities: list[Entity] = [] + for i, item in enumerate(selected): + item_id = item.get("id") or item.get("source_file") or f"item_{i}" + source = item.get("source_file", item.get("id", f"random_{i}")) + text = item.get("text", item.get("content", "")) + context_parts.append(f"[{i+1}] {source}\n{text}") + entities.append( + Entity( + id=f"random:{item_id}", + name=str(source), + entity_type="random_selection", + ) + ) + context_string = "\n\n".join(context_parts) + return QueryResult( + answer=context_string, + context_string=context_string, + retrieved_entities=entities, + ) + + def get_graph_snapshot(self) -> tuple[list[Entity], list[Edge]]: + return [], [] diff --git a/sme/cli.py b/sme/cli.py index 8f2af09..4efbe68 100644 --- a/sme/cli.py +++ b/sme/cli.py @@ -87,6 +87,18 @@ def _karpathy_compiled_loader() -> type[SMEAdapter]: return KarpathyCompiledAdapter +def _random_retrieval_loader() -> type[SMEAdapter]: + from sme.adapters.random_retrieval import RandomRetrievalAdapter + + return RandomRetrievalAdapter + + +def _oracle_retrieval_loader() -> type[SMEAdapter]: + from sme.adapters.oracle_retrieval import OracleRetrievalAdapter + + return OracleRetrievalAdapter + + _ADAPTER_REGISTRY: tuple[_AdapterSpec, ...] = ( _AdapterSpec( aliases=("ladybugdb", "ladybug"), @@ -141,6 +153,16 @@ def _karpathy_compiled_loader() -> type[SMEAdapter]: accepts=frozenset({"compiled_dir", "include_wiki"}), rename={"db_path": "compiled_dir"}, ), + _AdapterSpec( + aliases=("random", "random-retrieval", "random_retrieval"), + loader=_random_retrieval_loader, + accepts=frozenset({"seed", "n_results"}), + ), + _AdapterSpec( + aliases=("oracle", "oracle-retrieval", "oracle_retrieval"), + loader=_oracle_retrieval_loader, + accepts=frozenset({"questions"}), + ), ) diff --git a/tests/test_adapter_contract.py b/tests/test_adapter_contract.py index ae4d0cd..b9f04d4 100644 --- a/tests/test_adapter_contract.py +++ b/tests/test_adapter_contract.py @@ -134,11 +134,27 @@ def _full_context_factory(tmp_path: Path) -> SMEAdapter: return FullContextAdapter(vault) +def _random_retrieval_factory(tmp_path: Path) -> SMEAdapter: + """RandomRetrievalAdapter (TREC lower bound) — pure in-memory, no env.""" + from sme.adapters.random_retrieval import RandomRetrievalAdapter + + return RandomRetrievalAdapter() + + +def _oracle_retrieval_factory(tmp_path: Path) -> SMEAdapter: + """OracleRetrievalAdapter (TREC upper bound) — pure in-memory, no env.""" + from sme.adapters.oracle_retrieval import OracleRetrievalAdapter + + return OracleRetrievalAdapter() + + # Register adapters here. Keep IDs stable — they show in pytest output. ADAPTER_FACTORIES: dict[str, AdapterFactory] = { "mock": _mock_factory, "flat_baseline": _flat_baseline_factory, "full_context": _full_context_factory, + "random_retrieval": _random_retrieval_factory, + "oracle_retrieval": _oracle_retrieval_factory, } diff --git a/tests/test_oracle_retrieval.py b/tests/test_oracle_retrieval.py new file mode 100644 index 0000000..90fe7e7 --- /dev/null +++ b/tests/test_oracle_retrieval.py @@ -0,0 +1,103 @@ +"""Tests for the oracle retrieval baseline adapter (TREC upper bound). + +See ``sme/adapters/oracle_retrieval.py``. +""" + +from __future__ import annotations + +from sme.adapters.oracle_retrieval import OracleRetrievalAdapter + + +def _make_questions() -> list[dict]: + return [ + { + "text": "what is the capital of france?", + "expected_sources": ["Paris is the capital of France."], + }, + { + "text": "who wrote hamlet?", + "expected_sources": [ + "Hamlet was written by William Shakespeare.", + "Shakespeare's tragedies include Hamlet, Macbeth, and Othello.", + ], + }, + ] + + +def test_ingest_corpus_stores_items() -> None: + adapter = OracleRetrievalAdapter(questions=_make_questions()) + corpus = [{"id": "a", "text": "first"}, {"id": "b", "text": "second"}] + result = adapter.ingest_corpus(corpus) + assert result["entities_created"] == 2 + assert result["edges_created"] == 0 + assert result["errors"] == [] + + +def test_query_returns_expected_sources_for_known_question() -> None: + adapter = OracleRetrievalAdapter(questions=_make_questions()) + adapter.ingest_corpus([]) + result = adapter.query("what is the capital of france?") + assert result.error is None + assert len(result.retrieved_entities) == 1 + assert result.retrieved_entities[0].name == "Paris is the capital of France." + assert result.retrieved_entities[0].entity_type == "oracle_source" + # Entity ID tracks the gold source string, not a loop index, so + # per-ID downstream analysis (Cat 4/Cat 5) can map back to the + # source. + assert ( + result.retrieved_entities[0].id + == "oracle:Paris is the capital of France." + ) + + +def test_query_returns_all_expected_sources() -> None: + adapter = OracleRetrievalAdapter(questions=_make_questions()) + result = adapter.query("who wrote hamlet?") + assert result.error is None + assert len(result.retrieved_entities) == 2 + + +def test_expected_sources_appear_in_context_string() -> None: + adapter = OracleRetrievalAdapter(questions=_make_questions()) + result = adapter.query("who wrote hamlet?") + # The scorer uses substring matching, so each expected source must + # appear verbatim in context_string. + assert "Hamlet was written by William Shakespeare." in result.context_string + assert ( + "Shakespeare's tragedies include Hamlet, Macbeth, and Othello." + in result.context_string + ) + + +def test_unknown_question_returns_error() -> None: + adapter = OracleRetrievalAdapter(questions=_make_questions()) + result = adapter.query("a question we have never seen") + assert result.error == "NO_GOLD_ANSWER" + assert result.context_string == "" + assert result.retrieved_entities == [] + + +def test_empty_questions_list_treats_every_query_as_unknown() -> None: + adapter = OracleRetrievalAdapter() + result = adapter.query("anything") + assert result.error == "NO_GOLD_ANSWER" + + +def test_questions_without_text_or_sources_are_skipped() -> None: + questions = [ + {"text": "", "expected_sources": ["x"]}, + {"text": "q", "expected_sources": []}, + {"text": "valid", "expected_sources": ["answer"]}, + ] + adapter = OracleRetrievalAdapter(questions=questions) + # Only the third entry made it into the gold lookup. + assert adapter.query("").error == "NO_GOLD_ANSWER" + assert adapter.query("q").error == "NO_GOLD_ANSWER" + assert adapter.query("valid").error is None + + +def test_graph_snapshot_is_empty() -> None: + adapter = OracleRetrievalAdapter(questions=_make_questions()) + entities, edges = adapter.get_graph_snapshot() + assert entities == [] + assert edges == [] diff --git a/tests/test_random_retrieval.py b/tests/test_random_retrieval.py new file mode 100644 index 0000000..4556f3d --- /dev/null +++ b/tests/test_random_retrieval.py @@ -0,0 +1,91 @@ +"""Tests for the random retrieval baseline adapter (TREC lower bound). + +See ``sme/adapters/random_retrieval.py``. +""" + +from __future__ import annotations + +from sme.adapters.random_retrieval import RandomRetrievalAdapter + + +def _make_corpus(n: int) -> list[dict]: + return [ + {"id": f"item_{i}", "source_file": f"file_{i}.md", "text": f"body {i}"} + for i in range(n) + ] + + +def test_ingest_corpus_stores_items() -> None: + adapter = RandomRetrievalAdapter() + corpus = _make_corpus(5) + result = adapter.ingest_corpus(corpus) + assert result["entities_created"] == 5 + assert result["edges_created"] == 0 + assert result["errors"] == [] + assert result["warnings"] == [] + + +def test_query_returns_n_results_items() -> None: + adapter = RandomRetrievalAdapter(n_results=3) + adapter.ingest_corpus(_make_corpus(10)) + result = adapter.query("anything") + assert result.error is None + assert len(result.retrieved_entities) == 3 + # All selected items should reference real corpus sources. + for ent in result.retrieved_entities: + assert ent.name.startswith("file_") + assert ent.entity_type == "random_selection" + # Entity IDs track the corpus item's intrinsic id, not a loop + # index, so per-ID downstream analysis (Cat 4/Cat 5) can map a + # returned entity back to the source corpus item. + assert ent.id.startswith("random:item_") + # Context string is non-empty and contains source labels. + assert result.context_string + assert "[1]" in result.context_string + + +def test_query_caps_at_corpus_size() -> None: + adapter = RandomRetrievalAdapter(n_results=100) + adapter.ingest_corpus(_make_corpus(4)) + result = adapter.query("anything") + assert len(result.retrieved_entities) == 4 + + +def test_seeded_queries_are_reproducible() -> None: + corpus = _make_corpus(20) + a = RandomRetrievalAdapter(seed=42, n_results=5) + b = RandomRetrievalAdapter(seed=42, n_results=5) + a.ingest_corpus(corpus) + b.ingest_corpus(corpus) + res_a = a.query("q") + res_b = b.query("q") + names_a = [e.name for e in res_a.retrieved_entities] + names_b = [e.name for e in res_b.retrieved_entities] + assert names_a == names_b + + +def test_different_seeds_diverge() -> None: + corpus = _make_corpus(50) + a = RandomRetrievalAdapter(seed=1, n_results=10) + b = RandomRetrievalAdapter(seed=2, n_results=10) + a.ingest_corpus(corpus) + b.ingest_corpus(corpus) + names_a = [e.name for e in a.query("q").retrieved_entities] + names_b = [e.name for e in b.query("q").retrieved_entities] + assert names_a != names_b + + +def test_empty_corpus_returns_error() -> None: + adapter = RandomRetrievalAdapter() + result = adapter.query("anything") + assert result.error == "NO_CORPUS" + assert result.context_string == "" + assert result.retrieved_entities == [] + + +def test_graph_snapshot_is_empty() -> None: + adapter = RandomRetrievalAdapter() + adapter.ingest_corpus(_make_corpus(3)) + entities, edges = adapter.get_graph_snapshot() + assert entities == [] + assert edges == []