Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion internal/agent/retrieval/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ type QueryRequest struct {
TimeTo time.Time // if set, filter memories created before this time
Source string // if set, filter by memory source (mcp, filesystem, terminal, clipboard)
State string // if set, filter by memory state (active, fading, archived)
Type string // if set, filter by memory type (decision, error, insight, learning, general)
MinSalience float32 // if > 0, filter out memories below this salience
IncludeSuppressed bool // if true, include recall-suppressed memories
}
Expand Down Expand Up @@ -265,7 +266,7 @@ func (ra *RetrievalAgent) Query(ctx context.Context, req QueryRequest) (QueryRes
ranked := ra.rankResults(ctx, activated, req.IncludeReasoning)

// Step 7: Apply filters (project, time, source, state, salience)
if req.Project != "" || !req.TimeFrom.IsZero() || !req.TimeTo.IsZero() || req.Source != "" || req.State != "" || req.MinSalience > 0 {
if req.Project != "" || !req.TimeFrom.IsZero() || !req.TimeTo.IsZero() || req.Source != "" || req.State != "" || req.Type != "" || req.MinSalience > 0 {
ranked = ra.applyFilters(ranked, req)
}

Expand Down Expand Up @@ -1061,6 +1062,9 @@ func (ra *RetrievalAgent) applyFilters(results []store.RetrievalResult, req Quer
if req.State != "" && r.Memory.State != req.State {
continue
}
if req.Type != "" && r.Memory.Type != req.Type {
continue
}
if req.MinSalience > 0 && r.Memory.Salience < req.MinSalience {
continue
}
Expand Down
38 changes: 30 additions & 8 deletions internal/mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,11 @@ func (srv *MCPServer) handleRecall(ctx context.Context, args map[string]interfac
state = s
}

memType := ""
if t, ok := args["type"].(string); ok {
memType = t
}

var minSalience float32
if ms, ok := args["min_salience"].(float64); ok {
minSalience = float32(ms)
Expand Down Expand Up @@ -457,7 +462,7 @@ func (srv *MCPServer) handleRecall(ctx context.Context, args map[string]interfac
srv.log.Error("concept recall failed", "concepts", concepts, "error", err)
return nil, fmt.Errorf("concept recall failed: %w", err)
}
filtered := filterMemories(memories, source, state, minSalience)
filtered := filterMemories(memories, source, state, memType, minSalience)
text := fmt.Sprintf("Found %d memories matching concepts %v:\n\n", len(filtered), concepts)
for i, mem := range filtered {
text += fmt.Sprintf("%d. %s\n Summary: %s\n Concepts: %v\n\n",
Expand All @@ -477,6 +482,7 @@ func (srv *MCPServer) handleRecall(ctx context.Context, args map[string]interfac
Project: project,
Source: source,
State: state,
Type: memType,
MinSalience: minSalience,
}

Expand Down Expand Up @@ -779,8 +785,12 @@ func (srv *MCPServer) handleRecallProject(ctx context.Context, args map[string]i
limit = int(l)
}

// Parse optional filters
source, state, minSalience := parseRecallFilters(args)
// Parse optional filters — default min_salience to 0.7 for project recall
// to filter out watcher noise that agents don't need.
source, state, memType, minSalience := parseRecallFilters(args)
if _, explicit := args["min_salience"]; !explicit && minSalience == 0 {
minSalience = 0.7
}

// Get project summary
summary, err := srv.store.GetProjectSummary(ctx, project)
Expand Down Expand Up @@ -824,6 +834,7 @@ func (srv *MCPServer) handleRecallProject(ctx context.Context, args map[string]i
Project: project,
Source: source,
State: state,
Type: memType,
MinSalience: minSalience,
}

Expand All @@ -849,7 +860,7 @@ func (srv *MCPServer) handleRecallProject(ctx context.Context, args map[string]i
srv.log.Error("project recall failed", "project", project, "error", err)
return nil, fmt.Errorf("project recall failed: %w", err)
}
filtered := filterMemories(memories, source, state, minSalience)
filtered := filterMemories(memories, source, state, memType, minSalience)

text += fmt.Sprintf("\nMemories (%d):\n\n", len(filtered))
for i, mem := range filtered {
Expand All @@ -875,7 +886,7 @@ func (srv *MCPServer) handleRecallTimeline(ctx context.Context, args map[string]
limit = int(l)
}

source, state, minSalience := parseRecallFilters(args)
source, state, memType, minSalience := parseRecallFilters(args)

from := time.Now().Add(-time.Duration(hoursBack) * time.Hour)
to := time.Now()
Expand All @@ -886,7 +897,7 @@ func (srv *MCPServer) handleRecallTimeline(ctx context.Context, args map[string]
return nil, fmt.Errorf("timeline recall failed: %w", err)
}

filtered := filterMemories(memories, source, state, minSalience)
filtered := filterMemories(memories, source, state, memType, minSalience)

text := fmt.Sprintf("Timeline (last %dh, %d memories):\n\n", hoursBack, len(filtered))
for i, mem := range filtered {
Expand Down Expand Up @@ -1511,21 +1522,24 @@ func toolError(text string) map[string]interface{} {
}

// parseRecallFilters extracts optional source/state/min_salience from MCP args.
func parseRecallFilters(args map[string]interface{}) (source, state string, minSalience float32) {
func parseRecallFilters(args map[string]interface{}) (source, state, memType string, minSalience float32) {
if s, ok := args["source"].(string); ok {
source = s
}
if s, ok := args["state"].(string); ok {
state = s
}
if t, ok := args["type"].(string); ok {
memType = t
}
if ms, ok := args["min_salience"].(float64); ok {
minSalience = float32(ms)
}
return
}

// filterMemories filters a slice of memories by source, state, and minimum salience.
func filterMemories(memories []store.Memory, source, state string, minSalience float32) []store.Memory {
func filterMemories(memories []store.Memory, source, state, memType string, minSalience float32) []store.Memory {
var filtered []store.Memory
for _, m := range memories {
if source != "" && m.Source != source {
Expand All @@ -1534,6 +1548,9 @@ func filterMemories(memories []store.Memory, source, state string, minSalience f
if state != "" && m.State != state {
continue
}
if memType != "" && m.Type != memType {
continue
}
if minSalience > 0 && m.Salience < minSalience {
continue
}
Expand Down Expand Up @@ -1627,6 +1644,11 @@ func (srv *MCPServer) handleRecallSession(ctx context.Context, args map[string]i
return nil, fmt.Errorf("session_id parameter is required")
}

// Allow "current" to resolve to the active MCP session.
if sessionID == "current" {
sessionID = srv.sessionID
}

limit := 20
if l, ok := args["limit"].(float64); ok && int(l) > 0 {
limit = int(l)
Expand Down
7 changes: 6 additions & 1 deletion internal/mcp/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ func recallToolDef() ToolDefinition {
"type": "string",
"description": "Filter by memory source: mcp, filesystem, terminal, clipboard",
},
"type": map[string]interface{}{
"type": "string",
"description": "Filter by memory type: decision, error, insight, learning, general",
"enum": []string{"decision", "error", "insight", "learning", "general"},
},
"min_salience": map[string]interface{}{
"type": "number",
"description": "Minimum salience threshold (0.0-1.0). Filters out low-quality memories.",
Expand Down Expand Up @@ -370,7 +375,7 @@ func listSessionsToolDef() ToolDefinition {
func recallSessionToolDef() ToolDefinition {
return ToolDefinition{
Name: "recall_session",
Description: "Retrieve all memories from a specific MCP session, ordered by creation time. Use list_sessions to find session IDs.",
Description: "Retrieve all memories from a specific MCP session, ordered by creation time. Use \"current\" for the active session, or list_sessions to find past session IDs.",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
Expand Down