diff --git a/internal/pipeline/pipeline.go b/internal/pipeline/pipeline.go index c2d5fe4..12720f3 100644 --- a/internal/pipeline/pipeline.go +++ b/internal/pipeline/pipeline.go @@ -345,6 +345,10 @@ func (p *Pipeline) runFullPasses(files []discover.FileInfo) error { p.passConfigLinker() slog.Info("pass.timing", "pass", "configlinker", "elapsed", time.Since(t)) + t = time.Now() + p.passPubSubLinks() + slog.Info("pass.timing", "pass", "pubsublinks", "elapsed", time.Since(t)) + t = time.Now() p.passGitHistory() slog.Info("pass.timing", "pass", "githistory", "elapsed", time.Since(t)) @@ -503,6 +507,7 @@ func (p *Pipeline) runIncrementalPasses( slog.Warn("pass.httplink.err", "err", err) } p.passConfigLinker() + p.passPubSubLinks() p.passImplements() p.passGitHistory() @@ -1214,6 +1219,36 @@ func (p *Pipeline) resolveCallWithTypes( if p.registry.Exists(candidate) { return ResolutionResult{QualifiedName: candidate, Strategy: "type_dispatch", Confidence: 0.90, CandidateCount: 1} } + + // Two-hop resolution for chained field access: h.svc.Method() + // where h is *Handler and svc is a field. Resolve the last + // segment as a method name, excluding methods from the same + // module as the receiver (avoids self-referencing). + // Primarily useful for Go receiver patterns; resolves only the + // last segment (a.b.c.d() → resolves d()), not intermediate hops. + if strings.Contains(methodName, ".") { + chainParts := strings.Split(methodName, ".") + lastMethod := chainParts[len(chainParts)-1] + allCandidates := p.registry.FindByName(lastMethod) + // Exclude candidates from the receiver's module (same file) + receiverModule := modulePrefix(classQN) + var candidates []string + for _, c := range allCandidates { + if modulePrefix(c) != receiverModule { + candidates = append(candidates, c) + } + } + if len(candidates) == 1 { + return ResolutionResult{QualifiedName: candidates[0], Strategy: "type_dispatch", Confidence: 0.80, CandidateCount: 1} + } + if len(candidates) > 1 { + // Prefer candidates closest to caller's module tree + best := bestByImportDistance(candidates, moduleQN) + if best != "" { + return ResolutionResult{QualifiedName: best, Strategy: "type_dispatch", Confidence: 0.70, CandidateCount: len(candidates)} + } + } + } } } diff --git a/internal/pipeline/pipeline_cbm.go b/internal/pipeline/pipeline_cbm.go index e2d742d..ae55205 100644 --- a/internal/pipeline/pipeline_cbm.go +++ b/internal/pipeline/pipeline_cbm.go @@ -67,7 +67,29 @@ func cbmParseFileFromSource(projectName string, f discover.FileInfo, source []by // Convert CBM definitions to store.Node objects for i := range cbmResult.Definitions { - node, edge := cbmDefToNode(&cbmResult.Definitions[i], projectName, moduleQN) + // Fix: mark test functions as entry points. + // The C extractor only sets is_test on the module, not on individual functions. + // Test functions (Go Test*/Benchmark*/Example*, Python test_*, etc.) are invoked + // by the test runner, not by the call graph — they must be entry points. + def := &cbmResult.Definitions[i] + if cbmResult.IsTestFile && !def.IsEntryPoint && + (def.Label == "Function" || def.Label == "Method") && + isTestFunction(def.Name, f.Language) { + def.IsEntryPoint = true + def.IsTest = true + } + + // Mark exported Go handler methods as entry points. + // Echo handlers are registered via method value references (g.POST("", h.Method)) + // which the C extractor doesn't track as explicit calls. + if !def.IsEntryPoint && def.Label == "Method" && def.IsExported && + f.Language == lang.Go && + strings.Contains(f.RelPath, "handler") && + strings.Contains(def.Signature, "echo.Context") { + def.IsEntryPoint = true + } + + node, edge := cbmDefToNode(def, projectName, moduleQN) result.Nodes = append(result.Nodes, node) result.PendingEdges = append(result.PendingEdges, edge) } @@ -200,6 +222,7 @@ func enrichModuleNodeCBM(moduleNode *store.Node, cbmResult *cbm.FileResult, _ *p // Replaces the 14 language-specific infer*Types() functions. func inferTypesCBM( typeAssigns []cbm.TypeAssign, + defs []cbm.Definition, registry *FunctionRegistry, moduleQN string, importMap map[string]string, @@ -216,15 +239,47 @@ func inferTypesCBM( } } - // Return type propagation is handled by CBM TypeAssigns which already - // detect constructor patterns. Additional return-type-based inference - // from the returnTypes map is still useful for non-constructor calls. - // This would require the call data which we have in CBM Calls. - // For now, constructor-based inference covers the primary use case. + // Receiver type inference: for Go methods like func (h *Handler) Foo(), + // the receiver "h" has type Handler. Extract this from the Receiver field + // and add to the TypeMap so calls like h.svc.Method() can resolve. + for i := range defs { + if defs[i].Receiver == "" || defs[i].Label != "Method" { + continue + } + varName, typeName := parseGoReceiver(defs[i].Receiver) + if varName == "" || typeName == "" { + continue + } + if _, exists := types[varName]; exists { + continue // don't overwrite explicit type assignments + } + classQN := resolveAsClass(typeName, registry, moduleQN, importMap) + if classQN != "" { + types[varName] = classQN + } + } return types } +// parseGoReceiver extracts (varName, typeName) from a Go receiver string. +// Examples: "(h *Handler)" → ("h", "Handler"), "(s MyService)" → ("s", "MyService") +func parseGoReceiver(recv string) (string, string) { + // Strip parens + recv = strings.TrimSpace(recv) + recv = strings.TrimPrefix(recv, "(") + recv = strings.TrimSuffix(recv, ")") + recv = strings.TrimSpace(recv) + + parts := strings.Fields(recv) + if len(parts) != 2 { + return "", "" + } + varName := parts[0] + typeName := strings.TrimPrefix(parts[1], "*") + return varName, typeName +} + // resolveFileCallsCBM resolves all call targets using pre-extracted CBM data. // Replaces resolveFileCalls() — no AST walking needed. func (p *Pipeline) resolveFileCallsCBM(relPath string, ext *cachedExtraction) []resolvedEdge { @@ -232,7 +287,7 @@ func (p *Pipeline) resolveFileCallsCBM(relPath string, ext *cachedExtraction) [] importMap := p.importMaps[moduleQN] // Build type map from CBM type assignments - typeMap := inferTypesCBM(ext.Result.TypeAssigns, p.registry, moduleQN, importMap) + typeMap := inferTypesCBM(ext.Result.TypeAssigns, ext.Result.Definitions, p.registry, moduleQN, importMap) var edges []resolvedEdge @@ -262,6 +317,9 @@ func (p *Pipeline) resolveFileCallsCBM(relPath string, ext *cachedExtraction) [] result := p.resolveCallWithTypes(calleeName, moduleQN, importMap, typeMap) if result.QualifiedName == "" { if fuzzyResult, ok := p.registry.FuzzyResolve(calleeName, moduleQN, importMap); ok { + if fuzzyResult.QualifiedName == callerQN { + continue // skip self-reference + } edges = append(edges, resolvedEdge{ CallerQN: callerQN, TargetQN: fuzzyResult.QualifiedName, @@ -276,6 +334,10 @@ func (p *Pipeline) resolveFileCallsCBM(relPath string, ext *cachedExtraction) [] continue } + if result.QualifiedName == callerQN { + continue // skip self-reference + } + edges = append(edges, resolvedEdge{ CallerQN: callerQN, TargetQN: result.QualifiedName, diff --git a/internal/pipeline/pipeline_cbm_test.go b/internal/pipeline/pipeline_cbm_test.go new file mode 100644 index 0000000..9d9283d --- /dev/null +++ b/internal/pipeline/pipeline_cbm_test.go @@ -0,0 +1,67 @@ +package pipeline + +import "testing" + +func TestParseGoReceiver(t *testing.T) { + tests := []struct { + name string + input string + wantVar string + wantType string + }{ + { + name: "pointer receiver", + input: "(h *Handler)", + wantVar: "h", + wantType: "Handler", + }, + { + name: "value receiver", + input: "(s MyService)", + wantVar: "s", + wantType: "MyService", + }, + { + name: "empty string", + input: "", + wantVar: "", + wantType: "", + }, + { + name: "single word no pair", + input: "invalid", + wantVar: "", + wantType: "", + }, + { + name: "too many parts", + input: "(a b c)", + wantVar: "", + wantType: "", + }, + { + name: "empty parens", + input: "()", + wantVar: "", + wantType: "", + }, + { + name: "whitespace only parens", + input: "( )", + wantVar: "", + wantType: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotVar, gotType := parseGoReceiver(tt.input) + if gotVar != tt.wantVar { + t.Errorf("parseGoReceiver(%q) varName = %q, want %q", tt.input, gotVar, tt.wantVar) + } + if gotType != tt.wantType { + t.Errorf("parseGoReceiver(%q) typeName = %q, want %q", tt.input, gotType, tt.wantType) + } + }) + } +} diff --git a/internal/pipeline/pubsub.go b/internal/pipeline/pubsub.go new file mode 100644 index 0000000..aff7e70 --- /dev/null +++ b/internal/pipeline/pubsub.go @@ -0,0 +1,476 @@ +package pipeline + +import ( + "log/slog" + "os" + "path/filepath" + "regexp" + "sort" + "strings" + "sync" + "time" + + "github.com/DeusData/codebase-memory-mcp/internal/store" +) + +// publishMethodNames are method names that indicate event publishing (case-insensitive). +// Only high-signal names are included to avoid false positives with generic methods. +var publishMethodNames = map[string]bool{ + "publish": true, + "emit": true, + "dispatch": true, + "fire": true, + "broadcast": true, +} + +// subscribeMethodNames are method names that indicate event subscribing (case-insensitive). +var subscribeMethodNames = map[string]bool{ + "subscribe": true, + "addlistener": true, + "listen": true, +} + +// Per-language regex patterns for extracting event names from Publish/Subscribe calls. +// Each pattern must have exactly one capture group for the event name. +var publishEventPatterns = []*regexp.Regexp{ + // Go: bus.Publish(events.EventCheckinCompleted, ...) or bus.Publish(EventCheckinCompleted, ...) + regexp.MustCompile(`\.(?i:Publish|Emit|Dispatch|Fire|Broadcast)\(\s*(?:\w+\.)?(\w+)`), + // JS/TS: emitter.emit('event-name', ...) or emitter.emit("event-name", ...) + regexp.MustCompile(`\.(?i:emit|dispatch|fire|broadcast)\(\s*['"]([^'"]+)['"]`), +} + +var subscribeEventPatterns = []*regexp.Regexp{ + // Go: bus.Subscribe(events.EventCheckinCompleted, func(...) { ... }) + regexp.MustCompile(`\.(?i:Subscribe|AddListener|Listen)\(\s*(?:\w+\.)?(\w+)`), + // JS/TS: emitter.addListener('event-name', handler) or emitter.subscribe('event-name', handler) + regexp.MustCompile(`\.(?i:addListener|listen|subscribe)\(\s*['"]([^'"]+)['"]`), +} + +// subscribeCallLine holds a Subscribe call's event name and its source line number. +type subscribeCallLine struct { + eventName string + line int +} + +// passPubSubLinks detects in-process event bus patterns and creates ASYNC_CALLS +// edges from publisher functions to the handler functions called by subscribers. +// +// Algorithm (event-routed): +// 1. Find all CALLS edges whose target is a known publish/subscribe method name. +// 2. For each publisher function, read its source and extract event names from Publish calls. +// 3. For each subscriber function, read its source and extract (eventName, line) pairs from +// Subscribe calls. Attribute handler CALLS edges to the nearest preceding Subscribe call. +// 4. Create ASYNC_CALLS edges routed by event name: publisher → event → handlers. +// +// This replaces the previous cartesian-product approach which connected every publisher +// to every handler regardless of event type. +func (p *Pipeline) passPubSubLinks() { + t := time.Now() + + callEdges, err := p.Store.FindEdgesByType(p.ProjectName, "CALLS") + if err != nil { + slog.Warn("pubsub.calls_err", "err", err) + return + } + if len(callEdges) == 0 { + return + } + + // Collect all node IDs referenced by CALLS edges for batch lookup. + nodeIDs := collectEdgeNodeIDs(callEdges) + nodeLookup, err := p.Store.FindNodesByIDs(nodeIDs) + if err != nil { + slog.Warn("pubsub.node_lookup_err", "err", err) + return + } + + // Partition callers into publishers and subscribers based on target method name. + publisherIDs := make(map[int64]bool) + subscriberIDs := make(map[int64]bool) + // Track the Publish/Subscribe target QNs to exclude them from handler list. + pubsubTargetQNs := make(map[string]bool) + + for _, e := range callEdges { + target := nodeLookup[e.TargetID] + if target == nil { + continue + } + nameLower := strings.ToLower(target.Name) + if publishMethodNames[nameLower] { + publisherIDs[e.SourceID] = true + pubsubTargetQNs[target.QualifiedName] = true + } else if subscribeMethodNames[nameLower] { + subscriberIDs[e.SourceID] = true + pubsubTargetQNs[target.QualifiedName] = true + } + } + + if len(publisherIDs) == 0 || len(subscriberIDs) == 0 { + slog.Info("pubsub.skip", "publishers", len(publisherIDs), "subscribers", len(subscriberIDs)) + return + } + + // For each subscriber function, collect the OTHER functions it calls + // (excluding Publish/Subscribe themselves and logging functions). + // These are the actual event handler functions. + excludeNames := map[string]bool{ + "error": true, "warn": true, "info": true, "debug": true, + "printf": true, "println": true, "sprintf": true, "errorf": true, + } + + // subscriberHandlerCalls: subscriberID -> []handlerNodeID (all handler calls from that function) + subscriberHandlerCalls := make(map[int64][]int64) + for _, e := range callEdges { + if !subscriberIDs[e.SourceID] { + continue + } + target := nodeLookup[e.TargetID] + if target == nil { + continue + } + if pubsubTargetQNs[target.QualifiedName] { + continue + } + if excludeNames[strings.ToLower(target.Name)] { + continue + } + subscriberHandlerCalls[e.SourceID] = append(subscriberHandlerCalls[e.SourceID], e.TargetID) + } + + // --- Event-routed matching via source scanning --- + + fileCache := &sourceFileCache{files: make(map[string]string)} + + // Step 1: Build publisher event map: publisherNodeID -> []eventName + publisherEvents := make(map[int64][]string) + for pubID := range publisherIDs { + node := nodeLookup[pubID] + if node == nil || node.FilePath == "" { + continue + } + src, err := fileCache.readLines(p.RepoPath, node.FilePath, node.StartLine, node.EndLine) + if err != nil { + slog.Debug("pubsub.read_publisher", "file", node.FilePath, "err", err) + continue + } + events := extractEventNames(src, publishEventPatterns) + if len(events) > 0 { + publisherEvents[pubID] = events + slog.Debug("pubsub.publisher_events", "func", node.Name, "events", events) + } + } + + // Step 2: Build subscriber event-to-handler map: eventName -> []handlerNodeID + eventHandlers := make(map[string][]int64) + for subID := range subscriberIDs { + node := nodeLookup[subID] + if node == nil || node.FilePath == "" { + continue + } + handlers := subscriberHandlerCalls[subID] + if len(handlers) == 0 { + continue + } + src, err := fileCache.readLines(p.RepoPath, node.FilePath, node.StartLine, node.EndLine) + if err != nil { + slog.Debug("pubsub.read_subscriber", "file", node.FilePath, "err", err) + continue + } + + // Build handler name -> node ID map for this subscriber's handlers + handlerNameToIDs := make(map[string][]int64) + for _, hID := range handlers { + hNode := nodeLookup[hID] + if hNode != nil { + handlerNameToIDs[hNode.Name] = append(handlerNameToIDs[hNode.Name], hID) + } + } + + attribution := attributeHandlersToEvents(src, subscribeEventPatterns, handlerNameToIDs) + for eventName, hIDs := range attribution { + eventHandlers[eventName] = append(eventHandlers[eventName], hIDs...) + slog.Debug("pubsub.subscriber_event", "func", node.Name, "event", eventName, "handlers", len(hIDs)) + } + } + + // Deduplicate handlers per event + for evt, hIDs := range eventHandlers { + eventHandlers[evt] = deduplicateInt64(hIDs) + } + + // Step 3: Create event-routed ASYNC_CALLS edges + var edges []*store.Edge + seen := make(map[[2]int64]bool) + var fallbackCount int + + for pubID := range publisherIDs { + pubEvents := publisherEvents[pubID] + if len(pubEvents) == 0 { + // Fallback: publisher has no extracted events -> connect to all handlers with low confidence + for evt, hIDs := range eventHandlers { + for _, hID := range hIDs { + if pubID == hID { + continue + } + key := [2]int64{pubID, hID} + if seen[key] { + continue + } + seen[key] = true + fallbackCount++ + + handlerNode := nodeLookup[hID] + handlerName := "" + if handlerNode != nil { + handlerName = handlerNode.Name + } + edges = append(edges, &store.Edge{ + Project: p.ProjectName, + SourceID: pubID, + TargetID: hID, + Type: "ASYNC_CALLS", + Properties: map[string]any{ + "handler_name": handlerName, + "event_name": evt, + "confidence": 0.5, + "confidence_band": "medium", + "async_type": "event_bus", + "fallback": true, + }, + }) + } + } + continue + } + + for _, evt := range pubEvents { + hIDs := eventHandlers[evt] + for _, hID := range hIDs { + if pubID == hID { + continue + } + key := [2]int64{pubID, hID} + if seen[key] { + continue + } + seen[key] = true + + handlerNode := nodeLookup[hID] + handlerName := "" + if handlerNode != nil { + handlerName = handlerNode.Name + } + edges = append(edges, &store.Edge{ + Project: p.ProjectName, + SourceID: pubID, + TargetID: hID, + Type: "ASYNC_CALLS", + Properties: map[string]any{ + "handler_name": handlerName, + "event_name": evt, + "confidence": 0.9, + "confidence_band": "high", + "async_type": "event_bus", + }, + }) + } + } + } + + if len(edges) > 0 { + if err := p.Store.InsertEdgeBatch(edges); err != nil { + slog.Warn("pubsub.write_err", "err", err) + } + } + + slog.Info("pubsub.done", + "publishers", len(publisherIDs), + "subscribers", len(subscriberIDs), + "event_types", len(eventHandlers), + "edges_created", len(edges), + "fallback_edges", fallbackCount, + "elapsed", time.Since(t), + ) +} + +// extractEventNames extracts event constant names from source code using the given patterns. +// Each pattern must have exactly one capture group for the event name. +// Returns a deduplicated slice of event names found. +func extractEventNames(source string, patterns []*regexp.Regexp) []string { + seen := make(map[string]bool) + var result []string + for _, pat := range patterns { + matches := pat.FindAllStringSubmatch(source, -1) + for _, m := range matches { + if len(m) >= 2 && m[1] != "" { + name := m[1] + if !seen[name] { + seen[name] = true + result = append(result, name) + } + } + } + } + return result +} + +// attributeHandlersToEvents scans subscriber function source to attribute handler calls +// to specific events. It extracts (eventName, line) pairs from Subscribe calls, then +// for each handler function name, finds its call site(s) in the source and attributes +// each to the nearest preceding Subscribe call. +// +// This correctly handles the pattern where one function (e.g. RegisterListeners) calls +// Subscribe multiple times with different events and different inline handlers. +func attributeHandlersToEvents(source string, patterns []*regexp.Regexp, handlerNameToIDs map[string][]int64) map[string][]int64 { + result := make(map[string][]int64) + + // Extract Subscribe call sites: (eventName, lineNumber) + subCalls := extractSubscribeCallLines(source, patterns) + if len(subCalls) == 0 { + return result + } + + lines := strings.Split(source, "\n") + + // For each handler name, find its call sites in the source and attribute + // to the nearest preceding Subscribe call. + for handlerName, hIDs := range handlerNameToIDs { + // Find the handler name followed by '(' to confirm it's a call site, + // not a substring of another identifier or a string/comment. + for lineIdx, line := range lines { + lineNum := lineIdx + 1 // 1-based + idx := strings.Index(line, handlerName) + if idx < 0 { + continue + } + endIdx := idx + len(handlerName) + if endIdx >= len(line) || line[endIdx] != '(' { + continue + } + // Verify it looks like a call, not just a comment or string + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "//") || strings.HasPrefix(trimmed, "/*") { + continue + } + + // Attribute to nearest preceding Subscribe call + bestEvent := findNearestPrecedingEvent(subCalls, lineNum) + if bestEvent != "" { + result[bestEvent] = append(result[bestEvent], hIDs...) + } + } + } + + // Deduplicate handler IDs per event + for evt, hIDs := range result { + result[evt] = deduplicateInt64(hIDs) + } + + return result +} + +// extractSubscribeCallLines extracts (eventName, lineNumber) pairs from Subscribe calls +// in the given source string. +func extractSubscribeCallLines(source string, patterns []*regexp.Regexp) []subscribeCallLine { + var calls []subscribeCallLine + lines := strings.Split(source, "\n") + for lineIdx, line := range lines { + for _, pat := range patterns { + m := pat.FindStringSubmatch(line) + if len(m) >= 2 && m[1] != "" { + calls = append(calls, subscribeCallLine{ + eventName: m[1], + line: lineIdx + 1, // 1-based + }) + } + } + } + // Sort by line number (should already be ordered, but be safe) + sort.Slice(calls, func(i, j int) bool { + return calls[i].line < calls[j].line + }) + return calls +} + +// findNearestPrecedingEvent finds the event name of the Subscribe call that is +// closest to (and before or on) the given line number. +func findNearestPrecedingEvent(subCalls []subscribeCallLine, lineNum int) string { + best := "" + bestLine := 0 + for _, sc := range subCalls { + if sc.line <= lineNum && sc.line > bestLine { + best = sc.eventName + bestLine = sc.line + } + } + return best +} + +// sourceFileCache caches file contents to avoid re-reading the same file multiple times. +type sourceFileCache struct { + mu sync.Mutex + files map[string]string // absPath -> file contents +} + +// readLines reads the given file and returns the lines in the range [startLine, endLine] (1-based, inclusive). +// File contents are cached. +func (c *sourceFileCache) readLines(repoPath, relPath string, startLine, endLine int) (string, error) { + absPath := filepath.Join(repoPath, relPath) + + c.mu.Lock() + content, ok := c.files[absPath] + c.mu.Unlock() + + if !ok { + data, err := os.ReadFile(absPath) + if err != nil { + return "", err + } + content = string(data) + c.mu.Lock() + c.files[absPath] = content + c.mu.Unlock() + } + + if startLine <= 0 || endLine <= 0 { + return content, nil + } + + lines := strings.Split(content, "\n") + if startLine > len(lines) { + return "", nil + } + if endLine > len(lines) { + endLine = len(lines) + } + // Convert to 0-based indexing + return strings.Join(lines[startLine-1:endLine], "\n"), nil +} + +// deduplicateInt64 returns a new slice with duplicate values removed, preserving order. +func deduplicateInt64(ids []int64) []int64 { + seen := make(map[int64]bool, len(ids)) + result := make([]int64, 0, len(ids)) + for _, id := range ids { + if !seen[id] { + seen[id] = true + result = append(result, id) + } + } + return result +} + +// collectEdgeNodeIDs returns a deduplicated slice of all source and target node IDs +// referenced by the given edges. +func collectEdgeNodeIDs(edges []*store.Edge) []int64 { + seen := make(map[int64]bool, len(edges)*2) + for _, e := range edges { + seen[e.SourceID] = true + seen[e.TargetID] = true + } + ids := make([]int64, 0, len(seen)) + for id := range seen { + ids = append(ids, id) + } + return ids +} diff --git a/internal/pipeline/pubsub_test.go b/internal/pipeline/pubsub_test.go new file mode 100644 index 0000000..d5902eb --- /dev/null +++ b/internal/pipeline/pubsub_test.go @@ -0,0 +1,402 @@ +package pipeline + +import ( + "os" + "path/filepath" + "regexp" + "sort" + "testing" +) + +// --- extractEventNames tests --- + +func TestExtractEventNames_GoPublish(t *testing.T) { + source := ` + go func() { + s.EventBus.Publish(events.EventCheckinCompleted, events.CheckinCompletedPayload{ + UserID: userID, + EventID: eventID, + }) + }() +` + names := extractEventNames(source, publishEventPatterns) + if len(names) != 1 { + t.Fatalf("expected 1 event, got %d: %v", len(names), names) + } + if names[0] != "EventCheckinCompleted" { + t.Errorf("expected EventCheckinCompleted, got %s", names[0]) + } +} + +func TestExtractEventNames_GoPublishMultiple(t *testing.T) { + source := ` + s.eventBus.Publish(events.EventUserMentioned, events.UserMentionedPayload{}) + s.eventBus.Publish(events.EventMessageSent, events.MessageSentPayload{}) +` + names := extractEventNames(source, publishEventPatterns) + if len(names) != 2 { + t.Fatalf("expected 2 events, got %d: %v", len(names), names) + } + sort.Strings(names) + if names[0] != "EventMessageSent" || names[1] != "EventUserMentioned" { + t.Errorf("unexpected events: %v", names) + } +} + +func TestExtractEventNames_GoPublishWithoutPackagePrefix(t *testing.T) { + source := `bus.Publish(EventCheckinCompleted, payload)` + names := extractEventNames(source, publishEventPatterns) + if len(names) != 1 || names[0] != "EventCheckinCompleted" { + t.Errorf("expected [EventCheckinCompleted], got %v", names) + } +} + +func TestExtractEventNames_GoSubscribe(t *testing.T) { + source := ` + bus.Subscribe(events.EventCheckinCompleted, func(p any) { + payload, ok := p.(events.CheckinCompletedPayload) + if !ok { return } + xpSvc.AwardXP(ctx, payload.UserID, XPActionCheckin) + }) +` + names := extractEventNames(source, subscribeEventPatterns) + if len(names) != 1 || names[0] != "EventCheckinCompleted" { + t.Errorf("expected [EventCheckinCompleted], got %v", names) + } +} + +func TestExtractEventNames_JSEmit(t *testing.T) { + source := `emitter.emit('user.created', { userId: 123 });` + names := extractEventNames(source, publishEventPatterns) + if len(names) != 1 || names[0] != "user.created" { + t.Errorf("expected [user.created], got %v", names) + } +} + +func TestExtractEventNames_NoMatches(t *testing.T) { + source := `fmt.Println("hello world")` + names := extractEventNames(source, publishEventPatterns) + if len(names) != 0 { + t.Errorf("expected 0 events, got %v", names) + } +} + +func TestExtractEventNames_Deduplicates(t *testing.T) { + source := ` + s.EventBus.Publish(events.EventCheckinCompleted, payload1) + s.EventBus.Publish(events.EventCheckinCompleted, payload2) +` + names := extractEventNames(source, publishEventPatterns) + if len(names) != 1 { + t.Errorf("expected 1 deduplicated event, got %d: %v", len(names), names) + } +} + +// --- attributeHandlersToEvents tests --- + +func TestAttributeHandlersToEvents_SingleSubscribe(t *testing.T) { + source := `bus.Subscribe(events.EventCheckinCompleted, func(p any) { + payload, ok := p.(events.CheckinCompletedPayload) + if !ok { return } + xpSvc.AwardXP(ctx, payload.UserID, XPActionCheckin) + badgeSvc.CheckAndAwardBadges(ctx, payload.UserID, "checkin") +})` + handlerNameToIDs := map[string][]int64{ + "AwardXP": {100}, + "CheckAndAwardBadges": {101}, + } + + result := attributeHandlersToEvents(source, subscribeEventPatterns, handlerNameToIDs) + + handlers, ok := result["EventCheckinCompleted"] + if !ok { + t.Fatal("expected EventCheckinCompleted in result") + } + sort.Slice(handlers, func(i, j int) bool { return handlers[i] < handlers[j] }) + if len(handlers) != 2 || handlers[0] != 100 || handlers[1] != 101 { + t.Errorf("expected [100, 101], got %v", handlers) + } +} + +func TestAttributeHandlersToEvents_MultipleSubscribes(t *testing.T) { + // Simulates RegisterListeners with multiple Subscribe calls + source := `bus.Subscribe(events.EventCheckinCompleted, func(p any) { + payload, ok := p.(events.CheckinCompletedPayload) + if !ok { return } + xpSvc.AwardXP(ctx, payload.UserID, XPActionCheckin) + badgeSvc.CheckAndAwardBadges(ctx, payload.UserID, "checkin") + missionSvc.IncrementMissionProgress(ctx, payload.UserID, "checkin") +}) + +bus.Subscribe(events.EventVoteCast, func(p any) { + payload, ok := p.(events.VoteCastPayload) + if !ok { return } + xpSvc.AwardXP(ctx, payload.UserID, XPActionVote) + badgeSvc.CheckAndAwardBadges(ctx, payload.UserID, "vote") + missionSvc.IncrementMissionProgress(ctx, payload.UserID, "vote") +}) + +bus.Subscribe(events.EventPlaceSuggested, func(p any) { + payload, ok := p.(events.PlaceSuggestedPayload) + if !ok { return } + xpSvc.AwardXP(ctx, payload.UserID, XPActionPlaceSuggest) +}) + +bus.Subscribe(events.EventReactionAdded, func(p any) { + payload, ok := p.(events.ReactionAddedPayload) + if !ok { return } + xpSvc.AwardXP(ctx, payload.UserID, XPActionReaction) +})` + + handlerNameToIDs := map[string][]int64{ + "AwardXP": {100}, + "CheckAndAwardBadges": {101}, + "IncrementMissionProgress": {102}, + } + + result := attributeHandlersToEvents(source, subscribeEventPatterns, handlerNameToIDs) + + // EventCheckinCompleted -> AwardXP, CheckAndAwardBadges, IncrementMissionProgress + checkin := result["EventCheckinCompleted"] + sort.Slice(checkin, func(i, j int) bool { return checkin[i] < checkin[j] }) + if len(checkin) != 3 || checkin[0] != 100 || checkin[1] != 101 || checkin[2] != 102 { + t.Errorf("EventCheckinCompleted: expected [100,101,102], got %v", checkin) + } + + // EventVoteCast -> AwardXP, CheckAndAwardBadges, IncrementMissionProgress + vote := result["EventVoteCast"] + sort.Slice(vote, func(i, j int) bool { return vote[i] < vote[j] }) + if len(vote) != 3 || vote[0] != 100 || vote[1] != 101 || vote[2] != 102 { + t.Errorf("EventVoteCast: expected [100,101,102], got %v", vote) + } + + // EventPlaceSuggested -> AwardXP only + suggested := result["EventPlaceSuggested"] + if len(suggested) != 1 || suggested[0] != 100 { + t.Errorf("EventPlaceSuggested: expected [100], got %v", suggested) + } + + // EventReactionAdded -> AwardXP only + reaction := result["EventReactionAdded"] + if len(reaction) != 1 || reaction[0] != 100 { + t.Errorf("EventReactionAdded: expected [100], got %v", reaction) + } + + // Verify no extra events + if len(result) != 4 { + t.Errorf("expected 4 events, got %d: %v", len(result), keysOf(result)) + } +} + +func TestAttributeHandlersToEvents_NoSubscribeCalls(t *testing.T) { + source := `fmt.Println("no subscribe here")` + handlerNameToIDs := map[string][]int64{"AwardXP": {100}} + result := attributeHandlersToEvents(source, subscribeEventPatterns, handlerNameToIDs) + if len(result) != 0 { + t.Errorf("expected empty result, got %v", result) + } +} + +func TestAttributeHandlersToEvents_HandlerInComment(t *testing.T) { + source := `bus.Subscribe(events.EventCheckinCompleted, func(p any) { + // AwardXP is called here for XP + xpSvc.AwardXP(ctx, payload.UserID, XPActionCheckin) +})` + handlerNameToIDs := map[string][]int64{"AwardXP": {100}} + result := attributeHandlersToEvents(source, subscribeEventPatterns, handlerNameToIDs) + + // The comment line is skipped, but the real call on the next line should still match + handlers := result["EventCheckinCompleted"] + if len(handlers) != 1 || handlers[0] != 100 { + t.Errorf("expected [100], got %v", handlers) + } +} + +// --- extractSubscribeCallLines tests --- + +func TestExtractSubscribeCallLines(t *testing.T) { + source := `bus.Subscribe(events.EventCheckinCompleted, func(p any) { + xpSvc.AwardXP(ctx, p) +}) +bus.Subscribe(events.EventVoteCast, func(p any) { + xpSvc.AwardXP(ctx, p) +})` + calls := extractSubscribeCallLines(source, subscribeEventPatterns) + if len(calls) != 2 { + t.Fatalf("expected 2 subscribe calls, got %d", len(calls)) + } + if calls[0].eventName != "EventCheckinCompleted" || calls[0].line != 1 { + t.Errorf("call[0]: expected EventCheckinCompleted@1, got %s@%d", calls[0].eventName, calls[0].line) + } + if calls[1].eventName != "EventVoteCast" || calls[1].line != 4 { + t.Errorf("call[1]: expected EventVoteCast@4, got %s@%d", calls[1].eventName, calls[1].line) + } +} + +// --- sourceFileCache tests --- + +func TestSourceFileCache_ReadLines(t *testing.T) { + dir := t.TempDir() + content := "line1\nline2\nline3\nline4\nline5\n" + path := filepath.Join(dir, "test.go") + if err := os.WriteFile(path, []byte(content), 0o600); err != nil { + t.Fatal(err) + } + + cache := &sourceFileCache{files: make(map[string]string)} + + // Read lines 2-4 + got, err := cache.readLines(dir, "test.go", 2, 4) + if err != nil { + t.Fatal(err) + } + expected := "line2\nline3\nline4" + if got != expected { + t.Errorf("expected %q, got %q", expected, got) + } + + // Verify caching: second call should use cache + got2, err := cache.readLines(dir, "test.go", 1, 1) + if err != nil { + t.Fatal(err) + } + if got2 != "line1" { + t.Errorf("expected 'line1', got %q", got2) + } +} + +func TestSourceFileCache_ReadWholeFile(t *testing.T) { + dir := t.TempDir() + content := "line1\nline2\n" + path := filepath.Join(dir, "test.go") + if err := os.WriteFile(path, []byte(content), 0o600); err != nil { + t.Fatal(err) + } + + cache := &sourceFileCache{files: make(map[string]string)} + got, err := cache.readLines(dir, "test.go", 0, 0) + if err != nil { + t.Fatal(err) + } + if got != content { + t.Errorf("expected full file, got %q", got) + } +} + +// --- findNearestPrecedingEvent tests --- + +func TestFindNearestPrecedingEvent(t *testing.T) { + calls := []subscribeCallLine{ + {"EventA", 5}, + {"EventB", 15}, + {"EventC", 25}, + } + + tests := []struct { + line int + expected string + }{ + {3, ""}, // before any subscribe + {5, "EventA"}, // on the subscribe line itself + {10, "EventA"}, + {15, "EventB"}, + {20, "EventB"}, + {30, "EventC"}, + } + + for _, tt := range tests { + got := findNearestPrecedingEvent(calls, tt.line) + if got != tt.expected { + t.Errorf("line %d: expected %q, got %q", tt.line, tt.expected, got) + } + } +} + +// --- deduplicateInt64 tests --- + +func TestDeduplicateInt64(t *testing.T) { + input := []int64{1, 2, 3, 2, 1, 4} + got := deduplicateInt64(input) + if len(got) != 4 { + t.Errorf("expected 4 unique values, got %d: %v", len(got), got) + } +} + +// --- Pattern compilation tests --- + +func TestPublishEventPatterns_Compile(t *testing.T) { + // Verify the patterns compile and match expected strings + tests := []struct { + input string + patterns []*regexp.Regexp + want string + }{ + {`.Publish(events.EventCheckinCompleted,`, publishEventPatterns, "EventCheckinCompleted"}, + {`.Emit(events.EventCheckinCompleted,`, publishEventPatterns, "EventCheckinCompleted"}, + {`.Subscribe(events.EventVoteCast,`, subscribeEventPatterns, "EventVoteCast"}, + {`.emit('user.created',`, publishEventPatterns, "user.created"}, + } + + for _, tt := range tests { + found := false + for _, pat := range tt.patterns { + m := pat.FindStringSubmatch(tt.input) + if len(m) >= 2 && m[1] == tt.want { + found = true + break + } + } + if !found { + t.Errorf("no pattern matched %q for expected %q", tt.input, tt.want) + } + } +} + +func TestAttributeHandlersToEvents_SubstringNoFalseMatch(t *testing.T) { + source := ` +bus.Subscribe(events.EventCheckinCompleted, func(p any) { + svc.AwardXP(ctx, userID, 10) +}) +bus.Subscribe(events.EventVoteCast, func(p any) { + svc.Award(ctx, userID) +})` + + handlerNameToIDs := map[string][]int64{ + "Award": {100}, + "AwardXP": {200}, + } + + result := attributeHandlersToEvents(source, subscribeEventPatterns, handlerNameToIDs) + + // AwardXP should be attributed to EventCheckinCompleted + if !containsID(result["EventCheckinCompleted"], 200) { + t.Error("expected AwardXP under EventCheckinCompleted") + } + // Award should be attributed to EventVoteCast + if !containsID(result["EventVoteCast"], 100) { + t.Error("expected Award under EventVoteCast") + } + // Award should NOT appear under EventCheckinCompleted (substring false match) + if containsID(result["EventCheckinCompleted"], 100) { + t.Error("Award should not be under EventCheckinCompleted — substring false match") + } +} + +func containsID(ids []int64, target int64) bool { + for _, id := range ids { + if id == target { + return true + } + } + return false +} + +// keysOf returns the keys of a map for debugging. +func keysOf(m map[string][]int64) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} diff --git a/internal/pipeline/resolver.go b/internal/pipeline/resolver.go index b9db299..6fd4c49 100644 --- a/internal/pipeline/resolver.go +++ b/internal/pipeline/resolver.go @@ -135,7 +135,14 @@ func (r *FunctionRegistry) resolveViaNameLookup(calleeName, suffix, moduleQN str // Strategy 3: unique name — single candidate project-wide if len(candidates) == 1 { conf := 0.75 - if importMap != nil && !isImportReachable(candidates[0], importMap) { + reachable := importMap == nil || isImportReachable(candidates[0], importMap) + if !reachable { + // Cross-app guard: in monorepos, reject unique_name matches across + // app boundaries (e.g., apps/mobile → apps/api-go). These can only + // communicate via HTTP_CALLS, not direct CALLS edges. + if isCrossApp(moduleQN, candidates[0]) { + return ResolutionResult{} + } conf *= 0.5 } return ResolutionResult{QualifiedName: candidates[0], Strategy: "unique_name", Confidence: conf, CandidateCount: 1} @@ -225,8 +232,12 @@ func (r *FunctionRegistry) FuzzyResolve(calleeName, moduleQN string, importMap m // If there's exactly one candidate, use it if len(candidates) == 1 { + reachable := importMap == nil || isImportReachable(candidates[0], importMap) + if !reachable && isCrossApp(moduleQN, candidates[0]) { + return ResolutionResult{}, false + } conf := 0.40 - if importMap != nil && !isImportReachable(candidates[0], importMap) { + if !reachable { conf *= 0.5 } return ResolutionResult{ @@ -241,14 +252,23 @@ func (r *FunctionRegistry) FuzzyResolve(calleeName, moduleQN string, importMap m filtered = filterImportReachable(candidates, importMap) } if len(filtered) == 0 { - // No import-reachable candidates — use original with penalty - best := bestByImportDistance(candidates, moduleQN) + // No import-reachable candidates — filter cross-app, then pick best + var sameApp []string + for _, c := range candidates { + if !isCrossApp(moduleQN, c) { + sameApp = append(sameApp, c) + } + } + if len(sameApp) == 0 { + return ResolutionResult{}, false + } + best := bestByImportDistance(sameApp, moduleQN) if best == "" { return ResolutionResult{}, false } return ResolutionResult{ QualifiedName: best, Strategy: "fuzzy", - Confidence: 0.30 * 0.5, CandidateCount: len(candidates), + Confidence: 0.30 * 0.5, CandidateCount: len(sameApp), }, true } if len(filtered) == 1 { @@ -390,6 +410,33 @@ func filterImportReachable(candidates []string, importMap map[string]string) []s return filtered } +// isCrossApp detects when caller and callee are in different app boundaries +// within a monorepo. Qualified names encode the file path, so we extract the +// first 3 dotted segments after the project name (e.g., "apps.mobile.src" vs +// "apps.api-go.internal") and reject if they diverge. This prevents false +// CALLS edges between frontend and backend that can only communicate via HTTP. +func isCrossApp(callerQN, candidateQN string) bool { + callerApp := appSegment(callerQN) + candidateApp := appSegment(candidateQN) + if callerApp == "" || candidateApp == "" { + return false // can't determine, don't block + } + return callerApp != candidateApp +} + +// appSegment extracts the app boundary segment from a qualified name. +// For "Project.apps.mobile.src.components.X" → "apps.mobile" +// For "Project.apps.api-go.internal.event.X" → "apps.api-go" +// For "Project.scripts.foo.X" → "scripts.foo" +// Returns "" if the QN is too short to determine. +func appSegment(qn string) string { + parts := strings.SplitN(qn, ".", 4) // [project, dir1, dir2, rest] + if len(parts) < 3 { + return "" + } + return parts[1] + "." + parts[2] +} + // confidenceBand returns a human-readable band label for a confidence score. func confidenceBand(score float64) string { switch { diff --git a/internal/tools/tools.go b/internal/tools/tools.go index eeaf9a8..9e57b34 100644 --- a/internal/tools/tools.go +++ b/internal/tools/tools.go @@ -483,6 +483,10 @@ func (s *Server) registerIndexAndTraceTool() { "type": "number", "description": "Minimum confidence threshold (0.0-1.0) for CALLS edges. Filters out low-confidence fuzzy matches. Bands: high (>=0.7), medium (>=0.45), speculative (<0.45). Default 0 (no filter)." }, + "qualified_name": { + "type": "string", + "description": "Exact qualified name for disambiguation (e.g. 'auth.SessionService.Create'). Takes priority over function_name when both are provided. Use search_graph or the suggestions from a previous trace_call_path to find the correct qualified_name." + }, "project": { "type": "string", "description": "Project to trace in. Defaults to session project." @@ -854,3 +858,30 @@ func (s *Server) findNodeAcrossProjects(name string, projectFilter ...string) (* } return nil, "", fmt.Errorf("node not found: %s", name) } + +// findNodeByQNAcrossProjects searches for a node by exact qualified name. +func (s *Server) findNodeByQNAcrossProjects(qn string, projectFilter ...string) (*store.Node, string, error) { + filter := s.sessionProject + if len(projectFilter) > 0 && projectFilter[0] != "" { + filter = projectFilter[0] + } + if filter == "" { + return nil, "", fmt.Errorf("no project specified and no session project detected") + } + if !s.router.HasProject(filter) { + return nil, "", fmt.Errorf("project %q not found", filter) + } + + st, err := s.router.ForProject(filter) + if err != nil { + return nil, "", err + } + projects, _ := st.ListProjects() + for _, p := range projects { + node, findErr := st.FindNodeByQN(p.Name, qn) + if findErr == nil && node != nil { + return node, p.Name, nil + } + } + return nil, "", fmt.Errorf("node not found: %s", qn) +} diff --git a/internal/tools/trace.go b/internal/tools/trace.go index 8633dc7..ff1041a 100644 --- a/internal/tools/trace.go +++ b/internal/tools/trace.go @@ -37,11 +37,22 @@ func (s *Server) handleTraceCallPath(_ context.Context, req *mcp.CallToolRequest riskLabels := getBoolArg(args, "risk_labels") minConfidence := getFloatArg(args, "min_confidence", 0) + qualifiedName := getStringArg(args, "qualified_name") + project := getStringArg(args, "project") effectiveProject := s.resolveProjectName(project) - // Find the function node - rootNode, foundProject, findErr := s.findNodeAcrossProjects(funcName, effectiveProject) + // Find the function node — qualified_name takes priority when provided + var rootNode *store.Node + var foundProject string + var findErr error + + if qualifiedName != "" { + rootNode, foundProject, findErr = s.findNodeByQNAcrossProjects(qualifiedName, effectiveProject) + } + if rootNode == nil { + rootNode, foundProject, findErr = s.findNodeAcrossProjects(funcName, effectiveProject) + } if findErr != nil && !strings.HasPrefix(findErr.Error(), "node not found") { return errResult(findErr.Error()), nil }