Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/neurostack/cli/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def cmd_stats(args):

def cmd_prediction_errors(args):
from ..schema import DB_PATH, get_db
from ..search import _normalize_workspace
from ..search import PREDICTION_ERROR_MIN_OCCURRENCES, _normalize_workspace
conn = get_db(DB_PATH)

if args.resolve:
Expand Down Expand Up @@ -534,10 +534,11 @@ def cmd_prediction_errors(args):
FROM prediction_errors
{where}
GROUP BY note_path, error_type
HAVING COUNT(*) >= ?
ORDER BY occurrences DESC, avg_distance DESC
LIMIT ?
""",
params + [args.limit],
params + [PREDICTION_ERROR_MIN_OCCURRENCES, args.limit],
).fetchall()

total_where = "WHERE resolved_at IS NULL"
Expand All @@ -547,8 +548,14 @@ def cmd_prediction_errors(args):
total_params.append(ws + "/")

total = conn.execute(
f"SELECT COUNT(DISTINCT note_path) FROM prediction_errors {total_where}",
total_params,
f"""
SELECT COUNT(*) FROM (
SELECT note_path FROM prediction_errors {total_where}
GROUP BY note_path, error_type
HAVING COUNT(*) >= ?
)
""",
total_params + [PREDICTION_ERROR_MIN_OCCURRENCES],
).fetchone()[0]

if args.json:
Expand Down
27 changes: 24 additions & 3 deletions src/neurostack/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@
# distance = 1 - cosine_sim; values above 0.62 (sim < 0.38) indicate high prediction error.
PREDICTION_ERROR_SIM_THRESHOLD = 0.38

# Upper bound on similarity for a contextual_mismatch flag. The mismatch branch only
# fires when the top note is BOTH outside the query's context set AND a weak fit
# (sim in [PREDICTION_ERROR_SIM_THRESHOLD, this]). Without this ceiling the branch
# flagged correct retrievals — exact-title hits with strong cosine — purely because
# they were absent from the recall-limited in_context_notes boost set.
CONTEXTUAL_MISMATCH_MAX_SIM = 0.45

# A single ad-hoc query surprising a note is query difficulty, not note health.
# Surface (and demote on) a flag only once a note has surprised this many distinct
# retrieval events — the recurrence is what distinguishes a defective note from a
# weak/exploratory query that hit the least-bad target once.
PREDICTION_ERROR_MIN_OCCURRENCES = 2


def log_prediction_error(
conn: sqlite3.Connection,
Expand Down Expand Up @@ -749,14 +762,17 @@ def hybrid_search(
#
# Demotion is bounded: score *= 1 / (1 + 0.1 * error_count)
# 1 error → 0.91x, 3 errors → 0.77x, 10 errors → 0.50x
# Only notes that have surprised >= PREDICTION_ERROR_MIN_OCCURRENCES distinct
# retrieval events are demoted; a single ad-hoc flag is query noise and must
# not deprioritise an otherwise-correct note in future searches.
if meta_paths:
try:
placeholders = ",".join("?" * len(meta_paths))
error_rows = conn.execute(
f"SELECT note_path, COUNT(*) as cnt FROM prediction_errors "
f"WHERE note_path IN ({placeholders}) AND resolved_at IS NULL "
f"GROUP BY note_path",
meta_paths,
f"GROUP BY note_path HAVING COUNT(*) >= ?",
meta_paths + [PREDICTION_ERROR_MIN_OCCURRENCES],
).fetchall()
error_counts = {r["note_path"]: r["cnt"] for r in error_rows}
for r in valid_results:
Expand Down Expand Up @@ -873,7 +889,12 @@ def hybrid_search(
log_prediction_error(
conn, top["note_path"], query, top_cosine, "low_overlap", context
)
elif context and in_context_notes and top["note_path"] not in in_context_notes:
elif (
context
and in_context_notes
and top["note_path"] not in in_context_notes
and top_cosine < CONTEXTUAL_MISMATCH_MAX_SIM
):
log_prediction_error(
conn, top["note_path"], query, top_cosine, "contextual_mismatch", context
)
Expand Down
22 changes: 17 additions & 5 deletions src/neurostack/tools/search_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,12 @@ def vault_prediction_errors(

Error types:
- low_overlap: cosine distance > 0.62 — note is semantically distant from what retrieved it
- contextual_mismatch: note surfaced outside its expected domain context
- contextual_mismatch: note surfaced outside its expected domain context AND was only a
weak fit (sim < 0.45) — a strong hit outside the context boost set is not a mismatch

Only notes that have surprised >= 2 distinct retrieval events are surfaced: a single
ad-hoc flag reflects query difficulty, not note health. Single-occurrence rows still
accumulate in the log toward that threshold.

Args:
error_type: Filter by type — "low_overlap" or "contextual_mismatch". None = all.
Expand All @@ -498,7 +503,7 @@ def vault_prediction_errors(
results (e.g. "work/acme-cloud")
"""
from ..schema import DB_PATH, get_db
from ..search import _normalize_workspace
from ..search import PREDICTION_ERROR_MIN_OCCURRENCES, _normalize_workspace

