-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathaccuracy.py
More file actions
479 lines (416 loc) · 17.4 KB
/
accuracy.py
File metadata and controls
479 lines (416 loc) · 17.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
# pylint: disable=all
"""
Reward functions for accuracy evaluation.
This module provides reward functions that evaluate the accuracy of model responses
by comparing them with ground truth answers, optionally using preprocessing steps
like normalization and LaTeX parsing.
"""
import re
from typing import Any, Callable, Dict, List, Optional, Union, cast
from ..models import (
EvaluateResult,
Message,
MetricResult,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
)
def _to_text(content: Optional[Union[str, List[ChatCompletionContentPartParam]]]) -> str:
"""Coerce Message.content into a plain string for regex and comparisons."""
if content is None:
return ""
if isinstance(content, str):
return content
# List[ChatCompletionContentPartTextParam]
try:
texts: List[str] = []
for part in content:
if isinstance(part, ChatCompletionContentPartTextParam):
texts.append(part.text)
return "\n".join(texts)
except Exception:
return ""
from ..typed_interface import reward_function
def normalize_text(text: str) -> str:
"""
Normalize text for comparison by removing excess whitespace, punctuation.
Args:
text: The text to normalize
Returns:
Normalized text string
"""
text = text.lower()
text = re.sub(r"\s+", " ", text)
text = re.sub(r'[,.;:!?"\']', "", text)
# Remove parentheses, brackets, etc. that often appear in math expressions
# but keep their contents
text = re.sub(r"[\(\)\[\]\{\}]", "", text)
text = re.sub(r"[^\w\s\d+-/*=]", "", text)
text = text.replace("×", "*").replace("÷", "/")
return text.strip()
def extract_math_expression(text: str) -> str:
"""
Extract mathematical expressions from text.
This function attempts to find the final answer in mathematical texts,
handling both numerical answers and expressions.
Args:
text: Text that might contain mathematical expressions
Returns:
Extracted mathematical expression or normalized text if no clear
expression is found
"""
# Try to find answer patterns like "= 42" or "answer is 42"
answer_patterns = [
# Common exact answer formats
r"(?:answer|result|solution)(?:\s+is|\s*[:=])\s*(?:x\s*=\s*)?([-+]?\d+(?:\.\d+)?(?:/\d+)?)",
r"(?:therefore|thus|so)[,:]?\s*(?:x\s*=\s*)?([-+]?\d+(?:\.\d+)?(?:/\d+)?)",
r"(?:the value of|value)\s*(?:x|y|z)\s*(?:is|=)\s*([-+]?\d+(?:\.\d+)?(?:/\d+)?)",
r"x\s*=\s*([-+]?\d+(?:\.\d+)?(?:/\d+)?)", # x = 4
r"(?:=|equals)\s*([-+]?\d+(?:\.\d+)?(?:/\d+)?)",
# Common answer formats with parentheses
r"(?:answer|result|solution)[^0-9\n.]*?is[^0-9\n.]*?((?:\([-+]?\))?(?:\d+(?:\.\d+)?(?:/\d+)?))",
r"(?:answer|result|value)[^0-9\n.]*?((?:\([-+]?\))?(?:\d+(?:\.\d+)?(?:/\d+)?))",
# Special cases for pi
r"(?:answer|result|value|=)\s*(?:is\s*)?(?:π|pi)",
r"(?:answer|result|value|=)\s*(?:is\s*)?(\d+(?:\.\d+)?π)",
r"(?:answer|result|value|=)\s*(?:is\s*)?π(?:\s*=\s*)?(?:≈\s*)?(3\.14\d*)",
# Numerical answers with units
r"(?:answer|result|value|=)\s*(?:is\s*)?([-+]?\d+(?:\.\d+)?)\s*(?:meters|feet|kg|seconds)",
# LaTeX patterns
r"\$x\s*=\s*([-+]?\d+(?:\.\d+)?(?:/\d+)?)\$", # LaTeX: $x = 4$
# Decimal approximations
r"(?:approximately|about|≈|~)\s*([-+]?\d+\.\d+)",
]
# Check patterns in both original and lowercase text
for text_variant in [text, text.lower()]:
for pattern in answer_patterns:
match = re.search(pattern, text_variant, re.IGNORECASE)
if match:
# Check if this is a pi-only match
if pattern == r"(?:answer|result|value|=)\s*(?:is\s*)?(?:π|pi)":
return "3.14159" # Return standard pi approximation
if match.groups():
result = match.group(1).strip()
# Clean up any trailing punctuation
result = re.sub(r"[.,;:]$", "", result)
# Handle pi symbols in the answer
if "π" in result or "pi" in result.lower():
result = result.replace("π", "").replace("Pi", "").replace("pi", "")
try:
# If it's just a coefficient of pi, convert to decimal
if result.strip() in ("", "1"):
return "3.14159" # π alone or 1π
else:
# Try to convert coefficient to float and multiply by pi
coef = float(result.strip())
return str(coef * 3.14159)
except (ValueError, TypeError):
# If conversion fails, return the original with pi
return result
return result
# Check for answers in the last line (common in math problems)
lines = text.strip().split("\n")
for i in range(min(3, len(lines))): # Check last 3 lines
last_line = lines[-(i + 1)].strip()
if "answer" in last_line.lower() or "result" in last_line.lower() or "solution" in last_line.lower():
# Extract numbers from the last line
numbers = re.findall(r"[-+]?\d+(?:\.\d+)?", last_line)
if numbers:
return numbers[-1] # Take the last number
# Direct search for numbers that might be answers
# Only use as a fallback for short responses with few numbers
if len(text) < 200: # Only for short answers
# Count decimal numbers in text
numbers = re.findall(r"(?:^|\s|[^\w])([-+]?\d+(?:\.\d+)?)(?:\s|$|[^\w])", text)
if len(numbers) == 1: # If there's only one number, it's likely the answer
return numbers[0]
elif numbers and len(text.split()) < 30: # Very short text with numbers
# Take the last number in a short response
return numbers[-1]
# Look for capitalized city names or other proper nouns as answers
if re.search(r"capital|city|country|president|largest|smallest", text.lower()):
noun_pattern = r"is\s+([A-Z][a-zA-Z]+(?:\s+[A-Z][a-zA-Z]+)*)"
match = re.search(noun_pattern, text)
if match:
return match.group(1).strip()
# Look for LaTeX math expressions
latex_patterns = [
r"\$x\s*=\s*([^$]+)\$", # Inline math with x = ...
r"\$([^$]+)\$", # Inline math: $...$
r"\\\((.*?)\\\)", # Inline math: \(...\)
r"\\\[(.*?)\\\]", # Display math: \[...\]
]
for pattern in latex_patterns:
matches = re.findall(pattern, text)
if matches:
# Process the last match which is often the final answer
latex_expr = matches[-1].strip()
# Try to extract numbers from LaTeX
if "=" in latex_expr:
# If there's an equals sign, take what's on the right
parts = latex_expr.split("=")
latex_expr = parts[-1].strip()
# Extract plain numbers from LaTeX expression
nums = re.findall(r"[-+]?\d+(?:\.\d+)?", latex_expr)
if nums:
return nums[-1]
# If no plain numbers, return the cleaned LaTeX
return re.sub(r"[\\{}\[\]]", "", latex_expr)
# If we've reached here, try a more aggressive approach for common words
for word in ["Paris", "London", "yes", "no", "true", "false"]:
if word.lower() in text.lower():
return word
# Fall back to normalized text for short texts
if len(text) < 50:
return normalize_text(text)
return ""
def compare_math_expressions(pred: str, gt: str) -> float:
"""
Compare two mathematical expressions for equivalence.
Args:
pred: Predicted math expression
gt: Ground truth math expression
Returns:
Similarity score between 0.0 and 1.0
"""
if not pred and not gt:
return 1.0
if not pred or not gt:
return 0.0
pred_norm = normalize_text(pred)
gt_norm = normalize_text(gt)
if pred_norm == gt_norm:
return 1.0
if len(gt) > 2 and not gt.replace(".", "").isdigit():
if gt.lower() in pred.lower() or pred.lower() in gt.lower():
return 1.0
pred_clean = pred_norm.replace(" ", "")
gt_clean = gt_norm.replace(" ", "")
if (pred_clean.startswith("3.14") and gt_clean.startswith("3.14")) or (
pred_clean.startswith("314") and gt_clean.startswith("314")
):
return 1.0
try:
pred_float = float(pred_clean)
gt_float = float(gt_clean)
abs_diff = abs(pred_float - gt_float)
pred_str_decimal_part = str(pred_float).split(".")[1] if "." in str(pred_float) else ""
gt_str_decimal_part = str(gt_float).split(".")[1] if "." in str(gt_float) else ""
if (
len(pred_str_decimal_part) >= 2
and len(gt_str_decimal_part) >= 2
and pred_str_decimal_part[0:2] == gt_str_decimal_part[0:2]
):
if abs_diff < 0.01:
return 1.0
if max(abs(gt_float), 0.001) > 0 and abs_diff / max(abs(gt_float), 0.001) < 0.05:
return 0.9
except (ValueError, ZeroDivisionError, IndexError):
pass
pred_decimal_from_fraction: Optional[float] = None
if "/" in pred_clean and pred_clean.count("/") == 1:
try:
num, denom = pred_clean.split("/")
pred_decimal_from_fraction = float(num) / float(denom)
except (ValueError, ZeroDivisionError):
pass
gt_decimal_from_fraction: Optional[float] = None
if "/" in gt_clean and gt_clean.count("/") == 1:
try:
num, denom = gt_clean.split("/")
gt_decimal_from_fraction = float(num) / float(denom)
except (ValueError, ZeroDivisionError):
pass
try:
pred_val_inter: Optional[float] = None
if pred_decimal_from_fraction is not None:
pred_val_inter = pred_decimal_from_fraction
else:
try:
pred_val_inter = float(pred_clean)
except ValueError:
pass
gt_val_inter: Optional[float] = None
if gt_decimal_from_fraction is not None:
gt_val_inter = gt_decimal_from_fraction
else:
try:
gt_val_inter = float(gt_clean)
except ValueError:
pass
if pred_val_inter is None or gt_val_inter is None:
return string_similarity(pred_norm, gt_norm)
pred_value: float = cast(float, pred_val_inter)
gt_value: float = cast(float, gt_val_inter)
if pred_value == gt_value:
return 1.0
abs_error = abs(pred_value - gt_value)
abs_tolerance = 0.1
if abs(gt_value) < 0.1:
abs_tolerance = 0.001
elif abs(gt_value) < 1.0:
abs_tolerance = 0.01
if abs_error <= abs_tolerance:
return 1.0
if gt_value != 0:
relative_error = abs_error / abs(gt_value)
if relative_error < 0.001:
return 1.0
if relative_error < 0.01:
return 0.9
if relative_error < 0.05:
return 0.8
if relative_error < 0.1:
return 0.5
if relative_error < 0.3:
return 0.3
return 0.0
else:
if abs_error < 0.01:
return 1.0
if abs_error < 0.1:
return 0.5
return 0.0
except (ValueError, TypeError):
return string_similarity(pred_norm, gt_norm)
def string_similarity(s1: str, s2: str) -> float:
if not s1 and not s2:
return 1.0
if not s1 or not s2:
return 0.0
words1, words2 = set(s1.split()), set(s2.split())
if not words1 and not words2:
return 1.0
intersection = len(words1.intersection(words2))
union = len(words1.union(words2))
return intersection / union if union > 0 else 0.0
@reward_function
def accuracy_reward(
messages: Union[List[Message], List[Dict[str, Any]]],
ground_truth: Union[List[Message], List[Dict[str, Any]]],
extract_fn: Optional[Callable[[str], str]] = None,
compare_fn: Optional[Callable[[str, str], float]] = None,
**kwargs: Any,
) -> EvaluateResult:
model_response_text = ""
if not messages:
return EvaluateResult(
score=0.0,
reason="No messages provided.",
metrics={"accuracy": MetricResult(score=0.0, is_score_valid=False, reason="No messages provided.")},
)
model_last_message = messages[-1]
if isinstance(model_last_message, Message):
if model_last_message.role == "assistant" and model_last_message.content is not None:
model_response_text = _to_text(model_last_message.content)
else:
return EvaluateResult(
score=0.0,
reason="Last message not valid assistant response.",
metrics={
"accuracy": MetricResult(
score=0.0,
is_score_valid=False,
reason="Invalid assistant response.",
)
},
)
elif isinstance(model_last_message, dict):
if model_last_message.get("role") == "assistant" and model_last_message.get("content") is not None:
model_response_text = model_last_message.get("content", "")
else:
return EvaluateResult(
score=0.0,
reason="Last message not valid assistant response (dict).",
metrics={
"accuracy": MetricResult(
score=0.0,
is_score_valid=False,
reason="Invalid assistant response (dict).",
)
},
)
else:
return EvaluateResult(
score=0.0,
reason=f"Unexpected type for last message: {type(model_last_message)}.",
metrics={"accuracy": MetricResult(score=0.0, is_score_valid=False, reason="Invalid message type.")},
)
ground_truth_comparison_text = ""
if not ground_truth or not isinstance(ground_truth, list) or len(ground_truth) == 0:
return EvaluateResult(
score=0.0,
reason="Ground truth not provided/invalid.",
metrics={
"accuracy": MetricResult(
score=0.0,
is_score_valid=False,
reason="Invalid ground truth format.",
)
},
)
first_gt_message = ground_truth[0]
if isinstance(first_gt_message, Message):
if first_gt_message.content is not None:
ground_truth_comparison_text = _to_text(first_gt_message.content)
else:
return EvaluateResult(
score=0.0,
reason="First GT message has no content.",
metrics={
"accuracy": MetricResult(
score=0.0,
is_score_valid=False,
reason="Ground truth content missing.",
)
},
)
elif isinstance(first_gt_message, dict):
if first_gt_message.get("content") is not None:
ground_truth_comparison_text = first_gt_message.get("content", "")
else:
return EvaluateResult(
score=0.0,
reason="First GT message (dict) has no content.",
metrics={
"accuracy": MetricResult(
score=0.0,
is_score_valid=False,
reason="GT content missing (dict).",
)
},
)
else:
return EvaluateResult(
score=0.0,
reason=f"Unexpected type for first GT message: {type(first_gt_message)}.",
metrics={"accuracy": MetricResult(score=0.0, is_score_valid=False, reason="Invalid GT message type.")},
)
extracted_answer = extract_fn(model_response_text) if extract_fn else extract_math_expression(model_response_text)
if (
not extracted_answer
and model_response_text
and len(ground_truth_comparison_text) > 2
and ground_truth_comparison_text.lower() in model_response_text.lower()
):
extracted_answer = ground_truth_comparison_text
has_extracted = bool(extracted_answer)
similarity_score = (
compare_fn(extracted_answer, ground_truth_comparison_text)
if compare_fn
else compare_math_expressions(extracted_answer, ground_truth_comparison_text)
)
success = similarity_score >= 0.9
reason = f"Expected: '{ground_truth_comparison_text}', Extracted: '{extracted_answer}', Similarity: {similarity_score:.2f}"
metrics = {
"answer_extraction": MetricResult(
score=1.0 if has_extracted else 0.0,
is_score_valid=has_extracted,
reason=(f"Extracted answer: '{extracted_answer}'" if has_extracted else "Failed to extract answer"),
),
"answer_accuracy": MetricResult(
score=similarity_score,
is_score_valid=success,
reason=f"Answer similarity: {similarity_score:.2f}",
),
}
return EvaluateResult(score=similarity_score, reason=reason, metrics=metrics)