-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpostgres_utils.py
More file actions
148 lines (120 loc) · 4.73 KB
/
postgres_utils.py
File metadata and controls
148 lines (120 loc) · 4.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# TODO: move db connection creation out of functions, AWFUL way to do it!
import psycopg2 as psy
import os
import time
# ========================================================================
TABLE_NAME = 'exploits'
BATCH_SIZE = 1000
BATCH_DELAY_SECONDS = 3
TABLE_FIELDS = ['id', 'file_path', 'description', 'date_published', 'author',
'e_type', 'platform', 'codes']
CREATE_TABLE_QUERY = f'''
CREATE TABLE IF NOT EXISTS {TABLE_NAME} (
id INTEGER PRIMARY KEY,
file_path TEXT,
description TEXT,
date_published INTEGER,
author TEXT,
e_type TEXT,
platform TEXT,
codes TEXT[]
);
'''
INSERT_QUERY = f'''
INSERT INTO {TABLE_NAME} ({', '.join(TABLE_FIELDS)})
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (id) DO NOTHING;
'''
POSTGRES_PASSWORD = os.getenv('POSTGRES_PASSWORD')
if POSTGRES_PASSWORD is None:
print(f'[ERROR] Missing env variable: POSTGRES_PASSWORD')
exit(1)
CONNECTION_PARAMS = {
'host': 'localhost',
'port': 5432,
'database': 'postgres',
'user': 'postgres',
'password': POSTGRES_PASSWORD
}
# ========================================================================
class Exploit:
def __init__(self, id: int, file_path: str, description: str, published: int,
author: str, exploit_type: str, platform: str, codes: list[str]):
self.id = id
self.file_path = file_path
self.description = description
self.date_published = published
self.author = author
self.exploit_type = exploit_type
self.platform = platform
self.codes = codes
self.file_snippet = ''
# function is called when calling str(exploit)
def __str__(self) -> str:
return (
'EXPLOIT{ '
f'id: {self.id}, '
f'file_path: {self.file_path}, '
f'description: {self.description}, '
f'date_published: {self.date_published}, '
f'author: {self.author}, '
f'exploit_type: {self.exploit_type}, '
f'platform: {self.platform}, '
f'codes: {self.codes} '
f'file_snippet: {self.file_snippet}'
' }'
)
# create table if doesn't already exist
def create_table():
try:
with psy.connect(**CONNECTION_PARAMS) as con:
with con.cursor() as cursor:
cursor.execute(CREATE_TABLE_QUERY)
con.commit()
print(f'[POSTGRES] Table {TABLE_NAME} successfully created')
except Exception as e:
print(f'[ERROR] Error occurred while creating table {TABLE_NAME}: {e}')
# insert a list of values, if a value already exists do nothing
# use batches to prevent overloading db
def insert(values: list[tuple]):
try:
with psy.connect(**CONNECTION_PARAMS) as con:
with con.cursor() as cursor:
values_size = len(values)
for i in range(0, values_size, BATCH_SIZE):
cursor.executemany(INSERT_QUERY, values[i:i + BATCH_SIZE])
con.commit()
print(f'[POSTGRES] Batch {min(values_size, i + BATCH_SIZE)} of {values_size}')
time.sleep(BATCH_DELAY_SECONDS)
print(f'[POSTGRES] Exploit data successfully loaded')
except Exception as e:
print(f'[ERROR] Error loading exploit data: {e}')
# search for rows in db by specified fields
# return a list of Exploit with a length of 'limit'
def search_db(fields: dict[str, any], limit: int) -> list[Exploit]:
try:
with psy.connect(**CONNECTION_PARAMS) as con:
with con.cursor() as cursor:
where_statement = ''
if 'ids' in fields:
where_statement ='id = ANY(%s)'
else:
where_statement = ' AND '.join(
f'{key} @> ARRAY[%s]' if key == 'codes'
else f'{key} = %s' for key in fields
)
search_query = f'''
SELECT * FROM {TABLE_NAME} WHERE
{where_statement}
LIMIT %s;
'''
cursor.execute(search_query, (
*([fields['ids']] if 'ids' in fields else list(fields.values())),
limit,
))
results = cursor.fetchall()
print(f'[POSTGRES] Successfully retrieved {len(results)} rows of exploit data')
return [Exploit(*exploit) for exploit in results]
except Exception as e:
print(f'[ERROR] Error retrieving exploit data from {TABLE_NAME} table: {e}')
return []