Skip to content

Commit 445e26e

Browse files
committed
tools subpackage
1 parent c5d0788 commit 445e26e

3 files changed

Lines changed: 393 additions & 0 deletions

File tree

mmz/agents/tools/__init__.py

Whitespace-only changes.

mmz/agents/tools/with_guidance.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
2+
from dataclasses import dataclass, field
3+
from guidance import models, gen, block
4+
from typing import ClassVar, Optional
5+
import pydantic
6+
from pydantic import create_model
7+
from functools import cached_property
8+
import json
9+
import os
10+
11+
import guidance
12+
from guidance import one_or_more, select, zero_or_more
13+
from simple_parsing import Serializable
14+
import numpy as np
15+
16+
from mmz.agents.with_guidance import GuidanceLlamaCppConfig
17+
from mmz.agents import tools as mzt
18+
19+
20+
# stateless=True indicates this function does not depend on LLM generations
21+
@guidance(stateless=True)
22+
def reference_selection(lm, n_total_refs, selection_name='selected_ixes'):
23+
nums = [str(ii) for ii in range(n_total_refs)]
24+
return lm + one_or_more(select(nums, name=selection_name))
25+
26+
27+
@guidance(stateless=True)
28+
def operator(lm):
29+
return lm + select(['+' , '*', '**', '/', '-'])
30+
31+
32+
def get_summary_relevance_prompt(query: str, summaries: list[dict]) -> str:
33+
import json
34+
from datetime import datetime
35+
current_time = datetime.now().strftime("%Y-%m-%d")
36+
shuffled_ixes = list(range(len(summaries)))
37+
np.random.shuffle(shuffled_ixes)
38+
print("shuffled_ixes:", str(shuffled_ixes))
39+
prompt = f"""Analyze these search results and provide a ranked list of the most relevant ones.
40+
41+
IMPORTANT: Evaluate and rank based on these criteria (in order of importance):
42+
1. Timeliness - current/recent information as of {current_time}
43+
2. Direct relevance to query: "{query}"
44+
3. Source reliability (prefer official sources, established websites)
45+
4. Factual accuracy (cross-reference major claims)
46+
47+
Search results to evaluate:
48+
{json.dumps([{'reference index': ii,
49+
'title': s['title'],
50+
'summary': s['summary']}
51+
for ii, s in enumerate(summaries)]
52+
, indent=2)}
53+
54+
Return ONLY a JSON array of the 0-based reference index, ranked from most to least relevant.
55+
Include ONLY indices that meet ALL criteria, with the most relevant first.
56+
You should list all {len(summaries)} indices in your response.
57+
You should not output any number larger than {len(summaries) - 1}
58+
Respond with ONLY the JSON array, no other text."""
59+
#Example response (yours should be different!): {shuffled_ixes}
60+
return prompt
61+
62+
def get_summary_relevance_scalar_prompt(query: str, summary: dict):
63+
import json
64+
from datetime import datetime
65+
current_time = datetime.now().strftime("%Y-%m-%d")
66+
prompt = f"""Analyze these search results and provide a number
67+
between 0 and 100 according to its relevance to the users query,
68+
100 being the most relevant and likley answers the query
69+
0 being the least relevant and does not answer the query
70+
71+
IMPORTANT: Evaluate and estimate relevance based on these criteria (in order of importance):
72+
1. Timeliness - current/recent information as of {current_time}
73+
2. Direct relevance to query: "{query}"
74+
75+
Search results to evaluate:
76+
{json.dumps({'title': summary['title'],
77+
'summary': summary['summary']}
78+
, indent=2)}
79+
80+
Respond only with a number with in 0 and 100 and nothing else: """
81+
return prompt
82+
83+
@guidance
84+
def relevance_by_regex(llm, query, summaries):
85+
relevance_prompt = get_summary_relevance_prompt(query, summaries=summaries)
86+
out = llm + relevance_prompt + '[ ' + gen(regex=r'\d+') + ']'
87+
return out
88+
89+
90+
@guidance
91+
def relevance_by_selection(llm, query, summaries, selection_name='selected_ixes'):
92+
#print(f"Got summaries in relevance selection:\n{summaries}")
93+
relevance_prompt = get_summary_relevance_prompt(query, summaries=summaries)
94+
out = (llm + relevance_prompt
95+
+ '[ ' + reference_selection(n_total_refs=len(summaries),
96+
selection_name=selection_name) + ']')
97+
#print("Output produced")
98+
return out
99+
100+
101+
def get_list_of_int_grammar(name="integers"):
102+
from pydantic import create_model
103+
schema = create_model(f"list_of_{name}", **{name: list[int]})
104+
#class ListOfString(pydantic.BaseModel):
105+
# indices: list[str]
106+
json_list = guidance.json(name=name, schema=schema)
107+
return json_list
108+
109+
110+
def get_list_additional_topics_prompt(query: str) -> str:
111+
from datetime import datetime
112+
#datetime.now(tz='EST')
113+
#t = datetime.now().strftime()
114+
t = str(datetime.now())
115+
prompt = f"""The local time is {t}"""
116+
prompt += """Given the users query, produce a JSON list of other topics related to their query.\n"""
117+
prompt += f"""Here is their query: {query}\n"""
118+
prompt += """Provide a list of JSON strings of related topics: """
119+
return prompt
120+
121+
122+
def get_list_of_str_grammar(name="strings"):
123+
schema = create_model(f"list_of_{name}", **{name: list[str]})
124+
json_list = guidance.json(name=name, schema=schema)
125+
return json_list
126+
127+
128+
def get_q_and_a_grammar(name='answer'):
129+
schema = create_model(f"{name}", **{name: str, 'confidence': int})
130+
json_qa = guidance.json(name=name, schema=schema)
131+
return json_qa
132+
133+
134+
@guidance
135+
def select_next(choices):
136+
pass
137+
138+
139+
@guidance
140+
def relevance_by_json_int_list(llm, query, summaries, name='selected_ixes'):
141+
relevance_prompt = get_summary_relevance_prompt(query, summaries=summaries)
142+
#class ListOfIntegers(pydantic.BaseModel):
143+
# indices: list[int]
144+
#json_list = guidance.json("selected_ixes", schema=ListOfIntegers)
145+
#return llm + relevance_prompt + json_list
146+
return llm + relevance_prompt + get_list_of_int_grammar(name=name)
147+
148+
149+
@guidance
150+
def relevance_scalar(llm, query, summary, name='relevance_magnitude'):
151+
from pydantic import create_model
152+
schema = create_model(f"scalar_{name}", **{name: int})
153+
relevance_prompt = get_summary_relevance_scalar_prompt(query, summary=summary)
154+
return llm + relevance_prompt + guidance.json(name=name, schema=schema)
155+
156+
157+
@dataclass
158+
class GuidanceGuide(Serializable):
159+
model_preset: Optional[str] = 'med'
160+
161+
model_config: Optional[GuidanceLlamaCppConfig] = None
162+
163+
def __post_init__(self):
164+
if self.model_config is None:
165+
self.model_config = GuidanceLlamaCppConfig.get_preset(self.model_preset)
166+
167+
@property
168+
def model(self) -> models.Model:
169+
return self.model_config.model
170+
171+
def get_relevant_ixes_from_summary(self, user_q: str,
172+
summaries: list[dict],
173+
relevance_grammar: callable = relevance_by_selection,
174+
as_list: bool = True):
175+
res = self.model + relevance_grammar(user_q, summaries=summaries)
176+
res = res['selected_ixes']
177+
if as_list:
178+
if relevance_grammar == relevance_by_selection:
179+
res = json.loads(f"[{res}]")
180+
else:
181+
res = json.loads(res['selected_ixes'])['selected_ixes']
182+
return res
183+
184+
def filter_to_relevant_summeries(self, user_q:str,
185+
summaries: list[dict],
186+
relevance_grammar: callable = relevance_by_selection) -> list[dict]:
187+
ixes_to_keep = self.get_relevant_ixes_from_summary(
188+
user_q=user_q,
189+
summaries=summaries,
190+
relevance_grammar=relevance_grammar
191+
)
192+
return [summaries[i] for i in ixes_to_keep]
193+
194+
def get_relevance_score(self, user_q:str, summary: dict) -> int:
195+
res = self.model + relevance_scalar(query=user_q,
196+
summary=summary)
197+
res = json.loads(res['relevance_magnitude'])['relevance_magnitude']
198+
return int(res)
199+
200+
def expand_topic_grammar(self, user_q: str):
201+
return (self.model
202+
+ get_list_additional_topics_prompt(query=user_q)
203+
+ get_list_of_str_grammar(name='topics'))
204+
205+
def expand_topics(self, user_q: str,
206+
as_list: bool = True,
207+
deduplicate_list: bool = True
208+
) -> str | list[str]:
209+
# ** First ['topics'] access is to guidance to get that prompts raw results
210+
res = self.expand_topic_grammar(user_q=user_q)['topics']
211+
if as_list:
212+
# ** Next access ['topics'] is to access the value at the 'topics' key to
213+
# get the list of topics from the deserialized json
214+
topic_l = json.loads(res)['topics']
215+
topic_l = list(set(topic_l)) if deduplicate_list else topic_l
216+
return topic_l
217+
else:
218+
return res
219+
#return json.loads(res)['topics'] if as_list else res
220+
221+
def answer_query(self, user_q: str, content):
222+
return
223+
224+

