Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions libs/openant-core/tests/test_enhancer_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Tests for the agentic enhancer tools, specifically the get_static_dependencies tool."""
import pytest

from utilities.agentic_enhancer.repository_index import RepositoryIndex
from utilities.agentic_enhancer.tools import ToolExecutor


def _make_index(functions: dict) -> RepositoryIndex:
"""Create a RepositoryIndex from a minimal functions dict."""
return RepositoryIndex({"functions": functions})


SAMPLE_FUNCTIONS = {
"src/user.controller.ts:UserController.getUser": {
"name": "UserController.getUser",
"code": "async getUser(id) { return this.userService.findById(id); }",
"className": "UserController",
"unitType": "class_method",
"startLine": 10,
"endLine": 12,
},
"src/user.service.ts:UserService.findById": {
"name": "UserService.findById",
"code": "async findById(id) { return this.repo.findOne(id); }",
"className": "UserService",
"unitType": "class_method",
"startLine": 5,
"endLine": 7,
},
"src/auth.guard.ts:AuthGuard.canActivate": {
"name": "AuthGuard.canActivate",
"code": "canActivate(context) { return this.validate(context); }",
"className": "AuthGuard",
"unitType": "class_method",
"startLine": 3,
"endLine": 5,
},
}


class TestResolveDependencies:
"""Test RepositoryIndex.resolve_dependencies."""

def test_resolves_by_function_id(self):
index = _make_index(SAMPLE_FUNCTIONS)
result = index.resolve_dependencies([
"src/user.service.ts:UserService.findById"
])
assert len(result) == 1
assert result[0]["id"] == "src/user.service.ts:UserService.findById"
assert result[0]["className"] == "UserService"

def test_resolves_by_qualified_name(self):
"""Resolve using Class.method format when full ID is unknown."""
index = _make_index(SAMPLE_FUNCTIONS)
result = index.resolve_dependencies(["AuthGuard.canActivate"])
assert len(result) == 1
assert "AuthGuard.canActivate" in result[0]["id"]

def test_returns_empty_for_unknown(self):
index = _make_index(SAMPLE_FUNCTIONS)
result = index.resolve_dependencies(["nonExistentFunction"])
assert result == []

def test_deduplicates_results(self):
index = _make_index(SAMPLE_FUNCTIONS)
result = index.resolve_dependencies([
"src/user.service.ts:UserService.findById",
"src/user.service.ts:UserService.findById",
])
assert len(result) == 1


class TestGetStaticDependenciesTool:
"""Test the get_static_dependencies tool via ToolExecutor."""

def test_returns_resolved_deps(self):
index = _make_index(SAMPLE_FUNCTIONS)
executor = ToolExecutor(index)
executor.set_unit_context(
static_deps=["src/user.service.ts:UserService.findById"],
static_callers=[],
)

result = executor.execute("get_static_dependencies", {})
assert result["dependencies"]["count"] == 1
assert len(result["dependencies"]["resolved"]) == 1
assert result["dependencies"]["resolved"][0]["className"] == "UserService"
assert result["callers"]["count"] == 0

def test_returns_resolved_callers(self):
index = _make_index(SAMPLE_FUNCTIONS)
executor = ToolExecutor(index)
executor.set_unit_context(
static_deps=[],
static_callers=["src/user.controller.ts:UserController.getUser"],
)

result = executor.execute("get_static_dependencies", {})
assert result["callers"]["count"] == 1
assert result["callers"]["resolved"][0]["className"] == "UserController"

def test_empty_context(self):
index = _make_index(SAMPLE_FUNCTIONS)
executor = ToolExecutor(index)
executor.set_unit_context([], [])

result = executor.execute("get_static_dependencies", {})
assert result["dependencies"]["count"] == 0
assert result["callers"]["count"] == 0

def test_context_resets_between_units(self):
index = _make_index(SAMPLE_FUNCTIONS)
executor = ToolExecutor(index)

# First unit
executor.set_unit_context(
static_deps=["src/user.service.ts:UserService.findById"],
static_callers=[],
)
result1 = executor.execute("get_static_dependencies", {})
assert result1["dependencies"]["count"] == 1

# Second unit - different context
executor.set_unit_context(static_deps=[], static_callers=[])
result2 = executor.execute("get_static_dependencies", {})
assert result2["dependencies"]["count"] == 0
3 changes: 3 additions & 0 deletions libs/openant-core/utilities/agentic_enhancer/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def analyze_unit(
entry_point_path = self.reachability.get_entry_point_path(unit_id)
reaching_entry_point = self.reachability.get_reaching_entry_point(unit_id)

# Set static deps on tool executor for get_static_dependencies tool
self.tool_executor.set_unit_context(static_deps, static_callers)

# Build initial prompt with reachability info
user_prompt = get_user_prompt(
unit_id=unit_id,
Expand Down
41 changes: 31 additions & 10 deletions libs/openant-core/utilities/agentic_enhancer/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,40 @@

## Your Analysis Process

1. **Identify Dangerous Operations**
1. **Get Static Dependencies First**
Call `get_static_dependencies` to see what functions this code calls and what calls it.
Then use `read_function` to examine key dependencies — especially service methods
that may contain authorization, validation, or sanitization.

2. **Identify Dangerous Operations**
Look for: eval, exec, SQL queries, file I/O, deserialization, command execution, innerHTML

2. **Trace User Input Reachability**
3. **Trace User Input Reachability (Backward)**
If dangerous operations exist, trace BACKWARDS:
- Who calls this function?
- Who calls those callers?
- Does the chain lead to an entry point (route handler, CLI parser, stdin)?

3. **Apply Classification Logic**
4. **Trace Forward Into Called Functions**
Check what the function CALLS — especially service/repository methods:
- Use `search_definitions` to find implementations of called methods
- Look for authorization checks (auth, permission, guard, can, allow, authorize)
- Look for validation/sanitization in called code
- A function may delegate security to its callees (e.g., service-layer auth)
- For `this.someService.method()` patterns, search for the method name definition

5. **Apply Classification Logic**
```
Has dangerous sink?
├─ No → NEUTRAL or SECURITY_CONTROL
└─ Yes → Is reachable from entry point?
├─ Yes → EXPLOITABLE
├─ Yes → Are there security controls in called functions?
│ ├─ Yes → May be SECURITY_CONTROL or lower severity
│ └─ No → EXPLOITABLE
└─ No → VULNERABLE_INTERNAL
```

4. **Complete with finish tool**
6. **Complete with finish tool**
Provide classification, reasoning, and confidence level.

## Entry Point Examples
Expand Down Expand Up @@ -150,19 +165,25 @@ def get_user_prompt(

## Your Task

1. **Analyze for dangerous operations**: eval, exec, SQL, file I/O, deserialization, etc.
1. **Start with `get_static_dependencies`** to see resolved callees and callers.
Then use `read_function` to examine called service/repository methods.

2. **Consider reachability**: Can user input reach any dangerous operations?
2. **Analyze for dangerous operations**: eval, exec, SQL, file I/O, deserialization, etc.

3. **Consider reachability**: Can user input reach any dangerous operations?
- If this is an entry point or reachable from one: vulnerabilities are EXPLOITABLE
- If not reachable: vulnerabilities are VULNERABLE_INTERNAL

3. **Classify the code**:
- **EXPLOITABLE**: Dangerous ops + user input can reach them
4. **Trace forward**: Check called functions for authorization, validation, or security controls.
A function may delegate security to its service layer.

5. **Classify the code**:
- **EXPLOITABLE**: Dangerous ops + user input can reach them + no security controls in callees
- **VULNERABLE_INTERNAL**: Dangerous ops but no user input path
- **SECURITY_CONTROL**: Defensive code (validators, sanitizers)
- **NEUTRAL**: No security relevance

4. Call the `finish` tool with your classification and reasoning.
6. Call the `finish` tool with your classification and reasoning.

Begin your analysis."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
load_index_from_file: Load index from analyzer_output.json file
"""

