-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_improved_extraction.py
More file actions
161 lines (131 loc) · 6.04 KB
/
test_improved_extraction.py
File metadata and controls
161 lines (131 loc) · 6.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#!/usr/bin/env python3
"""Test improved answer extraction on actual AIME reasoning paths."""
import re
import json
def test_actual_aime_extraction():
"""Test answer extraction on actual AIME reasoning paths."""
# Load the actual AIME results
with open('results/aime_test_14b_8branch.json', 'r') as f:
data = json.load(f)
print("🔍 Testing Improved Answer Extraction on Actual AIME Results")
print("=" * 70)
for result in data['results']:
problem_id = result['problem_id']
ground_truth = result['ground_truth']
our_answer = result['our_answer']
reasoning = result['our_reasoning']
print(f"\n📝 Problem {problem_id}:")
print("-" * 50)
print(f"Ground Truth: {ground_truth}")
print(f"Our Answer: {our_answer}")
# Test improved extraction
improved_answer = extract_answer_improved(reasoning)
print(f"Improved Extraction: {improved_answer}")
# Show reasoning preview
print(f"Reasoning preview: {reasoning[:200]}...")
# Show all numbers found
number_pattern = re.compile(r"[-+]?\d[\d,]*(?:\.\d+)?")
all_numbers = number_pattern.findall(reasoning)
print(f"All numbers found: {all_numbers[:15]}...") # Show first 15
def extract_answer_improved(reasoning: str) -> str:
"""Improved answer extraction logic."""
if not reasoning or not reasoning.strip():
return ""
# Strategy 1: Look for explicit answer patterns (highest priority)
explicit_patterns = [
r"####\s*([-+]?\d+(?:\.\d+)?)", # #### answer
r"final answer.*?([-+]?\d+(?:\.\d+)?)", # final answer
r"answer is\s*[:\s]?([-+]?\d[\d,]*(?:\.\d+)?)(?=[\.\n]|$)", # answer is
r"\\boxed\{([^}]+)\}", # boxed answers
]
for pattern_str in explicit_patterns:
pattern = re.compile(pattern_str, re.I)
match = pattern.search(reasoning)
if match:
extracted_answer = match.group(1).strip()
cleaned_answer = clean_answer(extracted_answer)
if is_valid_answer(cleaned_answer):
return cleaned_answer
# Strategy 2: Look for "answer is" or "final answer" patterns with better regex
answer_patterns_enhanced = [
r"answer is[:\s]*([+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?)",
r"final answer[:\s]*([+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?)",
r"the answer[:\s]*([+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?)",
r"correct answer[:\s]*([+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?)",
r"answer[:\s]*([+-]?\d{1,3}(?:,\d{3})*(?:\.\d+)?)",
]
for pattern_str in answer_patterns_enhanced:
pattern = re.compile(pattern_str, re.IGNORECASE)
match = pattern.search(reasoning)
if match:
extracted_answer = match.group(1).strip()
cleaned_answer = clean_answer(extracted_answer)
if is_valid_answer(cleaned_answer):
return cleaned_answer
# Strategy 3: Look for the last reasonable number (improved)
number_pattern = re.compile(r"[-+]?\d[\d,]*(?:\.\d+)?")
all_numbers = number_pattern.findall(reasoning)
if all_numbers:
# Filter out very small numbers and look for the last reasonable one
reasonable_numbers = []
for num_str in all_numbers:
cleaned = clean_answer(num_str)
if is_valid_answer(cleaned):
num_val = float(cleaned)
# Only consider numbers >= 1 (filter out small intermediate calculations)
if num_val >= 1:
reasonable_numbers.append(cleaned)
if reasonable_numbers:
# For AIME problems, look for the last reasonable number that's not at the very beginning
# Skip the first number if it's followed by detailed reasoning (common pattern)
if len(reasonable_numbers) > 1:
# Check if the first number is followed by detailed reasoning
first_num = reasonable_numbers[0]
first_num_pos = reasoning.find(first_num)
if first_num_pos != -1:
# Look at the next 100 characters after the first number
next_chars = reasoning[first_num_pos + len(first_num):first_num_pos + len(first_num) + 100]
# If it's followed by detailed reasoning (contains common words), skip it
if any(word in next_chars.lower() for word in ['okay', 'so', 'let', 'first', 'now', 'then', 'we', 'i', 'the']):
# Skip the first number and return the last one
return reasonable_numbers[-1]
# Return the last reasonable number
return reasonable_numbers[-1]
return ""
def clean_answer(answer: str) -> str:
"""Clean answer for comparison."""
if not answer:
return ""
answer = str(answer).strip()
# Remove common prefixes
answer = re.sub(r'^(The answer is|Answer:|Final answer:?)\s*', '', answer, flags=re.IGNORECASE)
# Remove dollar signs and other currency symbols
answer = re.sub(r'[\$\s]+', '', answer)
# Remove boxed formatting
answer = re.sub(r'\\boxed\{([^}]+)\}', r'\1', answer)
# Remove brackets, parentheses
answer = re.sub(r'^[\\[\\](){}]+|[\\[\\](){}]+$', '', answer)
# Remove trailing punctuation
answer = re.sub(r'[.,;:!?]+$', '', answer)
# Remove commas from numbers
answer = answer.replace(",", "")
# Convert to float and back to remove unnecessary decimals
try:
num = float(answer)
if num == int(num):
return str(int(num))
else:
return str(num)
except ValueError:
return answer
def is_valid_answer(answer: str) -> bool:
"""Check if an answer is valid."""
if not answer or answer.strip() == "":
return False
try:
num = float(answer)
return -100000 <= num <= 1000000
except ValueError:
return False
if __name__ == "__main__":
test_actual_aime_extraction()