mmz/datasets/cvelist5.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import os
2+
import pandas as pd
3+
from dataclasses import dataclass, field
4+
from typing import List, Optional, Any, Tuple
5+
from simple_parsing import Serializable
6+
from functools import cached_property
7+
import subprocess
8+
import json
9+
from mmz.utils import run_subprocess
10+
11+
12+
@dataclass
13+
class ExploitDB(Serializable):
14+
root_directory: str
15+
search_sploit_path: str = field(default=None, init=None)
16+
17+
def __post_init__(self):
18+
self.search_sploit_path = os.path.join(self.root_directory, 'searchsploit')
19+
20+
def list_exploits(self) -> List[str]:
21+
"""List all exploits in the ExploitDB directory."""
22+
import os
23+
return [f for f in os.listdir(self.root_directory) if os.path.isfile(os.path.join(self.root_directory, f))]
24+
25+
@cached_property
26+
def help_string(self):
27+
"""Get the help string from searchsploit."""
28+
command = [self.search_sploit_path, '-h']
29+
output, error = run_subprocess(command)
30+
if error:
31+
return f"Error: {error}"
32+
else:
33+
return output
34+
35+
def print_help(self) -> str:
36+
"""Invoke searchsploit with --help argument using popen."""
37+
print(self.help_string)
38+
return self.help_string
39+
40+
def searchsploit(self, *args: str) -> str:
41+
command = [self.search_sploit_path] + list(args)
42+
output, _ = run_subprocess(command)
43+
return output
44+
45+
def searchsploit_as_json(self, *args: str,
46+
deserialize_results: bool = True) -> str | Any:
47+
if not any(_j in args for _j in ['-j', '--json']):
48+
args = list(['-j'] + list(args))
49+
o = self.searchsploit(*args)
50+
o = json.loads(o) if deserialize_results else o
51+
return o
52+
53+
@staticmethod
54+
def flatten_cve_results(results) -> list[dict]:
55+
"""
56+
Flattens the given CVE results to a simple list of dictionaries.
57+
58+
Args:
59+
results (dict): A dictionary containing CVE search results.
60+
61+
Returns:
62+
list: A flattened list of dictionaries with high-level summaries of each CVE.
63+
"""
64+
flattened_results = []
65+
66+
# Flatten RESULTS_EXPLOIT
67+
for item in results['RESULTS_EXPLOIT']:
68+
flattened_item = {
69+
'Title': item['Title'],
70+
'EDB-ID': item['EDB-ID'],
71+
'Date_Published': item['Date_Published'],
72+
'Author': item['Author'],
73+
'Type': item['Type'],
74+
'Platform': item['Platform'],
75+
'Verified': item['Verified'],
76+
'Source': 'exploit-db:exploit'
77+
}
78+
flattened_results.append(flattened_item)
79+
80+
# Flatten RESULTS_SHELLCODE
81+
for item in results['RESULTS_SHELLCODE']:
82+
flattened_item = {
83+
'Title': item['Title'],
84+
'EDB-ID': item['EDB-ID'],
85+
'Date_Published': item['Date_Published'],
86+
'Author': item['Author'],
87+
'Type': item['Type'],
88+
'Platform': item['Platform'],
89+
'Verified': item['Verified'],
90+
'Source': 'exploit-db:shellcode'
91+
}
92+
flattened_results.append(flattened_item)
93+
94+
return flattened_results
95+
96+
def searchsploit_as_summary(self, *args: str,
97+
fields: list[str] = None,
98+
n: int = None) -> str:
99+
fields = ['index', 'Title'] if fields is None else fields
100+
json_dat = self.searchsploit_as_json(*args)
101+
df = pd.DataFrame(self.flatten_cve_results(json_dat))
102+
_df = df.reset_index()[fields]
103+
_df = _df if n is None else _df.head(n)
104+
return _df.to_json()
105+
106+
107+
import guidance
108+
from guidance import capture, Tool
109+
from pydantic import create_model
110+
from guidance import regex
111+
112+
@guidance(stateless=True)
113+
def searchsploit_call(lm):
114+
#schema = create_model(f"{name}", **{name: str, 'confidence': int})
115+
#json_qa = guidance.json(name=name, schema=schema)
116+
words_rx = regex(r'\w+')
117+
words_rx.match('foo bar')
118+
119+
# capture just 'names' the expression, to be saved in the LM state
120+
return lm + 'searchsploit(' + capture(words_rx, 'tool_args') + ')'
121+
122+
123+
@guidance
124+
def searchsploit(lm):
125+
search_terms = lm['tool_args']
126+
# You typically don't want to run eval directly for security reasons
127+
# Here we are guaranteed to only have mathematical expressions
128+
#lm += f' = {eval(search_terms)}'
129+
#db = ExploitDB()
130+
db = ExploitDB(root_directory="/home/morgan/Projects/EXTERNAL/exploitdb/")
131+
lm += f' = {db.searchsploit_as_summary(*search_terms, n=10)}'
132+
return lm
133+
134+
135+
def test_guide():
136+
#from mmz.agents.with_guidance import GuidanceLlamaCppConfig
137+
from mmz.agents.tools.with_guidance import GuidanceGuide
138+
from guidance import gen
139+
gg = GuidanceGuide()
140+
searchsploit_tool = Tool(searchsploit_call(), searchsploit)
141+
few_shot = 'You are on a linux 6.0 system, write a brief report about the vulnerabilities from CVEs'
142+
lm = gg.model + few_shot + gen(max_tokens=1000,
143+
tools=[searchsploit_tool],
144+
stop=')')
145+
print(lm)
146+
147+
test_guide()
148+
149+
#exploit_db = ExploitDB(root_directory="/home/morgan/Projects/EXTERNAL/exploitdb/")
150+
#t = exploit_db.searchsploit_as_summary('linux', 'password')
151+
#len(t)
152+
#
153+
#
154+
##json_dat = exploit_db.searchsploit_as_json('dell')
155+
#json_dat = exploit_db.searchsploit_as_json('linux', 'password')
156+
#df = pd.DataFrame(exploit_db.flatten_cve_results(json_dat))
157+
#df.reset_index()[['index', 'Title']].to_json()
158+
#df.info()
159+
#json_dat.keys()
160+
#{k: type(v) for k, v in json_dat.items()}
161+
#{k: v if isinstance(v, str) else v[:5] for k, v in json_dat.items()}
162+
163+
#type(raw_json['SEARCH'])
164+
#exploit_db.print_help()
165+
166+
#if __name__ == "__main__":
167+
# Assuming root_directory is set to the correct path
168+
# exploit_db = ExploitDB(root_directory="/home/morgan/Projects/EXTERNAL/exploitdb/")
169+
# exploit_db.print_help()

0 commit comments

Comments
 (0)