import json
import re
from pathlib import Path
from typing import Optional

from utilities.file_io import read_json


class RepositoryIndex:
"""
Expand Down Expand Up @@ -247,6 +246,54 @@ def read_file_section(self, file_path: str, start_line: int, end_line: int) -> O
except Exception:
return None

def resolve_dependencies(self, dep_names: list[str]) -> list[dict]:
"""
Resolve dependency names from static analysis to function entries.

Handles both full function IDs (file:Class.method) and simple names.

Args:
dep_names: List of function IDs or names from static analysis

Returns:
List of {name, id, file, className} for each resolved dependency
"""
results = []
seen_ids = set()

for name in dep_names:
# First try as a direct function ID
func = self.functions.get(name)
if func and name not in seen_ids:
seen_ids.add(name)
results.append({
"name": name,
"id": name,
"file": name.rsplit(":", 1)[0] if ":" in name else "",
"className": func.get("className")
})
continue

# Try exact name match
matches = self.search_by_name(name, exact=True)
if not matches:
# Try just the method part (e.g., "Class.method" -> "method")
parts = name.rsplit(".", 1)
if len(parts) == 2:
matches = self.search_by_name(parts[1], exact=True)

for m in matches:
if m["id"] not in seen_ids:
seen_ids.add(m["id"])
results.append({
"name": name,
"id": m["id"],
"file": m["id"].rsplit(":", 1)[0] if ":" in m["id"] else "",
"className": m.get("className")
})

return results

def get_all_function_ids(self) -> list[str]:
"""
Get list of all function IDs.
Expand Down Expand Up @@ -284,6 +331,7 @@ def load_index_from_file(analyzer_output_path: str, repo_path: str = None) -> Re
Returns:
RepositoryIndex instance
"""
analyzer_output = read_json(analyzer_output_path)
with open(analyzer_output_path, 'r') as f:
analyzer_output = json.load(f)

return RepositoryIndex(analyzer_output, repo_path)
36 changes: 36 additions & 0 deletions libs/openant-core/utilities/agentic_enhancer/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@
"required": ["file_path", "start_line", "end_line"]
}
},
{
"name": "get_static_dependencies",
"description": "Get the statically-analyzed dependencies (functions called) and callers for the unit being analyzed. Returns resolved function IDs that can be read with read_function. Use this first to understand what the code calls and to trace into service methods for auth/validation checks.",
"input_schema": {
"type": "object",
"properties": {},
"required": []
}
},
{
"name": "finish",
"description": "Complete the analysis and return the final result. Call this when you have gathered enough context to understand the code's intent and security implications.",
Expand Down Expand Up @@ -165,6 +174,13 @@ def __init__(self, index: RepositoryIndex):
index: RepositoryIndex instance for searching
"""
self.index = index
self._unit_static_deps: list[str] = []
self._unit_static_callers: list[str] = []

