From 6c9a7ceefde90e8988284ec1e8a6abfc83e9c288 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20M=C3=A1rquez?= Date: Tue, 17 Mar 2026 18:50:38 +0100 Subject: [PATCH] feat(store): add recall tracking and promoted observations endpoint Add read-side instrumentation to observations: recall_count and last_recalled_at columns track how often each observation is retrieved via Search and GetObservation. Frequently recalled observations can be queried through a new /promoted HTTP endpoint and mem_promoted MCP tool. - Add recall_count and last_recalled_at to observations schema - Increment recall_count on Search() and GetObservation() (fire-and-forget) - Split GetObservation into public (with recall) and private (without) to prevent Timeline navigation from inflating counts - Add PromotedObservations() store method with configurable threshold - Add GET /promoted HTTP endpoint with project/scope/min_recalls/limit params - Add mem_promoted MCP tool to agent profile (deferred loading) - Add 4 tests covering recall increment and promotion filtering Closes #95 --- internal/mcp/mcp.go | 67 ++++++++++++++++- internal/mcp/mcp_test.go | 21 +++--- internal/server/server.go | 18 +++++ internal/store/store.go | 135 ++++++++++++++++++++++++++++------- internal/store/store_test.go | 130 +++++++++++++++++++++++++++++++++ 5 files changed, 335 insertions(+), 36 deletions(-) diff --git a/internal/mcp/mcp.go b/internal/mcp/mcp.go index 2dd50e2..a547b92 100644 --- a/internal/mcp/mcp.go +++ b/internal/mcp/mcp.go @@ -6,8 +6,8 @@ // // Tool profiles allow agents to load only the tools they need: // -// engram mcp → all 14 tools (default) -// engram mcp --tools=agent → 11 tools agents actually use (per skill files) +// engram mcp → all 15 tools (default) +// engram mcp --tools=agent → 12 tools agents actually use (per skill files) // engram mcp --tools=admin → 3 tools for TUI/CLI (delete, stats, timeline) // engram mcp --tools=agent,admin → combine profiles // engram mcp --tools=mem_save,mem_search → individual tool names @@ -56,6 +56,7 @@ var ProfileAgent = map[string]bool{ "mem_capture_passive": true, // extract learnings from text — referenced in Gemini/Codex protocol "mem_save_prompt": true, // save user prompts "mem_update": true, // update observation by ID — skills say "use mem_update when you have an exact ID to correct" + "mem_promoted": true, // frequently recalled observations — surface important context at session start } // ProfileAdmin contains tools for TUI, dashboards, and manual curation @@ -575,6 +576,34 @@ Duplicates are automatically detected and skipped — safe to call multiple time handleCapturePassive(s), ) } + + // ─── mem_promoted (profile: agent, deferred) ──────────────────────── + if shouldRegister("mem_promoted", allowlist) { + srv.AddTool( + mcp.NewTool("mem_promoted", + mcp.WithDescription("Get frequently recalled observations that have proven their value through repeated access. Use at session start to surface the most important context."), + mcp.WithDeferLoading(true), + mcp.WithTitleAnnotation("Get Promoted Memories"), + mcp.WithReadOnlyHintAnnotation(true), + mcp.WithDestructiveHintAnnotation(false), + mcp.WithIdempotentHintAnnotation(true), + mcp.WithOpenWorldHintAnnotation(false), + mcp.WithString("project", + mcp.Description("Filter by project name"), + ), + mcp.WithString("scope", + mcp.Description("Filter by scope: project (default) or personal"), + ), + mcp.WithNumber("min_recalls", + mcp.Description("Minimum recall count threshold (default: 5)"), + ), + mcp.WithNumber("limit", + mcp.Description("Max results (default: 7)"), + ), + ), + handlePromoted(s), + ) + } } // ─── Tool Handlers ─────────────────────────────────────────────────────────── @@ -1021,6 +1050,40 @@ func handleCapturePassive(s *store.Store) server.ToolHandlerFunc { } } +func handlePromoted(s *store.Store) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + project, _ := req.GetArguments()["project"].(string) + scope, _ := req.GetArguments()["scope"].(string) + minRecalls := intArg(req, "min_recalls", 5) + limit := intArg(req, "limit", 7) + + results, err := s.PromotedObservations(project, scope, minRecalls, limit) + if err != nil { + return mcp.NewToolResultError("Failed to fetch promoted memories: " + err.Error()), nil + } + + if len(results) == 0 { + return mcp.NewToolResultText("No promoted memories found."), nil + } + + var b strings.Builder + fmt.Fprintf(&b, "Found %d promoted memories:\n\n", len(results)) + for i, r := range results { + projectStr := "" + if r.Project != nil { + projectStr = fmt.Sprintf(" | project: %s", *r.Project) + } + preview := truncate(r.Content, 300) + fmt.Fprintf(&b, "[%d] #%d (%s) — %s [recalled %d times]\n %s\n %s%s | scope: %s\n\n", + i+1, r.ID, r.Type, r.Title, r.RecallCount, + preview, + r.CreatedAt, projectStr, r.Scope) + } + + return mcp.NewToolResultText(b.String()), nil + } +} + // ─── Helpers ───────────────────────────────────────────────────────────────── // defaultSessionID returns a project-scoped default session ID. diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go index 8576d39..a1cac30 100644 --- a/internal/mcp/mcp_test.go +++ b/internal/mcp/mcp_test.go @@ -929,7 +929,8 @@ func TestResolveToolsAgentProfile(t *testing.T) { "mem_save", "mem_search", "mem_context", "mem_session_summary", "mem_session_start", "mem_session_end", "mem_get_observation", "mem_suggest_topic_key", "mem_capture_passive", "mem_save_prompt", - "mem_update", // skills explicitly say "use mem_update when you have an exact ID to correct" + "mem_update", // skills explicitly say "use mem_update when you have an exact ID to correct" + "mem_promoted", // frequently recalled observations — surface important context at session start } for _, tool := range expectedTools { if !result[tool] { @@ -974,12 +975,12 @@ func TestResolveToolsCombinedProfiles(t *testing.T) { t.Fatal("expected non-nil allowlist for combined profiles") } - // Should have all 14 tools + // Should have all 15 tools allTools := []string{ "mem_save", "mem_search", "mem_context", "mem_session_summary", "mem_session_start", "mem_session_end", "mem_get_observation", "mem_suggest_topic_key", "mem_capture_passive", "mem_save_prompt", - "mem_update", "mem_delete", "mem_stats", "mem_timeline", + "mem_update", "mem_promoted", "mem_delete", "mem_stats", "mem_timeline", } for _, tool := range allTools { if !result[tool] { @@ -1164,7 +1165,7 @@ func TestNewServerWithToolsNilRegistersAll(t *testing.T) { "mem_save", "mem_search", "mem_context", "mem_session_summary", "mem_session_start", "mem_session_end", "mem_get_observation", "mem_suggest_topic_key", "mem_capture_passive", "mem_save_prompt", - "mem_update", "mem_delete", "mem_stats", "mem_timeline", + "mem_update", "mem_promoted", "mem_delete", "mem_stats", "mem_timeline", } for _, name := range allTools { @@ -1203,14 +1204,14 @@ func TestNewServerBackwardsCompatible(t *testing.T) { srv := NewServer(s) tools := srv.ListTools() - // 11 agent + 3 admin = 14 total - if len(tools) != 14 { - t.Errorf("NewServer should register all 14 tools, got %d", len(tools)) + // 12 agent + 3 admin = 15 total + if len(tools) != 15 { + t.Errorf("NewServer should register all 15 tools, got %d", len(tools)) } } func TestProfileConsistency(t *testing.T) { - // Verify that agent + admin = all 14 tools + // Verify that agent + admin = all 15 tools combined := make(map[string]bool) for tool := range ProfileAgent { combined[tool] = true @@ -1219,8 +1220,8 @@ func TestProfileConsistency(t *testing.T) { combined[tool] = true } - if len(combined) != 14 { - t.Errorf("agent + admin should cover all 14 tools, got %d", len(combined)) + if len(combined) != 15 { + t.Errorf("agent + admin should cover all 15 tools, got %d", len(combined)) } // Verify no overlap between profiles diff --git a/internal/server/server.go b/internal/server/server.go index abd7c39..a442483 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -131,6 +131,9 @@ func (s *Server) routes() { // Stats s.mux.HandleFunc("GET /stats", s.handleStats) + // Promoted (frequently recalled) + s.mux.HandleFunc("GET /promoted", s.handlePromoted) + // Project migration s.mux.HandleFunc("POST /projects/migrate", s.handleMigrateProject) @@ -493,6 +496,21 @@ func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) { jsonResponse(w, http.StatusOK, stats) } +func (s *Server) handlePromoted(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + scope := r.URL.Query().Get("scope") + minRecalls := queryInt(r, "min_recalls", 5) + limit := queryInt(r, "limit", 7) + + obs, err := s.store.PromotedObservations(project, scope, minRecalls, limit) + if err != nil { + jsonError(w, http.StatusInternalServerError, err.Error()) + return + } + + jsonResponse(w, http.StatusOK, obs) +} + // ─── Sync Status ───────────────────────────────────────────────────────────── func (s *Server) handleSyncStatus(w http.ResponseWriter, r *http.Request) { diff --git a/internal/store/store.go b/internal/store/store.go index c2fb25b..fae23ac 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -49,6 +49,8 @@ type Observation struct { RevisionCount int `json:"revision_count"` DuplicateCount int `json:"duplicate_count"` LastSeenAt *string `json:"last_seen_at,omitempty"` + RecallCount int `json:"recall_count"` + LastRecalledAt *string `json:"last_recalled_at,omitempty"` CreatedAt string `json:"created_at"` UpdatedAt string `json:"updated_at"` DeletedAt *string `json:"deleted_at,omitempty"` @@ -88,6 +90,8 @@ type TimelineEntry struct { RevisionCount int `json:"revision_count"` DuplicateCount int `json:"duplicate_count"` LastSeenAt *string `json:"last_seen_at,omitempty"` + RecallCount int `json:"recall_count"` + LastRecalledAt *string `json:"last_recalled_at,omitempty"` CreatedAt string `json:"created_at"` UpdatedAt string `json:"updated_at"` DeletedAt *string `json:"deleted_at,omitempty"` @@ -462,6 +466,8 @@ func (s *Store) migrate() error { revision_count INTEGER NOT NULL DEFAULT 1, duplicate_count INTEGER NOT NULL DEFAULT 1, last_seen_at TEXT, + recall_count INTEGER NOT NULL DEFAULT 0, + last_recalled_at TEXT, created_at TEXT NOT NULL DEFAULT (datetime('now')), updated_at TEXT NOT NULL DEFAULT (datetime('now')), deleted_at TEXT, @@ -551,6 +557,8 @@ func (s *Store) migrate() error { {name: "revision_count", definition: "INTEGER NOT NULL DEFAULT 1"}, {name: "duplicate_count", definition: "INTEGER NOT NULL DEFAULT 1"}, {name: "last_seen_at", definition: "TEXT"}, + {name: "recall_count", definition: "INTEGER NOT NULL DEFAULT 0"}, + {name: "last_recalled_at", definition: "TEXT"}, {name: "updated_at", definition: "TEXT NOT NULL DEFAULT ''"}, {name: "deleted_at", definition: "TEXT"}, } @@ -845,7 +853,7 @@ func (s *Store) AllObservations(project, scope string, limit int) ([]Observation query := ` SELECT o.id, ifnull(o.sync_id, '') as sync_id, o.session_id, o.type, o.title, o.content, o.tool_name, o.project, - o.scope, o.topic_key, o.revision_count, o.duplicate_count, o.last_seen_at, o.created_at, o.updated_at, o.deleted_at + o.scope, o.topic_key, o.revision_count, o.duplicate_count, o.last_seen_at, o.recall_count, o.last_recalled_at, o.created_at, o.updated_at, o.deleted_at FROM observations o WHERE o.deleted_at IS NULL ` @@ -874,7 +882,7 @@ func (s *Store) SessionObservations(sessionID string, limit int) ([]Observation, query := ` SELECT id, ifnull(sync_id, '') as sync_id, session_id, type, title, content, tool_name, project, - scope, topic_key, revision_count, duplicate_count, last_seen_at, created_at, updated_at, deleted_at + scope, topic_key, revision_count, duplicate_count, last_seen_at, recall_count, last_recalled_at, created_at, updated_at, deleted_at FROM observations WHERE session_id = ? AND deleted_at IS NULL ORDER BY created_at ASC @@ -1017,7 +1025,7 @@ func (s *Store) RecentObservations(project, scope string, limit int) ([]Observat query := ` SELECT o.id, ifnull(o.sync_id, '') as sync_id, o.session_id, o.type, o.title, o.content, o.tool_name, o.project, - o.scope, o.topic_key, o.revision_count, o.duplicate_count, o.last_seen_at, o.created_at, o.updated_at, o.deleted_at + o.scope, o.topic_key, o.revision_count, o.duplicate_count, o.last_seen_at, o.recall_count, o.last_recalled_at, o.created_at, o.updated_at, o.deleted_at FROM observations o WHERE o.deleted_at IS NULL ` @@ -1148,23 +1156,35 @@ func (s *Store) SearchPrompts(query string, project string, limit int) ([]Prompt // ─── Get Single Observation ────────────────────────────────────────────────── -func (s *Store) GetObservation(id int64) (*Observation, error) { +// getObservation retrieves a single observation by ID without side effects. +// Use this for internal reads that should not affect recall tracking. +func (s *Store) getObservation(id int64) (*Observation, error) { row := s.db.QueryRow( `SELECT id, ifnull(sync_id, '') as sync_id, session_id, type, title, content, tool_name, project, - scope, topic_key, revision_count, duplicate_count, last_seen_at, created_at, updated_at, deleted_at + scope, topic_key, revision_count, duplicate_count, last_seen_at, recall_count, last_recalled_at, created_at, updated_at, deleted_at FROM observations WHERE id = ? AND deleted_at IS NULL`, id, ) var o Observation if err := row.Scan( &o.ID, &o.SyncID, &o.SessionID, &o.Type, &o.Title, &o.Content, &o.ToolName, &o.Project, &o.Scope, &o.TopicKey, &o.RevisionCount, &o.DuplicateCount, &o.LastSeenAt, - &o.CreatedAt, &o.UpdatedAt, &o.DeletedAt, + &o.RecallCount, &o.LastRecalledAt, &o.CreatedAt, &o.UpdatedAt, &o.DeletedAt, ); err != nil { return nil, err } return &o, nil } +// GetObservation retrieves a single observation by ID and increments its recall count. +func (s *Store) GetObservation(id int64) (*Observation, error) { + o, err := s.getObservation(id) + if err != nil { + return nil, err + } + s.incrementRecall([]int64{o.ID}) + return o, nil +} + func (s *Store) UpdateObservation(id int64, p UpdateObservationParams) (*Observation, error) { var updated *Observation err := s.withTx(func(tx *sql.Tx) error { @@ -1294,7 +1314,7 @@ func (s *Store) Timeline(observationID int64, before, after int) (*TimelineResul } // 1. Get the focus observation - focus, err := s.GetObservation(observationID) + focus, err := s.getObservation(observationID) if err != nil { return nil, fmt.Errorf("timeline: observation #%d not found: %w", observationID, err) } @@ -1309,7 +1329,7 @@ func (s *Store) Timeline(observationID int64, before, after int) (*TimelineResul // 3. Get observations BEFORE the focus (same session, older, chronological order) beforeRows, err := s.queryItHook(s.db, ` SELECT id, session_id, type, title, content, tool_name, project, - scope, topic_key, revision_count, duplicate_count, last_seen_at, created_at, updated_at, deleted_at + scope, topic_key, revision_count, duplicate_count, last_seen_at, recall_count, last_recalled_at, created_at, updated_at, deleted_at FROM observations WHERE session_id = ? AND id < ? AND deleted_at IS NULL ORDER BY id DESC @@ -1326,7 +1346,7 @@ func (s *Store) Timeline(observationID int64, before, after int) (*TimelineResul if err := beforeRows.Scan( &e.ID, &e.SessionID, &e.Type, &e.Title, &e.Content, &e.ToolName, &e.Project, &e.Scope, &e.TopicKey, &e.RevisionCount, &e.DuplicateCount, &e.LastSeenAt, - &e.CreatedAt, &e.UpdatedAt, &e.DeletedAt, + &e.RecallCount, &e.LastRecalledAt, &e.CreatedAt, &e.UpdatedAt, &e.DeletedAt, ); err != nil { return nil, err } @@ -1343,7 +1363,7 @@ func (s *Store) Timeline(observationID int64, before, after int) (*TimelineResul // 4. Get observations AFTER the focus (same session, newer, chronological order) afterRows, err := s.queryItHook(s.db, ` SELECT id, session_id, type, title, content, tool_name, project, - scope, topic_key, revision_count, duplicate_count, last_seen_at, created_at, updated_at, deleted_at + scope, topic_key, revision_count, duplicate_count, last_seen_at, recall_count, last_recalled_at, created_at, updated_at, deleted_at FROM observations WHERE session_id = ? AND id > ? AND deleted_at IS NULL ORDER BY id ASC @@ -1360,7 +1380,7 @@ func (s *Store) Timeline(observationID int64, before, after int) (*TimelineResul if err := afterRows.Scan( &e.ID, &e.SessionID, &e.Type, &e.Title, &e.Content, &e.ToolName, &e.Project, &e.Scope, &e.TopicKey, &e.RevisionCount, &e.DuplicateCount, &e.LastSeenAt, - &e.CreatedAt, &e.UpdatedAt, &e.DeletedAt, + &e.RecallCount, &e.LastRecalledAt, &e.CreatedAt, &e.UpdatedAt, &e.DeletedAt, ); err != nil { return nil, err } @@ -1401,7 +1421,7 @@ func (s *Store) Search(query string, opts SearchOptions) ([]SearchResult, error) sql := ` SELECT o.id, ifnull(o.sync_id, '') as sync_id, o.session_id, o.type, o.title, o.content, o.tool_name, o.project, - o.scope, o.topic_key, o.revision_count, o.duplicate_count, o.last_seen_at, o.created_at, o.updated_at, o.deleted_at, + o.scope, o.topic_key, o.revision_count, o.duplicate_count, o.last_seen_at, o.recall_count, o.last_recalled_at, o.created_at, o.updated_at, o.deleted_at, fts.rank FROM observations_fts fts JOIN observations o ON o.id = fts.rowid @@ -1439,16 +1459,79 @@ func (s *Store) Search(query string, opts SearchOptions) ([]SearchResult, error) if err := rows.Scan( &sr.ID, &sr.SyncID, &sr.SessionID, &sr.Type, &sr.Title, &sr.Content, &sr.ToolName, &sr.Project, &sr.Scope, &sr.TopicKey, &sr.RevisionCount, &sr.DuplicateCount, - &sr.LastSeenAt, &sr.CreatedAt, &sr.UpdatedAt, &sr.DeletedAt, + &sr.LastSeenAt, &sr.RecallCount, &sr.LastRecalledAt, &sr.CreatedAt, &sr.UpdatedAt, &sr.DeletedAt, &sr.Rank, ); err != nil { return nil, err } results = append(results, sr) } + if len(results) > 0 { + ids := make([]int64, len(results)) + for i, r := range results { + ids[i] = r.ID + } + s.incrementRecall(ids) + } return results, rows.Err() } +// ─── Recall Tracking ───────────────────────────────────────────────────────── + +// incrementRecall bumps the recall_count for the given observation IDs. +// This is fire-and-forget: errors are silently ignored because recall +// tracking must never break reads. +func (s *Store) incrementRecall(ids []int64) { + if len(ids) == 0 { + return + } + placeholders := make([]string, len(ids)) + args := make([]any, len(ids)) + for i, id := range ids { + placeholders[i] = "?" + args[i] = id + } + query := fmt.Sprintf( + `UPDATE observations SET recall_count = recall_count + 1, last_recalled_at = datetime('now') WHERE id IN (%s) AND deleted_at IS NULL`, + strings.Join(placeholders, ","), + ) + _, _ = s.db.Exec(query, args...) +} + +// PromotedObservations returns observations that have been frequently recalled, +// ordered by recall_count descending. Use minRecalls to set the promotion threshold. +func (s *Store) PromotedObservations(project, scope string, minRecalls, limit int) ([]Observation, error) { + if limit <= 0 { + limit = 7 + } + if minRecalls <= 0 { + minRecalls = 5 + } + + query := ` + SELECT o.id, ifnull(o.sync_id, '') as sync_id, o.session_id, o.type, o.title, o.content, + o.tool_name, o.project, o.scope, o.topic_key, o.revision_count, o.duplicate_count, + o.last_seen_at, o.recall_count, o.last_recalled_at, o.created_at, o.updated_at, o.deleted_at + FROM observations o + WHERE o.deleted_at IS NULL AND o.recall_count >= ? + ` + args := []any{minRecalls} + + if project != "" { + query += " AND o.project = ?" + args = append(args, project) + } + if scope != "" { + query += " AND o.scope = ?" + args = append(args, scope) + } + + query += " ORDER BY o.recall_count DESC, o.last_recalled_at DESC LIMIT ?" + args = append(args, limit) + + return s.queryObservations(query, args...) +} + // ─── Stats ─────────────────────────────────────────────────────────────────── func (s *Store) Stats() (*Stats, error) { @@ -1562,7 +1645,7 @@ func (s *Store) Export() (*ExportData, error) { // Observations obsRows, err := s.queryItHook(s.db, `SELECT id, ifnull(sync_id, '') as sync_id, session_id, type, title, content, tool_name, project, - scope, topic_key, revision_count, duplicate_count, last_seen_at, created_at, updated_at, deleted_at + scope, topic_key, revision_count, duplicate_count, last_seen_at, recall_count, last_recalled_at, created_at, updated_at, deleted_at FROM observations ORDER BY id`, ) if err != nil { @@ -1574,7 +1657,7 @@ func (s *Store) Export() (*ExportData, error) { if err := obsRows.Scan( &o.ID, &o.SyncID, &o.SessionID, &o.Type, &o.Title, &o.Content, &o.ToolName, &o.Project, &o.Scope, &o.TopicKey, &o.RevisionCount, &o.DuplicateCount, &o.LastSeenAt, - &o.CreatedAt, &o.UpdatedAt, &o.DeletedAt, + &o.RecallCount, &o.LastRecalledAt, &o.CreatedAt, &o.UpdatedAt, &o.DeletedAt, ); err != nil { return nil, err } @@ -1632,8 +1715,8 @@ func (s *Store) Import(data *ExportData) (*ImportResult, error) { // Import observations (use new IDs — AUTOINCREMENT) for _, obs := range data.Observations { _, err := s.execHook(tx, - `INSERT INTO observations (sync_id, session_id, type, title, content, tool_name, project, scope, topic_key, normalized_hash, revision_count, duplicate_count, last_seen_at, created_at, updated_at, deleted_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + `INSERT INTO observations (sync_id, session_id, type, title, content, tool_name, project, scope, topic_key, normalized_hash, revision_count, duplicate_count, last_seen_at, recall_count, last_recalled_at, created_at, updated_at, deleted_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, normalizeExistingSyncID(obs.SyncID, "obs"), obs.SessionID, obs.Type, @@ -1647,6 +1730,8 @@ func (s *Store) Import(data *ExportData) (*ImportResult, error) { maxInt(obs.RevisionCount, 1), maxInt(obs.DuplicateCount, 1), obs.LastSeenAt, + obs.RecallCount, + obs.LastRecalledAt, obs.CreatedAt, obs.UpdatedAt, obs.DeletedAt, @@ -1987,12 +2072,12 @@ func (s *Store) ApplyPulledMutation(targetKey string, mutation SyncMutation) err func (s *Store) GetObservationBySyncID(syncID string) (*Observation, error) { row := s.db.QueryRow( `SELECT id, ifnull(sync_id, '') as sync_id, session_id, type, title, content, tool_name, project, - scope, topic_key, revision_count, duplicate_count, last_seen_at, created_at, updated_at, deleted_at + scope, topic_key, revision_count, duplicate_count, last_seen_at, recall_count, last_recalled_at, created_at, updated_at, deleted_at FROM observations WHERE sync_id = ? AND deleted_at IS NULL ORDER BY id DESC LIMIT 1`, syncID, ) var o Observation - if err := row.Scan(&o.ID, &o.SyncID, &o.SessionID, &o.Type, &o.Title, &o.Content, &o.ToolName, &o.Project, &o.Scope, &o.TopicKey, &o.RevisionCount, &o.DuplicateCount, &o.LastSeenAt, &o.CreatedAt, &o.UpdatedAt, &o.DeletedAt); err != nil { + if err := row.Scan(&o.ID, &o.SyncID, &o.SessionID, &o.Type, &o.Title, &o.Content, &o.ToolName, &o.Project, &o.Scope, &o.TopicKey, &o.RevisionCount, &o.DuplicateCount, &o.LastSeenAt, &o.RecallCount, &o.LastRecalledAt, &o.CreatedAt, &o.UpdatedAt, &o.DeletedAt); err != nil { return nil, err } return &o, nil @@ -2426,11 +2511,11 @@ func decodeSyncPayload(payload []byte, dest any) error { func (s *Store) getObservationTx(tx *sql.Tx, id int64) (*Observation, error) { row := tx.QueryRow( `SELECT id, ifnull(sync_id, '') as sync_id, session_id, type, title, content, tool_name, project, - scope, topic_key, revision_count, duplicate_count, last_seen_at, created_at, updated_at, deleted_at + scope, topic_key, revision_count, duplicate_count, last_seen_at, recall_count, last_recalled_at, created_at, updated_at, deleted_at FROM observations WHERE id = ? AND deleted_at IS NULL`, id, ) var o Observation - if err := row.Scan(&o.ID, &o.SyncID, &o.SessionID, &o.Type, &o.Title, &o.Content, &o.ToolName, &o.Project, &o.Scope, &o.TopicKey, &o.RevisionCount, &o.DuplicateCount, &o.LastSeenAt, &o.CreatedAt, &o.UpdatedAt, &o.DeletedAt); err != nil { + if err := row.Scan(&o.ID, &o.SyncID, &o.SessionID, &o.Type, &o.Title, &o.Content, &o.ToolName, &o.Project, &o.Scope, &o.TopicKey, &o.RevisionCount, &o.DuplicateCount, &o.LastSeenAt, &o.RecallCount, &o.LastRecalledAt, &o.CreatedAt, &o.UpdatedAt, &o.DeletedAt); err != nil { return nil, err } return &o, nil @@ -2438,7 +2523,7 @@ func (s *Store) getObservationTx(tx *sql.Tx, id int64) (*Observation, error) { func (s *Store) getObservationBySyncIDTx(tx *sql.Tx, syncID string, includeDeleted bool) (*Observation, error) { query := `SELECT id, ifnull(sync_id, '') as sync_id, session_id, type, title, content, tool_name, project, - scope, topic_key, revision_count, duplicate_count, last_seen_at, created_at, updated_at, deleted_at + scope, topic_key, revision_count, duplicate_count, last_seen_at, recall_count, last_recalled_at, created_at, updated_at, deleted_at FROM observations WHERE sync_id = ?` if !includeDeleted { query += ` AND deleted_at IS NULL` @@ -2446,7 +2531,7 @@ func (s *Store) getObservationBySyncIDTx(tx *sql.Tx, syncID string, includeDelet query += ` ORDER BY id DESC LIMIT 1` row := tx.QueryRow(query, syncID) var o Observation - if err := row.Scan(&o.ID, &o.SyncID, &o.SessionID, &o.Type, &o.Title, &o.Content, &o.ToolName, &o.Project, &o.Scope, &o.TopicKey, &o.RevisionCount, &o.DuplicateCount, &o.LastSeenAt, &o.CreatedAt, &o.UpdatedAt, &o.DeletedAt); err != nil { + if err := row.Scan(&o.ID, &o.SyncID, &o.SessionID, &o.Type, &o.Title, &o.Content, &o.ToolName, &o.Project, &o.Scope, &o.TopicKey, &o.RevisionCount, &o.DuplicateCount, &o.LastSeenAt, &o.RecallCount, &o.LastRecalledAt, &o.CreatedAt, &o.UpdatedAt, &o.DeletedAt); err != nil { return nil, err } return &o, nil @@ -2559,7 +2644,7 @@ func (s *Store) queryObservations(query string, args ...any) ([]Observation, err if err := rows.Scan( &o.ID, &o.SyncID, &o.SessionID, &o.Type, &o.Title, &o.Content, &o.ToolName, &o.Project, &o.Scope, &o.TopicKey, &o.RevisionCount, &o.DuplicateCount, &o.LastSeenAt, - &o.CreatedAt, &o.UpdatedAt, &o.DeletedAt, + &o.RecallCount, &o.LastRecalledAt, &o.CreatedAt, &o.UpdatedAt, &o.DeletedAt, ); err != nil { return nil, err } @@ -2650,6 +2735,8 @@ func (s *Store) migrateLegacyObservationsTable() error { revision_count INTEGER NOT NULL DEFAULT 1, duplicate_count INTEGER NOT NULL DEFAULT 1, last_seen_at TEXT, + recall_count INTEGER NOT NULL DEFAULT 0, + last_recalled_at TEXT, created_at TEXT NOT NULL DEFAULT (datetime('now')), updated_at TEXT NOT NULL DEFAULT (datetime('now')), deleted_at TEXT, diff --git a/internal/store/store_test.go b/internal/store/store_test.go index a72121b..68e2d48 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -4005,3 +4005,133 @@ func TestMigrateProjectIdempotent(t *testing.T) { t.Fatal("second migration should be a no-op") } } + +func TestRecallCountIncrementsOnSearch(t *testing.T) { + s := newTestStore(t) + s.CreateSession("s1", "proj", "/tmp") + id, err := s.AddObservation(AddObservationParams{ + SessionID: "s1", Type: "bugfix", Title: "recall test", + Content: "searchable recall content", Project: "proj", Scope: "project", + }) + if err != nil { + t.Fatalf("add observation: %v", err) + } + + obs, err := s.GetObservation(id) + if err != nil { + t.Fatalf("get observation: %v", err) + } + // GetObservation itself increments recall_count by 1 + initialRecall := obs.RecallCount + + results, err := s.Search("searchable recall", SearchOptions{Limit: 10}) + if err != nil { + t.Fatalf("search: %v", err) + } + if len(results) == 0 { + t.Fatal("expected search results") + } + + obs, err = s.GetObservation(id) + if err != nil { + t.Fatalf("get observation after search: %v", err) + } + // After search (+1) and this GetObservation (+1), total should be initialRecall + 2 + if obs.RecallCount != initialRecall+2 { + t.Fatalf("expected recall_count=%d after search + get, got %d", initialRecall+2, obs.RecallCount) + } + if obs.LastRecalledAt == nil { + t.Fatal("expected last_recalled_at to be set") + } +} + +func TestRecallCountIncrementsOnGetObservation(t *testing.T) { + s := newTestStore(t) + s.CreateSession("s1", "proj", "/tmp") + id, err := s.AddObservation(AddObservationParams{ + SessionID: "s1", Type: "decision", Title: "get test", + Content: "some content", Project: "proj", Scope: "project", + }) + if err != nil { + t.Fatalf("add observation: %v", err) + } + + // First get: scans recall_count=0 (pre-increment), then bumps DB to 1 + obs, _ := s.GetObservation(id) + if obs.RecallCount != 0 { + t.Fatalf("expected recall_count=0 on first get (pre-increment), got %d", obs.RecallCount) + } + + // Second get: scans recall_count=1 (pre-increment), then bumps DB to 2 + obs, _ = s.GetObservation(id) + if obs.RecallCount != 1 { + t.Fatalf("expected recall_count=1 on second get (pre-increment), got %d", obs.RecallCount) + } + + // Third get: scans recall_count=2 (pre-increment), confirming increments stack + obs, _ = s.GetObservation(id) + if obs.RecallCount != 2 { + t.Fatalf("expected recall_count=2 on third get (pre-increment), got %d", obs.RecallCount) + } +} + +func TestPromotedObservationsReturnsHighRecallOnly(t *testing.T) { + s := newTestStore(t) + s.CreateSession("s1", "proj", "/tmp") + + // Create two observations + id1, _ := s.AddObservation(AddObservationParams{ + SessionID: "s1", Type: "decision", Title: "popular fact", + Content: "widely recalled", Project: "proj", Scope: "project", + }) + _, _ = s.AddObservation(AddObservationParams{ + SessionID: "s1", Type: "decision", Title: "obscure fact", + Content: "never recalled", Project: "proj", Scope: "project", + }) + + // Manually bump recall_count for the popular observation + for i := 0; i < 5; i++ { + s.incrementRecall([]int64{id1}) + } + + // Query with threshold of 5 + promoted, err := s.PromotedObservations("proj", "", 5, 10) + if err != nil { + t.Fatalf("promoted observations: %v", err) + } + if len(promoted) != 1 { + t.Fatalf("expected 1 promoted observation, got %d", len(promoted)) + } + if promoted[0].ID != id1 { + t.Fatalf("expected promoted observation id=%d, got %d", id1, promoted[0].ID) + } + if promoted[0].RecallCount < 5 { + t.Fatalf("expected recall_count >= 5, got %d", promoted[0].RecallCount) + } +} + +func TestPromotedObservationsExcludesDeleted(t *testing.T) { + s := newTestStore(t) + s.CreateSession("s1", "proj", "/tmp") + + id, _ := s.AddObservation(AddObservationParams{ + SessionID: "s1", Type: "decision", Title: "deleted popular", + Content: "was popular", Project: "proj", Scope: "project", + }) + + // Bump recall count + for i := 0; i < 10; i++ { + s.incrementRecall([]int64{id}) + } + + // Soft delete + s.DeleteObservation(id, false) + + promoted, err := s.PromotedObservations("proj", "", 1, 10) + if err != nil { + t.Fatalf("promoted observations: %v", err) + } + if len(promoted) != 0 { + t.Fatalf("expected 0 promoted observations after delete, got %d", len(promoted)) + } +}