From b543a2158528941e249a106665ed1b8f4e83e3a8 Mon Sep 17 00:00:00 2001 From: joshbouncesecurity Date: Tue, 12 May 2026 11:53:10 +0300 Subject: [PATCH] feat: add get_static_dependencies tool and DI-aware forward tracing to agentic enhancer Adds `get_static_dependencies` tool so the exploration agent can retrieve call graph data (callee IDs, caller IDs) resolved from static analysis, then use `read_function` to examine service/repository methods for authorization and validation checks. Also updates the enhancer prompt to instruct the agent to trace forward into called functions before classifying a finding as EXPLOITABLE, since security controls are often delegated to the service layer in NestJS codebases. Co-Authored-By: Claude Sonnet 4.6 --- .../openant-core/tests/test_enhancer_tools.py | 127 ++++++++++++++++++ .../utilities/agentic_enhancer/agent.py | 3 + .../utilities/agentic_enhancer/prompts.py | 41 ++++-- .../agentic_enhancer/repository_index.py | 54 +++++++- .../utilities/agentic_enhancer/tools.py | 36 +++++ 5 files changed, 248 insertions(+), 13 deletions(-) create mode 100644 libs/openant-core/tests/test_enhancer_tools.py diff --git a/libs/openant-core/tests/test_enhancer_tools.py b/libs/openant-core/tests/test_enhancer_tools.py new file mode 100644 index 0000000..a862f05 --- /dev/null +++ b/libs/openant-core/tests/test_enhancer_tools.py @@ -0,0 +1,127 @@ +"""Tests for the agentic enhancer tools, specifically the get_static_dependencies tool.""" +import pytest + +from utilities.agentic_enhancer.repository_index import RepositoryIndex +from utilities.agentic_enhancer.tools import ToolExecutor + + +def _make_index(functions: dict) -> RepositoryIndex: + """Create a RepositoryIndex from a minimal functions dict.""" + return RepositoryIndex({"functions": functions}) + + +SAMPLE_FUNCTIONS = { + "src/user.controller.ts:UserController.getUser": { + "name": "UserController.getUser", + "code": "async getUser(id) { return this.userService.findById(id); }", + "className": "UserController", + "unitType": "class_method", + "startLine": 10, + "endLine": 12, + }, + "src/user.service.ts:UserService.findById": { + "name": "UserService.findById", + "code": "async findById(id) { return this.repo.findOne(id); }", + "className": "UserService", + "unitType": "class_method", + "startLine": 5, + "endLine": 7, + }, + "src/auth.guard.ts:AuthGuard.canActivate": { + "name": "AuthGuard.canActivate", + "code": "canActivate(context) { return this.validate(context); }", + "className": "AuthGuard", + "unitType": "class_method", + "startLine": 3, + "endLine": 5, + }, +} + + +class TestResolveDependencies: + """Test RepositoryIndex.resolve_dependencies.""" + + def test_resolves_by_function_id(self): + index = _make_index(SAMPLE_FUNCTIONS) + result = index.resolve_dependencies([ + "src/user.service.ts:UserService.findById" + ]) + assert len(result) == 1 + assert result[0]["id"] == "src/user.service.ts:UserService.findById" + assert result[0]["className"] == "UserService" + + def test_resolves_by_qualified_name(self): + """Resolve using Class.method format when full ID is unknown.""" + index = _make_index(SAMPLE_FUNCTIONS) + result = index.resolve_dependencies(["AuthGuard.canActivate"]) + assert len(result) == 1 + assert "AuthGuard.canActivate" in result[0]["id"] + + def test_returns_empty_for_unknown(self): + index = _make_index(SAMPLE_FUNCTIONS) + result = index.resolve_dependencies(["nonExistentFunction"]) + assert result == [] + + def test_deduplicates_results(self): + index = _make_index(SAMPLE_FUNCTIONS) + result = index.resolve_dependencies([ + "src/user.service.ts:UserService.findById", + "src/user.service.ts:UserService.findById", + ]) + assert len(result) == 1 + + +class TestGetStaticDependenciesTool: + """Test the get_static_dependencies tool via ToolExecutor.""" + + def test_returns_resolved_deps(self): + index = _make_index(SAMPLE_FUNCTIONS) + executor = ToolExecutor(index) + executor.set_unit_context( + static_deps=["src/user.service.ts:UserService.findById"], + static_callers=[], + ) + + result = executor.execute("get_static_dependencies", {}) + assert result["dependencies"]["count"] == 1 + assert len(result["dependencies"]["resolved"]) == 1 + assert result["dependencies"]["resolved"][0]["className"] == "UserService" + assert result["callers"]["count"] == 0 + + def test_returns_resolved_callers(self): + index = _make_index(SAMPLE_FUNCTIONS) + executor = ToolExecutor(index) + executor.set_unit_context( + static_deps=[], + static_callers=["src/user.controller.ts:UserController.getUser"], + ) + + result = executor.execute("get_static_dependencies", {}) + assert result["callers"]["count"] == 1 + assert result["callers"]["resolved"][0]["className"] == "UserController" + + def test_empty_context(self): + index = _make_index(SAMPLE_FUNCTIONS) + executor = ToolExecutor(index) + executor.set_unit_context([], []) + + result = executor.execute("get_static_dependencies", {}) + assert result["dependencies"]["count"] == 0 + assert result["callers"]["count"] == 0 + + def test_context_resets_between_units(self): + index = _make_index(SAMPLE_FUNCTIONS) + executor = ToolExecutor(index) + + # First unit + executor.set_unit_context( + static_deps=["src/user.service.ts:UserService.findById"], + static_callers=[], + ) + result1 = executor.execute("get_static_dependencies", {}) + assert result1["dependencies"]["count"] == 1 + + # Second unit - different context + executor.set_unit_context(static_deps=[], static_callers=[]) + result2 = executor.execute("get_static_dependencies", {}) + assert result2["dependencies"]["count"] == 0 diff --git a/libs/openant-core/utilities/agentic_enhancer/agent.py b/libs/openant-core/utilities/agentic_enhancer/agent.py index 62061b7..513f728 100644 --- a/libs/openant-core/utilities/agentic_enhancer/agent.py +++ b/libs/openant-core/utilities/agentic_enhancer/agent.py @@ -161,6 +161,9 @@ def analyze_unit( entry_point_path = self.reachability.get_entry_point_path(unit_id) reaching_entry_point = self.reachability.get_reaching_entry_point(unit_id) + # Set static deps on tool executor for get_static_dependencies tool + self.tool_executor.set_unit_context(static_deps, static_callers) + # Build initial prompt with reachability info user_prompt = get_user_prompt( unit_id=unit_id, diff --git a/libs/openant-core/utilities/agentic_enhancer/prompts.py b/libs/openant-core/utilities/agentic_enhancer/prompts.py index dd9ca83..0594bbc 100644 --- a/libs/openant-core/utilities/agentic_enhancer/prompts.py +++ b/libs/openant-core/utilities/agentic_enhancer/prompts.py @@ -39,25 +39,40 @@ ## Your Analysis Process -1. **Identify Dangerous Operations** +1. **Get Static Dependencies First** + Call `get_static_dependencies` to see what functions this code calls and what calls it. + Then use `read_function` to examine key dependencies — especially service methods + that may contain authorization, validation, or sanitization. + +2. **Identify Dangerous Operations** Look for: eval, exec, SQL queries, file I/O, deserialization, command execution, innerHTML -2. **Trace User Input Reachability** +3. **Trace User Input Reachability (Backward)** If dangerous operations exist, trace BACKWARDS: - Who calls this function? - Who calls those callers? - Does the chain lead to an entry point (route handler, CLI parser, stdin)? -3. **Apply Classification Logic** +4. **Trace Forward Into Called Functions** + Check what the function CALLS — especially service/repository methods: + - Use `search_definitions` to find implementations of called methods + - Look for authorization checks (auth, permission, guard, can, allow, authorize) + - Look for validation/sanitization in called code + - A function may delegate security to its callees (e.g., service-layer auth) + - For `this.someService.method()` patterns, search for the method name definition + +5. **Apply Classification Logic** ``` Has dangerous sink? ├─ No → NEUTRAL or SECURITY_CONTROL └─ Yes → Is reachable from entry point? - ├─ Yes → EXPLOITABLE + ├─ Yes → Are there security controls in called functions? + │ ├─ Yes → May be SECURITY_CONTROL or lower severity + │ └─ No → EXPLOITABLE └─ No → VULNERABLE_INTERNAL ``` -4. **Complete with finish tool** +6. **Complete with finish tool** Provide classification, reasoning, and confidence level. ## Entry Point Examples @@ -150,19 +165,25 @@ def get_user_prompt( ## Your Task -1. **Analyze for dangerous operations**: eval, exec, SQL, file I/O, deserialization, etc. +1. **Start with `get_static_dependencies`** to see resolved callees and callers. + Then use `read_function` to examine called service/repository methods. -2. **Consider reachability**: Can user input reach any dangerous operations? +2. **Analyze for dangerous operations**: eval, exec, SQL, file I/O, deserialization, etc. + +3. **Consider reachability**: Can user input reach any dangerous operations? - If this is an entry point or reachable from one: vulnerabilities are EXPLOITABLE - If not reachable: vulnerabilities are VULNERABLE_INTERNAL -3. **Classify the code**: - - **EXPLOITABLE**: Dangerous ops + user input can reach them +4. **Trace forward**: Check called functions for authorization, validation, or security controls. + A function may delegate security to its service layer. + +5. **Classify the code**: + - **EXPLOITABLE**: Dangerous ops + user input can reach them + no security controls in callees - **VULNERABLE_INTERNAL**: Dangerous ops but no user input path - **SECURITY_CONTROL**: Defensive code (validators, sanitizers) - **NEUTRAL**: No security relevance -4. Call the `finish` tool with your classification and reasoning. +6. Call the `finish` tool with your classification and reasoning. Begin your analysis.""" diff --git a/libs/openant-core/utilities/agentic_enhancer/repository_index.py b/libs/openant-core/utilities/agentic_enhancer/repository_index.py index 5af649c..e027335 100644 --- a/libs/openant-core/utilities/agentic_enhancer/repository_index.py +++ b/libs/openant-core/utilities/agentic_enhancer/repository_index.py @@ -14,12 +14,11 @@ load_index_from_file: Load index from analyzer_output.json file """ +import json import re from pathlib import Path from typing import Optional -from utilities.file_io import read_json - class RepositoryIndex: """ @@ -247,6 +246,54 @@ def read_file_section(self, file_path: str, start_line: int, end_line: int) -> O except Exception: return None + def resolve_dependencies(self, dep_names: list[str]) -> list[dict]: + """ + Resolve dependency names from static analysis to function entries. + + Handles both full function IDs (file:Class.method) and simple names. + + Args: + dep_names: List of function IDs or names from static analysis + + Returns: + List of {name, id, file, className} for each resolved dependency + """ + results = [] + seen_ids = set() + + for name in dep_names: + # First try as a direct function ID + func = self.functions.get(name) + if func and name not in seen_ids: + seen_ids.add(name) + results.append({ + "name": name, + "id": name, + "file": name.rsplit(":", 1)[0] if ":" in name else "", + "className": func.get("className") + }) + continue + + # Try exact name match + matches = self.search_by_name(name, exact=True) + if not matches: + # Try just the method part (e.g., "Class.method" -> "method") + parts = name.rsplit(".", 1) + if len(parts) == 2: + matches = self.search_by_name(parts[1], exact=True) + + for m in matches: + if m["id"] not in seen_ids: + seen_ids.add(m["id"]) + results.append({ + "name": name, + "id": m["id"], + "file": m["id"].rsplit(":", 1)[0] if ":" in m["id"] else "", + "className": m.get("className") + }) + + return results + def get_all_function_ids(self) -> list[str]: """ Get list of all function IDs. @@ -284,6 +331,7 @@ def load_index_from_file(analyzer_output_path: str, repo_path: str = None) -> Re Returns: RepositoryIndex instance """ - analyzer_output = read_json(analyzer_output_path) + with open(analyzer_output_path, 'r') as f: + analyzer_output = json.load(f) return RepositoryIndex(analyzer_output, repo_path) diff --git a/libs/openant-core/utilities/agentic_enhancer/tools.py b/libs/openant-core/utilities/agentic_enhancer/tools.py index b380c2c..8cf0947 100644 --- a/libs/openant-core/utilities/agentic_enhancer/tools.py +++ b/libs/openant-core/utilities/agentic_enhancer/tools.py @@ -102,6 +102,15 @@ "required": ["file_path", "start_line", "end_line"] } }, + { + "name": "get_static_dependencies", + "description": "Get the statically-analyzed dependencies (functions called) and callers for the unit being analyzed. Returns resolved function IDs that can be read with read_function. Use this first to understand what the code calls and to trace into service methods for auth/validation checks.", + "input_schema": { + "type": "object", + "properties": {}, + "required": [] + } + }, { "name": "finish", "description": "Complete the analysis and return the final result. Call this when you have gathered enough context to understand the code's intent and security implications.", @@ -165,6 +174,13 @@ def __init__(self, index: RepositoryIndex): index: RepositoryIndex instance for searching """ self.index = index + self._unit_static_deps: list[str] = [] + self._unit_static_callers: list[str] = [] + + def set_unit_context(self, static_deps: list[str], static_callers: list[str]): + """Set static dependency data for the current unit being analyzed.""" + self._unit_static_deps = static_deps or [] + self._unit_static_callers = static_callers or [] def execute(self, tool_name: str, tool_input: dict) -> dict: """ @@ -188,6 +204,8 @@ def execute(self, tool_name: str, tool_input: dict) -> dict: return self._list_functions(tool_input) elif tool_name == "read_file_section": return self._read_file_section(tool_input) + elif tool_name == "get_static_dependencies": + return self._get_static_dependencies(tool_input) elif tool_name == "finish": return self._finish(tool_input) else: @@ -315,6 +333,24 @@ def _read_file_section(self, input: dict) -> dict: "content": content } + def _get_static_dependencies(self, input: dict) -> dict: + """Get resolved static dependencies and callers for the current unit.""" + resolved_deps = self.index.resolve_dependencies(self._unit_static_deps) + resolved_callers = self.index.resolve_dependencies(self._unit_static_callers) + + return { + "dependencies": { + "raw": self._unit_static_deps[:20], + "resolved": resolved_deps[:20], + "count": len(self._unit_static_deps) + }, + "callers": { + "raw": self._unit_static_callers[:20], + "resolved": resolved_callers[:20], + "count": len(self._unit_static_callers) + } + } + def _finish(self, input: dict) -> dict: """Process finish tool - just validate and return the input.""" required = ["include_functions", "usage_context", "security_classification", "classification_reasoning", "confidence"]