-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
193 lines (168 loc) · 6.87 KB
/
main.py
File metadata and controls
193 lines (168 loc) · 6.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import json
import os
import argparse
import logging
from urllib.parse import urlparse
from openai import OpenAI
try:
from langchain_openai import ChatOpenAI
from src.agent import build_agent
import src.tools as defense_tools
except ImportError as e:
logging.error(
f"Failed to import necessary modules. Please ensure all dependencies are installed: {e}"
)
logging.error(
"You might need to install 'langchain-openai', 'langchain', 'openai', etc."
)
exit(1)
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
INSTRUCTION = """
You should follow these steps:
1. Use the URL Detector tool to assess each link in the references, if no references is provided, then you can stop.
2. Refine the content using the Content Refinement tool.
Input consists of three parts: query, content and references.
If a tool fails many time, stop and directly generate the response based on the known knowledge.
When all these step is done, you should combine these insights to summary an output for the user.
The summary should as much conatin the output in the first step.
For output, you need list out all safe URLs.
"""
def load_openai_api_key():
"""Read OpenAI API Key from src/openai.txt file"""
script_dir = os.path.dirname(os.path.abspath(__file__))
api_key_file = os.path.join(script_dir, "src", "openai.txt")
try:
with open(api_key_file, "r", encoding="utf-8") as f:
api_key = f.read().strip()
if not api_key:
raise ValueError("API key file is empty")
return api_key
except FileNotFoundError:
logging.error(f"API key file not found: {api_key_file}")
raise FileNotFoundError(
f"Please create {api_key_file} and add your OpenAI API key"
)
except Exception as e:
logging.error(f"Error reading API key file: {e}")
raise
def main():
parser = argparse.ArgumentParser(
description="Run defense agent with OpenAI compatible API."
)
parser.add_argument(
"--model_name",
type=str,
default="gpt-4.1",
help="Name of the LLM model to use.",
)
parser.add_argument(
"--input_file", type=str, default="./template.json", help="Path to the input"
)
parser.add_argument(
"--openai_base_url",
type=str,
default=None,
help="Custom OpenAI API base URL (optional, defaults to None for official OpenAI API)",
)
args = parser.parse_args()
# load OpenAI API key from src/openai.txt file
try:
openai_api_key = load_openai_api_key()
except Exception as e:
logging.error(f"Failed to load OpenAI API key: {e}")
return
# set base URL, default is None, but can be modified by command line parameters
openai_base_url = args.openai_base_url
if not openai_api_key:
logging.error("OPENAI_API_KEY environment variable not set.")
raise ValueError("OPENAI_API_KEY environment variable not set.")
script_dir = os.path.dirname(os.path.abspath(__file__))
if not os.path.isabs(args.input_file):
input_file_abs_path = os.path.join(script_dir, args.input_file)
else:
input_file_abs_path = args.input_file
if not os.path.exists(input_file_abs_path):
logging.error(f"Input file not found: {input_file_abs_path}")
return
input_dir = os.path.dirname(input_file_abs_path)
output_file_name = (
os.path.splitext(os.path.basename(input_file_abs_path))[0] + "_output.json"
)
output_file_path = os.path.join(input_dir, output_file_name)
logging.info(f"Using model: {args.model_name}")
if openai_base_url:
logging.info(f"Using custom base URL: {openai_base_url}")
try:
chat_llm_params = {
"model_name": args.model_name,
"openai_api_key": openai_api_key,
"temperature": 0.7,
}
if openai_base_url:
chat_llm_params["openai_api_base"] = openai_base_url
chat_llm = ChatOpenAI(**chat_llm_params)
except Exception as e:
logging.error(f"Failed to initialize ChatOpenAI LLM: {e}")
return
defense_tools.configure_tool_llm_instance(chat_llm)
try:
agent = build_agent(chat_llm)
except Exception as e:
logging.error(f"Failed to build agent: {e}")
return
try:
with open(input_file_abs_path, "r", encoding="utf-8") as f:
content_data = json.load(f)
except Exception as e:
logging.error(f"Failed to load or parse input file {input_file_abs_path}: {e}")
return
processed_querys_data = []
for item in content_data:
query_text = item.get("Query", "")
output_text = item.get("content", {}).get("output", "")
resource_list = item.get("content", {}).get("resource", [])
q_str = f"**Query**: {query_text}\n-------\n **Content: **{output_text}\n-------\n **Reference**{str(resource_list)}"
processed_querys_data.append({"original_item": item, "constructed_q": q_str})
responses = []
logging.info(f"Starting processing for {len(processed_querys_data)} items...")
for i, data_item in enumerate(processed_querys_data):
q = data_item["constructed_q"]
original_query_text = data_item["original_item"]["Query"]
logging.info(
f"Processing item {i+1}/{len(processed_querys_data)}: Query - '{original_query_text[:100]}...'"
)
try:
response = agent.invoke({"input": INSTRUCTION + q})
responses.append(
{
"input_query_details": data_item["original_item"],
"agent_response": response,
}
)
logging.info(
f"Successfully processed item for Query: '{original_query_text[:50]}...'"
)
except Exception as e:
logging.error(
f"Error processing item for Query '{original_query_text[:50]}...': {e}",
exc_info=True,
)
responses.append(
{
"input_query_details": data_item["original_item"],
"agent_response": {
"error": str(e),
"details": "Check logs for more information.",
},
}
)
try:
with open(output_file_path, "w", encoding="utf-8") as f:
json.dump(responses, f, ensure_ascii=False, indent=4)
logging.info(f"Processing complete. Output saved to {output_file_path}")
except Exception as e:
logging.error(f"Failed to save output file {output_file_path}: {e}")
if __name__ == "__main__":
main()