diff --git a/__tests__/graph.test.ts b/__tests__/graph.test.ts index 7c771af0..fd1cf5d5 100644 --- a/__tests__/graph.test.ts +++ b/__tests__/graph.test.ts @@ -281,6 +281,19 @@ export { main }; expect(Array.isArray(callers)).toBe(true); }); + it('should get instantiating callers of a class', () => { + const nodes = cg.getNodesByKind('class'); + const derivedClass = nodes.find((n) => n.name === 'DerivedClass'); + + if (!derivedClass) { + return; + } + + const callers = cg.getCallers(derivedClass.id); + + expect(callers.some((c) => c.node.name === 'main' && c.edge.kind === 'instantiates')).toBe(true); + }); + it('should get callees of a function', () => { const nodes = cg.getNodesByKind('function'); const processValue = nodes.find((n) => n.name === 'processValue'); diff --git a/src/graph/traversal.ts b/src/graph/traversal.ts index c366721b..82c74691 100644 --- a/src/graph/traversal.ts +++ b/src/graph/traversal.ts @@ -248,7 +248,7 @@ export class GraphTraverser { } visited.add(nodeId); - const incomingEdges = this.queries.getIncomingEdges(nodeId, ['calls', 'references', 'imports']); + const incomingEdges = this.queries.getIncomingEdges(nodeId, ['calls', 'references', 'imports', 'instantiates']); if (incomingEdges.length === 0) return; // Batch-fetch all caller nodes in one round-trip instead of one