Skip to content

Commit 918b725

Browse files
committed
Add initial support for the VulnerableCode agent
Signed-off-by: ziad hany <ziadhany2016@gmail.com>
1 parent fabe035 commit 918b725

13 files changed

Lines changed: 300 additions & 0 deletions

File tree

agent/__init__.py

Whitespace-only changes.

agent/admin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from django.contrib import admin
2+
3+
# Register your models here.

agent/apps.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from django.apps import AppConfig
2+
3+
4+
class AgentConfig(AppConfig):
5+
default_auto_field = "django.db.models.BigAutoField"
6+
name = "agent"

agent/forms.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from django import forms
2+
3+
4+
class VulnerabilityAgentForm(forms.Form):
5+
message = forms.CharField(
6+
required=True,
7+
widget=forms.TextInput(
8+
attrs={"placeholder": "Ask the VulnerableCode Agent anything you need."}
9+
),
10+
)

agent/migrations/__init__.py

Whitespace-only changes.

agent/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from django.db import models
2+
3+
# Create your models here.

agent/templates/vuln-agent.html

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
{% extends "base.html" %}
2+
{% load humanize %}
3+
{% load widget_tweaks %}
4+
{% load markdown_to_html %}
5+
6+
{% block title %}
7+
VulnerableCode Agent
8+
{% endblock %}
9+
10+
{% block content %}
11+
<section class="section pt-0">
12+
13+
<article class='panel is-info'>
14+
<div class='panel-heading py-2 is-size-6'>
15+
Ask VulnerableCode Agent
16+
<div class="dropdown is-hoverable has-text-weight-normal">
17+
<div class="dropdown-trigger">
18+
<i class="fa fa-question-circle ml-2"></i>
19+
</div>
20+
<div class="dropdown-menu dropdown-instructions-width" id="dropdown-menu4" role="menu">
21+
<div class="dropdown-content dropdown-instructions-box-shadow">
22+
<div class="dropdown-item">
23+
<div>
24+
Ask the agent to analyze a vulnerability, examine the affected version, and retrieve more details about it.
25+
</div>
26+
</div>
27+
</div>
28+
</div>
29+
</div>
30+
</div>
31+
<div class="panel-block">
32+
<div class="pb-3 width-100-pct">
33+
<form
34+
action="{% url 'vuln-agent' %}"
35+
method="post"
36+
name="vulnerability_agent_form"
37+
>
38+
<div class="field has-addons mt-3">
39+
<div class="control width-100-pct">
40+
{{ vulnerability_agent_form.message|add_class:"input" }}
41+
</div>
42+
<div class="control">
43+
<button class="button is-link" type="submit" id="submit_pkg">
44+
Send
45+
</button>
46+
{% csrf_token %}
47+
</div>
48+
</div>
49+
</form>
50+
</div>
51+
</div>
52+
53+
<div id="tab-content">
54+
{{ message|markdown_to_html|safe }}
55+
</div>
56+
</article>
57+
58+
</section>
59+
60+
{% endblock %}

agent/tests.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from django.test import TestCase
2+
3+
# Create your tests here.

agent/urls.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from django.urls import path
2+
3+
from agent.views import VulnAgent
4+
5+
urlpatterns = [
6+
path(
7+
"",
8+
VulnAgent.as_view(),
9+
name="vuln-agent",
10+
),
11+
]

