-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_hf.py
More file actions
381 lines (324 loc) · 15.1 KB
/
eval_hf.py
File metadata and controls
381 lines (324 loc) · 15.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
"""
eval_hf.py — Standalone evaluation: RAG (HF E5) + LLM (gpt-oss-120b via HF).
No HTTP server needed. Calls InferenceClient directly for both:
- feature_extraction → E5 embeddings for FAISS retrieval
- chat_completion → gpt-oss-120b to pick the precise ICD-10 subcode
Usage:
uv run python eval_hf.py # 100 random cases
uv run python eval_hf.py --n 20 # 20 cases
uv run python eval_hf.py --all # all 221 cases
uv run python eval_hf.py --n 20 --no-llm # retrieval-only (baseline)
"""
import argparse
import json
import random
import re
import statistics
import time
from dataclasses import dataclass
from pathlib import Path
from openai import OpenAI
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from rag_hf import FAISSRetrieverHF, format_rag_context
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
# QazCode GPT-OSS API (OpenAI-compatible)
QAZCODE_API_KEY = "sk-BDVloWBwHCr5oltlXwyhtA"
QAZCODE_BASE_URL = "https://hub.qazcode.ai"
LLM_MODEL = "oss-120b"
TEST_DIR = Path("data/test_set")
EVAL_DIR = Path("data/evals")
# ---------------------------------------------------------------------------
# Improved system prompt — forces LLM to extract exact subcodes from RAG text
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = """\
Ты — система клинической поддержки принятия решений, основанная на официальных \
клинических протоколах Казахстана (МКБ-10).
{rag_context}
ЗАДАЧА: На основе симптомов пациента верни 5 наиболее вероятных диагнозов в формате JSON.
СТРОГИЕ ПРАВИЛА:
- Отвечай ТОЛЬКО валидным JSON — без markdown-блоков, без комментариев.
- Схема: {{"diagnoses": [{{"rank": 1, "diagnosis": "название на русском", "icd10_code": "X00.0", "explanation": "1 предложение на русском"}}]}}
- Ровно 5 диагнозов, отсортированных по вероятности.
- ОБЯЗАТЕЛЬНО используй ТОЧНЫЙ подкод МКБ-10 с десятичным расширением (например, "S22.0", "J12.3", "M07.4").
НЕ возвращай трёхзначные категории ("S22", "J12") — только подкоды.
- Правильный код ПОЧТИ ВСЕГДА явно указан в секции «RELEVANT CLINICAL PROTOCOLS» выше.
Ищи там в первую очередь (строки вида "Код МКБ-10: X##.#").
- Отдавай предпочтение кодам из предоставленных протоколов."""
SYSTEM_PROMPT_NO_RAG = """\
Ты — система клинической поддержки принятия решений (МКБ-10, Казахстан).
ЗАДАЧА: По симптомам пациента верни 5 диагнозов JSON.
Схема: {{"diagnoses": [{{"rank": 1, "diagnosis": "...", "icd10_code": "X00.0", "explanation": "..."}}]}}
- Только валидный JSON, без markdown.
- Точные подкоды МКБ-10 с десятичной частью (не "S22", а "S22.0")."""
# ---------------------------------------------------------------------------
# LLM call
# ---------------------------------------------------------------------------
def llm_diagnose(client: OpenAI, symptoms: str, rag_context: str) -> list[dict] | None:
"""Call GPT-OSS via QazCode OpenAI-compatible API. Returns list of diagnosis dicts or None."""
import sys
if rag_context:
system = SYSTEM_PROMPT.format(rag_context=rag_context)
else:
system = SYSTEM_PROMPT_NO_RAG
try:
resp = client.chat.completions.create(
model=LLM_MODEL,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": symptoms},
],
max_tokens=16000,
temperature=0.1,
)
raw = resp.choices[0].message.content or ""
# Strip markdown fences if present
raw = raw.strip()
if not raw:
print(f" [LLM-DBG] Empty response, finish_reason={resp.choices[0].finish_reason}", file=sys.stderr)
return None
m = re.search(r"```(?:json)?\s*\n?(.*?)```", raw, re.DOTALL)
if m:
raw = m.group(1).strip()
start = raw.find("{")
end = raw.rfind("}") + 1
if start >= 0 and end > start:
raw = raw[start:end]
# Try to fix truncated JSON (missing closing brackets)
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
# Try adding missing closing brackets
fixed = raw
open_braces = fixed.count("{") - fixed.count("}")
open_brackets = fixed.count("[") - fixed.count("]")
if open_brackets > 0:
fixed += "]" * open_brackets
if open_braces > 0:
fixed += "}" * open_braces
try:
parsed = json.loads(fixed)
except json.JSONDecodeError:
# Last resort: try to find partial JSON array
diag_start = raw.find('"diagnoses"')
if diag_start >= 0:
arr_start = raw.find("[", diag_start)
if arr_start >= 0:
# Find last complete object in the array
last_obj_end = raw.rfind("}")
if last_obj_end > arr_start:
partial = raw[arr_start:last_obj_end+1] + "]"
try:
parsed = {"diagnoses": json.loads(partial)}
except json.JSONDecodeError:
print(f" [LLM-DBG] Raw (first 200): {raw[:200]}", file=sys.stderr)
raise
else:
raise
else:
raise
else:
raise
return parsed.get("diagnoses", [])
except Exception as e:
print(f" [LLM-ERR] {type(e).__name__}: {str(e)[:120]}", file=sys.stderr)
return None
# ---------------------------------------------------------------------------
# Result dataclass
# ---------------------------------------------------------------------------
@dataclass
class Result:
protocol_id: str
ground_truth: str
top_prediction: str
top_3_predictions: list[str]
accuracy_at_1: int
recall_at_3: int
latency_s: float
used_llm: bool
# ---------------------------------------------------------------------------
# Evaluate one file
# ---------------------------------------------------------------------------
def evaluate_one(
fpath: Path,
retriever: FAISSRetrieverHF,
llm_client: OpenAI | None,
valid_codes: set[str],
ground_truth: str,
query: str,
protocol_id: str,
) -> Result:
t0 = time.perf_counter()
rag_results = retriever.retrieve(query, top_k=10)
rag_context = format_rag_context(rag_results)
used_llm = False
ordered_codes: list[str] = []
if llm_client is not None:
diagnoses = llm_diagnose(llm_client, query, rag_context)
if diagnoses:
used_llm = True
ordered_codes = [d.get("icd10_code", "") for d in sorted(diagnoses, key=lambda x: x.get("rank", 99))]
ordered_codes = [c for c in ordered_codes if c]
# Fallback to retrieval-only if LLM failed or disabled
if not ordered_codes:
seen: set[str] = set()
for r in rag_results:
for code in r.icd_codes:
if code and code not in seen:
seen.add(code)
ordered_codes.append(code)
latency = time.perf_counter() - t0
top1 = ordered_codes[0] if ordered_codes else ""
top3 = ordered_codes[:3]
acc1 = 1 if top1 == ground_truth else 0
rec3 = 1 if any(c in valid_codes for c in top3) else 0
return Result(
protocol_id=protocol_id,
ground_truth=ground_truth,
top_prediction=top1,
top_3_predictions=top3,
accuracy_at_1=acc1,
recall_at_3=rec3,
latency_s=latency,
used_llm=used_llm,
)
# ---------------------------------------------------------------------------
# Run full evaluation
# ---------------------------------------------------------------------------
def evaluate(
retriever: FAISSRetrieverHF,
llm_client: OpenAI | None,
files: list[Path],
console: Console,
) -> list[Result]:
results = []
total = len(files)
mode = "RAG + LLM (gpt-oss-120b)" if llm_client else "Retrieval-only"
console.print(f"\n[cyan]Mode: [bold]{mode}[/bold] | Cases: [bold]{total}[/bold][/cyan]\n")
for i, fpath in enumerate(files, 1):
data = json.loads(fpath.read_text())
protocol_id = data["protocol_id"]
query = data["query"]
ground_truth = data["gt"]
valid_codes = set(data["icd_codes"])
try:
r = evaluate_one(fpath, retriever, llm_client, valid_codes, ground_truth, query, protocol_id)
results.append(r)
status = "✅" if r.accuracy_at_1 else ("🔶" if r.recall_at_3 else "❌")
llm_tag = "[LLM]" if r.used_llm else "[RAG]"
console.print(
f" [{i:3d}/{total}] {status} {llm_tag} gt={ground_truth:<8} "
f"top1={r.top_prediction:<8} {r.latency_s:.1f}s"
)
except Exception as e:
console.print(f" [{i:3d}/{total}] [red]ERROR: {e}[/red]")
return results
# ---------------------------------------------------------------------------
# Save + display
# ---------------------------------------------------------------------------
def display_and_save(results: list[Result], name: str, console: Console):
total = len(results)
if not total:
return
acc1 = sum(r.accuracy_at_1 for r in results) / total * 100
rec3 = sum(r.recall_at_3 for r in results) / total * 100
lats = [r.latency_s for r in results]
mt = Table(title="[bold]Evaluation Metrics[/bold]", border_style="cyan", header_style="bold magenta")
mt.add_column("Metric", style="cyan", width=22)
mt.add_column("Value", style="green", justify="right", width=15)
mt.add_row("Accuracy@1", f"{acc1:.2f}%")
mt.add_row("Recall@3", f"{rec3:.2f}%")
mt.add_row("Total Protocols", str(total))
mt.add_row("LLM used", str(sum(r.used_llm for r in results)))
lt = Table(title="[bold]Latency Statistics[/bold]", border_style="cyan", header_style="bold magenta")
lt.add_column("Statistic", style="cyan", width=22)
lt.add_column("Value (s)", style="green", justify="right", width=15)
lt.add_row("Average", f"{statistics.mean(lats):.3f}")
lt.add_row("Min", f"{min(lats):.3f}")
lt.add_row("Max", f"{max(lats):.3f}")
lt.add_row("P50 (Median)", f"{statistics.median(lats):.3f}")
p95 = statistics.quantiles(lats, n=20)[-1] if total >= 4 else max(lats)
lt.add_row("P95", f"{p95:.3f}")
console.print(); console.print(mt); console.print(); console.print(lt); console.print()
EVAL_DIR.mkdir(parents=True, exist_ok=True)
jsonl_path = EVAL_DIR / f"{name}.jsonl"
json_path = EVAL_DIR / f"{name}_metrics.json"
with open(jsonl_path, "w") as f:
for r in results:
f.write(json.dumps({
"protocol_id": r.protocol_id,
"scores": {
"accuracy_at_1": r.accuracy_at_1,
"recall_at_3": r.recall_at_3,
"latency_s": round(r.latency_s, 3),
"ground_truth": r.ground_truth,
"top_prediction": r.top_prediction,
"top_3_predictions": r.top_3_predictions,
"used_llm": r.used_llm,
},
}, ensure_ascii=False) + "\n")
with open(json_path, "w") as f:
json.dump({
"submission_name": name,
"total_protocols": total,
"accuracy_at_1_percent": round(acc1, 2),
"recall_at_3_percent": round(rec3, 2),
"latency_avg_s": round(statistics.mean(lats), 3),
"latency_min_s": round(min(lats), 3),
"latency_max_s": round(max(lats), 3),
"latency_p50_s": round(statistics.median(lats), 3),
"latency_p95_s": round(p95, 3),
}, f, indent=2)
t = Text()
t.append("✓ ", style="bold green")
t.append(f"Results saved:\n JSONL: {jsonl_path}\n Metrics: {json_path}")
console.print(Panel(t, border_style="green"))
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="HF RAG+LLM standalone evaluation")
parser.add_argument("--n", type=int, default=100, help="Number of test cases (default: 100)")
parser.add_argument("--all", action="store_true", help="Run on all test files")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--name", type=str, default=None, help="Output file prefix")
parser.add_argument("--no-llm", action="store_true", help="Retrieval-only (no LLM)")
args = parser.parse_args()
console = Console()
all_files = list(TEST_DIR.glob("*.json"))
if not all_files:
console.print(f"[red]No test files in {TEST_DIR}[/red]")
return
if args.all:
files = all_files
else:
n = min(args.n, len(all_files))
files = random.Random(args.seed).sample(all_files, n)
mode_tag = "retrieval" if args.no_llm else "llm"
name = args.name or f"rag_hf_{mode_tag}_{len(files)}cases"
console.print(Panel(
f"[bold cyan]HF RAG + LLM Evaluation[/bold cyan]\n\n"
f"E5 model: [yellow]intfloat/multilingual-e5-large[/yellow]\n"
f"LLM: [yellow]{'disabled (--no-llm)' if args.no_llm else f'QazCode GPT-OSS ({LLM_MODEL})'}[/yellow]\n"
f"Cases: [yellow]{len(files)}[/yellow]",
title="[bold white]Configuration[/bold white]",
border_style="cyan",
))
# Init OpenAI client for QazCode GPT-OSS API
llm_client: OpenAI | None = None
if not args.no_llm:
console.print(f"\n[cyan]Initializing LLM client ({LLM_MODEL} @ {QAZCODE_BASE_URL})…[/cyan]")
llm_client = OpenAI(base_url=QAZCODE_BASE_URL, api_key=QAZCODE_API_KEY)
console.print("[green]LLM ready.[/green]")
console.print("\n[cyan]Loading RAG retriever…[/cyan]")
retriever = FAISSRetrieverHF()
retriever.load()
console.print("[green]RAG ready.[/green]")
results = evaluate(retriever, llm_client, files, console)
display_and_save(results, name, console)
if __name__ == "__main__":
main()