conn = get_db(DB_PATH)

Expand Down Expand Up @@ -534,10 +539,11 @@ def vault_prediction_errors(
FROM prediction_errors
{where}
GROUP BY note_path, error_type
HAVING COUNT(*) >= ?
ORDER BY occurrences DESC, avg_distance DESC
LIMIT ?
""",
params + [limit],
params + [PREDICTION_ERROR_MIN_OCCURRENCES, limit],
).fetchall()

results = [
Expand All @@ -560,8 +566,14 @@ def vault_prediction_errors(
total_params.append(ws + "/")

total_unresolved = conn.execute(
f"SELECT COUNT(DISTINCT note_path) FROM prediction_errors {total_where}",
total_params,
f"""
SELECT COUNT(*) FROM (
SELECT note_path FROM prediction_errors {total_where}
GROUP BY note_path, error_type
HAVING COUNT(*) >= ?
)
""",
total_params + [PREDICTION_ERROR_MIN_OCCURRENCES],
).fetchone()[0]

return {
Expand Down
218 changes: 218 additions & 0 deletions tests/test_prediction_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
"""Tests for the prediction-error detection branch and the gating that surfaces it.

The writer ``log_prediction_error`` (insert + rate-limit) is covered in test_search.py.
What's exercised here is the *decision* logic — when ``hybrid_search`` actually flags a
result (search.py detection branch) — plus the occurrence/similarity gates that decide
which flags get surfaced and which demote notes in later retrieval.

Style mirrors TestCooccurrenceBoost in test_search.py: a real in-memory sqlite DB with
``get_db``/``get_embedding`` monkeypatched, never MagicMock — see
[[community-workspace-filter-fix-2026-05-12]] for why that matters.
"""

import struct

import numpy as np
import pytest

from neurostack.search import (
CONTEXTUAL_MISMATCH_MAX_SIM,
PREDICTION_ERROR_MIN_OCCURRENCES,
PREDICTION_ERROR_SIM_THRESHOLD,
hybrid_search,
)

DIM = 768


def _emb(a: float, b: float, dim: int = DIM) -> bytes:
"""Embedding blob with components (a, b, 0, 0, ...).

Against the query vector e0 = (1, 0, 0, ...) the cosine similarity is
a / sqrt(a^2 + b^2); pick (a, b) on the unit circle and cosine == a.
"""
v = [0.0] * dim
v[0] = a
v[1] = b
return struct.pack(f"{dim}f", *v)


def _query_emb() -> np.ndarray:
"""Query vector e0 — cosine with _emb(a, b) is exactly a when a^2 + b^2 == 1."""
q = np.zeros(DIM, dtype=np.float32)
q[0] = 1.0
return q


def _add_note(conn, path, *, content="predtoken body text", emb=None, title="N"):
"""Insert a note plus (optionally) a single chunk with an embedding."""
conn.execute(
"INSERT INTO notes (path, title, content_hash, updated_at) VALUES (?, ?, ?, ?)",
(path, title, f"h_{path}", "2026-01-01"),
)
if emb is not None:
conn.execute(
"INSERT INTO chunks (note_path, heading_path, content, content_hash, "
"position, embedding) VALUES (?, ?, ?, ?, ?, ?)",
(path, "## H", content, f"hc_{path}", 0, emb),
)
conn.commit()


def _patch_search(monkeypatch, conn):
"""Route hybrid_search at the in-memory conn and a deterministic query vector."""
import neurostack.search as search_mod

monkeypatch.setattr(search_mod, "get_db", lambda path: conn)
monkeypatch.setattr(search_mod, "get_embedding", lambda q, base_url=None: _query_emb())


def _errors(conn, error_type=None):
sql = "SELECT note_path, error_type, cosine_distance FROM prediction_errors"
params: list = []
if error_type:
sql += " WHERE error_type = ?"
params.append(error_type)
return conn.execute(sql, params).fetchall()


# --- cosines chosen relative to the thresholds (a^2 + b^2 == 1, so cosine == a) ---
_LOW = (0.10, 0.99499) # sim 0.10 < 0.38 -> low_overlap
_STRONG = (0.95, 0.31225) # sim 0.95 -> no flag at all
_BAND = (0.40, 0.91652) # sim 0.40 in [0.38, 0.45) -> contextual_mismatch
_MID = (0.60, 0.80) # sim 0.60 >= 0.45 -> NOT a contextual_mismatch


class TestLowOverlapDetection:
def test_fires_below_threshold(self, in_memory_db, monkeypatch):
conn = in_memory_db
_add_note(conn, "stale.md", emb=_emb(*_LOW))
_patch_search(monkeypatch, conn)

hybrid_search("predtoken", top_k=5, embed_url="http://fake")

rows = _errors(conn, "low_overlap")
assert len(rows) == 1
assert rows[0]["note_path"] == "stale.md"
# distance stored is 1 - sim
assert rows[0]["cosine_distance"] == pytest.approx(1.0 - _LOW[0], abs=1e-3)

def test_no_flag_above_threshold(self, in_memory_db, monkeypatch):
conn = in_memory_db
_add_note(conn, "good.md", emb=_emb(*_STRONG))
_patch_search(monkeypatch, conn)

hybrid_search("predtoken", top_k=5, embed_url="http://fake")

assert _errors(conn) == []

def test_fts_only_hit_not_flagged(self, in_memory_db, monkeypatch):
"""A note that FTS-matches but has no chunk embedding is dropped before
the rerank, so it never reaches the detection branch — no flag."""
conn = in_memory_db
_add_note(conn, "noembed.md", emb=None)
# the note row alone has no chunk; give it an FTS-matchable chunk WITHOUT an embedding
conn.execute(
"INSERT INTO chunks (note_path, heading_path, content, content_hash, position) "
"VALUES (?, ?, ?, ?, ?)",
("noembed.md", "## H", "predtoken body text", "hc_noembed", 0),
)
conn.commit()
_patch_search(monkeypatch, conn)

hybrid_search("predtoken", top_k=5, embed_url="http://fake")

assert _errors(conn) == []

def test_only_top_result_checked(self, in_memory_db, monkeypatch):
"""Detection inspects deduped[0] only. A strong #1 result shields a weak
#2 from being flagged."""
conn = in_memory_db
_add_note(conn, "winner.md", emb=_emb(*_STRONG)) # cosine 0.95 -> ranks #1
_add_note(conn, "loser.md", emb=_emb(*_LOW)) # cosine 0.10 -> ranks #2
_patch_search(monkeypatch, conn)

results = hybrid_search("predtoken", top_k=5, embed_url="http://fake")

assert results[0].note_path == "winner.md"
assert _errors(conn) == []


class TestContextualMismatchDetection:
def _setup(self, conn, target_emb):
# decoy lives under the context substring -> populates in_context_notes
_add_note(conn, "azure-decoy.md", emb=None, title="Decoy")
# target FTS-matches the query, is NOT in the context set
_add_note(conn, "target.md", emb=target_emb)

def test_fires_in_weak_band(self, in_memory_db, monkeypatch):
conn = in_memory_db
self._setup(conn, _emb(*_BAND))
_patch_search(monkeypatch, conn)

hybrid_search("predtoken", top_k=5, embed_url="http://fake", context="azure")

rows = _errors(conn, "contextual_mismatch")
assert len(rows) == 1
assert rows[0]["note_path"] == "target.md"

def test_suppressed_for_strong_hit(self, in_memory_db, monkeypatch):
"""The fix: a strong semantic hit outside the context boost set is NOT a
mismatch. Without the CONTEXTUAL_MISMATCH_MAX_SIM ceiling this flagged
correct retrievals (exact-title hits)."""
conn = in_memory_db
self._setup(conn, _emb(*_MID)) # sim 0.60 >= ceiling
_patch_search(monkeypatch, conn)

hybrid_search("predtoken", top_k=5, embed_url="http://fake", context="azure")

assert _errors(conn) == []

def test_no_context_no_mismatch(self, in_memory_db, monkeypatch):
"""Without a context argument the mismatch branch can't fire; a band-sim
result above the low_overlap floor produces no flag at all."""
conn = in_memory_db
self._setup(conn, _emb(*_BAND))
_patch_search(monkeypatch, conn)

hybrid_search("predtoken", top_k=5, embed_url="http://fake") # no context

assert _errors(conn) == []


class TestSurfacingGate:
"""vault_prediction_errors surfaces a note only once it has surprised
PREDICTION_ERROR_MIN_OCCURRENCES distinct retrieval events."""

def _seed(self, conn, note, n):
for i in range(n):
conn.execute(
"INSERT INTO prediction_errors (note_path, query, cosine_distance, error_type) "
"VALUES (?, ?, ?, ?)",
(note, f"q{i}", 0.7, "low_overlap"),
)
conn.commit()

def test_single_occurrence_not_surfaced(self, in_memory_db, monkeypatch):
conn = in_memory_db
self._seed(conn, "oneshot.md", 1)
self._seed(conn, "recurrent.md", PREDICTION_ERROR_MIN_OCCURRENCES)

import neurostack.schema as schema_mod
from neurostack.tools.search_tools import vault_prediction_errors

monkeypatch.setattr(schema_mod, "get_db", lambda path: conn)

out = vault_prediction_errors()
surfaced = {e["note_path"] for e in out["errors"]}

assert "recurrent.md" in surfaced
assert "oneshot.md" not in surfaced
assert out["total_flagged_notes"] == 1


def test_thresholds_ordered():
"""Sanity: the mismatch ceiling sits above the low-overlap floor, leaving a
real band for contextual_mismatch to occupy, and occurrences gate is >= 2."""
assert PREDICTION_ERROR_SIM_THRESHOLD < CONTEXTUAL_MISMATCH_MAX_SIM < 1.0
assert PREDICTION_ERROR_MIN_OCCURRENCES >= 2
Loading