-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathvec_search.py
More file actions
462 lines (362 loc) · 15.4 KB
/
vec_search.py
File metadata and controls
462 lines (362 loc) · 15.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
"""Optional vector search layer using sqlite-vec + sentence-transformers.
Adds semantic similarity search (cosine distance) alongside FTS5 BM25.
Gracefully degrades to FTS5-only when dependencies are not installed.
Architecture:
embed_text() → float32 bytes (384-dim MiniLM-L6-v2)
vec_sync_entity() → called after FTS sync on entity writes
vector_search() → KNN query against entity_embeddings vec0 table
rrf_merge() → Reciprocal Rank Fusion of FTS5 + vector results
The merged results feed into the existing 6-signal reranker (smart_retrieval.py)
without any changes to that module.
"""
from __future__ import annotations
import logging
import sqlite3
import threading
from typing import Any
logger = logging.getLogger("sqlite-kb")
_VEC_LOAD_ERRORS = (AttributeError, OSError, sqlite3.Error)
_EMBEDDING_ERRORS = (AttributeError, OSError, RuntimeError, TypeError, ValueError)
# ── Availability check ─────────────────────────────────────────────────
try:
import sqlite_vec
_HAS_VEC = True
except ImportError:
_HAS_VEC = False
try:
from sentence_transformers import SentenceTransformer
_HAS_ST = True
except ImportError:
_HAS_ST = False
VEC_AVAILABLE: bool = _HAS_VEC and _HAS_ST
# ── Model singleton (lazy loaded) ─────────────────────────────────────
_model: SentenceTransformer | None = None
_model_lock = threading.Lock()
_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
EMBEDDING_DIM = 384
MAX_OBS_FOR_EMBEDDING = 20 # MiniLM-L6-v2 has 256-token limit; cap observations to fit
def _get_model() -> SentenceTransformer:
"""Lazy-load the embedding model on first use (thread-safe)."""
global _model
if _model is None:
with _model_lock:
if _model is None:
_model = SentenceTransformer(_MODEL_NAME)
logger.info("Loaded embedding model: %s", _MODEL_NAME)
return _model
# ── Extension loader ───────────────────────────────────────────────────
def load_vec(conn: sqlite3.Connection) -> bool:
"""Load sqlite-vec extension on a connection.
Uses a sentinel table query to detect if already loaded, avoiding
redundant enable_load_extension cycles.
"""
if not _HAS_VEC:
return False
# Fast check: if vec0 is already usable, skip reload
try:
conn.execute("SELECT vec_version()")
return True
except sqlite3.Error:
pass
try:
conn.enable_load_extension(True)
try:
sqlite_vec.load(conn)
finally:
try:
conn.enable_load_extension(False)
except _VEC_LOAD_ERRORS:
pass # don't mask the original load error
return True
except _VEC_LOAD_ERRORS as e:
logger.debug("sqlite-vec load failed: %s", e)
return False
# ── Table management ───────────────────────────────────────────────────
def _init_vec_table(conn: sqlite3.Connection, table_name: str) -> bool:
"""Create a vec0 virtual table if it doesn't exist."""
if not load_vec(conn):
return False
try:
conn.execute(
f"CREATE VIRTUAL TABLE IF NOT EXISTS {table_name} "
f"USING vec0(embedding float[{EMBEDDING_DIM}])"
)
logger.info("%s vec0 table ready (dim=%d)", table_name, EMBEDDING_DIM)
return True
except sqlite3.Error as e:
logger.warning("Failed to create %s: %s", table_name, e)
return False
def _embed_text_or_none(text: str, *, context: str) -> bytes | None:
try:
return embed_text(text)
except _EMBEDDING_ERRORS as exc:
logger.warning("Embedding failed for %s: %s", context, exc)
return None
def _existing_embedding_rowids(conn: sqlite3.Connection, table_name: str) -> set[int]:
rowids: set[int] = set()
try:
for row in conn.execute(f"SELECT rowid FROM {table_name}"):
rowids.add(row[0])
except sqlite3.Error as exc:
logger.debug("Failed to read %s rowids: %s", table_name, exc)
return rowids
def init_vec_table(conn: sqlite3.Connection) -> bool:
"""Create the vec0 virtual table for entity embeddings."""
return _init_vec_table(conn, "entity_embeddings")
# ── Embedding generation ───────────────────────────────────────────────
def _entity_text(name: str, entity_type: str, observations: list[str]) -> str:
"""Compose the text to embed for an entity."""
obs_str = ". ".join(observations[:MAX_OBS_FOR_EMBEDDING])
return f"{name} ({entity_type}): {obs_str}"
def embed_text(text: str) -> bytes:
"""Generate a 384-dim embedding and return as raw float32 bytes for vec0."""
model = _get_model()
vec = model.encode(text, normalize_embeddings=True)
return vec.astype("float32").tobytes()
# ── Sync helpers (called after FTS sync on writes) ─────────────────────
def vec_sync_entity(conn: sqlite3.Connection, entity_id: int) -> bool:
"""Update the embedding for an entity. Creates or replaces."""
if not VEC_AVAILABLE or not load_vec(conn):
return False
try:
row = conn.execute(
"SELECT name, entity_type FROM entities WHERE id = ?",
(entity_id,),
).fetchone()
if row is None:
vec_remove_entity(conn, entity_id)
return False
obs_rows = conn.execute(
"SELECT content FROM observations WHERE entity_id = ? ORDER BY id",
(entity_id,),
).fetchall()
obs = [r["content"] for r in obs_rows]
text = _entity_text(row["name"], row["entity_type"], obs)
emb = _embed_text_or_none(text, context=f"entity:{entity_id}")
if emb is None:
return False
conn.execute("DELETE FROM entity_embeddings WHERE rowid = ?", (entity_id,))
conn.execute(
"INSERT INTO entity_embeddings(rowid, embedding) VALUES (?, ?)",
(entity_id, emb),
)
return True
except sqlite3.Error as exc:
logger.warning("vec_sync_entity(%d) failed: %s", entity_id, exc)
return False
def vec_remove_entity(conn: sqlite3.Connection, entity_id: int) -> None:
"""Remove an entity's embedding."""
if not VEC_AVAILABLE:
return
try:
if load_vec(conn):
conn.execute("DELETE FROM entity_embeddings WHERE rowid = ?", (entity_id,))
except sqlite3.Error as e:
logger.debug("vec_remove_entity(%d) failed: %s", entity_id, e)
# ── Vector search ──────────────────────────────────────────────────────
def vector_search(conn: sqlite3.Connection, query: str, limit: int = 50) -> list[dict]:
"""Perform KNN vector search.
Returns list of dicts with: eid, name, entity_type, project, distance.
"""
if not VEC_AVAILABLE or not load_vec(conn):
return []
emb = _embed_text_or_none(query, context="entity_search_query")
if emb is None:
return []
try:
rows = conn.execute(
"SELECT ee.rowid AS eid, ee.distance, "
"e.name, e.entity_type, e.project "
"FROM entity_embeddings ee "
"JOIN entities e ON e.id = ee.rowid "
"WHERE ee.embedding MATCH ? AND k = ? "
"ORDER BY ee.distance",
(emb, limit),
).fetchall()
return [dict(r) for r in rows]
except sqlite3.Error as e:
logger.warning("vector_search failed: %s", e)
return []
# ── Reciprocal Rank Fusion ─────────────────────────────────────────────
def rrf_merge(
fts_results: list[Any],
vec_results: list[dict],
k: int = 60,
) -> list[dict]:
"""Merge FTS5 and vector results using Reciprocal Rank Fusion.
RRF(d) = sum(1 / (k + rank_i(d))) for each ranking source.
Returns combined results ordered by RRF score (descending), formatted
to match the FTS5 row format expected by rerank_entities().
"""
scores: dict[int, float] = {}
entity_data: dict[int, dict] = {}
# FTS5 contributions (fts_results may be sqlite3.Row objects)
for rank, item in enumerate(fts_results):
eid = item["eid"]
scores[eid] = scores.get(eid, 0.0) + 1.0 / (k + rank + 1)
entity_data[eid] = {
"eid": eid,
"name": item["name"],
"entity_type": item["entity_type"],
"project": item["project"],
}
# Vector contributions
for rank, item in enumerate(vec_results):
eid = item["eid"]
scores[eid] = scores.get(eid, 0.0) + 1.0 / (k + rank + 1)
if eid not in entity_data:
entity_data[eid] = {
"eid": eid,
"name": item["name"],
"entity_type": item["entity_type"],
"project": item.get("project"),
}
# Sort by RRF score descending
ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)
results = []
for eid, rrf_score in ranked:
data = entity_data[eid]
# Use negative RRF score as rank (matches FTS5 convention: lower = better)
data["rank"] = -rrf_score
results.append(data)
return results
# ── Task vector search ─────────────────────────────────────────────────
def init_task_vec_table(conn: sqlite3.Connection) -> bool:
"""Create the vec0 virtual table for task embeddings."""
return _init_vec_table(conn, "task_embeddings")
def _task_text(title: str, description: str | None, notes: str | None) -> str:
"""Compose the text to embed for a task."""
parts = [title or ""]
if description:
parts.append(description[:500])
if notes:
parts.append(notes[:300])
return ". ".join(parts)
def vec_sync_task(conn: sqlite3.Connection, task_id: str) -> bool:
"""Update the embedding for a task. Creates or replaces."""
if not VEC_AVAILABLE or not load_vec(conn):
return False
try:
row = conn.execute(
"SELECT rowid, title, description, notes FROM tasks WHERE id = ?",
(task_id,),
).fetchone()
if row is None:
return False
text = _task_text(row["title"], row["description"], row["notes"])
emb = _embed_text_or_none(text, context=f"task:{task_id}")
if emb is None:
return False
rowid = row["rowid"]
conn.execute("DELETE FROM task_embeddings WHERE rowid = ?", (rowid,))
conn.execute(
"INSERT INTO task_embeddings(rowid, embedding) VALUES (?, ?)",
(rowid, emb),
)
return True
except sqlite3.Error as exc:
logger.warning("vec_sync_task(%s) failed: %s", task_id, exc)
return False
def vec_remove_task(conn: sqlite3.Connection, task_id: str) -> None:
"""Remove a task's embedding by its UUID."""
if not VEC_AVAILABLE:
return
try:
if load_vec(conn):
row = conn.execute(
"SELECT rowid FROM tasks WHERE id = ?", (task_id,)
).fetchone()
if row:
conn.execute(
"DELETE FROM task_embeddings WHERE rowid = ?", (row["rowid"],)
)
except sqlite3.Error as e:
logger.debug("vec_remove_task(%s) failed: %s", task_id, e)
def task_vector_search(
conn: sqlite3.Connection, query: str, limit: int = 50
) -> list[dict]:
"""KNN vector search over task embeddings.
Returns list of dicts with task fields + distance.
"""
if not VEC_AVAILABLE or not load_vec(conn):
return []
emb = _embed_text_or_none(query, context="task_search_query")
if emb is None:
return []
try:
rows = conn.execute(
"SELECT t.id, t.title, t.description, t.notes, t.status, "
"t.priority, t.section, t.due_date, t.project, t.parent_id, "
"t.type, t.updated_at, te.distance "
"FROM task_embeddings te "
"JOIN tasks t ON t.rowid = te.rowid "
"WHERE te.embedding MATCH ? AND k = ? "
"ORDER BY te.distance",
(emb, limit),
).fetchall()
return [dict(r) for r in rows]
except sqlite3.Error as e:
logger.warning("task_vector_search failed: %s", e)
return []
def task_rrf_merge(
fts_results: list[dict],
vec_results: list[dict],
k: int = 60,
) -> list[dict]:
"""Merge FTS5 and vector task results using Reciprocal Rank Fusion.
Keyed by task UUID (id string), not integer eid.
Returns combined results ordered by RRF score descending.
"""
scores: dict[str, float] = {}
task_data: dict[str, dict] = {}
for rank, item in enumerate(fts_results):
tid = item["id"]
scores[tid] = scores.get(tid, 0.0) + 1.0 / (k + rank + 1)
task_data[tid] = dict(item)
for rank, item in enumerate(vec_results):
tid = item["id"]
scores[tid] = scores.get(tid, 0.0) + 1.0 / (k + rank + 1)
if tid not in task_data:
task_data[tid] = dict(item)
ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)
return [task_data[tid] for tid, _ in ranked]
def backfill_task_embeddings(conn: sqlite3.Connection) -> int:
"""Generate embeddings for all tasks that don't have one yet.
Returns the number of tasks backfilled.
"""
if not VEC_AVAILABLE or not load_vec(conn):
return 0
existing = _existing_embedding_rowids(conn, "task_embeddings")
all_tasks = conn.execute("SELECT id, rowid FROM tasks").fetchall()
missing = [r["id"] for r in all_tasks if r["rowid"] not in existing]
count = 0
for tid in missing:
try:
if vec_sync_task(conn, tid):
count += 1
except _EMBEDDING_ERRORS as e:
logger.warning("backfill failed for task %s: %s", tid, e)
if count:
logger.info("Backfilled embeddings for %d tasks", count)
return count
# ── Backfill utility ───────────────────────────────────────────────────
def backfill_embeddings(conn: sqlite3.Connection) -> int:
"""Generate embeddings for all entities that don't have one yet.
Returns the number of entities backfilled.
"""
if not VEC_AVAILABLE or not load_vec(conn):
return 0
# Find entities without embeddings
existing = _existing_embedding_rowids(conn, "entity_embeddings")
all_entities = conn.execute("SELECT id FROM entities").fetchall()
missing = [r["id"] for r in all_entities if r["id"] not in existing]
count = 0
for eid in missing:
try:
if vec_sync_entity(conn, eid):
count += 1
except _EMBEDDING_ERRORS as e:
logger.warning("backfill failed for entity %d: %s", eid, e)
if count:
logger.info("Backfilled embeddings for %d entities", count)
return count