def set_unit_context(self, static_deps: list[str], static_callers: list[str]):
"""Set static dependency data for the current unit being analyzed."""
self._unit_static_deps = static_deps or []
self._unit_static_callers = static_callers or []

def execute(self, tool_name: str, tool_input: dict) -> dict:
"""
Expand All @@ -188,6 +204,8 @@ def execute(self, tool_name: str, tool_input: dict) -> dict:
return self._list_functions(tool_input)
elif tool_name == "read_file_section":
return self._read_file_section(tool_input)
elif tool_name == "get_static_dependencies":
return self._get_static_dependencies(tool_input)
elif tool_name == "finish":
return self._finish(tool_input)
else:
Expand Down Expand Up @@ -315,6 +333,24 @@ def _read_file_section(self, input: dict) -> dict:
"content": content
}

def _get_static_dependencies(self, input: dict) -> dict:
"""Get resolved static dependencies and callers for the current unit."""
resolved_deps = self.index.resolve_dependencies(self._unit_static_deps)
resolved_callers = self.index.resolve_dependencies(self._unit_static_callers)

return {
"dependencies": {
"raw": self._unit_static_deps[:20],
"resolved": resolved_deps[:20],
"count": len(self._unit_static_deps)
},
"callers": {
"raw": self._unit_static_callers[:20],
"resolved": resolved_callers[:20],
"count": len(self._unit_static_callers)
}
}

def _finish(self, input: dict) -> dict:
"""Process finish tool - just validate and return the input."""
required = ["include_functions", "usage_context", "security_classification", "classification_reasoning", "confidence"]
Expand Down
Loading