agent/views.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import re
2+
from pathlib import Path
3+
from typing import List
4+
from typing import Union
5+
6+
import chromadb
7+
import yaml
8+
from chromadb.utils import embedding_functions
9+
from django.http.response import Http404
10+
from django.shortcuts import render
11+
from django.views import View
12+
from langchain.chains import RetrievalQA
13+
from langchain.prompts import PromptTemplate
14+
from langchain.schema import Document
15+
from langchain_chroma import Chroma
16+
from langchain_community.document_loaders import DirectoryLoader
17+
from langchain_core.document_loaders import BaseLoader
18+
from langchain_huggingface import HuggingFaceEmbeddings
19+
from langchain_ollama import OllamaLLM
20+
from tqdm import tqdm
21+
22+
from agent.forms import VulnerabilityAgentForm
23+
24+
25+
class YAMLLoader(BaseLoader):
26+
"""Load and parse a YAML file into a Document."""
27+
28+
def __init__(self, file_path: Union[str, Path]):
29+
"""Initialize with the file path."""
30+
self.file_path = file_path
31+
32+
def load(self) -> List[Document]:
33+
# Open the YAML file and load its content
34+
with open(self.file_path, "r") as file:
35+
try:
36+
# Load the YAML content
37+
data = yaml.safe_load(file)
38+
# Convert the YAML content to a string (or you can format it differently)
39+
text = str(data.get("summary", ""))
40+
except yaml.YAMLError as e:
41+
print(f"Error loading YAML file {self.file_path}: {e}")
42+
text = "" # Set text to empty in case of error
43+
44+
# Define metadata with file path information
45+
metadata = {"source": str(self.file_path)}
46+
47+
# Return the loaded content as a list of Documents
48+
return [Document(page_content=text, metadata=metadata)]
49+
50+
51+
# Initialize embeddings
52+
embeddings = HuggingFaceEmbeddings(
53+
model_name="sentence-transformers/all-MiniLM-L6-v2",
54+
model_kwargs={"device": "cpu"}, # Use CPU
55+
encode_kwargs={"normalize_embeddings": True}, # Normalize embeddings for cosine similarity
56+
)
57+
58+
59+
try:
60+
# Load ChromaDB Persistent Client
61+
chroma_client = chromadb.PersistentClient(path="vuln_index")
62+
63+
# Load the existing collection
64+
collection = chroma_client.get_collection(name="vuln_embeddings")
65+
66+
print("✅ ChromaDB collection loaded successfully!")
67+
except Exception as e:
68+
print(f"⚠️ Collection not found. Initializing ChromaDB. Error: {e}")
69+
70+
# Load documents from a directory
71+
loader = DirectoryLoader(
72+
"vulnerablecode-data", # ADD THE vulnerablecode-data PATH
73+
glob="**/*.yaml",
74+
use_multithreading=True,
75+
loader_cls=YAMLLoader,
76+
)
77+
docs = loader.load()
78+
print(f"Loaded {len(docs)} documents.")
79+
80+
# Initialize ChromaDB client
81+
chroma_client = chromadb.PersistentClient(
82+
path="vuln_index"
83+
) # Chroma persists data automatically
84+
85+
# Define collection (equivalent to a FAISS index)
86+
collection = chroma_client.get_or_create_collection(name="vuln_embeddings")
87+
88+
# Ensure embeddings function is compatible
89+
embedding_function = embedding_functions.DefaultEmbeddingFunction()
90+
91+
# Index each document by its file name
92+
for i, doc in enumerate(tqdm(docs, desc="Indexing documents")):
93+
file = doc.metadata.get("source", "unknown")
94+
file_name = Path(file).stem
95+
package_name = Path(file).parts[8]
96+
print(file_name, package_name)
97+
98+
# Add to ChromaDB
99+
collection.add(
100+
ids=[file_name], # Unique identifier (use file name)
101+
documents=[doc.page_content], # Document content
102+
metadatas=[
103+
{
104+
"file_name": file_name,
105+
"package_name": package_name,
106+
"vulnerability_id": file_name,
107+
}
108+
],
109+
)
110+
111+
print("✅ Documents indexed in ChromaDB.")
112+
113+
114+
llm = OllamaLLM(model="deepseek-r1:14b")
115+
vector_db = Chroma(
116+
client=chroma_client, collection_name="vuln_embeddings", embedding_function=embeddings
117+
)
118+
retriever = vector_db.as_retriever(search_type="mmr", search_kwargs={"k": 1})
119+
qa_chain = RetrievalQA.from_chain_type(llm, retriever=retriever, chain_type="stuff")
120+
121+
122+
class VulnAgent(View):
123+
template_name = "vuln-agent.html"
124+
125+
def get(self, request):
126+
context = {
127+
"vulnerability_agent_form": VulnerabilityAgentForm(request.GET),
128+
}
129+
return render(request=request, template_name=self.template_name, context=context)
130+
131+
def post(self, request):
132+
form = VulnerabilityAgentForm(request.POST)
133+
if form.is_valid():
134+
question = form.cleaned_data["message"]
135+
message_data = self.summary_analyzer(question=question)
136+
137+
context = {
138+
"vulnerability_agent_form": VulnerabilityAgentForm(initial={"message": question}),
139+
"message": message_data,
140+
}
141+
return render(request=request, template_name=self.template_name, context=context)
142+
return Http404("Invalid form data") # FIXME
143+
144+
def summary_analyzer(self, question):
145+
prompt = PromptTemplate(
146+
input_variables=["context", "question"],
147+
template="""
148+
You are a highly specialized Vulnerability Analysis Assistant. Your task is to analyze the following vulnerability summary and accurately extract the affected and fixed versions of the software.
149+
150+
Output Format:
151+
- Affected Version: Use one of the following formats:
152+
- >= <version>, <= <version>, > <version>, < <version>
153+
- A specific range like <version1> - <version2>
154+
- Fixed Version: Use one of the following formats:
155+
- >= <version>, <= <version>, > <version>, < <version>
156+
- "Not Fixed" if no fixed version is mentioned.
157+
158+
Instructions:
159+
- Ensure accuracy by considering different ways affected and fixed versions might be described in the summary.
160+
- Extract only version-related details without adding any extra information.
161+
162+
Database Context:
163+
{context}
164+
165+
Question:
166+
{question}
167+
168+
Provide the answer strictly based on the above context.
169+
""",
170+
)
171+
vulnerability_id = extract_vulnerability_id(question)
172+
retriever.search_kwargs["filter"] = {"vulnerability_id": vulnerability_id}
173+
context = retriever.invoke(question)
174+
175+
print(context)
176+
formatted_prompt = prompt.format(context=context, question=question)
177+
response = qa_chain.invoke(formatted_prompt)
178+
179+
result = response["result"]
180+
cleaned_result = re.sub(r"<think>.*?</think>", "", result, flags=re.DOTALL).strip()
181+
return cleaned_result
182+
183+
184+
def extract_vulnerability_id(query):
185+
"""
186+
Extracts the vulnerability ID from a user query.
187+
Assumes the format: 'VCID-xxxx-xxxx-xxxx'.
188+
"""
189+
match = re.search(r"VCID-[a-zA-Z0-9-]+", query)
190+
return match.group(0) if match else None

0 commit comments

Comments
 (0)