Skip to content

Commit e4d0925

Browse files
committed
Improve answer extraction and logging in majority voting
Enhanced the extract_answer function to better handle LaTeX boxed answers and multiple choice patterns. Moved majority voting summary from being appended to the response to logging via the logger, ensuring cleaner output.
1 parent 0016daa commit e4d0925

File tree

2 files changed

+91
-42
lines changed

2 files changed

+91
-42
lines changed

optillm/plugins/majority_voting_plugin.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,14 @@ def extract_answer(text: str) -> Optional[str]:
4545
# Remove any trailing whitespace
4646
text = text.strip()
4747

48-
# Pattern 1: Look for "Answer:" or "Final Answer:" patterns
48+
# Pattern 1: Look for LaTeX boxed format first (handle both \boxed and \\boxed)
49+
boxed_match = re.search(r'\\{1,2}boxed\{([^}]+)\}', text)
50+
if boxed_match:
51+
answer = boxed_match.group(1).strip()
52+
logger.debug(f"Extracted boxed answer: {answer}")
53+
return answer
54+
55+
# Pattern 2: Look for "Answer:" or "Final Answer:" patterns
4956
answer_patterns = [
5057
r'(?:final\s+)?answer\s*[:=]\s*(.+?)(?:\n|$)',
5158
r'(?:the\s+)?(?:final\s+)?answer\s+is\s*[:=]?\s*(.+?)(?:\n|$)',
@@ -62,13 +69,6 @@ def extract_answer(text: str) -> Optional[str]:
6269
logger.debug(f"Extracted answer using pattern: {answer}")
6370
return answer
6471

65-
# Pattern 2: Look for LaTeX boxed format
66-
boxed_match = re.search(r'\\boxed\{([^}]+)\}', text)
67-
if boxed_match:
68-
answer = boxed_match.group(1).strip()
69-
logger.debug(f"Extracted boxed answer: {answer}")
70-
return answer
71-
7272
# Pattern 3: Look for standalone numbers (useful for math problems)
7373
# Check the last few lines for a number
7474
lines = text.split('\n')
@@ -80,20 +80,29 @@ def extract_answer(text: str) -> Optional[str]:
8080
logger.debug(f"Extracted number answer: {line}")
8181
return line
8282

83-
# Pattern 4: If the last line is short (< 50 chars), it might be the answer
83+
# Pattern 4: For multiple choice, look for single letter answers
84+
# Check this before the generic last line check
85+
mc_patterns = [
86+
r'(?:the\s+)?(?:correct\s+)?(?:answer|option)\s+is\s+([A-E])(?:\b|$)',
87+
r'(?:choose|select|pick)\s+(?:option\s+)?([A-E])(?:\b|$)',
88+
r'\b([A-E])\s*\)\s*[A-Za-z]+.*is\s+(?:the\s+)?(?:correct|right)',
89+
r'^([A-E])$', # Just a letter on its own line
90+
]
91+
92+
for pattern in mc_patterns:
93+
mc_match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
94+
if mc_match:
95+
answer = mc_match.group(1).upper()
96+
logger.debug(f"Extracted multiple choice answer: {answer}")
97+
return answer
98+
99+
# Pattern 5: If the last line is short (< 50 chars), it might be the answer
84100
if lines:
85101
last_line = lines[-1].strip()
86102
if last_line and len(last_line) < 50 and not last_line.endswith(':'):
87103
logger.debug(f"Using last line as answer: {last_line}")
88104
return last_line
89105

90-
# Pattern 5: For multiple choice, look for single letter answers
91-
mc_match = re.search(r'\b([A-E])\b(?:\s*\))?$', text)
92-
if mc_match:
93-
answer = mc_match.group(1)
94-
logger.debug(f"Extracted multiple choice answer: {answer}")
95-
return answer
96-
97106
logger.warning("Could not extract a clear answer from the response")
98107
return None
99108

@@ -240,21 +249,18 @@ def run(
240249
# Get the full response corresponding to the most common answer
241250
winning_response = answer_to_response.get(most_common_answer, candidates[0])
242251

243-
# Add voting summary to the response
244-
voting_summary = f"\n\n**Majority Voting Result**:\n"
245-
voting_summary += f"- Generated {k} candidates\n"
246-
voting_summary += f"- Most common answer: {most_common_answer}\n"
247-
voting_summary += f"- Votes: {count}/{len(answers)} ({confidence:.1%} confidence)\n"
252+
# Log voting summary to console instead of adding to response
253+
logger.info("Majority Voting Summary:")
254+
logger.info(f" - Generated {k} candidates")
255+
logger.info(f" - Most common answer: {most_common_answer}")
256+
logger.info(f" - Votes: {count}/{len(answers)} ({confidence:.1%} confidence)")
248257

249258
if len(answer_counts) > 1:
250-
voting_summary += f"- Other answers: "
251259
other_answers = [f"{ans} ({cnt} votes)" for ans, cnt in answer_counts.items() if ans != most_common_answer]
252-
voting_summary += ", ".join(other_answers)
253-
254-
# Return the full response from the winning answer with voting summary
255-
final_response = winning_response + voting_summary
260+
logger.info(f" - Other answers: {', '.join(other_answers)}")
256261

257-
return final_response, total_tokens
262+
# Return only the full response from the winning answer
263+
return winning_response, total_tokens
258264

259265
except Exception as e:
260266
logger.error(f"Error in majority voting: {str(e)}")

scripts/eval_optillmbench.py

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,44 @@
2121
logger = logging.getLogger(__name__)
2222

2323
# Define the approaches to test
24-
# Each approach is (name, description)
24+
# Each approach is (name, description, extra_body_params)
2525
APPROACHES = [
26-
("none", "Baseline without any optimization"),
27-
("leap", "LEAP Approach"),
28-
("rto", "Round Trip Optimization"),
29-
("cot_reflection", "Chain of Thought with Reflection"),
30-
("self_consistency", "Self Consistency Check"),
31-
("plansearch", "Planning with Search"),
32-
("re2", "ReRead Approach"),
33-
("z3", "Z3 Solver for Mathematical Problems"),
34-
("coc", "Chain of Code"),
35-
("executecode" , "Execute Code"),
36-
("spl", "System Prompt Learning")
26+
("none", "Baseline without any optimization", {}),
27+
("leap", "LEAP Approach", {}),
28+
("rto", "Round Trip Optimization", {}),
29+
("cot_reflection", "Chain of Thought with Reflection", {}),
30+
("self_consistency", "Self Consistency Check", {}),
31+
("plansearch", "Planning with Search", {}),
32+
("re2", "ReRead Approach", {}),
33+
("z3", "Z3 Solver for Mathematical Problems", {}),
34+
("coc", "Chain of Code", {}),
35+
("executecode" , "Execute Code", {}),
36+
("spl", "System Prompt Learning", {})
37+
]
38+
39+
# Define test-time compute approaches for sequential and parallel scaling
40+
TEST_TIME_COMPUTE_APPROACHES = [
41+
# Baseline
42+
("none", "Baseline without any optimization", {}),
43+
44+
# Sequential test-time compute using thinkdeeper with different thinking budgets
45+
("thinkdeeper_8k", "ThinkDeeper with 8K thinking tokens", {
46+
"decoding": "thinkdeeper",
47+
"max_thinking_tokens": 8000
48+
}),
49+
("thinkdeeper_16k", "ThinkDeeper with 16K thinking tokens", {
50+
"decoding": "thinkdeeper",
51+
"max_thinking_tokens": 16000
52+
}),
53+
("thinkdeeper_32k", "ThinkDeeper with 32K thinking tokens", {
54+
"decoding": "thinkdeeper",
55+
"max_thinking_tokens": 32000
56+
}),
57+
58+
# Parallel test-time compute using majority voting with different k values
59+
("majority_voting_6", "Majority Voting with k=6", {"k": 6}),
60+
("majority_voting_36", "Majority Voting with k=36", {"k": 36}),
61+
("majority_voting_60", "Majority Voting with k=60", {"k": 60}),
3762
]
3863

3964
def load_optillm_bench() -> datasets.Dataset:
@@ -265,6 +290,7 @@ def evaluate_model(
265290
model: str,
266291
dataset: datasets.Dataset,
267292
approach: str,
293+
approach_extra_body: Dict[str, Any] = None,
268294
max_samples: int = None
269295
) -> Tuple[Dict[str, float], List[Dict[str, Any]]]:
270296
"""
@@ -286,8 +312,18 @@ def evaluate_model(
286312
# Prepare the dataset
287313
examples = dataset if max_samples is None else dataset.select(range(max_samples))
288314

289-
# Create model name with approach
290-
full_model_name = f"{approach}-{model}" if approach != "none" else model
315+
# Create model name with approach - handle special cases
316+
if approach == "none":
317+
full_model_name = model
318+
elif approach.startswith("thinkdeeper_"):
319+
# For thinkdeeper, use base model name (decoding is passed in extra_body)
320+
full_model_name = model
321+
elif approach.startswith("majority_voting_"):
322+
# For majority voting, use majority_voting prefix
323+
full_model_name = f"majority_voting-{model}"
324+
else:
325+
# Standard approach prefix
326+
full_model_name = f"{approach}-{model}"
291327

292328
for example in tqdm(examples, desc=f"Evaluating {approach}"):
293329
try:
@@ -297,6 +333,11 @@ def evaluate_model(
297333
# Record start time
298334
start_time = time.time()
299335

336+
# Prepare extra_body parameters
337+
extra_body = {"spl_learning": False}
338+
if approach_extra_body:
339+
extra_body.update(approach_extra_body)
340+
300341
# Make API call
301342
response = client.chat.completions.create(
302343
model=full_model_name,
@@ -306,7 +347,7 @@ def evaluate_model(
306347
],
307348
temperature=0.2,
308349
max_tokens=4096,
309-
extra_body= {"spl_learning": False},
350+
extra_body=extra_body,
310351
)
311352

312353
# Calculate time taken
@@ -469,6 +510,8 @@ def main():
469510
help="Directory to save results")
470511
parser.add_argument("--approaches", nargs="+",
471512
help="Specific approaches to evaluate (default: all)")
513+
parser.add_argument("--test-time-compute", action="store_true",
514+
help="Evaluate test-time compute approaches (sequential and parallel scaling)")
472515
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
473516
args = parser.parse_args()
474517

0 commit comments

Comments
 (0)