-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_chat_reliability.py
More file actions
172 lines (145 loc) · 5.51 KB
/
test_chat_reliability.py
File metadata and controls
172 lines (145 loc) · 5.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#!/usr/bin/env python3
"""Test chat endpoint reliability with 600 requests per batch."""
import requests
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
import time
BASE_URL = "http://localhost:8000"
BATCH_SIZE = 600
MAX_BATCHES = 10 # Safety limit
def test_single_request(request_num):
"""Test a single chat request."""
try:
start = time.time()
# Test with general mode (no session needed)
response = requests.post(
f"{BASE_URL}/api/chat",
json={
"session_id": "test",
"question": f"What is 2+2? (request {request_num})",
"mode": "general"
},
timeout=30,
stream=True
)
if response.status_code != 200:
return {
"success": False,
"request_num": request_num,
"error": f"HTTP {response.status_code}",
"duration": time.time() - start
}
# Read the streaming response
tokens = []
for line in response.iter_lines():
if line:
line_str = line.decode('utf-8')
if line_str.startswith('data: '):
try:
data = json.loads(line_str[6:])
if 'token' in data:
tokens.append(data['token'])
if data.get('done'):
break
except json.JSONDecodeError:
pass
response_text = ''.join(tokens)
duration = time.time() - start
if not response_text:
return {
"success": False,
"request_num": request_num,
"error": "Empty response",
"duration": duration
}
return {
"success": True,
"request_num": request_num,
"response_length": len(response_text),
"duration": duration
}
except Exception as e:
return {
"success": False,
"request_num": request_num,
"error": str(e),
"duration": time.time() - start
}
def run_batch(batch_num, batch_size=BATCH_SIZE):
"""Run a batch of requests."""
print(f"\n{'='*60}")
print(f"BATCH {batch_num}: Testing {batch_size} requests")
print(f"Started at: {datetime.now().strftime('%H:%M:%S')}")
print(f"{'='*60}")
results = []
failures = []
# Use thread pool for concurrent requests (max 20 concurrent)
with ThreadPoolExecutor(max_workers=20) as executor:
futures = {
executor.submit(test_single_request, i): i
for i in range(1, batch_size + 1)
}
completed = 0
for future in as_completed(futures):
completed += 1
result = future.result()
results.append(result)
if not result['success']:
failures.append(result)
print(f"❌ Request {result['request_num']}: FAILED - {result['error']}")
# Progress update every 50 requests
if completed % 50 == 0:
success_count = completed - len(failures)
print(f"Progress: {completed}/{batch_size} | ✓ {success_count} | ✗ {len(failures)}")
# Summary
success_count = len([r for r in results if r['success']])
avg_duration = sum(r['duration'] for r in results) / len(results)
print(f"\n{'='*60}")
print(f"BATCH {batch_num} SUMMARY")
print(f"{'='*60}")
print(f"Total requests: {batch_size}")
print(f"✓ Successful: {success_count} ({success_count/batch_size*100:.1f}%)")
print(f"✗ Failed: {len(failures)} ({len(failures)/batch_size*100:.1f}%)")
print(f"Avg duration: {avg_duration:.2f}s")
print(f"Completed at: {datetime.now().strftime('%H:%M:%S')}")
if failures:
print(f"\nFailure breakdown:")
error_types = {}
for f in failures:
error_types[f['error']] = error_types.get(f['error'], 0) + 1
for error, count in sorted(error_types.items(), key=lambda x: -x[1]):
print(f" - {error}: {count}")
return len(failures) == 0, failures
def main():
print("\n" + "="*60)
print("RAG MED CHAT RELIABILITY TEST")
print("="*60)
print(f"Testing with batches of {BATCH_SIZE} requests")
print(f"Will retry failed batches up to {MAX_BATCHES} total batches")
# Check if backend is up
try:
response = requests.get(f"{BASE_URL}/", timeout=5)
if response.status_code != 200:
print(f"\n❌ Backend not responding correctly (HTTP {response.status_code})")
return
except Exception as e:
print(f"\n❌ Cannot connect to backend: {e}")
return
print("✓ Backend is up and running\n")
batch_num = 1
while batch_num <= MAX_BATCHES:
success, failures = run_batch(batch_num, BATCH_SIZE)
if success:
print(f"\n{'🎉'*20}")
print(f"SUCCESS! All {BATCH_SIZE} requests passed in batch {batch_num}!")
print(f"{'🎉'*20}\n")
break
else:
print(f"\n⚠️ Batch {batch_num} had {len(failures)} failures. Retrying with new batch...\n")
batch_num += 1
time.sleep(2) # Brief pause between batches
if batch_num > MAX_BATCHES:
print(f"\n❌ Reached maximum of {MAX_BATCHES} batches. Check your backend!")
if __name__ == "__main__":
main()