diff --git a/docs/ideas.md b/docs/ideas.md index bc4db60..444a5a2 100644 --- a/docs/ideas.md +++ b/docs/ideas.md @@ -430,6 +430,44 @@ shape future adapter work and category implementations: and `baselines/jp_realm_v0_1_*.json` made it obvious where to add `jp_realm_v0_2`, `family_palace_v0_1`, etc. Don't refactor. +## RlmAdapter — research scaffold (2026-04-26) + +`sme/adapters/rlm_adapter.py` ships an adapter that treats RLM +(jphein/rlm fork of alexzhang13/rlm) as the read-side orchestrator: +the LLM itself decides when to call `mempalace_search`, with what +queries, and how to compose results. familiar's deterministic +retrieve→rerank→decay→compress pipeline becomes the *baseline* this +adapter is benchmarked against. + +**Design:** RLM gets `mempalace_search` registered as a +`custom_tools` callable. The adapter wraps that callable to capture +every search result into a per-query buffer. After `rlm.completion()` +returns, the buffer's contents become `context_string` (in tool-call +order) and `retrieved_entities` (one Entity per drawer). The Cat 1 / +retrieve substring scorer measures whether `expected_sources` ended +up in `context_string` — same contract as every other adapter. + +**Test coverage** (`tests/test_rlm_adapter.py`, 5 tests): tool-call +aggregation, capture-buffer reset across queries, error-dict +graceful handling on palace network failure, empty graph snapshot +(Cat 8 N/A for RLM), ingest_corpus skipped. + +**To benchmark live:** +```bash +PORTKEY_API_KEY=... PALACE_DAEMON_URL=http://disks.jphe.in:8085 \ +PALACE_API_KEY=... \ +venv/bin/sme-eval retrieve --adapter rlm \ + --questions sme/corpora/jp_realm_v0_1/questions.yaml \ + --json baselines/jp_realm_v0_1_rlm_$(date +%Y%m%d).json +``` + +A live benchmark will reveal whether RLM-orchestration recovers +recall on questions familiar misses, plateaus at the same number, +or regresses — the answer determines whether v0.4 should +productionize RLM into familiar's chat path. Without that data, +the design spec's "RLM and familiar are complementary" hypothesis +stays a hypothesis. + ## What's next ### Categories that aren't implemented yet diff --git a/docs/sme_spec_v8.md b/docs/sme_spec_v8.md index 262024b..0c01467 100644 --- a/docs/sme_spec_v8.md +++ b/docs/sme_spec_v8.md @@ -176,9 +176,16 @@ class SMEAdapter(ABC): Returns: {'type': 'declared'|'readme'|'inferred', 'schema': [...], 'documentation': str}""" return {'type': 'inferred', 'schema': [], 'documentation': ''} + + def get_harness_manifest(self) -> list[HarnessDescriptor]: + """Return invocation surfaces this memory system exposes. + Used by Category 9 (Harness Integration). Adapters that don't + expose a harness surface (pure library APIs) return [] — Cat 9b + reports empty_manifest=True rather than crashing.""" + return [] ``` -Three required methods. That's the minimum viable adapter. `get_flat_retrieval` and `get_ontology_source` have defaults — SME fills in its own flat baseline and infers ontology from the graph if the adapter doesn't provide them. +Three required methods. That's the minimum viable adapter. `get_flat_retrieval`, `get_ontology_source`, and `get_harness_manifest` have defaults — SME fills in its own flat baseline, infers ontology from the graph, and treats an empty harness manifest as "Cat 9 not applicable" when the adapter doesn't provide them. ### Default Adapters diff --git a/pyproject.toml b/pyproject.toml index e0505c3..d600201 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,17 @@ ladybugdb = [ neo4j = [ "neo4j>=5.0", ] +# RlmAdapter — Recursive Language Models. Pinned to jphein/rlm (fork +# of alexzhang13/rlm) because the baselines/jp_realm_v0_1_rlm_*.json +# artifacts were produced against this fork; reproducing them on a +# fresh clone needs the same code path, even when the package surface +# is currently identical to upstream. Distribution name is `rlms`; +# importable module is `rlm`. Requires Python >=3.11 (one minor +# higher than SME's >=3.10 floor), so this extra is opt-in and not +# part of `all`. Tests are guarded with `pytest.importorskip`. +rlm = [ + "rlms @ git+https://github.com/jphein/rlm.git", +] all = [ "sme-eval[topology,viz,ladybugdb,neo4j]", ] diff --git a/sme/adapters/base.py b/sme/adapters/base.py index b71e3c3..8336952 100644 --- a/sme/adapters/base.py +++ b/sme/adapters/base.py @@ -4,7 +4,7 @@ never touches a database directly — it talks to this thin interface. Three required methods: ingest_corpus, query, get_graph_snapshot. -Two optional: get_flat_retrieval, get_ontology_source. +Three optional: get_flat_retrieval, get_ontology_source, get_harness_manifest. """ from __future__ import annotations @@ -116,7 +116,7 @@ class QueryResult: class SMEAdapter(ABC): """Implement this for your database/memory system. - Three required methods. Two optional. + Three required methods. Three optional. """ # --- Required ------------------------------------------------------ diff --git a/sme/adapters/rlm_adapter.py b/sme/adapters/rlm_adapter.py new file mode 100644 index 0000000..c137293 --- /dev/null +++ b/sme/adapters/rlm_adapter.py @@ -0,0 +1,376 @@ +"""RlmAdapter — Recursive Language Models as an SME adapter. + +Treats RLM (jphein/rlm fork of alexzhang13/rlm) as the read-side +orchestrator: instead of familiar's deterministic retrieve→ground→ +answer pipeline, the LLM itself decides when to call mempalace_search, +how to compose results, when to recurse. + +For SME scoring purposes: + - `query(text)` calls `rlm.completion(text)` once + - `mempalace_search` is exposed as an RLM custom_tool — every call's + result is captured into a per-query buffer + - `context_string` becomes the concatenation of every drawer text + RLM pulled during the run (in call order). The substring scorer + in Cat 1 / retrieve scoring sees exactly what RLM saw. + - `retrieved_entities` mirror the captured drawers as SME Entity rows + so Cat 7 token counting and Cat 8 hop-counts have something real. + +Backend: configurable via constructor. Defaults to portkey + the +familiar router URL when both are available; falls back to direct +OpenAI/Anthropic/Bedrock per RLM's own backend resolution. + +Spec ref: + ~/Projects/familiar.realm.watch/docs/superpowers/specs/2026-04-23-familiar-realm-watch-design.md + § "Composition with RLM (Recursive Language Models) — added 2026-04-25" +""" + +from __future__ import annotations + +import json +import os +import time +from typing import Any, Optional +from urllib import error as _urlerror +from urllib import request as _urlrequest + +from sme.adapters.base import Edge, Entity, QueryResult, SMEAdapter + + +_DEFAULT_DAEMON_TIMEOUT = 10.0 +_DEFAULT_LIMIT = 5 + +_CLOUD_CHAT_CONFIG_CANDIDATES = [ + os.path.expanduser("~/.config/cloud-chat-assistant/config.json"), + os.path.expanduser("~/.cloud-chat-assistant.json"), +] + + +def _resolve_default_backend(bk: dict) -> tuple[str, dict]: + """Pick a backend + defaults when none was explicitly passed. + + Priority order: + 1. RLM_BASE_URL / RLM_MODEL env vars — point the openai backend + at any OpenAI-compat endpoint (familiar's own /v1, katana's + llama.cpp, vLLM, etc.). Most explicit; wins. + 2. Cloud-chat-assistant config file (JP's home env) — Azure + OpenAI endpoint with key. Reads as `azure_openai`. + 3. Standard env vars (AZURE_OPENAI / OPENAI / ANTHROPIC / PORTKEY). + 4. Fall through to openai default (will fail without a key). + """ + if os.environ.get("RLM_BASE_URL"): + bk.setdefault("base_url", os.environ["RLM_BASE_URL"]) + bk.setdefault("model_name", os.environ.get("RLM_MODEL", "qwen2.5-7b")) + bk.setdefault("api_key", os.environ.get("RLM_API_KEY", "no-auth-needed")) + return "openai", bk + + for path in _CLOUD_CHAT_CONFIG_CANDIDATES: + try: + with open(path) as f: + cfg = json.load(f) + except (OSError, ValueError): + continue + endpoint = cfg.get("endpoint") or cfg.get("azure_endpoint") + api_key = cfg.get("api_key") + deployment = cfg.get("deployment") + if endpoint and api_key: + bk.setdefault("azure_endpoint", endpoint) + bk.setdefault("api_key", api_key) + if deployment: + bk.setdefault("azure_deployment", deployment) + bk.setdefault("model_name", deployment) + bk.setdefault("api_version", cfg.get("api_version", "2024-08-01-preview")) + return "azure_openai", bk + + if os.environ.get("AZURE_OPENAI_API_KEY") and os.environ.get("AZURE_OPENAI_ENDPOINT"): + return "azure_openai", bk + if os.environ.get("OPENAI_API_KEY"): + return "openai", bk + if os.environ.get("ANTHROPIC_API_KEY"): + return "anthropic", bk + if os.environ.get("PORTKEY_API_KEY"): + return "portkey", bk + return "openai", bk + + +class RlmAdapter(SMEAdapter): + """RLM-orchestrated palace consumer. + + Args: + api_url: palace-daemon HTTP base URL (e.g. http://disks.jphe.in:8085) + api_key: PALACE_API_KEY for the daemon (read from env if unset) + backend: RLM backend identifier ("portkey", "openai", "anthropic", ...) + backend_kwargs: passed through to RLM(...) — model_name, api_key, etc. + environment: RLM REPL environment ("local" by default) + verbose: forwards to RLM verbose flag + kind: palace-daemon /search kind filter (default "content") + timeout_s: per-search HTTP timeout + """ + + def __init__( + self, + *, + api_url: Optional[str] = None, + api_key: Optional[str] = None, + backend: Optional[str] = None, + backend_kwargs: Optional[dict[str, Any]] = None, + environment: str = "local", + verbose: bool = False, + kind: str = "content", + timeout_s: float = _DEFAULT_DAEMON_TIMEOUT, + invocation_mode: Optional[str] = None, + **_unused: Any, + ) -> None: + """invocation_mode controls system-prompt augmentation for the + SME #3 Step 2 discriminating experiment: + - None (default): vanilla RLM system prompt; LLM decides when + to invoke mempalace_search based on training. Matches Step 1 + baseline behavior. + - "forced": prepend a directive requiring at least one + mempalace_search call before FINAL. Tests whether the + invocation-rate ceiling (gemma4: 60% zero-call) is the + dominant Cat 9a lever on substring-shaped corpora. + - "grounded": prepend a directive requiring the answer to + quote at least one retrieved source filename. Tests whether + the substring-scorer-vs-LLM-synthesis gap is the lever on + file-shaped corpora (n=200 git-derived probes, etc.). + """ + self.invocation_mode = invocation_mode + # Lazy-import RLM so multipass doesn't require it for non-rlm runs. + from rlm import RLM + from rlm.utils.prompts import RLM_SYSTEM_PROMPT + + self.api_url = (api_url or os.environ.get("PALACE_DAEMON_URL", "http://disks.jphe.in:8085")).rstrip("/") + self.api_key = api_key or os.environ.get("PALACE_API_KEY", "") + self.kind = kind + self.timeout_s = timeout_s + self.verbose = verbose + + # Per-query capture buffer — populated by _mempalace_search and + # drained at the end of each query() call. _capture grows by one + # entry per drawer returned (across all tool invocations); + # _tool_call_count tracks actual invocations of _mempalace_search + # so 9a-shaped invocation-rate reads aren't conflated with + # per-call drawer counts. + self._capture: list[dict] = [] + self._tool_call_count: int = 0 + + # Backend resolution. JP's home environment ships a multi-provider + # config in ~/.config/cloud-chat-assistant/config.json. If no + # explicit backend was passed, prefer that file's Azure OpenAI + # entry (which is what cloud-chat-assistant uses by default). + bk = dict(backend_kwargs or {}) + if backend is None: + backend, bk = _resolve_default_backend(bk) + # Per-backend defaults for model + key. + if "model_name" not in bk: + if backend == "portkey": + bk["model_name"] = "@openai/gpt-5-nano" + elif backend == "openai": + bk["model_name"] = "gpt-5-nano" + elif backend == "azure_openai": + # Filled in by _resolve_default_backend if the file is present. + bk.setdefault("model_name", "gpt-4o") + if "api_key" not in bk: + if backend == "portkey": + bk["api_key"] = os.environ.get("PORTKEY_API_KEY", "") + elif backend == "openai": + bk["api_key"] = os.environ.get("OPENAI_API_KEY", "") + elif backend == "anthropic": + bk["api_key"] = os.environ.get("ANTHROPIC_API_KEY", "") + elif backend == "azure_openai": + bk["api_key"] = os.environ.get("AZURE_OPENAI_API_KEY", "") + + # Build the system prompt. By default, use RLM's own. In + # invocation_mode="forced"/"grounded", prepend an extra + # paragraph of constraints BEFORE the standard RLM prompt + # (which still owns the {custom_tools_section} format + # placeholder that RLM fills at completion time). + custom_system_prompt: Optional[str] = None + if invocation_mode == "forced": + custom_system_prompt = ( + "MANDATORY RETRIEVAL CONSTRAINT (test condition, do not ignore):\n" + "Before you provide FINAL(...) or FINAL_VAR(...), you MUST call\n" + "`mempalace_search(...)` at least once with a query relevant to the\n" + "user's question. Even if you believe you can answer from training\n" + "data, you MUST first invoke the search tool. Never produce FINAL\n" + "without at least one mempalace_search call in your history.\n" + "\n" + + RLM_SYSTEM_PROMPT + ) + elif invocation_mode == "grounded": + custom_system_prompt = ( + "MANDATORY GROUNDING CONSTRAINT (test condition, do not ignore):\n" + "Before you provide FINAL(...) or FINAL_VAR(...), you MUST (1) call\n" + "`mempalace_search(...)` at least once with a query relevant to the\n" + "user's question, AND (2) include in your final answer at least one\n" + "source filename or drawer_id from the retrieved results. Quote the\n" + "source verbatim from the mempalace_search return value. If no\n" + "retrieved drawer is relevant, say so explicitly in FINAL and quote\n" + "the search query you used.\n" + "\n" + + RLM_SYSTEM_PROMPT + ) + + rlm_kwargs: dict[str, Any] = dict( + backend=backend, + backend_kwargs=bk, + environment=environment, + custom_tools={ + "mempalace_search": { + "tool": self._mempalace_search, + "description": ( + "Search JP's palace for drawers semantically related to a query. " + "Returns a list of dicts with text, wing, room, source_file, similarity. " + "Default limit is 5. Use this to ground factual claims about JP, " + "his projects, his realm, and any past events." + ), + }, + }, + verbose=verbose, + ) + if custom_system_prompt is not None: + rlm_kwargs["custom_system_prompt"] = custom_system_prompt + self._rlm = RLM(**rlm_kwargs) + + # ------------------------------------------------------------------ + # mempalace_search — exposed to RLM's REPL via custom_tools. + # ------------------------------------------------------------------ + + def _mempalace_search(self, query: str, limit: int = _DEFAULT_LIMIT) -> list[dict]: + """HTTP call to palace-daemon /search; capture results for SME scoring.""" + self._tool_call_count += 1 + params = {"q": query, "limit": str(limit), "kind": self.kind} + url = f"{self.api_url}/search?" + "&".join(f"{k}={_urlrequest.quote(v)}" for k, v in params.items()) + req = _urlrequest.Request(url) + if self.api_key: + req.add_header("x-api-key", self.api_key) + try: + with _urlrequest.urlopen(req, timeout=self.timeout_s) as resp: + payload = json.loads(resp.read().decode("utf-8")) + except (_urlerror.URLError, _urlerror.HTTPError, OSError) as e: + return [{"error": str(e), "results": []}] + + results = payload.get("results", []) or [] + # Trim what we return to RLM (it has limited context); keep the same + # shape stable so the model's prompt isn't re-tokenized by surprise. + # NOTE: source_file is load-bearing — SME's substring scorer matches + # against expected_sources which are filenames; dropping source_file + # from the trimmed entry meant the LLM couldn't quote it AND the + # context_string used by the scorer didn't contain it, so retrieval + # that landed the right drawer would still score 0. Fixed 2026-05-16. + trimmed: list[dict] = [] + for r in results[:limit]: + entry = { + "drawer_id": r.get("drawer_id") or r.get("id"), + "text": (r.get("text") or "")[:500], + "wing": r.get("wing"), + "room": r.get("room"), + "source_file": r.get("source_file"), + "similarity": r.get("similarity"), + } + trimmed.append(entry) + self._capture.append(entry) + return trimmed + + # ------------------------------------------------------------------ + # SMEAdapter contract. + # ------------------------------------------------------------------ + + def ingest_corpus(self, corpus: list[dict]) -> dict: + """No-op stub. RLM consumes a palace it didn't author; ingestion + happens upstream via `mempalace mine` / familiar reflect. + + Returns the full SMEAdapter contract dict (with errors/warnings + empty lists) so downstream harness code reading those keys + doesn't KeyError. Prior to 2026-05-16 this returned an + incomplete dict missing both required keys. + """ + return { + "entities_created": 0, + "edges_created": 0, + "errors": [], + "warnings": [], + "skipped": True, + } + + def get_graph_snapshot(self) -> tuple[list[Entity], list[Edge]]: + """RLM doesn't maintain a graph view. Return empty lists — Cat 8 + ontology coherence isn't applicable here. Cat 1/2/7 still work + because they only need query() output.""" + return [], [] + + def query(self, question: str) -> QueryResult: + self._capture = [] + self._tool_call_count = 0 + t0 = time.time() + try: + result = self._rlm.completion(question) + answer = getattr(result, "response", str(result)) + except Exception as e: # pragma: no cover — backend / network + return QueryResult( + answer="", + context_string="", + error=f"{type(e).__name__}: {e}", + ) + + # Build context_string from BOTH the captured search results AND + # the synthesized answer. RLM's "context" is split between what + # the LM pulled via tool calls (the search captures) and what it + # produced via training-data synthesis (the answer). Familiar's + # equivalent — the system prompt — is purely retrieval. To give + # the substring scorer a fair shake at RLM's full output, we + # include both sides. + ctx_lines = [f"── RLM-orchestrated retrieval ({len(self._capture)} drawers) ──"] + for r in self._capture: + tags = [] + if r.get("drawer_id"): + tags.append(f"drawer_id={r['drawer_id']}") + if r.get("source_file"): + tags.append(f"source_file={r['source_file']}") + if r.get("wing"): + tags.append(f"wing={r['wing']}") + if r.get("room"): + tags.append(f"room={r['room']}") + if isinstance(r.get("similarity"), (int, float)): + tags.append(f"similarity={r['similarity']:.3f}") + ctx_lines.append("[" + " · ".join(tags) + "]") + ctx_lines.append(r.get("text", "")) + ctx_lines.append("") + ctx_lines.append("── RLM answer ──") + ctx_lines.append(answer) + context_string = "\n".join(ctx_lines) + + # SME entities — one per captured drawer. + entities: list[Entity] = [] + for r in self._capture: + if not r.get("drawer_id"): + continue + entities.append(Entity( + id=str(r["drawer_id"]), + name=str(r["drawer_id"]), + entity_type="drawer", + properties={ + "text": r.get("text", ""), + "wing": r.get("wing"), + "room": r.get("room"), + "similarity": r.get("similarity"), + "matched_via": "rlm_tool_call", + }, + )) + + elapsed_ms = round((time.time() - t0) * 1000, 1) + return QueryResult( + answer=answer, + context_string=context_string, + retrieved_entities=entities, + retrieved_edges=[], + # Strings, not dicts — cli's `'; '.join(path)` expects strings. + retrieval_path=[ + f"rlm_completion ({elapsed_ms}ms, {self._tool_call_count} tool calls, {len(self._capture)} drawers)", + ], + error=None, + ) + + def close(self) -> None: + # RLM doesn't expose an explicit close; nothing to release here. + pass diff --git a/sme/categories/gap_detection.py b/sme/categories/gap_detection.py index 6e6b973..85226d8 100644 --- a/sme/categories/gap_detection.py +++ b/sme/categories/gap_detection.py @@ -311,19 +311,23 @@ def score_gap_detection( } recalled = 0 - considered = 0 + seeded_applicable = 0 # Renamed from `considered` to avoid shadowing + # the candidate-pair count from _candidate_gaps() above — that + # value is load-bearing for the report's `candidate_gaps_considered` + # field (pre-filter total over component pairs), distinct from + # the seeded-edge denominator used here for gap_recall. for u, v in seeded_missing_edges: cu = node_to_comp.get(u) cv = node_to_comp.get(v) if cu is None or cv is None: continue - considered += 1 + seeded_applicable += 1 if cu == cv: continue # Endpoints already in same component — not a cross-cluster gap if frozenset({cu, cv}) in reported_pairs: recalled += 1 - gap_recall = (recalled / considered) if considered else 0.0 + gap_recall = (recalled / seeded_applicable) if seeded_applicable else 0.0 gap_precision = ( (recalled / len(reported_pairs)) if reported_pairs else 0.0 ) diff --git a/sme/cli.py b/sme/cli.py index 50ade5a..6fd9e17 100644 --- a/sme/cli.py +++ b/sme/cli.py @@ -13,8 +13,9 @@ import logging import sys import time +from dataclasses import dataclass, field from pathlib import Path -from typing import Any +from typing import Any, Callable from sme.adapters.base import SMEAdapter from sme.topology import TopologyAnalyzer @@ -22,145 +23,170 @@ log = logging.getLogger("sme") +@dataclass(frozen=True) +class _AdapterSpec: + """Allowlist registration for one adapter. + + `accepts` enumerates the constructor kwargs the adapter understands. + Any CLI-level kwarg not in this set is silently dropped — this makes + drop-list drift (the PR #7 class of regression) structurally + impossible: a new CLI flag can't break an old adapter just by being + present in the bag of kwargs. + + `rename` translates CLI-side names to constructor-side names (e.g. + the CLI's --api-url maps to FamiliarAdapter's `base_url`). + """ + + aliases: tuple[str, ...] + loader: Callable[[], type[SMEAdapter]] + accepts: frozenset[str] + rename: dict[str, str] = field(default_factory=dict) + + +def _ladybugdb_loader() -> type[SMEAdapter]: + from sme.adapters.ladybugdb import LadybugDBAdapter + + return LadybugDBAdapter + + +def _mempalace_daemon_loader() -> type[SMEAdapter]: + from sme.adapters.mempalace_daemon import MemPalaceDaemonAdapter + + return MemPalaceDaemonAdapter + + + +def _rlm_loader() -> type[SMEAdapter]: + from sme.adapters.rlm_adapter import RlmAdapter + + return RlmAdapter + + +def _familiar_loader() -> type[SMEAdapter]: + from sme.adapters.familiar import FamiliarAdapter + + return FamiliarAdapter + + +def _mempalace_loader() -> type[SMEAdapter]: + from sme.adapters.mempalace import MemPalaceAdapter + + return MemPalaceAdapter + + +def _flat_loader() -> type[SMEAdapter]: + from sme.adapters.flat_baseline import FlatBaselineAdapter + + return FlatBaselineAdapter + + +def _full_context_loader() -> type[SMEAdapter]: + from sme.conditions.full_context import FullContextAdapter + + return FullContextAdapter + + + +def _karpathy_compiled_loader() -> type[SMEAdapter]: + from sme.conditions.karpathy_compiled import KarpathyCompiledAdapter + + return KarpathyCompiledAdapter + + +_ADAPTER_REGISTRY: tuple[_AdapterSpec, ...] = ( + _AdapterSpec( + aliases=("ladybugdb", "ladybug"), + loader=_ladybugdb_loader, + accepts=frozenset({ + "db_path", "read_only", "buffer_pool_size", + "include_node_tables", "include_edge_tables", "auto_discover", + "skip_infrastructure", "api_url", "default_query_mode", + "api_timeout", + }), + ), + _AdapterSpec( + aliases=("mempalace-daemon", "mempalace_daemon"), + loader=_mempalace_daemon_loader, + accepts=frozenset({ + "api_url", "api_key", "env_file", "kind", "api_timeout", + "prefer_graph_endpoint", "read_only", + }), + ), + _AdapterSpec( + aliases=("rlm",), + loader=_rlm_loader, + accepts=frozenset({ + "api_url", "api_key", "backend", "backend_kwargs", + "environment", "verbose", "kind", "timeout_s", + }), + ), + _AdapterSpec( + aliases=("familiar",), + loader=_familiar_loader, + accepts=frozenset({ + "base_url", "timeout_s", "mock_inference", "opener", + }), + rename={"api_url": "base_url"}, + ), + _AdapterSpec( + aliases=("mempalace",), + loader=_mempalace_loader, + accepts=frozenset({ + "db_path", "read_only", "kg_path", "collection_name", + "include_kg", "include_drawers", "max_drawer_nodes", + }), + ), + _AdapterSpec( + aliases=("flat", "flat_baseline"), + loader=_flat_loader, + accepts=frozenset({ + "db_path", "read_only", "collection_name", "n_results", + }), + ), + _AdapterSpec( + aliases=("full-context", "full_context"), + loader=_full_context_loader, + accepts=frozenset({"vault_dir", "read_only"}), + rename={"db_path": "vault_dir"}, + ), + _AdapterSpec( + aliases=("karpathy-compiled", "karpathy_compiled"), + loader=_karpathy_compiled_loader, + accepts=frozenset({"compiled_dir", "include_wiki"}), + rename={"db_path": "compiled_dir"}, + ), +) + + +def _registry_by_alias() -> dict[str, _AdapterSpec]: + out: dict[str, _AdapterSpec] = {} + for spec in _ADAPTER_REGISTRY: + for alias in spec.aliases: + out[alias] = spec + return out + + def _load_adapter(name: str, **kwargs) -> SMEAdapter: + """Build an adapter by name from the registry. + + Drops None-valued kwargs (so adapter defaults take over), applies + each spec's rename map (CLI-side → constructor-side), then keeps + only kwargs the adapter actually accepts. Unknown kwargs are + silently dropped — this is the structural fix for the PR #7 class + of drop-list drift (M0nkeyFl0wer/multipass-structural-memory-eval#20). + """ name = name.lower() - # Drop Nones so adapter defaults kick in + spec = _registry_by_alias().get(name) + if spec is None: + raise SystemExit(f"unknown adapter: {name}") + kwargs = {k: v for k, v in kwargs.items() if v is not None} + for src, dst in spec.rename.items(): + if src in kwargs: + kwargs[dst] = kwargs.pop(src) + filtered = {k: v for k, v in kwargs.items() if k in spec.accepts} - if name == "ladybugdb" or name == "ladybug": - from sme.adapters.ladybugdb import LadybugDBAdapter - - return LadybugDBAdapter(**kwargs) - - if name in ("mempalace-daemon", "mempalace_daemon"): - from sme.adapters.mempalace_daemon import MemPalaceDaemonAdapter - - # Drop kwargs the daemon adapter doesn't understand - for k in ( - "include_node_tables", - "include_edge_tables", - "auto_discover", - "kg_path", - "collection_name", - "default_query_mode", - "db_path", - "buffer_pool_size", - ): - kwargs.pop(k, None) - return MemPalaceDaemonAdapter(**kwargs) - - if name == "familiar": - from sme.adapters.familiar import FamiliarAdapter - - # Drop kwargs the familiar adapter doesn't understand - for k in ( - "include_node_tables", - "include_edge_tables", - "auto_discover", - "kg_path", - "collection_name", - "default_query_mode", - "db_path", - "buffer_pool_size", - "api_key", - "kind", - "read_only", - ): - kwargs.pop(k, None) - # CLI uses --api-url; familiar adapter constructor uses base_url. - if "api_url" in kwargs: - kwargs["base_url"] = kwargs.pop("api_url") - return FamiliarAdapter(**kwargs) - - if name == "mempalace": - from sme.adapters.mempalace import MemPalaceAdapter - - # LadybugDB-specific kwargs are silently ignored for other adapters - for k in ( - "include_node_tables", - "include_edge_tables", - "auto_discover", - "api_url", - "api_key", - "kind", - "default_query_mode", - ): - kwargs.pop(k, None) - return MemPalaceAdapter(**kwargs) - - if name == "flat" or name == "flat_baseline": - from sme.adapters.flat_baseline import FlatBaselineAdapter - - for k in ( - "include_node_tables", - "include_edge_tables", - "auto_discover", - "kg_path", - "api_url", - "api_key", - "kind", - "default_query_mode", - ): - kwargs.pop(k, None) - return FlatBaselineAdapter(**kwargs) - - if name in ("full-context", "full_context"): - # Karpathy-baseline Condition D1 — see - # docs/cross_validation_2026.md § (4) and - # sme/conditions/full_context.py. Treats `--db` as the vault - # path; loads every .md file under it as the prompt context. - from sme.conditions.full_context import FullContextAdapter - - # Drop kwargs other adapters use that don't apply to D1. - for k in ( - "include_node_tables", - "include_edge_tables", - "auto_discover", - "kg_path", - "api_url", - "api_key", - "kind", - "collection_name", - "default_query_mode", - "mock_inference", - "timeout_s", - "buffer_pool_size", - ): - kwargs.pop(k, None) - # FullContextAdapter takes vault_dir, not db_path. - if "db_path" in kwargs: - kwargs["vault_dir"] = kwargs.pop("db_path") - return FullContextAdapter(**kwargs) - - if name in ("karpathy-compiled", "karpathy_compiled"): - # Karpathy-baseline Condition D2 — see - # docs/cross_validation_2026.md § (4) and - # sme/conditions/karpathy_compiled.py. Reads a pre-compiled - # wiki + index produced by `sme-eval compile-wiki`. Treats - # `--db` as the path to the compiled output directory. - from sme.conditions.karpathy_compiled import KarpathyCompiledAdapter - - for k in ( - "include_node_tables", - "include_edge_tables", - "auto_discover", - "kg_path", - "api_url", - "api_key", - "kind", - "collection_name", - "default_query_mode", - "mock_inference", - "timeout_s", - "buffer_pool_size", - "read_only", # accepted for CLI parity; not a constructor arg - ): - kwargs.pop(k, None) - if "db_path" in kwargs: - kwargs["compiled_dir"] = kwargs.pop("db_path") - return KarpathyCompiledAdapter(**kwargs) - - raise SystemExit(f"unknown adapter: {name}") + return spec.loader()(**filtered) def _fmt_int(n: int) -> str: @@ -803,21 +829,7 @@ def cmd_cat5(args: argparse.Namespace) -> int: raw = doc.get("missing_edges") or doc.get("seeded_missing_edges") or [] seeded = [(pair[0], pair[1]) for pair in raw if len(pair) == 2] - adapter_kwargs: dict[str, Any] = { - "db_path": args.db, - "read_only": True, - "auto_discover": args.auto_discover, - } - if args.node_tables: - adapter_kwargs["include_node_tables"] = args.node_tables - if args.edge_tables: - adapter_kwargs["include_edge_tables"] = args.edge_tables - if args.kg_path: - adapter_kwargs["kg_path"] = args.kg_path - if args.collection_name: - adapter_kwargs["collection_name"] = args.collection_name - - adapter = _load_adapter(args.adapter, **adapter_kwargs) + adapter = _load_adapter_from_args(args) entities, edges = adapter.get_graph_snapshot() log.info("snapshot: %d entities, %d edges", len(entities), len(edges)) @@ -1372,9 +1384,9 @@ def main(argv: list[str] | None = None) -> int: "--adapter", required=True, help="adapter name (flat | mempalace | mempalace-daemon | familiar | " - "ladybugdb | full-context). full-context is the Karpathy-baseline " - "Condition D1 — pass --db and it loads every .md file " - "as the prompt context with no retrieval.", + "rlm | ladybugdb | full-context | karpathy-compiled). full-context " + "is the Karpathy-baseline Condition D1 — pass --db and " + "it loads every .md file as the prompt context with no retrieval.", ) ret.add_argument( "--db", diff --git a/tests/test_adapter_contract.py b/tests/test_adapter_contract.py new file mode 100644 index 0000000..ae4d0cd --- /dev/null +++ b/tests/test_adapter_contract.py @@ -0,0 +1,250 @@ +"""SMEAdapter conformance testkit (M0nkeyFl0wer/multipass-structural-memory-eval#8). + +Parametric contract tests verifying every registered adapter conforms to the +``sme.adapters.base.SMEAdapter`` ABC: ``query`` returns a typed +``QueryResult``, ``get_graph_snapshot`` is internally consistent, ``ingest_corpus`` +accepts a list of dicts (or raises ``NotImplementedError``), and the optional +``get_harness_manifest`` returns a list when present. + +Existing per-adapter unit tests (``test_familiar_adapter.py``, +``test_mempalace_daemon_adapter.py`` etc.) verify HTTP deserialization and +adapter-specific behavior. This module verifies *the contract*. + +To opt a new adapter in, register a factory at the bottom under +``ADAPTER_FACTORIES``. A factory takes a ``tmp_path`` and returns a constructed +adapter (or skips, via ``pytest.skip``, when its environment is missing). +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Callable + +import pytest + +from sme.adapters.base import ( + Edge, + Entity, + HarnessDescriptor, + QueryResult, + SMEAdapter, +) + + +# --- MockAdapter ------------------------------------------------------ +# +# Minimal in-memory SMEAdapter used as a contract reference. Concrete enough +# that substring queries return real hits and the graph snapshot has a +# self-consistent entity/edge set. + + +class MockAdapter(SMEAdapter): + """In-memory SMEAdapter for contract testing. + + Stores ingested corpus as a list of dicts. ``query`` does substring + matching over ``text``; ``get_graph_snapshot`` returns a simple + two-node, one-edge graph. + """ + + def __init__(self) -> None: + self._corpus: list[dict] = [] + + def ingest_corpus(self, corpus: list[dict]) -> dict: + self._corpus.extend(corpus) + return { + "entities_created": len(corpus), + "edges_created": 0, + "errors": [], + "warnings": [], + } + + def query(self, question: str, n_results: int = 5) -> QueryResult: + q = (question or "").lower() + hits = [d for d in self._corpus if q and q in str(d.get("text", "")).lower()] + hits = hits[:n_results] + ctx = "\n".join(str(d.get("text", "")) for d in hits) + entities = [ + Entity( + id=str(d.get("id", f"mock:{i}")), + name=str(d.get("id", f"mock:{i}")), + entity_type="chunk", + properties={"text": d.get("text", "")}, + ) + for i, d in enumerate(hits) + ] + return QueryResult( + answer=ctx, + context_string=ctx, + retrieved_entities=entities, + retrieval_path=["mock", f"q={question}"], + ) + + def get_graph_snapshot(self) -> tuple[list[Entity], list[Edge]]: + # Trivial but internally consistent: edge endpoints both exist. + entities = [ + Entity(id="a", name="A", entity_type="topic"), + Entity(id="b", name="B", entity_type="topic"), + ] + edges = [Edge(source_id="a", target_id="b", edge_type="related_to")] + return entities, edges + + def get_harness_manifest(self) -> list[HarnessDescriptor]: + return [] + + +# --- Adapter factories ------------------------------------------------ +# +# Each factory builds one adapter. Network-dependent adapters use ``pytest.skip`` +# from within the factory when their environment is unavailable; that keeps the +# parametrize list flat and the skip reason visible in pytest output. + + +AdapterFactory = Callable[[Path], SMEAdapter] + + +def _mock_factory(tmp_path: Path) -> SMEAdapter: + return MockAdapter() + + +def _flat_baseline_factory(tmp_path: Path) -> SMEAdapter: + """FlatBaselineAdapter over an empty ChromaDB collection in tmp_path.""" + try: + import chromadb + except ImportError: # pragma: no cover — chromadb is a project dep + pytest.skip("chromadb not installed") + + from sme.adapters.flat_baseline import FlatBaselineAdapter + + db_path = tmp_path / "chroma" + db_path.mkdir() + client = chromadb.PersistentClient(path=str(db_path)) + client.create_collection("mempalace_drawers") + # Drop the construction-time handle so the adapter opens its own. + del client + return FlatBaselineAdapter(db_path=str(db_path)) + + +def _full_context_factory(tmp_path: Path) -> SMEAdapter: + """FullContextAdapter over a tiny tmp vault.""" + from sme.conditions.full_context import FullContextAdapter + + vault = tmp_path / "vault" + vault.mkdir() + (vault / "note.md").write_text("# note\nhello world\n", encoding="utf-8") + return FullContextAdapter(vault) + + +# Register adapters here. Keep IDs stable — they show in pytest output. +ADAPTER_FACTORIES: dict[str, AdapterFactory] = { + "mock": _mock_factory, + "flat_baseline": _flat_baseline_factory, + "full_context": _full_context_factory, +} + + +@pytest.fixture(params=sorted(ADAPTER_FACTORIES.keys())) +def adapter(request: pytest.FixtureRequest, tmp_path: Path) -> SMEAdapter: + factory = ADAPTER_FACTORIES[request.param] + return factory(tmp_path) + + +# --- Contract tests --------------------------------------------------- + + +def test_is_sme_adapter_subclass(adapter: SMEAdapter) -> None: + assert isinstance(adapter, SMEAdapter) + + +def test_query_returns_QueryResult(adapter: SMEAdapter) -> None: + result = adapter.query("test", n_results=3) + assert isinstance(result, QueryResult) + assert isinstance(result.context_string, str) + assert isinstance(result.retrieval_path, list) + assert all(isinstance(p, (str, int, float)) for p in result.retrieval_path) + assert isinstance(result.retrieved_entities, list) + assert all(isinstance(e, Entity) for e in result.retrieved_entities) + assert isinstance(result.retrieved_edges, list) + assert all(isinstance(e, Edge) for e in result.retrieved_edges) + assert isinstance(result.answer, str) + # ``error`` is Optional[str]; type, not presence. + assert result.error is None or isinstance(result.error, str) + + +def test_query_without_n_results_kwarg(adapter: SMEAdapter) -> None: + """The ABC signature is ``query(question)``; ``n_results`` is an + adapter-level extension. The minimum contract is that ``query(question)`` + alone returns a ``QueryResult``.""" + result = adapter.query("test") + assert isinstance(result, QueryResult) + + +def test_graph_snapshot_returns_typed_pair(adapter: SMEAdapter) -> None: + snapshot = adapter.get_graph_snapshot() + assert isinstance(snapshot, tuple) + assert len(snapshot) == 2 + entities, edges = snapshot + assert isinstance(entities, list) + assert isinstance(edges, list) + assert all(isinstance(e, Entity) for e in entities) + assert all(isinstance(e, Edge) for e in edges) + + +def test_graph_snapshot_internally_consistent(adapter: SMEAdapter) -> None: + """Every edge's source/target id must exist in the entity list. + + Adapters with no graph return ``([], [])`` — vacuously consistent. + """ + entities, edges = adapter.get_graph_snapshot() + entity_ids = {e.id for e in entities} + for edge in edges: + assert edge.source_id in entity_ids, ( + f"edge source_id {edge.source_id!r} not in entity ids" + ) + assert edge.target_id in entity_ids, ( + f"edge target_id {edge.target_id!r} not in entity ids" + ) + + +def test_ingest_corpus_accepts_list_of_dicts(adapter: SMEAdapter) -> None: + """``ingest_corpus`` must either succeed with a result dict or raise + ``NotImplementedError`` — both are valid per the ABC. AttributeError, + TypeError, or KeyError on the canonical shape would be a contract bug. + """ + corpus = [{"id": "x", "text": "y"}] + try: + result = adapter.ingest_corpus(corpus) + except NotImplementedError: + return + assert isinstance(result, dict) + for key in ("entities_created", "edges_created", "errors", "warnings"): + assert key in result, f"ingest result missing required key {key!r}" + assert isinstance(result["entities_created"], int) + assert isinstance(result["edges_created"], int) + assert isinstance(result["errors"], list) + assert isinstance(result["warnings"], list) + + +def test_get_harness_manifest_returns_list(adapter: SMEAdapter) -> None: + """``get_harness_manifest`` defaults to ``[]`` on the ABC. Anything + that overrides it must still return a list of ``HarnessDescriptor``. + """ + manifest = adapter.get_harness_manifest() + assert isinstance(manifest, list) + assert all(isinstance(d, HarnessDescriptor) for d in manifest) + + +def test_get_ontology_source_returns_typed_dict(adapter: SMEAdapter) -> None: + """Optional method with a sensible default on the ABC. When present, + must return a dict with ``type``, ``schema``, ``documentation``.""" + src = adapter.get_ontology_source() + assert isinstance(src, dict) + assert src.get("type") in {"declared", "readme", "inferred"} + assert isinstance(src.get("schema"), list) + assert isinstance(src.get("documentation"), str) + + +def test_close_is_idempotent(adapter: SMEAdapter) -> None: + """``close`` is part of the lifecycle contract and must tolerate + being called more than once.""" + adapter.close() + adapter.close() diff --git a/tests/test_adapter_harness_manifest_contract.py b/tests/test_adapter_harness_manifest_contract.py new file mode 100644 index 0000000..ce677a3 --- /dev/null +++ b/tests/test_adapter_harness_manifest_contract.py @@ -0,0 +1,90 @@ +"""Regression test for the get_harness_manifest() ABC contract. + +Closes the silent-AttributeError class entirely: every adapter class must +either inherit the base ABC default (returns []) or override it to return +a list. Tests at the class level (no instantiation) so adapters with +heavy network/disk constructors are still covered. + +Upstream issue: M0nkeyFl0wer/multipass-structural-memory-eval#19. +""" + +from __future__ import annotations + +import inspect + +from sme.adapters.base import SMEAdapter +from sme.adapters.familiar import FamiliarAdapter +from sme.adapters.flat_baseline import FlatBaselineAdapter +from sme.adapters.ladybugdb import LadybugDBAdapter +from sme.adapters.mempalace import MemPalaceAdapter +from sme.adapters.mempalace_daemon import MemPalaceDaemonAdapter +from sme.conditions.full_context import FullContextAdapter +from sme.conditions.karpathy_compiled import KarpathyCompiledAdapter + +ADAPTER_CLASSES = [ + FamiliarAdapter, + FlatBaselineAdapter, + FullContextAdapter, + KarpathyCompiledAdapter, + LadybugDBAdapter, + MemPalaceAdapter, + MemPalaceDaemonAdapter, +] + + +def test_base_adapter_default_is_empty_list(): + """The ABC default must return an empty list, not raise.""" + + class _MinimalAdapter(SMEAdapter): + def ingest_corpus(self, corpus): + return { + "entities_created": 0, + "edges_created": 0, + "errors": [], + "warnings": [], + } + + def query(self, question): + from sme.adapters.base import QueryResult + + return QueryResult(answer="") + + def get_graph_snapshot(self): + return [], [] + + assert _MinimalAdapter().get_harness_manifest() == [] + + +def test_every_adapter_has_get_harness_manifest(): + """Each shipped adapter must resolve `get_harness_manifest` — either + inherited from `SMEAdapter` (default `[]`) or overridden on the class. + Catches the silent-AttributeError class at import time without needing + to instantiate heavyweight adapters.""" + for cls in ADAPTER_CLASSES: + method = getattr(cls, "get_harness_manifest", None) + assert callable(method), ( + f"{cls.__name__} is missing get_harness_manifest entirely " + f"(expected ABC default to be inherited)" + ) + sig = inspect.signature(method) + # One positional param: self. No required extras. + params = list(sig.parameters.values()) + assert len(params) >= 1, f"{cls.__name__}.get_harness_manifest must take self" + for p in params[1:]: + assert p.default is not inspect.Parameter.empty or p.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ), ( + f"{cls.__name__}.get_harness_manifest must be callable with no " + f"args beyond self; found required param {p.name!r}" + ) + + +def test_base_default_inherited_when_not_overridden(): + """Adapters that don't override should resolve to the ABC default.""" + for cls in ADAPTER_CLASSES: + if "get_harness_manifest" not in cls.__dict__: + assert cls.get_harness_manifest is SMEAdapter.get_harness_manifest, ( + f"{cls.__name__} did not override get_harness_manifest but " + f"does not resolve to SMEAdapter.get_harness_manifest" + ) diff --git a/tests/test_cli_adapter_forwarding.py b/tests/test_cli_adapter_forwarding.py new file mode 100644 index 0000000..e72bd7a --- /dev/null +++ b/tests/test_cli_adapter_forwarding.py @@ -0,0 +1,247 @@ +"""Tests for the allowlist-based adapter registry in sme.cli. + +Inverts the legacy drop-list pattern (PR #7 regression class) — see +M0nkeyFl0wer/multipass-structural-memory-eval#20. Each adapter declares +which kwargs it accepts via `_AdapterSpec.accepts`; unknown kwargs are +silently dropped at the registry boundary, so new CLI flags can't +break old adapters by drifting past stale drop-lists. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from sme.cli import _ADAPTER_REGISTRY, _load_adapter, _registry_by_alias + + +class _StubAdapter: + """Throw-away adapter that captures whatever kwargs reach it.""" + + def __init__(self, **kwargs: Any) -> None: + self.captured_kwargs = kwargs + + +@pytest.fixture +def stub_loader(monkeypatch): + """Yield a helper that swaps an adapter's loader for a stub. + + `_AdapterSpec` is a frozen dataclass — `object.__setattr__` bypasses + that, and this fixture restores the original loader on teardown so + tests don't bleed mutations into each other. + """ + restores: list[tuple[Any, Any]] = [] + + def _patch(alias: str) -> None: + spec = _registry_by_alias()[alias] + restores.append((spec, spec.loader)) + object.__setattr__(spec, "loader", lambda: _StubAdapter) + + yield _patch + + for spec, original in restores: + object.__setattr__(spec, "loader", original) + + +# --------------------------------------------------------------------- +# Registry shape checks + + +def test_unknown_adapter_raises(): + with pytest.raises(SystemExit, match="unknown adapter"): + _load_adapter("does-not-exist") + + +def test_every_declared_alias_is_loadable_by_name(): + """All declared aliases route to a real spec.""" + aliases = [a for spec in _ADAPTER_REGISTRY for a in spec.aliases] + by_alias = _registry_by_alias() + for alias in aliases: + assert alias in by_alias, f"alias {alias!r} not in registry index" + + +# --------------------------------------------------------------------- +# Core contract: rename, allowlist, None-strip, unknown-drop + + +@pytest.mark.parametrize("alias", ["ladybugdb", "ladybug"]) +def test_ladybugdb_aliases_resolve(stub_loader, alias): + stub_loader(alias) + out = _load_adapter(alias, db_path="/tmp/x", read_only=True) + assert isinstance(out, _StubAdapter) + assert out.captured_kwargs["db_path"] == "/tmp/x" + + +def test_familiar_renames_api_url_to_base_url(stub_loader): + stub_loader("familiar") + out = _load_adapter("familiar", api_url="http://nowhere:1", timeout_s=1.0) + assert out.captured_kwargs == { + "base_url": "http://nowhere:1", + "timeout_s": 1.0, + } + assert "api_url" not in out.captured_kwargs + + +def test_full_context_renames_db_path_to_vault_dir(stub_loader): + stub_loader("full-context") + out = _load_adapter("full-context", db_path="/tmp/vault", read_only=True) + assert out.captured_kwargs == { + "vault_dir": "/tmp/vault", + "read_only": True, + } + assert "db_path" not in out.captured_kwargs + + +def test_karpathy_compiled_renames_db_path_to_compiled_dir(stub_loader): + stub_loader("karpathy-compiled") + out = _load_adapter("karpathy-compiled", db_path="/tmp/wiki") + assert out.captured_kwargs == {"compiled_dir": "/tmp/wiki"} + + +def test_rlm_keeps_kind_and_api_url(stub_loader): + """`rlm` accepts `kind` — it forwards into mempalace_search /search.""" + stub_loader("rlm") + out = _load_adapter( + "rlm", + api_url="http://disks:8085", + api_key="abc", + kind="content", + include_node_tables=["X"], + db_path="/tmp/db", + ) + assert out.captured_kwargs == { + "api_url": "http://disks:8085", + "api_key": "abc", + "kind": "content", + } + + +def test_unknown_kwargs_silently_dropped(stub_loader): + """The PR #7 class of regression: a new CLI flag must not blow up + an old adapter just by being present in the kwargs bag.""" + stub_loader("familiar") + out = _load_adapter( + "familiar", + api_url="http://nowhere:1", + # Every one of these belongs to *some* other adapter and must + # not reach FamiliarAdapter's constructor. + include_node_tables=["X"], + include_edge_tables=["Y"], + auto_discover=True, + kg_path="/tmp/kg.sqlite3", + collection_name="drawers", + default_query_mode="hybrid", + db_path="/tmp/db", + buffer_pool_size=128, + api_key="secret", + kind="content", + read_only=True, + # Invented future flag that doesn't exist anywhere yet. + some_future_flag_that_doesnt_exist=42, + ) + assert out.captured_kwargs == {"base_url": "http://nowhere:1"} + + +def test_none_valued_kwargs_are_stripped(stub_loader): + """`None` means 'use the adapter default' — never forward as-is.""" + stub_loader("flat") + out = _load_adapter( + "flat", + db_path="/tmp/db", + read_only=True, + collection_name=None, + n_results=None, + ) + assert out.captured_kwargs == {"db_path": "/tmp/db", "read_only": True} + assert "collection_name" not in out.captured_kwargs + assert "n_results" not in out.captured_kwargs + + +def test_mempalace_daemon_drops_db_path(stub_loader): + """Daemon adapter has a legacy `db_path` constructor arg that's a + no-op — old drop-list dropped it; preserve that behavior.""" + stub_loader("mempalace-daemon") + out = _load_adapter( + "mempalace-daemon", + api_url="http://localhost:8085", + api_key="key", + kind="content", + db_path="/tmp/should-be-dropped", + buffer_pool_size=128, + ) + assert "db_path" not in out.captured_kwargs + assert "buffer_pool_size" not in out.captured_kwargs + assert out.captured_kwargs == { + "api_url": "http://localhost:8085", + "api_key": "key", + "kind": "content", + } + + +def test_mempalace_keeps_kg_path_and_collection_name(stub_loader): + stub_loader("mempalace") + out = _load_adapter( + "mempalace", + db_path="/tmp/chroma", + kg_path="/tmp/kg.sqlite3", + collection_name="drawers", + read_only=True, + # Should be dropped — not in mempalace's accepts + api_url="http://x", + kind="content", + ) + assert out.captured_kwargs == { + "db_path": "/tmp/chroma", + "kg_path": "/tmp/kg.sqlite3", + "collection_name": "drawers", + "read_only": True, + } + + +# --------------------------------------------------------------------- +# Integration: real construction for lightweight adapters + + +def test_familiar_real_construction(): + """FamiliarAdapter has no heavy deps — verify the registry routes + correctly all the way through to the real constructor.""" + adapter = _load_adapter("familiar", api_url="http://nowhere:1", timeout_s=1.0) + assert type(adapter).__name__ == "FamiliarAdapter" + assert adapter.base_url == "http://nowhere:1" + assert adapter.timeout_s == 1.0 + + +def test_full_context_real_construction(tmp_path): + """FullContextAdapter only validates the vault path exists.""" + vault = tmp_path / "vault" + vault.mkdir() + adapter = _load_adapter("full-context", db_path=str(vault)) + assert type(adapter).__name__ == "FullContextAdapter" + assert adapter.vault_dir == vault + + +# --------------------------------------------------------------------- +# Sanity: rename/accepts internal consistency + + +def test_rename_targets_are_in_accepts(): + """A spec that renames `foo` → `bar` must accept `bar`. Otherwise + the registry would rename a kwarg only to immediately drop it.""" + for spec in _ADAPTER_REGISTRY: + for src, dst in spec.rename.items(): + assert dst in spec.accepts, ( + f"adapter {spec.aliases[0]!r} renames {src!r} → {dst!r} " + f"but {dst!r} is not in accepts={sorted(spec.accepts)}" + ) + + +def test_rename_sources_are_not_in_accepts(): + """If `foo` renames to `bar`, `foo` itself shouldn't also appear + in accepts — that creates ambiguity about which the adapter wants.""" + for spec in _ADAPTER_REGISTRY: + for src in spec.rename: + assert src not in spec.accepts, ( + f"adapter {spec.aliases[0]!r} both renames {src!r} away " + f"and lists it in accepts — pick one" + ) diff --git a/tests/test_gap_detection.py b/tests/test_gap_detection.py index 4d0cb29..398f1ec 100644 --- a/tests/test_gap_detection.py +++ b/tests/test_gap_detection.py @@ -85,6 +85,43 @@ def test_seeded_gap_recall_is_one(gap_graph): assert report.gap_precision == pytest.approx(1.0) +def test_seeded_gap_does_not_corrupt_candidate_gaps_considered(gap_graph): + """Regression for the variable-shadowing bug at gap_detection.py:314. + + Before the fix, providing seeded_missing_edges silently rebound the + `considered` local from "component pairs examined" to "seeded edges + examined" — so `candidate_gaps_considered` in the returned report + carried the wrong number when graded-fixture runs were used. + + The two values must remain distinct. + """ + entities, edges, truth = gap_graph + + # Reading WITHOUT seeded edges — captures the true component-pair count. + plain = score_gap_detection(entities, edges, run_homology=False) + plain_considered = plain.candidate_gaps_considered + + # Reading WITH seeded edges — candidate_gaps_considered must stay + # the same (it's a property of the structural pass, not of the + # seeded scoring overlay). + graded = score_gap_detection( + entities, + edges, + seeded_missing_edges=truth["seeded_missing_edges"], + run_homology=False, + ) + assert graded.candidate_gaps_considered == plain_considered, ( + f"candidate_gaps_considered changed when seeded edges were " + f"supplied: plain={plain_considered}, graded=" + f"{graded.candidate_gaps_considered}. The seeded-edge " + f"counter should be a separate local, not shadow the " + f"component-pair counter." + ) + # Sanity: gap_recall is still computed correctly (the variable + # rename should be invisible to consumers reading this field). + assert graded.gap_recall == pytest.approx(1.0) + + def test_candidate_gap_has_score_and_examples(gap_graph): entities, edges, _ = gap_graph report = score_gap_detection(entities, edges, run_homology=False) diff --git a/tests/test_graph_mapping.py b/tests/test_graph_mapping.py new file mode 100644 index 0000000..ec5d489 --- /dev/null +++ b/tests/test_graph_mapping.py @@ -0,0 +1,410 @@ +"""Tests for sme.adapters._graph_mapping.project_graph. + +Pins the input/output contract for the shared palace-daemon /graph +projection. Both MemPalaceDaemonAdapter and FamiliarAdapter consume +this function, so any change to the output shape is a coordinated +break across two adapters and any downstream Cat 8 scoring. + +Covers: + +- Output type contract: ``(list[Entity], list[Edge])`` +- Wing entities: id prefix, type, drawer_count carried in properties +- Room entities: cross-wing aggregation, ``general`` filtered out +- room→wing ``member_of`` edges +- Tunnel edges: wing<->wing pairs (skip self-pair, all combinations) +- KG entities + KG triples: id prefixing, temporal/source props copied +- Edge cases: empty payload, missing keys, malformed entries, the + exact field shape MemPalace daemon emits. +""" +from __future__ import annotations + +from sme.adapters._graph_mapping import project_graph +from sme.adapters.base import Edge, Entity + + +# ── Output contract ──────────────────────────────────────────────── + + +def test_project_graph_returns_two_lists(): + entities, edges = project_graph({}) + assert isinstance(entities, list) + assert isinstance(edges, list) + + +def test_project_graph_empty_payload(): + entities, edges = project_graph({}) + assert entities == [] + assert edges == [] + + +def test_project_graph_missing_keys_are_safe(): + """All top-level keys are optional; missing ones don't raise.""" + entities, edges = project_graph({"wings": {"alpha": 3}}) + # Just the wing entity, no edges + assert len(entities) == 1 + assert entities[0].entity_type == "wing" + assert edges == [] + + +def test_project_graph_null_values_treated_as_empty(): + """None for any top-level field is normalized to empty.""" + payload = { + "wings": None, + "rooms": None, + "tunnels": None, + "kg_entities": None, + "kg_triples": None, + } + entities, edges = project_graph(payload) + assert entities == [] + assert edges == [] + + +# ── Wing entities ────────────────────────────────────────────────── + + +def test_wings_become_entities_with_drawer_count(): + payload = {"wings": {"alpha": 10, "beta": 5}} + entities, _ = project_graph(payload) + wings = [e for e in entities if e.entity_type == "wing"] + assert len(wings) == 2 + by_id = {e.id: e for e in wings} + assert "wing:alpha" in by_id + assert by_id["wing:alpha"].name == "alpha" + assert by_id["wing:alpha"].properties["drawer_count"] == 10 + assert by_id["wing:alpha"].properties["_table"] == "wing" + assert by_id["wing:beta"].properties["drawer_count"] == 5 + + +def test_wings_emitted_in_sorted_order(): + """Determinism matters for diffing graph snapshots.""" + payload = {"wings": {"zebra": 1, "alpha": 2, "mango": 3}} + entities, _ = project_graph(payload) + wing_ids = [e.id for e in entities if e.entity_type == "wing"] + assert wing_ids == ["wing:alpha", "wing:mango", "wing:zebra"] + + +# ── Room entities + edges ───────────────────────────────────────── + + +def test_rooms_aggregate_across_wings(): + """A room appearing in two wings has wings_list of both and the + drawer count is summed.""" + payload = { + "wings": {"alpha": 10, "beta": 5}, + "rooms": [ + {"wing": "alpha", "rooms": {"shared": 3, "alpha_only": 7}}, + {"wing": "beta", "rooms": {"shared": 2, "beta_only": 3}}, + ], + } + entities, edges = project_graph(payload) + rooms = {e.id: e for e in entities if e.entity_type == "room:untyped"} + assert set(rooms.keys()) == {"room:shared", "room:alpha_only", "room:beta_only"} + assert rooms["room:shared"].properties["wings"] == ["alpha", "beta"] + assert rooms["room:shared"].properties["drawer_count"] == 5 + assert rooms["room:alpha_only"].properties["wings"] == ["alpha"] + assert rooms["room:alpha_only"].properties["drawer_count"] == 7 + + +def test_room_general_is_filtered_out(): + """The 'general' room is a default catch-all; the projection skips it.""" + payload = { + "wings": {"alpha": 10}, + "rooms": [{"wing": "alpha", "rooms": {"general": 5, "real": 3}}], + } + entities, _ = project_graph(payload) + room_ids = [e.id for e in entities if e.entity_type == "room:untyped"] + assert "room:general" not in room_ids + assert "room:real" in room_ids + + +def test_room_empty_string_filtered_out(): + payload = { + "wings": {"alpha": 10}, + "rooms": [{"wing": "alpha", "rooms": {"": 5, "real": 3}}], + } + entities, _ = project_graph(payload) + room_ids = [e.id for e in entities if e.entity_type == "room:untyped"] + assert "room:" not in room_ids + + +def test_room_to_wing_member_of_edges(): + payload = { + "wings": {"alpha": 10, "beta": 5}, + "rooms": [ + {"wing": "alpha", "rooms": {"shared": 3}}, + {"wing": "beta", "rooms": {"shared": 2}}, + ], + } + _, edges = project_graph(payload) + member_edges = [e for e in edges if e.edge_type == "member_of"] + # One edge per (room, wing) — two wings sharing one room → 2 edges + assert len(member_edges) == 2 + edge_pairs = {(e.source_id, e.target_id) for e in member_edges} + assert ("room:shared", "wing:alpha") in edge_pairs + assert ("room:shared", "wing:beta") in edge_pairs + # The drawer_count is preserved on the edge + for e in member_edges: + assert e.properties["drawer_count"] == 5 # aggregate + assert e.properties["_table"] == "structural" + + +def test_rooms_emitted_in_sorted_order(): + payload = { + "wings": {"a": 1}, + "rooms": [{"wing": "a", "rooms": {"z": 1, "m": 1, "b": 1}}], + } + entities, _ = project_graph(payload) + room_ids = [ + e.id for e in entities if e.entity_type == "room:untyped" + ] + assert room_ids == ["room:b", "room:m", "room:z"] + + +def test_rooms_handle_null_rooms_field(): + """A wing whose rooms map is None doesn't crash.""" + payload = {"wings": {"a": 1}, "rooms": [{"wing": "a", "rooms": None}]} + entities, _ = project_graph(payload) + rooms = [e for e in entities if e.entity_type.startswith("room")] + assert rooms == [] + + +def test_rooms_drawer_count_handles_null(): + """Null counts default to zero, not crash.""" + payload = { + "wings": {"a": 1}, + "rooms": [{"wing": "a", "rooms": {"foo": None}}], + } + entities, _ = project_graph(payload) + foo = next(e for e in entities if e.id == "room:foo") + assert foo.properties["drawer_count"] == 0 + + +# ── Tunnel edges ─────────────────────────────────────────────────── + + +def test_tunnels_emit_wing_pair_edges(): + payload = { + "wings": {"a": 1, "b": 1}, + "tunnels": [{"room": "shared", "wings": ["a", "b"]}], + } + _, edges = project_graph(payload) + tunnel_edges = [e for e in edges if e.edge_type == "tunnel"] + assert len(tunnel_edges) == 1 + assert tunnel_edges[0].source_id == "wing:a" + assert tunnel_edges[0].target_id == "wing:b" + assert tunnel_edges[0].properties["via_room"] == "shared" + + +def test_tunnels_three_wings_yield_three_pairs(): + """Three wings sharing one room → C(3,2) = 3 tunnel edges.""" + payload = { + "wings": {"a": 1, "b": 1, "c": 1}, + "tunnels": [{"room": "shared", "wings": ["a", "b", "c"]}], + } + _, edges = project_graph(payload) + tunnel_edges = [e for e in edges if e.edge_type == "tunnel"] + assert len(tunnel_edges) == 3 + pairs = {(e.source_id, e.target_id) for e in tunnel_edges} + # Sorted within tunnel → no duplicate reversed pairs + assert pairs == { + ("wing:a", "wing:b"), + ("wing:a", "wing:c"), + ("wing:b", "wing:c"), + } + + +def test_single_wing_tunnel_yields_no_edges(): + """A tunnel with only one wing in the list has no pairs to form.""" + payload = {"tunnels": [{"room": "lonely", "wings": ["a"]}]} + _, edges = project_graph(payload) + assert [e for e in edges if e.edge_type == "tunnel"] == [] + + +def test_tunnel_null_wings_field_safe(): + payload = {"tunnels": [{"room": "r", "wings": None}]} + _, edges = project_graph(payload) + assert [e for e in edges if e.edge_type == "tunnel"] == [] + + +# ── KG entities and triples ─────────────────────────────────────── + + +def test_kg_entities_get_kg_prefix(): + payload = { + "kg_entities": [ + {"id": "max", "name": "Max", "type": "person", + "properties": {"age": 11}} + ] + } + entities, _ = project_graph(payload) + kg = [e for e in entities if e.id.startswith("kg:")] + assert len(kg) == 1 + assert kg[0].id == "kg:max" + assert kg[0].name == "Max" + assert kg[0].entity_type == "kg:person" + assert kg[0].properties["age"] == 11 + assert kg[0].properties["_table"] == "kg_entity" + + +def test_kg_entity_without_id_skipped(): + payload = { + "kg_entities": [ + {"name": "missing-id"}, + {"id": "valid", "type": "x"}, + ] + } + entities, _ = project_graph(payload) + kg = [e for e in entities if e.id.startswith("kg:")] + assert len(kg) == 1 + assert kg[0].id == "kg:valid" + + +def test_kg_entity_defaults_when_fields_missing(): + payload = {"kg_entities": [{"id": "bare"}]} + entities, _ = project_graph(payload) + bare = next(e for e in entities if e.id == "kg:bare") + # Name falls back to id; type falls back to 'unknown' + assert bare.name == "bare" + assert bare.entity_type == "kg:unknown" + + +def test_kg_triples_become_edges_with_temporal_props(): + payload = { + "kg_triples": [ + { + "subject": "max", + "predicate": "loves", + "object": "chess", + "valid_from": "2026-01-01", + "valid_to": None, + "confidence": 0.9, + "source_file": "diary.md", + } + ] + } + _, edges = project_graph(payload) + kg_edges = [e for e in edges if e.source_id.startswith("kg:")] + assert len(kg_edges) == 1 + e = kg_edges[0] + assert e.source_id == "kg:max" + assert e.target_id == "kg:chess" + assert e.edge_type == "loves" + assert e.properties["_created_at"] == "2026-01-01" + assert e.properties["confidence"] == 0.9 + assert e.properties["source_file"] == "diary.md" + assert e.properties["_table"] == "kg_triple" + + +def test_kg_triple_missing_subject_or_object_skipped(): + payload = { + "kg_triples": [ + {"subject": "max", "predicate": "loves"}, # no object + {"object": "chess", "predicate": "loves"}, # no subject + {"subject": "max", "predicate": "loves", "object": "chess"}, + ] + } + _, edges = project_graph(payload) + kg_edges = [e for e in edges if e.source_id.startswith("kg:")] + assert len(kg_edges) == 1 + + +def test_kg_triple_default_predicate_when_missing(): + payload = { + "kg_triples": [{"subject": "a", "object": "b"}] # no predicate + } + _, edges = project_graph(payload) + kg_edges = [e for e in edges if e.source_id.startswith("kg:")] + assert kg_edges[0].edge_type == "kg_related" + + +# ── Extra/unknown fields don't break projection ──────────────────── + + +def test_extra_fields_ignored(): + """Extra top-level keys like ``kg_stats`` and unknown sub-fields + don't affect the projection — forward compatibility.""" + payload = { + "wings": {"a": 1}, + "kg_stats": {"some_future_field": 123}, + "rooms": [], + "tunnels": [], + "kg_entities": [], + "kg_triples": [], + "rfc_status": "experimental", # unknown future field + } + entities, edges = project_graph(payload) + assert len(entities) == 1 # just the wing + assert edges == [] + + +# ── End-to-end realistic payload ─────────────────────────────────── + + +def test_realistic_palace_daemon_payload(): + """A payload shaped like an actual daemon /graph response covering + all five projection paths.""" + payload = { + "wings": {"code": 50, "decisions": 30, "personal": 20}, + "rooms": [ + {"wing": "code", "rooms": {"chromadb": 5, "general": 10}}, + {"wing": "decisions", "rooms": {"chromadb": 2}}, + {"wing": "personal", "rooms": {"family": 8}}, + ], + "tunnels": [ + {"room": "chromadb", "wings": ["code", "decisions"]}, + ], + "kg_entities": [ + {"id": "chroma", "name": "ChromaDB", "type": "tool"}, + {"id": "max", "name": "Max", "type": "person"}, + ], + "kg_triples": [ + { + "subject": "max", + "predicate": "uses", + "object": "chroma", + "valid_from": "2026-04-01", + } + ], + } + entities, edges = project_graph(payload) + + # All five entity classes are produced + types = {e.entity_type for e in entities} + assert "wing" in types + assert "room:untyped" in types + assert any(t.startswith("kg:") for t in types) + + # Wings: 3 + assert sum(1 for e in entities if e.entity_type == "wing") == 3 + # Rooms: 2 (chromadb + family; general filtered) + rooms = [e for e in entities if e.entity_type == "room:untyped"] + assert {e.id for e in rooms} == {"room:chromadb", "room:family"} + # KG entities: 2 + assert sum(1 for e in entities if e.id.startswith("kg:")) == 2 + + # Edges: member_of (3 room→wing edges), tunnel (1 pair), kg (1 triple) + member_edges = [e for e in edges if e.edge_type == "member_of"] + assert len(member_edges) == 3 + tunnel_edges = [e for e in edges if e.edge_type == "tunnel"] + assert len(tunnel_edges) == 1 + kg_edges = [e for e in edges if e.source_id.startswith("kg:")] + assert len(kg_edges) == 1 + assert kg_edges[0].edge_type == "uses" + + +def test_returned_entities_and_edges_are_correct_types(): + """Every returned item must be an Entity / Edge dataclass instance — + downstream code uses field access on these.""" + payload = { + "wings": {"a": 1}, + "rooms": [{"wing": "a", "rooms": {"r": 1}}], + "tunnels": [{"room": "r", "wings": ["a", "b"]}], + "kg_entities": [{"id": "x", "type": "t"}], + "kg_triples": [{"subject": "x", "object": "y", "predicate": "p"}], + } + entities, edges = project_graph(payload) + for e in entities: + assert isinstance(e, Entity) + for ed in edges: + assert isinstance(ed, Edge) diff --git a/tests/test_multi_hop.py b/tests/test_multi_hop.py new file mode 100644 index 0000000..df6c724 --- /dev/null +++ b/tests/test_multi_hop.py @@ -0,0 +1,368 @@ +"""Tests for sme.categories.multi_hop (Cat 2c). + +Covers: + +- Hop-bucket grouping over the retrieve-results JSON shape + (mean_recall, hit_rate, mean_tokens, correct_count per hop). +- A/B/C delta math (delta_B_minus_A, delta_B_minus_C, ratio_B_over_A) + including pp deltas, token deltas, and infinite ratios when A has + zero recall at a hop. +- Verdict logic: "structure earns complexity", "neutral tax", "harmful", + and the ratio-grows-with-depth branch. +- Edge cases: missing Condition A or C, single hop bucket, empty + question list, single question. +- to_dict() output shape — the JSON consumers downstream depend on + this. +""" +from __future__ import annotations + +import json + +import pytest + +from sme.categories.multi_hop import ( + Cat2cReport, + HopBreakdown, + _build_condition_report, + _verdict, + score_cat2c, +) + + +# ── Helpers ──────────────────────────────────────────────────────── + + +def _write_retrieve_json(path, questions: list[dict]) -> str: + """Write a minimal retrieve-results JSON payload and return the path.""" + path.write_text(json.dumps({"questions": questions})) + return str(path) + + +def _q(min_hops: int, recall: float, tokens: float, hit: bool | None = None) -> dict: + """Build a single retrieve-results question dict.""" + if hit is None: + hit = recall > 0 + return { + "min_hops": min_hops, + "recall": recall, + "tokens": tokens, + "hit": hit, + } + + +# ── _build_condition_report ──────────────────────────────────────── + + +def test_build_condition_report_groups_by_hop(): + data = { + "questions": [ + _q(1, 1.0, 100), + _q(1, 0.5, 200), + _q(2, 1.0, 300), + _q(3, 0.0, 400, hit=False), + ] + } + rep = _build_condition_report("B", "graph", data) + assert rep.total_questions == 4 + # Three hop buckets + assert set(rep.by_hop.keys()) == {1, 2, 3} + # 1-hop bucket: two questions, mean recall = 0.75 + assert rep.by_hop[1].n == 2 + assert rep.by_hop[1].mean_recall == pytest.approx(0.75) + assert rep.by_hop[1].mean_tokens == pytest.approx(150.0) + assert rep.by_hop[1].correct_count == 1 # only the recall=1.0 one + # 2-hop bucket: single full-recall question + assert rep.by_hop[2].n == 1 + assert rep.by_hop[2].mean_recall == pytest.approx(1.0) + assert rep.by_hop[2].correct_count == 1 + # 3-hop bucket: missed + assert rep.by_hop[3].correct_count == 0 + assert rep.by_hop[3].hit_rate == 0.0 + + +def test_build_condition_report_hit_rate_distinct_from_recall(): + """A question with recall < 1 but hit=True still counts as a hit.""" + data = { + "questions": [ + _q(1, 0.5, 100, hit=True), + _q(1, 0.5, 100, hit=True), + _q(1, 0.0, 100, hit=False), + ] + } + rep = _build_condition_report("B", "graph", data) + assert rep.by_hop[1].hit_rate == pytest.approx(2 / 3) + assert rep.by_hop[1].correct_count == 0 # no full-recall queries + + +def test_build_condition_report_overall_totals(): + data = { + "questions": [ + _q(1, 1.0, 100), + _q(2, 1.0, 200), + _q(3, 0.5, 300, hit=True), + ] + } + rep = _build_condition_report("B", "graph", data) + assert rep.total_questions == 3 + assert rep.full_recall == 2 + assert rep.partial_hits == 3 + assert rep.mean_recall == pytest.approx((1 + 1 + 0.5) / 3) + assert rep.mean_tokens == pytest.approx((100 + 200 + 300) / 3) + # tokens_per_correct = total_tokens / full_recall + assert rep.tokens_per_correct == pytest.approx(600 / 2) + + +def test_build_condition_report_no_full_recall_yields_none_tokens_per_correct(): + data = {"questions": [_q(1, 0.5, 100, hit=True)]} + rep = _build_condition_report("B", "graph", data) + assert rep.full_recall == 0 + assert rep.tokens_per_correct is None + + +def test_build_condition_report_empty_questions(): + rep = _build_condition_report("B", "graph", {"questions": []}) + assert rep.total_questions == 0 + assert rep.full_recall == 0 + assert rep.mean_recall == 0.0 + assert rep.mean_tokens == 0.0 + assert rep.tokens_per_correct is None + assert rep.by_hop == {} + + +def test_build_condition_report_missing_min_hops_defaults_to_zero(): + """Question with no min_hops field still gets grouped — into bucket 0.""" + data = {"questions": [{"recall": 1.0, "tokens": 100, "hit": True}]} + rep = _build_condition_report("B", "graph", data) + assert 0 in rep.by_hop + assert rep.by_hop[0].n == 1 + + +# ── score_cat2c — requires graph_json ───────────────────────────── + + +def test_score_cat2c_requires_graph_json(): + with pytest.raises(ValueError, match="graph_json"): + score_cat2c() # type: ignore[call-arg] + + +def test_score_cat2c_b_only(tmp_path): + """With only Condition B, no deltas computed; verdict is 'incomplete'.""" + b = _write_retrieve_json( + tmp_path / "b.json", [_q(1, 0.5, 100), _q(2, 1.0, 200)] + ) + report = score_cat2c(graph_json=b) + assert "B" in report.conditions + assert "A" not in report.conditions + assert "C" not in report.conditions + assert report.delta_B_minus_A == {} + assert report.delta_B_minus_C == {} + assert report.ratio_B_over_A == {} + assert report.verdict == "incomplete" + + +# ── A/B/C delta math ────────────────────────────────────────────── + + +def test_score_cat2c_full_abc_deltas(tmp_path): + """All three conditions, two hop depths each.""" + a = _write_retrieve_json( + tmp_path / "a.json", [_q(1, 0.8, 100), _q(2, 0.2, 150)] + ) + b = _write_retrieve_json( + tmp_path / "b.json", [_q(1, 0.9, 120), _q(2, 0.8, 180)] + ) + c = _write_retrieve_json( + tmp_path / "c.json", [_q(1, 0.85, 110), _q(2, 0.4, 160)] + ) + report = score_cat2c(flat_json=a, graph_json=b, no_structure_json=c) + + # 1-hop: B-A = +10pp recall, +20 tokens + d1a = report.delta_B_minus_A[1] + assert d1a["recall_delta_pp"] == pytest.approx(10.0) + assert d1a["tokens_delta"] == pytest.approx(20.0) + # 2-hop: B-A = +60pp recall, +30 tokens + d2a = report.delta_B_minus_A[2] + assert d2a["recall_delta_pp"] == pytest.approx(60.0) + assert d2a["tokens_delta"] == pytest.approx(30.0) + # 1-hop: B-C = +5pp + d1c = report.delta_B_minus_C[1] + assert d1c["recall_delta_pp"] == pytest.approx(5.0) + # 2-hop: B-C = +40pp + d2c = report.delta_B_minus_C[2] + assert d2c["recall_delta_pp"] == pytest.approx(40.0) + + # Ratio B/A: 1-hop = 0.9/0.8, 2-hop = 0.8/0.2 = 4.0 — grows with hops + assert report.ratio_B_over_A[1] == pytest.approx(0.9 / 0.8) + assert report.ratio_B_over_A[2] == pytest.approx(4.0) + + +def test_score_cat2c_ratio_infinite_when_a_recall_zero(tmp_path): + """If Condition A has zero recall at a depth, ratio is +inf.""" + a = _write_retrieve_json(tmp_path / "a.json", [_q(3, 0.0, 100, hit=False)]) + b = _write_retrieve_json(tmp_path / "b.json", [_q(3, 0.7, 200, hit=True)]) + report = score_cat2c(flat_json=a, graph_json=b) + assert report.ratio_B_over_A[3] == float("inf") + + +def test_score_cat2c_negative_delta_when_b_worse_than_a(tmp_path): + a = _write_retrieve_json(tmp_path / "a.json", [_q(2, 0.9, 100)]) + b = _write_retrieve_json(tmp_path / "b.json", [_q(2, 0.4, 200)]) + report = score_cat2c(flat_json=a, graph_json=b) + assert report.delta_B_minus_A[2]["recall_delta_pp"] == pytest.approx(-50.0) + assert report.delta_B_minus_A[2]["tokens_delta"] == pytest.approx(100.0) + + +def test_score_cat2c_alignment_skips_unmatched_hops(tmp_path): + """If A has 1-hop and B has 2-hop only, no delta is computed for + hops missing from either side.""" + a = _write_retrieve_json(tmp_path / "a.json", [_q(1, 0.5, 100)]) + b = _write_retrieve_json(tmp_path / "b.json", [_q(2, 0.8, 200)]) + report = score_cat2c(flat_json=a, graph_json=b) + assert report.delta_B_minus_A == {} # no overlap → no delta rows + + +# ── Verdict logic ───────────────────────────────────────────────── + + +def test_verdict_incomplete_when_no_a(): + report = Cat2cReport() + # No delta_B_minus_A populated → verdict is "incomplete" + verdict, details = _verdict(report) + assert verdict == "incomplete" + assert any("Condition A" in d for d in details) + + +def test_verdict_neutral_tax_when_b_matches_a(): + """B - A is near zero at all depths → 'neutral tax'.""" + report = Cat2cReport() + report.delta_B_minus_A = { + 1: {"recall_delta_pp": 1.0, "tokens_delta": 0}, + 2: {"recall_delta_pp": -2.0, "tokens_delta": 0}, + } + report.delta_B_minus_C = { + 1: {"recall_delta_pp": 0.0, "tokens_delta": 0}, + 2: {"recall_delta_pp": 0.0, "tokens_delta": 0}, + } + verdict, details = _verdict(report) + assert verdict == "structure is a neutral tax" + # And the B-C narration should call it "neutral tax / nothing beyond metadata" + assert any("neutral tax" in d for d in details) + + +def test_verdict_earns_complexity_when_ratio_grows(): + """B beats A at multiple depths AND the ratio grows with hop depth.""" + report = Cat2cReport() + report.delta_B_minus_A = { + 1: {"recall_delta_pp": 10.0, "tokens_delta": 50}, + 2: {"recall_delta_pp": 40.0, "tokens_delta": 100}, + 3: {"recall_delta_pp": 60.0, "tokens_delta": 200}, + } + report.ratio_B_over_A = {1: 1.1, 2: 2.0, 3: 5.0} # clearly grows + report.delta_B_minus_C = { + 1: {"recall_delta_pp": 8.0, "tokens_delta": 0}, + 2: {"recall_delta_pp": 30.0, "tokens_delta": 0}, + } + verdict, details = _verdict(report) + assert verdict == "structure earns complexity (scales with depth)" + assert any("ratio grows" in d.lower() or "spec predicts" in d.lower() for d in details) + + +def test_verdict_uniform_scale_when_b_wins_flat_but_ratio_flat(): + """B beats A but the B/A ratio is roughly constant across depths.""" + report = Cat2cReport() + report.delta_B_minus_A = { + 1: {"recall_delta_pp": 10.0, "tokens_delta": 0}, + 2: {"recall_delta_pp": 10.0, "tokens_delta": 0}, + } + # 1.5x at 1-hop, 1.5x at 2-hop — last is NOT > first * 1.2 + report.ratio_B_over_A = {1: 1.5, 2: 1.5} + verdict, _ = _verdict(report) + assert verdict == "structure adds value at uniform scale" + + +def test_verdict_harmful_when_b_loses_to_a_only(): + report = Cat2cReport() + report.delta_B_minus_A = { + 1: {"recall_delta_pp": -10.0, "tokens_delta": 100}, + 2: {"recall_delta_pp": -20.0, "tokens_delta": 200}, + } + report.ratio_B_over_A = {1: 0.6, 2: 0.5} + verdict, _ = _verdict(report) + assert verdict == "structure harmful at multi-hop" + + +def test_verdict_mixed_when_some_win_some_lose(): + report = Cat2cReport() + report.delta_B_minus_A = { + 1: {"recall_delta_pp": -10.0, "tokens_delta": 0}, + 2: {"recall_delta_pp": 20.0, "tokens_delta": 0}, + } + report.ratio_B_over_A = {1: 0.7, 2: 1.5} + verdict, _ = _verdict(report) + assert verdict == "mixed: structure helps at some depths and hurts at others" + + +def test_verdict_b_minus_c_negative_flagged_in_details(): + """Even if B beats A, a negative B-C is surfaced as 'structurally harmful' + in the details.""" + report = Cat2cReport() + report.delta_B_minus_A = { + 2: {"recall_delta_pp": 20.0, "tokens_delta": 0}, + } + report.ratio_B_over_A = {2: 1.4} + report.delta_B_minus_C = { + 2: {"recall_delta_pp": -15.0, "tokens_delta": 0}, + } + _, details = _verdict(report) + assert any("structural routing is actively harmful" in d for d in details) + + +# ── to_dict() shape contract ───────────────────────────────────── + + +def test_to_dict_keys_and_string_hop_keys(tmp_path): + """to_dict converts integer hop keys to strings (JSON-friendly).""" + a = _write_retrieve_json(tmp_path / "a.json", [_q(1, 0.5, 100)]) + b = _write_retrieve_json(tmp_path / "b.json", [_q(1, 0.7, 110)]) + report = score_cat2c(flat_json=a, graph_json=b) + d = report.to_dict() + # Top-level keys + assert { + "conditions", + "delta_B_minus_A", + "delta_B_minus_C", + "ratio_B_over_A", + "verdict", + "verdict_details", + } <= set(d.keys()) + # Hop keys serialized as strings + assert all(isinstance(k, str) for k in d["delta_B_minus_A"].keys()) + assert all(isinstance(k, str) for k in d["ratio_B_over_A"].keys()) + # The condition payload also string-keyed by hop + by_hop = d["conditions"]["B"]["by_hop"] + assert all(isinstance(k, str) for k in by_hop.keys()) + + +def test_to_dict_roundtrip_through_json(tmp_path): + """The whole report should be JSON-serializable end-to-end.""" + b = _write_retrieve_json(tmp_path / "b.json", [_q(1, 1.0, 100)]) + report = score_cat2c(graph_json=b) + payload = json.dumps(report.to_dict()) + parsed = json.loads(payload) + assert parsed["conditions"]["B"]["full_recall"] == 1 + + +# ── HopBreakdown dataclass sanity ──────────────────────────────── + + +def test_hop_breakdown_dataclass_fields(): + bk = HopBreakdown( + hops=2, + n=4, + mean_recall=0.5, + hit_rate=0.75, + mean_tokens=200.0, + correct_count=2, + ) + assert bk.hops == 2 + assert bk.correct_count == 2 diff --git a/tests/test_ontology_coherence.py b/tests/test_ontology_coherence.py new file mode 100644 index 0000000..ccb35e5 --- /dev/null +++ b/tests/test_ontology_coherence.py @@ -0,0 +1,757 @@ +"""Tests for sme.categories.ontology_coherence (Cat 8). + +Covers all five sub-tests plus the claim library and hall usage: + +- 8a Type coverage (incl. prefix matching for ``drawer:hall_X``) +- 8b Edge vocabulary (incl. case-insensitive fallback) +- 8c Schema-data alignment (top-type concentration warning, entropy + passthrough from structural_health) +- 8d Drift score and hall_usage scorer for MemPalace-shaped graphs +- 8e Claim verification: library pattern matching, untestable denylist, + inline operational_override, cross-category deferral (Cat 7/3/2b), + temporal/provenance coverage metrics +- ``Cat8Report.to_dict()`` shape contract +- ``ImpliedOntology.load()`` round-trip via a YAML fixture +""" +from __future__ import annotations + +import textwrap + +import pytest + +from sme.adapters.base import Edge, Entity +from sme.categories.ontology_coherence import ( + Cat8Report, + ClaimResult, + ImpliedOntology, + _score_claim, + _score_hall_usage, + is_untestable, + load_claim_library, + match_claim_pattern, + score_cat8, +) + + +# ── Helpers ──────────────────────────────────────────────────────── + + +def _ent(eid: str, etype: str, **props) -> Entity: + return Entity(id=eid, name=eid, entity_type=etype, properties=dict(props)) + + +def _edge(s: str, t: str, etype: str, **props) -> Edge: + return Edge(source_id=s, target_id=t, edge_type=etype, properties=dict(props)) + + +@pytest.fixture +def empty_claim_library() -> dict: + """A claim library with no claims and no untestable patterns.""" + return {"claims": [], "untestable_patterns": []} + + +# ── load_claim_library ───────────────────────────────────────────── + + +def test_load_claim_library_uses_repo_default(): + """The default path resolves relative to the package root and loads + the shipped structural_claims.yaml.""" + lib = load_claim_library() + assert "claims" in lib + assert any("hierarchical" in c.get("pattern", "") for c in lib["claims"]) + assert "untestable_patterns" in lib + + +# ── Pattern matching helpers ─────────────────────────────────────── + + +def test_match_claim_pattern_finds_hierarchical(): + lib = load_claim_library() + entry = match_claim_pattern( + "The graph is hierarchical with nested wings", lib + ) + assert entry is not None + assert entry["name"] == "Hierarchical structure" + + +def test_match_claim_pattern_returns_none_when_no_match(empty_claim_library): + assert match_claim_pattern("anything", empty_claim_library) is None + + +def test_match_claim_pattern_case_insensitive(): + lib = load_claim_library() + entry = match_claim_pattern("HIERARCHICAL", lib) + assert entry is not None + + +def test_is_untestable_flags_ux_claims(): + lib = load_claim_library() + assert is_untestable("It is intuitive and easy to use", lib) + assert is_untestable("Highly scalable", lib) + assert not is_untestable("Edges have provenance tracking", lib) + + +# ── ImpliedOntology.load() ──────────────────────────────────────── + + +def test_implied_ontology_load_round_trip(tmp_path): + p = tmp_path / "ont.yaml" + p.write_text( + textwrap.dedent( + """ + version: v0.1 + source: declared + entity_types: [drawer, wing, room] + edge_types: [member_of, tunnel] + hall_vocabulary: [code, decisions] + structural_claims: + - id: c1 + text: "The graph is hierarchical" + vocabulary_claims: + - id: v1 + text: "Hall vocabulary" + retrieval_claims: + - id: r1 + text: "Structure improves retrieval" + """ + ).strip() + ) + ont = ImpliedOntology.load(p) + assert ont.version == "v0.1" + assert ont.source == "declared" + assert "drawer" in ont.entity_types + assert "tunnel" in ont.edge_types + assert ont.hall_vocabulary == ["code", "decisions"] + assert ont.structural_claims[0]["id"] == "c1" + assert ont.retrieval_claims[0]["text"] == "Structure improves retrieval" + + +def test_implied_ontology_load_missing_fields(tmp_path): + """Missing optional keys default to empty lists, no crash.""" + p = tmp_path / "ont.yaml" + p.write_text("version: v0.1\nsource: inferred\n") + ont = ImpliedOntology.load(p) + assert ont.entity_types == [] + assert ont.edge_types == [] + assert ont.structural_claims == [] + + +# ── 8a Type coverage ────────────────────────────────────────────── + + +def test_8a_type_coverage_exact_match(empty_claim_library): + ont = ImpliedOntology( + version="t", + source="declared", + entity_types=["wing", "room"], + ) + entities = [_ent("e1", "wing"), _ent("e2", "room")] + report = score_cat8( + ont, entities, [], {"edge_type_entropy_bits": 0.0}, + claim_library=empty_claim_library, + ) + assert report.type_coverage == pytest.approx(1.0) + assert sorted(report.types_found) == ["room", "wing"] + assert report.types_missing == [] + + +def test_8a_type_coverage_prefix_match(empty_claim_library): + """A declared 'drawer' matches 'drawer:hall_code' via prefix rule.""" + ont = ImpliedOntology( + version="t", source="declared", entity_types=["drawer"] + ) + entities = [_ent("e1", "drawer:hall_code")] + report = score_cat8( + ont, entities, [], {}, claim_library=empty_claim_library, + ) + assert "drawer" in report.types_found + assert report.types_missing == [] + + +def test_8a_undeclared_types_surfaced(empty_claim_library): + ont = ImpliedOntology( + version="t", source="declared", entity_types=["wing"] + ) + entities = [_ent("e1", "wing"), _ent("e2", "drawer:hall_code")] + report = score_cat8( + ont, entities, [], {}, claim_library=empty_claim_library, + ) + # 'drawer:hall_code' has prefix 'drawer' which isn't declared + assert "drawer:hall_code" in report.types_undeclared + + +def test_8a_missing_declared_type(empty_claim_library): + ont = ImpliedOntology( + version="t", source="declared", entity_types=["wing", "ghost"] + ) + entities = [_ent("e1", "wing")] + report = score_cat8( + ont, entities, [], {}, claim_library=empty_claim_library, + ) + assert "ghost" in report.types_missing + assert report.type_coverage == pytest.approx(0.5) + + +def test_8a_empty_declared_types_yields_perfect_coverage(empty_claim_library): + """No declared types → vacuously 1.0 coverage.""" + ont = ImpliedOntology(version="t", source="inferred") + entities = [_ent("e1", "x")] + report = score_cat8( + ont, entities, [], {}, claim_library=empty_claim_library, + ) + assert report.type_coverage == pytest.approx(1.0) + + +# ── 8b Edge vocabulary ──────────────────────────────────────────── + + +def test_8b_edge_vocabulary_exact_match(empty_claim_library): + ont = ImpliedOntology( + version="t", source="declared", + entity_types=["wing"], + edge_types=["member_of", "tunnel"], + ) + entities = [_ent("a", "wing"), _ent("b", "wing")] + edges = [ + _edge("a", "b", "member_of"), + _edge("a", "b", "tunnel"), + ] + report = score_cat8( + ont, entities, edges, {}, claim_library=empty_claim_library, + ) + assert report.edge_vocabulary_coverage == pytest.approx(1.0) + + +def test_8b_edge_case_insensitive_fallback(empty_claim_library): + ont = ImpliedOntology( + version="t", source="declared", + entity_types=["wing"], + edge_types=["Member_Of"], + ) + entities = [_ent("a", "wing"), _ent("b", "wing")] + edges = [_edge("a", "b", "member_of")] + report = score_cat8( + ont, entities, edges, {}, claim_library=empty_claim_library, + ) + # Different case but same word — falls back via case-insensitive match + assert "Member_Of" in report.edges_found + assert report.edge_vocabulary_coverage == pytest.approx(1.0) + + +def test_8b_missing_and_undeclared(empty_claim_library): + ont = ImpliedOntology( + version="t", source="declared", + entity_types=["wing"], edge_types=["member_of", "missing_edge"], + ) + entities = [_ent("a", "wing"), _ent("b", "wing")] + edges = [_edge("a", "b", "member_of"), _edge("a", "b", "tunnel")] + report = score_cat8( + ont, entities, edges, {}, claim_library=empty_claim_library, + ) + assert "missing_edge" in report.edges_missing + assert "tunnel" in report.edges_undeclared + + +def test_8b_empty_declared_yields_perfect_coverage(empty_claim_library): + ont = ImpliedOntology(version="t", source="inferred", entity_types=["x"]) + entities = [_ent("a", "x")] + edges = [_edge("a", "a", "self")] + report = score_cat8( + ont, entities, edges, {}, claim_library=empty_claim_library, + ) + assert report.edge_vocabulary_coverage == pytest.approx(1.0) + + +# ── 8c Schema-data alignment ────────────────────────────────────── + + +def test_8c_concentration_warning_above_threshold(empty_claim_library): + ont = ImpliedOntology(version="t", source="declared", entity_types=["x"]) + # 9 of type 'x', 1 of type 'y' → 90% > 80% triggers warning + entities = [_ent(f"e{i}", "x") for i in range(9)] + [_ent("eY", "y")] + report = score_cat8( + ont, entities, [], {}, claim_library=empty_claim_library, + ) + assert report.entity_type_concentration is not None + assert report.entity_type_concentration["top_type"] == "x" + assert report.entity_type_concentration["fraction"] == pytest.approx(0.9) + assert report.concentration_warning is not None + + +def test_8c_no_warning_below_threshold(empty_claim_library): + ont = ImpliedOntology(version="t", source="declared", entity_types=["x"]) + # 5 of 'x', 5 of 'y' → 50%, no warning + entities = [_ent(f"e{i}", "x") for i in range(5)] + [ + _ent(f"y{i}", "y") for i in range(5) + ] + report = score_cat8( + ont, entities, [], {}, claim_library=empty_claim_library, + ) + assert report.concentration_warning is None + + +def test_8c_entropy_bits_taken_from_structural_health(empty_claim_library): + ont = ImpliedOntology(version="t", source="declared", entity_types=["x"]) + entities = [_ent("e1", "x")] + report = score_cat8( + ont, entities, [], + {"edge_type_entropy_bits": 2.5}, + claim_library=empty_claim_library, + ) + assert report.edge_type_entropy_bits == pytest.approx(2.5) + + +def test_8c_no_entities_no_concentration(empty_claim_library): + ont = ImpliedOntology(version="t", source="declared") + report = score_cat8( + ont, [], [], {}, claim_library=empty_claim_library, + ) + assert report.entity_type_concentration is None + assert report.concentration_warning is None + + +# ── 8d Drift score ───────────────────────────────────────────────── + + +def test_8d_drift_zero_when_all_declared_present(empty_claim_library): + ont = ImpliedOntology( + version="t", source="declared", + entity_types=["wing"], edge_types=["member_of"], + ) + entities = [_ent("a", "wing"), _ent("b", "wing")] + edges = [_edge("a", "b", "member_of")] + report = score_cat8( + ont, entities, edges, {}, claim_library=empty_claim_library, + ) + assert report.drift_score == pytest.approx(0.0) + + +def test_8d_drift_proportional_to_missing(empty_claim_library): + ont = ImpliedOntology( + version="t", source="declared", + entity_types=["wing", "ghost"], + edge_types=["member_of", "phantom"], + ) + entities = [_ent("a", "wing"), _ent("b", "wing")] + edges = [_edge("a", "b", "member_of")] + report = score_cat8( + ont, entities, edges, {}, claim_library=empty_claim_library, + ) + # 2 of 4 declared names actually used → 50% drift + assert report.drift_score == pytest.approx(0.5) + + +def test_8d_drift_zero_when_nothing_declared(empty_claim_library): + ont = ImpliedOntology(version="t", source="inferred") + report = score_cat8( + ont, [], [], {}, claim_library=empty_claim_library, + ) + assert report.drift_score == pytest.approx(0.0) + + +# ── 8d hall_usage scorer (MemPalace-specific) ───────────────────── + + +def test_hall_usage_empty_drawers(): + out = _score_hall_usage([], ["code", "decisions"]) + assert out["total_drawers"] == 0 + assert out["fraction_populated"] == 0.0 + assert out["distribution"] == {} + + +def test_hall_usage_from_properties(): + drawers = [ + _ent("d1", "drawer", hall="code"), + _ent("d2", "drawer", hall="decisions"), + _ent("d3", "drawer", hall=""), # unpopulated + ] + out = _score_hall_usage(drawers, ["code", "decisions"]) + assert out["total_drawers"] == 3 + assert out["populated_count"] == 2 + assert out["fraction_populated"] == pytest.approx(2 / 3) + assert out["distribution"] == {"code": 1, "decisions": 1} + assert out["in_vocabulary_count"] == 2 + + +def test_hall_usage_falls_back_to_entity_type_suffix(): + """If hall property is empty, the suffix after ``drawer:`` is used.""" + drawers = [ + _ent("d1", "drawer:hall_code"), + _ent("d2", "drawer:untyped"), # 'untyped' is filtered out + ] + out = _score_hall_usage(drawers, ["hall_code"]) + assert out["populated_count"] == 1 + assert "hall_code" in out["distribution"] + # 'untyped' is filtered, not counted + assert "untyped" not in out["distribution"] + + +def test_hall_usage_in_vocab_accepts_prefix_form(): + """The 'hall_X' form matches a 'X' declared vocabulary entry.""" + drawers = [_ent("d1", "drawer", hall="hall_code")] + out = _score_hall_usage(drawers, ["code"]) + assert out["in_vocabulary_count"] == 1 + + +def test_hall_usage_ignores_non_drawers(): + """Entities whose type doesn't start with 'drawer' are excluded.""" + ents = [ + _ent("w1", "wing"), + _ent("r1", "room"), + _ent("d1", "drawer", hall="code"), + ] + out = _score_hall_usage(ents, ["code"]) + assert out["total_drawers"] == 1 + + +def test_score_cat8_populates_hall_usage_when_vocab_declared(empty_claim_library): + ont = ImpliedOntology( + version="t", source="declared", + entity_types=["drawer"], hall_vocabulary=["code"], + ) + entities = [_ent("d1", "drawer", hall="code")] + report = score_cat8( + ont, entities, [], {}, claim_library=empty_claim_library, + ) + assert report.hall_usage is not None + assert report.hall_usage["populated_count"] == 1 + + +def test_score_cat8_hall_usage_none_when_no_vocab(empty_claim_library): + ont = ImpliedOntology( + version="t", source="declared", entity_types=["drawer"] + ) + entities = [_ent("d1", "drawer", hall="code")] + report = score_cat8( + ont, entities, [], {}, claim_library=empty_claim_library, + ) + assert report.hall_usage is None + + +# ── 8e Claim verification ────────────────────────────────────────── + + +def test_claim_untestable_via_denylist(): + """An 'easy to use' UX claim matches the untestable pattern.""" + lib = load_claim_library() + claim = {"id": "c_ux", "text": "intuitive and easy to use"} + result = _score_claim(claim, [], [], {}, lib) + assert result.status == "untestable" + assert "denylist" in result.operational_definition + + +def test_claim_untestable_when_no_library_match(empty_claim_library): + claim = {"id": "c_x", "text": "some unmatched claim"} + result = _score_claim(claim, [], [], {}, empty_claim_library) + assert result.status == "untestable" + + +def test_claim_temporal_coverage_pass(): + """An edge with _created_at → temporal claim passes when coverage > 0.5.""" + lib = load_claim_library() + claim = {"id": "c_t", "text": "Edges track temporal validity"} + edges = [ + _edge("a", "b", "x", _created_at="2026-01-01"), + _edge("a", "b", "y", _created_at="2026-01-02"), + _edge("a", "b", "z"), + ] + result = _score_claim(claim, [], edges, {}, lib) + assert result.status == "pass" + assert result.metrics["fraction_edges_with_created_at"] == pytest.approx(2 / 3) + + +def test_claim_temporal_coverage_fail_below_threshold(): + lib = load_claim_library() + claim = {"id": "c_t", "text": "Temporal tracking is supported"} + edges = [ + _edge("a", "b", "x"), + _edge("a", "b", "y"), + _edge("a", "b", "z", _created_at="2026-01-01"), + ] + result = _score_claim(claim, [], edges, {}, lib) + assert result.status == "fail" + + +def test_claim_provenance_coverage_pass(): + lib = load_claim_library() + claim = {"id": "c_p", "text": "We provide provenance tracking"} + edges = [_edge("a", "b", "x", _created_by="extractor_v1") for _ in range(3)] + result = _score_claim(claim, [], edges, {}, lib) + assert result.status == "pass" + + +def test_claim_cat7_deferral_skipped_without_results(): + lib = load_claim_library() + claim = {"id": "c_r", "text": "Structure improves retrieval"} + result = _score_claim(claim, [], [], {}, lib, cat7_results=None) + assert result.status == "skipped" + assert "Cat 7" in result.notes + + +def test_claim_cat7_pass_when_recall_lifts_more_than_5pp(): + lib = load_claim_library() + claim = {"id": "c_r", "text": "Structure improves retrieval"} + result = _score_claim( + claim, [], [], {}, lib, + cat7_results={"graph_mean_recall": 0.70, "flat_mean_recall": 0.60}, + ) + assert result.status == "pass" + assert result.metrics["delta_recall"] == pytest.approx(0.10) + + +def test_claim_cat7_fail_when_recall_within_5pp(): + lib = load_claim_library() + claim = {"id": "c_r", "text": "boost retrieval"} + result = _score_claim( + claim, [], [], {}, lib, + cat7_results={"graph_mean_recall": 0.62, "flat_mean_recall": 0.60}, + ) + assert result.status == "fail" + + +def test_claim_cat3_pass_when_contradiction_pairs_present(): + lib = load_claim_library() + claim = {"id": "c_c", "text": "We detect contradictions"} + result = _score_claim( + claim, [], [], {}, lib, + cat3_results={"contradiction_pairs": 3}, + ) + assert result.status == "pass" + + +def test_claim_cat3_skipped_without_results(): + lib = load_claim_library() + claim = {"id": "c_c", "text": "Disagreement surfaces contradictions"} + result = _score_claim(claim, [], [], {}, lib, cat3_results=None) + assert result.status == "skipped" + + +def test_claim_cat2b_pass_above_threshold(): + lib = load_claim_library() + claim = {"id": "c_d", "text": "We dedup entities"} + result = _score_claim( + claim, [], [], {}, lib, + cat2b_results={"canonicalization_recall": 0.7}, + ) + assert result.status == "pass" + + +def test_claim_inline_override_cat7_delta_recall_pass(): + """operational_override with metric=cat7_delta_recall — 'not a moat' + passes when |delta| < 10pp.""" + claim = { + "id": "c_o", + "text": "Structure is not a moat", + "operational_override": { + "metric": "cat7_delta_recall", + "pass_condition": "abs delta < 10pp", + "description": "within ±10pp band", + }, + } + result = _score_claim( + claim, [], [], {}, + {"claims": [], "untestable_patterns": []}, + cat7_results={"graph_mean_recall": 0.55, "flat_mean_recall": 0.52}, + ) + assert result.status == "pass" + + +def test_claim_inline_override_cat7_delta_recall_fail(): + claim = { + "id": "c_o", + "text": "Structure is not a moat", + "operational_override": { + "metric": "cat7_delta_recall", + "pass_condition": "abs delta < 10pp", + }, + } + result = _score_claim( + claim, [], [], {}, + {"claims": [], "untestable_patterns": []}, + cat7_results={"graph_mean_recall": 0.75, "flat_mean_recall": 0.50}, + ) + assert result.status == "fail" + + +def test_claim_inline_override_unknown_metric(): + claim = { + "id": "c_o", + "text": "anything", + "operational_override": {"metric": "unknown_metric"}, + } + result = _score_claim( + claim, [], [], {}, + {"claims": [], "untestable_patterns": []}, + ) + assert result.status == "untestable" + assert "no override handler" in result.notes + + +# ── score_cat8 end-to-end ──────────────────────────────────────── + + +def test_score_cat8_perfect_alignment(empty_claim_library): + """Graph matches the declared ontology perfectly.""" + ont = ImpliedOntology( + version="t", source="declared", + entity_types=["wing", "room"], + edge_types=["member_of"], + ) + entities = [ + _ent("w1", "wing"), + _ent("w2", "wing"), + _ent("r1", "room"), + ] + edges = [_edge("r1", "w1", "member_of")] + report = score_cat8( + ont, entities, edges, {"edge_type_entropy_bits": 0.0}, + claim_library=empty_claim_library, + ) + assert report.type_coverage == pytest.approx(1.0) + assert report.edge_vocabulary_coverage == pytest.approx(1.0) + assert report.drift_score == pytest.approx(0.0) + assert report.types_missing == [] + assert report.edges_missing == [] + + +def test_score_cat8_complete_mismatch(empty_claim_library): + """Declared types don't appear; graph has only undeclared ones.""" + ont = ImpliedOntology( + version="t", source="declared", + entity_types=["wing", "room"], + edge_types=["member_of"], + ) + entities = [_ent("a", "alien"), _ent("b", "stranger")] + edges = [_edge("a", "b", "weird")] + report = score_cat8( + ont, entities, edges, {}, claim_library=empty_claim_library, + ) + assert report.type_coverage == pytest.approx(0.0) + assert report.edge_vocabulary_coverage == pytest.approx(0.0) + assert report.drift_score == pytest.approx(1.0) + assert sorted(report.types_undeclared) == ["alien", "stranger"] + + +def test_score_cat8_empty_graph(empty_claim_library): + """No entities, no edges — everything is missing, coverage is zero.""" + ont = ImpliedOntology( + version="t", source="declared", + entity_types=["wing"], edge_types=["member_of"], + ) + report = score_cat8( + ont, [], [], {}, claim_library=empty_claim_library, + ) + assert report.type_coverage == pytest.approx(0.0) + assert report.edge_vocabulary_coverage == pytest.approx(0.0) + assert report.entity_type_concentration is None + + +def test_score_cat8_claim_pass_rate(empty_claim_library): + """Two structural claims: one passes (temporal), one fails (provenance).""" + lib = load_claim_library() + ont = ImpliedOntology( + version="t", source="declared", entity_types=["x"], + structural_claims=[ + {"id": "c1", "text": "Edges have temporal validity"}, + {"id": "c2", "text": "Provenance lineage tracking"}, + ], + ) + entities = [_ent("a", "x"), _ent("b", "x")] + # All edges have _created_at (temporal passes); none have _created_by + edges = [ + _edge("a", "b", "e1", _created_at="2026-01-01"), + _edge("a", "b", "e2", _created_at="2026-01-02"), + ] + report = score_cat8(ont, entities, edges, {}, claim_library=lib) + assert report.claims_tested == 2 + assert report.claims_passed == 1 + assert report.claims_pass_rate == pytest.approx(0.5) + + +def test_score_cat8_vocabulary_claim_passes_when_halls_populated( + empty_claim_library, +): + """The 'five_standard_halls' vocab claim passes if ≥50% drawers have hall.""" + ont = ImpliedOntology( + version="t", source="declared", + entity_types=["drawer"], + hall_vocabulary=["code", "decisions"], + vocabulary_claims=[ + {"id": "five_standard_halls", "text": "Five halls declared"} + ], + ) + entities = [ + _ent("d1", "drawer", hall="code"), + _ent("d2", "drawer", hall="decisions"), + ] + report = score_cat8( + ont, entities, [], {}, claim_library=empty_claim_library, + ) + hall_claim = next(c for c in report.claims if c.claim_id == "five_standard_halls") + assert hall_claim.status == "pass" + assert hall_claim.metrics["populated_count"] == 2 + + +def test_score_cat8_vocabulary_claim_fails_when_halls_empty(empty_claim_library): + ont = ImpliedOntology( + version="t", source="declared", + entity_types=["drawer"], + hall_vocabulary=["code"], + vocabulary_claims=[ + {"id": "five_standard_halls", "text": "Five halls declared"} + ], + ) + entities = [_ent("d1", "drawer")] # no hall property + report = score_cat8( + ont, entities, [], {}, claim_library=empty_claim_library, + ) + hall_claim = next(c for c in report.claims if c.claim_id == "five_standard_halls") + assert hall_claim.status == "fail" + + +def test_score_cat8_introspection_score_is_zero_by_default(empty_claim_library): + ont = ImpliedOntology(version="t", source="declared") + report = score_cat8( + ont, [], [], {}, claim_library=empty_claim_library, + ) + assert report.introspection_score == 0.0 + assert report.introspection_available == [] + + +# ── to_dict() shape contract ───────────────────────────────────── + + +def test_to_dict_top_level_keys(empty_claim_library): + ont = ImpliedOntology(version="t", source="declared", entity_types=["x"]) + entities = [_ent("a", "x")] + report = score_cat8( + ont, entities, [], {}, claim_library=empty_claim_library, + ) + d = report.to_dict() + assert { + "8a_type_coverage", + "8b_edge_vocabulary", + "8c_schema_alignment", + "8d_drift", + "8e_claims", + "introspection", + } <= set(d.keys()) + + +def test_to_dict_claim_detail_round_trip(empty_claim_library): + report = Cat8Report() + report.claims = [ + ClaimResult( + claim_id="x", + claim_text="foo", + status="pass", + operational_definition="def", + metrics={"m": 1.0}, + notes="ok", + ) + ] + d = report.to_dict() + detail = d["8e_claims"]["detail"][0] + assert detail["id"] == "x" + assert detail["status"] == "pass" + assert detail["metrics"]["m"] == 1.0 diff --git a/tests/test_rlm_adapter.py b/tests/test_rlm_adapter.py new file mode 100644 index 0000000..17e1d06 --- /dev/null +++ b/tests/test_rlm_adapter.py @@ -0,0 +1,289 @@ +"""Unit tests for RlmAdapter. + +We don't exercise the real RLM/portkey/openai backend in tests because +that would burn API credits and require live network. Instead we patch +the RLM class so completion() returns a stubbed response and triggers +mempalace_search via the captured tool callable. + +A live A/B benchmark (rlm vs familiar on jp-realm-v0.1) belongs in +baselines/, not in unit tests — it's a research run, not a contract +check. + +The `rlm` package itself is not installable from PyPI under that name +(the distribution is `rlms`, source on GitHub) and requires Python +>=3.11, so it sits behind the `[rlm]` extra in pyproject. On a fresh +clone without that extra installed, this whole test module skips. +Install via: + pip install -e ".[rlm]" +""" + +from __future__ import annotations + +import pytest + +pytest.importorskip("rlm") + +import json # noqa: E402 (intentionally after importorskip) +from unittest.mock import MagicMock, patch # noqa: E402 +from urllib import request as _urlrequest # noqa: E402 + + +def _stub_palace_response(results: list[dict]) -> bytes: + return json.dumps({"query": "x", "results": results}).encode("utf-8") + + +def test_query_aggregates_tool_calls_into_context_string(monkeypatch): + captured_tool: list = [] + + class _StubRLM: + def __init__(self, *args, **kwargs): + # Capture the mempalace_search tool the adapter passed to RLM. + tools = kwargs["custom_tools"] + self._search = tools["mempalace_search"]["tool"] + captured_tool.append(self._search) + + def completion(self, q): + # Simulate the LM calling mempalace_search as part of its REPL. + self._search("hermes-agent", limit=2) + self._search("rlm recursive", limit=2) + return MagicMock(response="rlm-orchestrated answer") + + # Stub palace-daemon over urllib — return two different result sets. + call_count = {"n": 0} + + class _Resp: + def __init__(self, body): self._body = body + def __enter__(self): return self + def __exit__(self, *a): return False + def read(self): return self._body + + def _stub_urlopen(req, timeout=None): + call_count["n"] += 1 + if "hermes" in req.full_url: + return _Resp(_stub_palace_response([ + {"drawer_id": "drawer_x1", "text": "hermes-agent is a JP fork", "wing": "projects", "room": "forks", "similarity": 0.78}, + ])) + return _Resp(_stub_palace_response([ + {"drawer_id": "drawer_x2", "text": "rlm is a recursive language model paradigm", "wing": "projects", "room": "forks", "similarity": 0.82}, + ])) + + with patch("rlm.RLM", _StubRLM), patch.object(_urlrequest, "urlopen", _stub_urlopen): + from sme.adapters.rlm_adapter import RlmAdapter + a = RlmAdapter(api_url="http://test:8085", api_key="k", backend="openai") + out = a.query("tell me about hermes-agent and rlm") + + assert out.error is None + assert out.answer == "rlm-orchestrated answer" + # Both captured tool calls' results land in context_string. + assert "hermes-agent is a JP fork" in out.context_string + assert "rlm is a recursive language model paradigm" in out.context_string + # The synthesized answer is also there, so substring-scorers that + # match on what the system surfaced see both retrieval + synthesis. + assert "── RLM answer ──" in out.context_string + assert "rlm-orchestrated answer" in out.context_string + # And in retrieved_entities, in call order. + ids = [e.id for e in out.retrieved_entities] + assert ids == ["drawer_x1", "drawer_x2"] + # retrieval_path notes the rlm step + tool count (single string entry, + # cli formats it via '; '.join). + assert "rlm_completion" in out.retrieval_path[0] + assert "2 tool calls" in out.retrieval_path[0] + + +def test_source_file_preserved_in_capture_and_context_string(): + """Regression: source_file must round-trip from daemon response → trimmed + dict (visible to LLM) → context_string (visible to substring scorer). + + Pre-fix bug: source_file was dropped, so file-shaped expected_sources + in SME corpora silently scored 0 on RLM runs even when retrieval + landed the right drawer. Fixed 2026-05-16. + """ + captured_tool: list = [] + + class _StubRLM: + def __init__(self, *args, **kwargs): + tools = kwargs["custom_tools"] + self._search = tools["mempalace_search"]["tool"] + captured_tool.append(self._search) + + def completion(self, q): + results = self._search("vlan printer", limit=1) + # Confirm LLM sees source_file in the tool return. + assert results[0]["source_file"] == "VLAN-11-printer-notes.md" + return MagicMock(response="answered using printer notes") + + class _Resp: + def __init__(self, body): self._body = body + def __enter__(self): return self + def __exit__(self, *a): return False + def read(self): return self._body + + def _stub_urlopen(req, timeout=None): + return _Resp(_stub_palace_response([ + {"drawer_id": "d-vlan", "text": "VLAN 11 print server", "wing": "homelab", + "room": "infrastructure", "source_file": "VLAN-11-printer-notes.md", + "similarity": 0.9}, + ])) + + with patch("rlm.RLM", _StubRLM), patch.object(_urlrequest, "urlopen", _stub_urlopen): + from sme.adapters.rlm_adapter import RlmAdapter + a = RlmAdapter(api_url="http://test:8085", api_key="k", backend="openai") + out = a.query("which file documents the printer VLAN?") + + # source_file lands in context_string for the substring scorer. + assert "VLAN-11-printer-notes.md" in out.context_string + # And in the captured entity so Cat 7 / 8 readings can use it. + assert out.retrieved_entities[0].id == "d-vlan" + + +def test_invocation_mode_forced_prepends_directive(monkeypatch): + """The 'forced' mode wraps the standard RLM system prompt with an + invocation-required directive without losing the rest of the prompt. + """ + captured_kwargs: dict = {} + + class _StubRLM: + def __init__(self, *args, **kwargs): + captured_kwargs.update(kwargs) + + def completion(self, q): + return MagicMock(response="ok") + + with patch("rlm.RLM", _StubRLM): + from sme.adapters.rlm_adapter import RlmAdapter + RlmAdapter( + api_url="http://test:8085", api_key="k", + backend="openai", invocation_mode="forced", + ) + + sp = captured_kwargs.get("custom_system_prompt", "") + assert "MANDATORY RETRIEVAL CONSTRAINT" in sp + assert "mempalace_search" in sp + # Default RLM scaffolding is still present (REPL, FINAL, etc.). + assert "REPL" in sp + assert "FINAL" in sp + + +def test_invocation_mode_grounded_prepends_directive(): + """The 'grounded' mode wraps with a source-quoting directive.""" + captured_kwargs: dict = {} + + class _StubRLM: + def __init__(self, *args, **kwargs): + captured_kwargs.update(kwargs) + + def completion(self, q): + return MagicMock(response="ok") + + with patch("rlm.RLM", _StubRLM): + from sme.adapters.rlm_adapter import RlmAdapter + RlmAdapter( + api_url="http://test:8085", api_key="k", + backend="openai", invocation_mode="grounded", + ) + + sp = captured_kwargs.get("custom_system_prompt", "") + assert "MANDATORY GROUNDING CONSTRAINT" in sp + assert "source filename" in sp + + +def test_invocation_mode_default_no_custom_prompt(): + """Default behavior (no invocation_mode) passes RLM's own prompt through.""" + captured_kwargs: dict = {} + + class _StubRLM: + def __init__(self, *args, **kwargs): + captured_kwargs.update(kwargs) + + def completion(self, q): + return MagicMock(response="ok") + + with patch("rlm.RLM", _StubRLM): + from sme.adapters.rlm_adapter import RlmAdapter + RlmAdapter(api_url="http://test:8085", api_key="k", backend="openai") + + # No custom_system_prompt key when invocation_mode is None. + assert "custom_system_prompt" not in captured_kwargs + + +def test_query_capture_resets_between_calls(): + """Two consecutive query() calls should not leak entities between them.""" + + class _StubRLM: + def __init__(self, *args, **kwargs): + self._search = kwargs["custom_tools"]["mempalace_search"]["tool"] + + def completion(self, q): + self._search("first call", limit=1) + return MagicMock(response="ok") + + class _Resp: + def __enter__(self): return self + def __exit__(self, *a): return False + def read(self): return _stub_palace_response([ + {"drawer_id": "d1", "text": "t", "wing": "w", "room": "r", "similarity": 0.5}, + ]) + + def _stub_urlopen(req, timeout=None): + return _Resp() + + with patch("rlm.RLM", _StubRLM), patch.object(_urlrequest, "urlopen", _stub_urlopen): + from sme.adapters.rlm_adapter import RlmAdapter + a = RlmAdapter(api_url="http://test:8085", api_key="k", backend="openai") + first = a.query("q1") + second = a.query("q2") + + # Each call captures exactly one drawer; second call doesn't see first's. + assert len(first.retrieved_entities) == 1 + assert len(second.retrieved_entities) == 1 + + +def test_search_failure_returned_as_error_dict_not_raised(): + """Network failures inside mempalace_search shouldn't crash query().""" + + class _StubRLM: + def __init__(self, *args, **kwargs): + self._search = kwargs["custom_tools"]["mempalace_search"]["tool"] + + def completion(self, q): + results = self._search("anything") + return MagicMock(response=f"saw {len(results)} candidates") + + def _stub_urlopen(req, timeout=None): + raise OSError("connection refused") + + with patch("rlm.RLM", _StubRLM), patch.object(_urlrequest, "urlopen", _stub_urlopen): + from sme.adapters.rlm_adapter import RlmAdapter + a = RlmAdapter(api_url="http://test:8085", api_key="k", backend="openai") + out = a.query("x") + + assert out.error is None # query() didn't crash + # The failure is visible in the answer (RLM saw the error dict and stringified it). + assert "saw 1 candidates" in out.answer # one capture entry, the error one + # No entities captured (the failure dict has no drawer_id). + assert out.retrieved_entities == [] + + +def test_get_graph_snapshot_returns_empty(): + """RLM doesn't maintain a graph view — Cat 8 is N/A.""" + class _StubRLM: + def __init__(self, *args, **kwargs): pass + + with patch("rlm.RLM", _StubRLM): + from sme.adapters.rlm_adapter import RlmAdapter + a = RlmAdapter(api_url="http://x", api_key="k") + ents, edges = a.get_graph_snapshot() + assert ents == [] + assert edges == [] + + +def test_ingest_corpus_is_skipped(): + class _StubRLM: + def __init__(self, *args, **kwargs): pass + + with patch("rlm.RLM", _StubRLM): + from sme.adapters.rlm_adapter import RlmAdapter + a = RlmAdapter(api_url="http://x", api_key="k") + out = a.ingest_corpus([{"id": "ignored"}]) + assert out["skipped"] is True + assert out["entities_created"] == 0