-
Notifications
You must be signed in to change notification settings - Fork 130
Expand file tree
/
Copy pathllm.py
More file actions
265 lines (227 loc) · 12.5 KB
/
llm.py
File metadata and controls
265 lines (227 loc) · 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
"""
==============================================================================
llm.py
==============================================================================
This file contains the LLM class for the project.
"""
import time
import random
from datetime import datetime
import openai
from logger import log_llm_call, log_problematic_request
def timed_llm_call(client, api_provider, model, prompt, role, call_id, max_tokens=4096, log_dir=None,
sleep_seconds=15, retries_on_timeout=1000, attempt=1, use_json_mode=False):
"""
Make a timed LLM call with error handling and retry logic.
EMPTY RESPONSE HANDLING STRATEGY:
- Training calls (call_id starts with 'train_'): Skip the entire training sample
- Test calls (call_id starts with 'test_'): Mark as incorrect (return wrong answers)
- All empty responses are logged to problematic_requests/ for SambaNova support analysis
For test calls specifically: Returns "INCORRECT_DUE_TO_EMPTY_RESPONSE" repeated 4 times
(comma-separated) to handle the 4-question format used in financial NER evaluation.
Args:
client: API client
model: Model name to use
prompt: Text prompt to send
role: Role for logging (generator, reflector, curator)
call_id: Unique identifier for this call (format: {train|test}_{role}_{details})
max_tokens: Maximum tokens to generate
log_dir: Directory for detailed logging
sleep_seconds: Base sleep time between retries
retries_on_timeout: Maximum number of retries for timeouts/rate limits/empty responses
attempt: Current attempt number (for recursive calls)
use_json_mode: Whether to use JSON mode for structured output
Returns:
tuple: (response_text, call_info_dict)
Special return values for empty responses:
- Training: ("INCORRECT_DUE_TO_EMPTY_RESPONSE, INCORRECT_DUE_TO_EMPTY_RESPONSE, ...", call_info)
- Testing: ("INCORRECT_DUE_TO_EMPTY_RESPONSE, INCORRECT_DUE_TO_EMPTY_RESPONSE, ...", call_info)
"""
start_time = time.time()
prompt_time = time.time()
print(f"[{role.upper()}] Starting call {call_id}...")
# Check if we're using API key mixer for dynamic key rotation on retries
using_key_mixer = False
while True:
try:
# Get client
active_client = client
# Prepare API call parameters
if api_provider == "openai":
max_tokens_key = "max_completion_tokens"
else:
max_tokens_key = "max_tokens"
api_params = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.0,
max_tokens_key: max_tokens
}
# Add JSON mode if requested
if use_json_mode:
api_params["response_format"] = {"type": "json_object"}
call_start = time.time()
response = active_client.chat.completions.create(**api_params)
call_end = time.time()
# Check if response is valid
if not response or not response.choices or len(response.choices) == 0:
raise Exception("Empty response from API")
response_time = time.time()
total_time = response_time - start_time
response_content = response.choices[0].message.content
if response_content is None:
raise Exception("API returned None content")
call_info = {
"role": role,
"call_id": call_id,
"model": model,
"prompt": prompt,
"response": response_content,
"prompt_time": prompt_time - start_time,
"response_time": response_time - prompt_time,
"total_time": total_time,
"call_time": call_end - call_start,
"prompt_length": len(prompt),
"response_length": len(response_content),
"prompt_num_tokens": response.usage.prompt_tokens,
"response_num_tokens": response.usage.completion_tokens,
}
print(f"[{role.upper()}] Call {call_id} completed in {total_time:.2f}s")
if log_dir:
log_llm_call(log_dir, call_info)
return response_content, call_info
except Exception as e:
# Check for both timeout and rate limit errors
is_timeout = any(k in str(e).lower() for k in ["timeout", "timed out", "connection"])
is_rate_limit = any(k in str(e).lower() for k in ["rate limit", "429", "rate_limit_exceeded"])
is_empty_response = "empty response" in str(e).lower() or "api returned none content" in str(e).lower()
# Check for server errors (500, 502, 503, etc.) that should be retried
is_server_error = False
if hasattr(e, 'response'):
try:
status_code = getattr(e.response, 'status_code', None)
if status_code and status_code >= 500:
is_server_error = True
print(f"[{role.upper()}] Server error detected: HTTP {status_code}")
except:
pass
# Also check for 500 errors in the error message itself
if any(k in str(e).lower() for k in ["500 internal server error", "internal server error", "502 bad gateway", "503 service unavailable"]):
is_server_error = True
print(f"[{role.upper()}] Server error detected in message: {str(e)[:100]}...")
# Also check for specific OpenAI exceptions
if hasattr(openai, 'RateLimitError') and isinstance(e, openai.RateLimitError):
is_rate_limit = True
# Check for OpenAI InternalServerError
if hasattr(openai, 'InternalServerError') and isinstance(e, openai.InternalServerError):
is_server_error = True
print(f"[{role.upper()}] OpenAI InternalServerError detected")
# Debug empty response issues
if is_empty_response:
print(f"\n🚨 DEBUG: Empty response detected for {call_id}")
print(f"📝 Exception type: {type(e).__name__}")
print(f"📝 Exception message: {str(e)}")
print(f"📝 Using JSON mode: {use_json_mode}")
print(f"📝 Model: {model}")
print(f"📝 Prompt length: {len(prompt)}")
print(f"📝 Prompt preview (first 500 chars):")
print(f" {prompt[:500]}...")
print(f"📝 Full exception details: {repr(e)}")
if hasattr(e, 'response'):
print(f"📝 Raw response object: {e.response}")
if hasattr(e.response, 'text'):
print(f"📝 Raw response text: {e.response.text}")
if hasattr(e.response, 'content'):
print(f"📝 Raw response content: {e.response.content}")
print("-" * 60)
# Log problematic requests for SambaNova support
log_problematic_request(call_id, prompt, model, api_params, e, log_dir, using_key_mixer,
client if using_key_mixer else None)
# For empty responses, we handle differently based on context
if is_empty_response:
# Log the problematic request for SambaNova support
log_problematic_request(call_id, prompt, model, api_params, e, log_dir, using_key_mixer,
client if using_key_mixer else None)
# Check if this is a training or test call to decide behavior
if call_id.startswith('train_'):
# In training: Mark as incorrect answer (same as testing)
print(f"[{role.upper()}] 🚨 Empty response in training - marking as INCORRECT for {call_id}")
error_time = time.time()
call_info = {
"role": role,
"call_id": call_id,
"model": model,
"prompt": prompt,
"error": "TRAINING_INCORRECT: " + str(e),
"total_time": error_time - start_time,
"prompt_length": len(prompt),
"response_length": 0,
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3],
"datetime": datetime.now().isoformat(),
"training_marked_incorrect_due_to_empty_response": True
}
if log_dir:
log_llm_call(log_dir, call_info)
# Return a response that will be marked as incorrect
# For the 4-question format, we return 4 wrong answers
incorrect_response = "INCORRECT_DUE_TO_EMPTY_RESPONSE, INCORRECT_DUE_TO_EMPTY_RESPONSE, INCORRECT_DUE_TO_EMPTY_RESPONSE, INCORRECT_DUE_TO_EMPTY_RESPONSE"
return incorrect_response, call_info
elif call_id.startswith('test_'):
# In testing: Treat as incorrect answer
print(f"[{role.upper()}] 🚨 Empty response in testing - marking as INCORRECT for {call_id}")
error_time = time.time()
call_info = {
"role": role,
"call_id": call_id,
"model": model,
"prompt": prompt,
"error": "TEST_INCORRECT: " + str(e),
"total_time": error_time - start_time,
"prompt_length": len(prompt),
"response_length": 0,
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3],
"datetime": datetime.now().isoformat(),
"test_marked_incorrect_due_to_empty_response": True
}
if log_dir:
log_llm_call(log_dir, call_info)
# Return a response that will be marked as incorrect
# For the 4-question format, we return 4 wrong answers
incorrect_response = "INCORRECT_DUE_TO_EMPTY_RESPONSE, INCORRECT_DUE_TO_EMPTY_RESPONSE, INCORRECT_DUE_TO_EMPTY_RESPONSE, INCORRECT_DUE_TO_EMPTY_RESPONSE"
return incorrect_response, call_info
# Retry logic for timeouts, rate limits, and server errors
if (is_timeout or is_rate_limit or is_server_error) and attempt < retries_on_timeout:
attempt += 1
if is_rate_limit:
error_type = "rate limited"
base_sleep = sleep_seconds * 2
elif is_server_error:
error_type = "server error (500+)"
base_sleep = sleep_seconds * 1.5 # Moderate delay for server errors
elif is_empty_response:
error_type = "returned empty response"
base_sleep = sleep_seconds
else:
error_type = "timed out"
base_sleep = sleep_seconds
jitter = random.uniform(0.5, 1.5) # Add jitter to avoid thundering herd
sleep_time = base_sleep * jitter
print(f"[{role.upper()}] Call {call_id} {error_type}, sleeping {sleep_time:.1f}s then retrying "
f"({attempt}/{retries_on_timeout})...")
time.sleep(sleep_time)
continue
error_time = time.time()
call_info = {
"role": role,
"call_id": call_id,
"model": model,
"prompt": prompt,
"error": str(e),
"total_time": error_time - start_time,
"prompt_length": len(prompt),
"attempt": attempt,
}
print(f"[{role.upper()}] Call {call_id} failed after {error_time - start_time:.2f}s: {e}")
if log_dir:
log_llm_call(log_dir, call_info)
raise e