diff --git a/.gitignore b/.gitignore index 65d1fa7e..e597e5b3 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +*.7z # PyInstaller # Usually these files are written by a python script from a template @@ -52,7 +53,7 @@ coverage.xml .hypothesis/ .pytest_cache/ cover/ - +analysis_output/ # Translations *.mo *.pot diff --git a/README.md b/README.md index e2cb8d52..9007585f 100644 --- a/README.md +++ b/README.md @@ -197,4 +197,76 @@ export MEDARC_DISABLE_TOKEN_TRACKING=true - If provider doesn't return usage data, defaults to 0 - Model tokens include all inference API calls - Judge tokens include all LLM-as-judge calls via `judge()` method (e.g., FactScore: 6-20 verification calls per example) -- **Note**: Some judge implementations (e.g., FactScore claim extraction) make additional API calls (claim extraction) that are currently not tracked not part of judge() calls or get stored in state["responses"]. These represent a small overhead (~10-20% of total judge tokens) and are present in existing implementations like MedRedQA, keep in mind when calculating. \ No newline at end of file +- **Note**: Some judge implementations (e.g., FactScore claim extraction) make additional API calls (claim extraction) that are currently not tracked not part of judge() calls or get stored in state["responses"]. These represent a small overhead (~10-20% of total judge tokens) and are present in existing implementations like MedRedQA, keep in mind when calculating. + +## MCQ Answer Analysis + +Post-hoc analysis of MCQ benchmark results. Extracts model answers using the same parsing pipeline as evaluation, generates confusion matrices, and computes cross-rollout consistency metrics. + +### Usage + +```bash +# Single model +python scripts/mcq_answer_analysis.py \ + --logs-dir /path/to/model_results \ + --output-dir ./analysis_output \ + --model model-name + +# All models +python scripts/mcq_answer_analysis.py \ + --logs-dir /path/to/all_models \ + --output-dir ./analysis_output \ + --all-models + +# Specific benchmark only +python scripts/mcq_answer_analysis.py \ + --logs-dir /path/to/all_models \ + --output-dir ./analysis_output \ + --all-models \ + --benchmark medqa +``` + +### Output Files + +| File | Description | +|------|-------------| +| `{model}_{benchmark}.csv` | Per-example analysis with parsed answers | +| `{model}_{benchmark}_confusion.csv` | Confusion matrix (correct → predicted) | +| `{model}_{benchmark}_rollouts.csv` | Cross-rollout comparison (if multiple rollouts) | +| `{model}_summary.json` | Aggregate statistics per benchmark | +| `{model}_benchmark_metrics.csv` | Summary metrics table | +| `all_models_metrics.csv` | Cross-model comparison (with `--all-models`) | + +### Metrics + +- **Accuracy**: Standard correctness rate +- **Parsing success rate**: % of completions with extracted answer +- **Variation rate**: % of questions with different answers across rollouts +- **Semantic consistency**: Among varied answers, % with same answer text (different letter, same content) +- **Positional bias**: Per-position selection rate vs ground truth distribution + +### Model Parsed Answer Logging + +New evaluations automatically log parsed answers to `info` dict in `results.jsonl`: + +```json +{ + "info": { + "model_parsed_answer": "B", + "parsing_method": "anchored_token" + } +} +``` + +This enables exact reproducibility in post-hoc analysis. For older results without logging, the analysis script applies the full parsing pipeline (environment-specific XML/boxed extraction → MCQ answer parsing). + +To enable logging in custom environments, pass `info=info` to `multiple_choice_accuracy()`: + +```python +is_correct = multiple_choice_accuracy( + llm_answer=parsed, + answer_letter=answer, + answer_text=answer_text, + info=info # Enables parsed answer logging +) +``` \ No newline at end of file diff --git a/environments/careqa/careqa.py b/environments/careqa/careqa.py index e67178af..66255a46 100644 --- a/environments/careqa/careqa.py +++ b/environments/careqa/careqa.py @@ -32,7 +32,7 @@ def accuracy(completion, answer: str, parser: vf.Parser, info: dict | None = Non """Reward based on shared multiple-choice accuracy grading.""" parsed = parser.parse_answer(completion) or "" answer_text = info.get("answer_text", None) if info else None - is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text) + is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text, info=info) return 1.0 if is_correct else 0.0 diff --git a/environments/head_qa_v2/head_qa_v2.py b/environments/head_qa_v2/head_qa_v2.py index 9ac835be..3bfaf544 100644 --- a/environments/head_qa_v2/head_qa_v2.py +++ b/environments/head_qa_v2/head_qa_v2.py @@ -139,7 +139,7 @@ def cot_prompt(example: dict[str, Any]) -> dict[str, Any]: def accuracy(completion: Any, answer: str, parser: vf.Parser, info: dict[str, Any] | None = None) -> float: parsed = parser.parse_answer(completion) or "" answer_text = info.get("answer_text") if info else None - is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text) + is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text, info=info) return 1.0 if is_correct else 0.0 diff --git a/environments/longhealth/longhealth.py b/environments/longhealth/longhealth.py index 722dc7db..41fbe225 100644 --- a/environments/longhealth/longhealth.py +++ b/environments/longhealth/longhealth.py @@ -465,7 +465,7 @@ def accuracy(completion: Any, answer: str, parser: vf.Parser, info: dict | None parsed = parser.parse_answer(completion) or "" answer_text = info.get("correct_answer_text", None) if info else None is_correct = multiple_choice_accuracy( - llm_answer=parsed, answer_letter=answer, answer_text=answer_text, prefix="The correct answer is" + llm_answer=parsed, answer_letter=answer, answer_text=answer_text, prefix="The correct answer is", info=info ) return 1.0 if is_correct else 0.0 diff --git a/environments/m_arc/m_arc.py b/environments/m_arc/m_arc.py index 595bf01e..c6f28f26 100644 --- a/environments/m_arc/m_arc.py +++ b/environments/m_arc/m_arc.py @@ -194,7 +194,7 @@ def load_environment( def accuracy(completion, answer: str, parser: vf.Parser, info: dict | None = None, **kwargs) -> float: parsed = parser.parse_answer(completion) or "" answer_text = info.get("answer_text", None) if info else None - is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text) + is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text, info=info) return 1.0 if is_correct else 0.0 rubric = vf.Rubric(funcs=[accuracy], weights=[1.0], parser=parser) diff --git a/environments/med_mcqa/med_mcqa.py b/environments/med_mcqa/med_mcqa.py index 2f01a106..1207a57c 100644 --- a/environments/med_mcqa/med_mcqa.py +++ b/environments/med_mcqa/med_mcqa.py @@ -132,7 +132,7 @@ def _map_example(example: dict[str, Any]) -> dict[str, Any] | None: def accuracy(completion: Any, answer: str, parser: vf.Parser, info: dict[str, Any] | None = None) -> float: parsed = parser.parse_answer(completion) or "" answer_text = info.get("answer_text", None) if info else None - is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text) + is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text, info=info) return 1.0 if is_correct else 0.0 rubric = vf.Rubric(funcs=[accuracy], weights=[1.0], parser=parser) diff --git a/environments/medbullets/medbullets.py b/environments/medbullets/medbullets.py index f0d3431e..c6938080 100644 --- a/environments/medbullets/medbullets.py +++ b/environments/medbullets/medbullets.py @@ -143,7 +143,7 @@ def load_environment( def accuracy(completion, answer: str, parser: vf.Parser, info: dict | None = None, **kwargs) -> float: parsed = parser.parse_answer(completion) or "" answer_text = info.get("answer_text", None) if info else None - is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text) + is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text, info=info) return 1.0 if is_correct else 0.0 rubric = vf.Rubric(funcs=[accuracy], weights=[1.0], parser=parser) diff --git a/environments/medconceptsqa/medconceptsqa.py b/environments/medconceptsqa/medconceptsqa.py index eaf39742..4e81c8ff 100644 --- a/environments/medconceptsqa/medconceptsqa.py +++ b/environments/medconceptsqa/medconceptsqa.py @@ -192,7 +192,7 @@ def _map(row: dict, idx: int | None = None) -> dict: def accuracy(completion: Any, answer: str, parser: vf.Parser, info: dict | None = None) -> float: parsed = parser.parse_answer(completion) or "" answer_text = info.get("answer_text", None) if info else None - is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text) + is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text, info=info) return 1.0 if is_correct else 0.0 rubric = vf.Rubric(funcs=[accuracy], weights=[1.0], parser=parser) diff --git a/environments/medexqa/medexqa.py b/environments/medexqa/medexqa.py index 33fa4671..fe65b4f6 100644 --- a/environments/medexqa/medexqa.py +++ b/environments/medexqa/medexqa.py @@ -289,7 +289,7 @@ def _is_correct(parser, completion, answer: str, info: dict | None = None) -> bo completion_text = completion or "" parsed = parser.parse_answer(completion) or completion_text answer_text = (info or {}).get("answer_text", "") - return multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text) + return multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text, info=info) def combined_reward(parser, completion, answer, **kwargs) -> float: """Gate explanation scoring on MCQ correctness.""" @@ -322,7 +322,7 @@ async def combined_judge_reward(judge, prompt, completion, answer, state: State, model_rational = getattr(parsed, "explanation", None) is_correct = multiple_choice_accuracy( - llm_answer=model_answer, answer_letter=answer, answer_text=answer_text + llm_answer=model_answer, answer_letter=answer, answer_text=answer_text, info=info ) if not is_correct: diff --git a/environments/medqa/medqa.py b/environments/medqa/medqa.py index b6da35e1..daa47944 100644 --- a/environments/medqa/medqa.py +++ b/environments/medqa/medqa.py @@ -19,7 +19,7 @@ def accuracy(completion, answer: str, parser: vf.Parser, info: dict | None = Non """Reward based on shared multiple-choice accuracy grading.""" parsed = parser.parse_answer(completion) or "" answer_text = info.get("answer_text", None) if info else None - is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text) + is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text, info=info) return 1.0 if is_correct else 0.0 diff --git a/environments/medxpertqa/medxpertqa.py b/environments/medxpertqa/medxpertqa.py index 71f07772..6dd1eb5d 100644 --- a/environments/medxpertqa/medxpertqa.py +++ b/environments/medxpertqa/medxpertqa.py @@ -111,7 +111,7 @@ def _map(example: dict) -> dict: def accuracy(completion, answer: str, parser: vf.Parser, info: dict | None = None) -> float: parsed = parser.parse_answer(completion) or "" answer_text = info.get("answer_text", None) if info else None - is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text) + is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text, info=info) return 1.0 if is_correct else 0.0 rubric = vf.Rubric(funcs=[accuracy], weights=[1.0], parser=parser) diff --git a/environments/metamedqa/metamedqa.py b/environments/metamedqa/metamedqa.py index 197d9631..65a55f78 100644 --- a/environments/metamedqa/metamedqa.py +++ b/environments/metamedqa/metamedqa.py @@ -74,7 +74,7 @@ def _map(ex: dict, idx: int | None = None): def accuracy(completion, answer: str, parser: vf.Parser, info: dict | None = None, **kwargs) -> float: parsed = parser.parse_answer(completion) or "" answer_text = info.get("answer_text", None) if info else None - is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text) + is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text, info=info) return 1.0 if is_correct else 0.0 rubric = vf.Rubric(funcs=[accuracy], weights=[1.0], parser=parser) diff --git a/environments/mmlu_pro_health/mmlu_pro_health.py b/environments/mmlu_pro_health/mmlu_pro_health.py index 6ad0feb1..0c05f83f 100644 --- a/environments/mmlu_pro_health/mmlu_pro_health.py +++ b/environments/mmlu_pro_health/mmlu_pro_health.py @@ -193,7 +193,7 @@ def _convert_options(row: dict) -> dict: def accuracy(completion, answer: str, parser: vf.Parser, info: dict | None = None, **kwargs) -> float: parsed = parser.parse_answer(completion) or "" answer_text = info.get("answer_text", None) if info else None - is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text) + is_correct = multiple_choice_accuracy(llm_answer=parsed, answer_letter=answer, answer_text=answer_text, info=info) return 1.0 if is_correct else 0.0 rubric = vf.Rubric(funcs=[accuracy], weights=[1.0], parser=parser) diff --git a/environments/pubmedqa/pubmedqa.py b/environments/pubmedqa/pubmedqa.py index eb312659..f6c027ec 100644 --- a/environments/pubmedqa/pubmedqa.py +++ b/environments/pubmedqa/pubmedqa.py @@ -90,6 +90,7 @@ def accuracy(completion, answer: str, parser: vf.Parser, info: dict | None = Non llm_answer=parsed, answer_letter=answer, answer_text=answer_text, + info=info, ) return 1.0 if is_correct else 0.0 diff --git a/medarc_verifiers/rewards/multiple_choice_accuracy.py b/medarc_verifiers/rewards/multiple_choice_accuracy.py index 71e123a8..e704cfa1 100644 --- a/medarc_verifiers/rewards/multiple_choice_accuracy.py +++ b/medarc_verifiers/rewards/multiple_choice_accuracy.py @@ -31,6 +31,9 @@ class MCQAccuracyResult: correct_answer: Optional[str] = None """The correct answer for reference, if available.""" + parsed_answer: Optional[str] = None + """Parsed answer token (letter/number) extracted from the model output, if available.""" + def _nfkc_casefold(text: str) -> str: """Unicode normalize + casefold for robust text comparison.""" @@ -213,6 +216,7 @@ def multiple_choice_accuracy( accept_answer_text: bool = True, strip_tex: bool = True, return_details: bool = False, + info: Optional[dict] = None, ) -> bool | MCQAccuracyResult: """ Grade a multiple-choice answer with layered strategies: @@ -230,26 +234,50 @@ def multiple_choice_accuracy( accept_answer_text: Whether to fall back to text matching strip_tex: Whether to strip LaTeX formatting return_details: If True, return MCQAccuracyResult dataclass instead of bool - Returns: bool (if return_details=False) or MCQAccuracyResult (if return_details=True) """ def _result( - is_correct: bool, method: str, predicted: str | None, actual: str | None, return_details: bool + is_correct: bool, + method: str, + matched: str | None, + actual: str | None, + return_details: bool, + parsed: str | None = None, + log_method: str | None = None, ) -> bool | MCQAccuracyResult: - """Helper to format return value.""" + """Helper to format return value. + + Args: + is_correct: Whether the answer was graded as correct + method: The parsing method for MCQAccuracyResult (original behavior) + matched: The answer that matched correctly (None if incorrect) + actual: The correct answer letter + return_details: Whether to return MCQAccuracyResult or bool + parsed: What the model actually said (regardless of correctness) + log_method: The actual parsing method for info dict logging (defaults to method) + """ + # Log parsed answer to info dict if provided + if info is not None: + info["model_parsed_answer"] = parsed + info["parsing_method"] = log_method if log_method is not None else method + if not return_details: return is_correct - return MCQAccuracyResult( + + result = MCQAccuracyResult( is_correct=is_correct, method=method, - matched_answer=predicted, + matched_answer=matched, correct_answer=actual, + parsed_answer=parsed, ) + return result + if not llm_answer: - return _result(False, "none", None, None, return_details) + return _result(False, "none", None, None, return_details, parsed=None) # Normalize the response llm_answer = _remove_think_tags(llm_answer) @@ -269,19 +297,28 @@ def _result( raise ValueError(f"Invalid answer_letter '{answer_letter=}'. Must be a single letter or digit string.") explicit_choice_found = False + model_predicted = None # Track what the model actually said + parse_method = "none" # Strategy 1: Only answer letter anywhere (without anchoring) - if answer_letter == _norm_letter(llm_answer): - return _result(True, "direct_answer", llm_answer, answer_letter, return_details) + normalized_llm = _norm_letter(llm_answer) + if normalized_llm and len(llm_answer.strip()) <= 3: + model_predicted = normalized_llm + parse_method = "direct_answer" + if normalized_llm == answer_letter: + return _result(True, "direct_answer", llm_answer, answer_letter, return_details, parsed=normalized_llm) # Strategy 2: Accept leading option token like "B. answer ..." leading_match = LEADING_OPTION_PATTERN.match(llm_answer_original) if leading_match and answer_letter: predicted = _norm_letter(leading_match.group(1)) + if predicted and model_predicted is None: + model_predicted = predicted + parse_method = "anchored_token" if _token_kind_matches_answer_letter(predicted, answer_letter): explicit_choice_found = True if predicted == answer_letter: - return _result(True, "anchored_token", predicted, answer_letter, return_details) + return _result(True, "anchored_token", predicted, answer_letter, return_details, parsed=predicted) # Strategy 3: Anchored token (prefix matches first, fallback to generic anchors) prefix_matches = [] @@ -290,7 +327,7 @@ def _result( if prefix_norm: flexible_prefix = re.escape(prefix_norm).replace(r"\ ", r"\s+") prefix_pattern = re.compile( - rf"{flexible_prefix}\s*[:\-–—]?\s*(?:is\s*)?(?Pnot\s+|isn['’]t\s+)?\(?\s*(?P[A-Za-z]|\d{{1,2}})\s*[\)\.:]?(?![\w+\-/])", + rf"{flexible_prefix}\s*[:\-–—]?\s*(?:is\s*)?(?Pnot\s+|isn['']t\s+)?\(?\s*(?P[A-Za-z]|\d{{1,2}})\s*[\)\.:]?(?![\w+\-/])", re.IGNORECASE, ) prefix_matches = list(prefix_pattern.finditer(llm_answer)) @@ -299,10 +336,13 @@ def _result( if anchored_matches and answer_letter: last_match = anchored_matches[-1] predicted = _norm_letter(last_match.group("opt")) + if predicted and last_match.group("neg") is None: + model_predicted = predicted + parse_method = "anchored_token" if last_match.group("neg") is None and _token_kind_matches_answer_letter(predicted, answer_letter): explicit_choice_found = True if predicted == answer_letter and last_match.group("neg") is None: - return _result(True, "anchored_token", predicted, answer_letter, return_details) + return _result(True, "anchored_token", predicted, answer_letter, return_details, parsed=predicted) # Strategy 4: Last token in the answer tail, ignore negative contexts like "C is incorrect", if not explicit_choice_found and answer_letter: @@ -318,8 +358,12 @@ def _result( continue if _negative_after_option(tail, token_match): continue + if model_predicted is None: + model_predicted = predicted + parse_method = "last_token" if predicted == answer_letter: - return _result(True, "last_token", predicted, answer_letter, return_details) + return _result(True, "last_token", predicted, answer_letter, return_details, parsed=predicted) + break # Take the first valid token we find # Strategy 5: Exact answer text match if there's no explicit choice found # Only search at beginning and end to avoid matching reasoning in the middle @@ -343,11 +387,11 @@ def _result( # Check beginning first match = pattern.search(beginning_region) if match and not _negated_near(beginning_region, match): - return _result(True, "answer_text", beginning_region, answer_text, return_details) + return _result(True, "answer_text", beginning_region, answer_text, return_details, parsed=model_predicted) # Then check end (after reasoning) match = pattern.search(end_region) if match and not _negated_near(end_region, match): - return _result(True, "answer_text", end_region, answer_text, return_details) + return _result(True, "answer_text", end_region, answer_text, return_details, parsed=model_predicted) - return _result(False, "none", None, None, return_details) + return _result(False, "none", None, None, return_details, parsed=model_predicted, log_method=parse_method) diff --git a/scripts/mcq_answer_analysis.py b/scripts/mcq_answer_analysis.py new file mode 100755 index 00000000..4189fe35 --- /dev/null +++ b/scripts/mcq_answer_analysis.py @@ -0,0 +1,1036 @@ +#!/usr/bin/env python3 +""" +MCQ Answer Analysis Script + +Analyzes incorrect answers from model completion logs for MCQ benchmarks only. +Mirrors the exact parsing logic used during evaluation: +1. Environment-specific parsing (XML tags or \\boxed{} extraction) +2. multiple_choice_accuracy parsing strategies + +Features: +- Wrong answer extraction and categorization +- Positional bias detection +- Confusion matrices (which answers are confused for which) +- Cross-rollout variation analysis +- Semantic consistency analysis +- Robustness metrics +- Incremental updates (add new models/benchmarks without re-running all) + +Usage: + # Analyze all models + python scripts/mcq_answer_analysis.py \ + --logs-dir /path/to/medmarks_raw/raw \ + --output-dir ./analysis_output \ + --all-models + + # Analyze specific model + python scripts/mcq_answer_analysis.py \ + --logs-dir /path/to/model_results \ + --output-dir ./analysis_output \ + --model model-name + + # Add new benchmark incrementally + python scripts/mcq_answer_analysis.py \ + --logs-dir /path/to/medmarks_raw/raw \ + --output-dir ./analysis_output \ + --benchmark new_benchmark \ + --all-models +""" + +import argparse +import json +import math +import re +import sys +from pathlib import Path +from typing import Any, Optional + +import pandas as pd + +# Add parent directory to path to import medarc_verifiers +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from medarc_verifiers.rewards.multiple_choice_accuracy import ( + multiple_choice_accuracy, + _remove_think_tags, + _nfkc_casefold, + _norm_letter, + ANCHOR_PATTERN, + TOKEN_PATTERN, + LEADING_OPTION_PATTERN, +) + + +# ============================================================================= +# ENVIRONMENT-SPECIFIC PARSING (Step 1) +# Mirrors the parsing done by vf.XMLParser and extract_boxed_answer +# ============================================================================= + + +def extract_xml_answer(text: str) -> str | None: + """ + Extract content from ... tags. + Mirrors vf.XMLParser.parse_answer() logic - finds the LAST match. + """ + if not text: + return None + # Pattern matching XMLParser: \s*(.*?)\s* + pattern = r"\s*(.*?)\s*" + matches = list(re.finditer(pattern, text, re.DOTALL | re.IGNORECASE)) + if matches: + # Take the last match (same as XMLParser with last=True) + return matches[-1].group(1).strip() + return None + + +def extract_boxed_answer(text: str) -> str: + """ + Extract content from \\boxed{...} format. + Exact copy of verifiers.utils.data_utils.extract_boxed_answer + """ + def find_matching_brace(s: str, start: int) -> int: + count = 1 + i = start + while i < len(s) and count > 0: + if s[i] == "{": + count += 1 + elif s[i] == "}": + count -= 1 + i += 1 + return i - 1 if count == 0 else -1 + + # Find last \boxed{ + boxed_start = text.rfind("\\boxed{") + if boxed_start == -1: + return text + # Find the content between the braces + content_start = boxed_start + 7 # len('\\boxed{') + closing_brace = find_matching_brace(text, content_start) + + if closing_brace == -1: + return text + + return text[content_start:closing_brace] + + +def detect_answer_format(completion_text: str, system_prompt: str = "") -> str: + """ + Detect whether the completion uses XML or BOXED format. + Returns: 'xml', 'boxed', or 'unknown' + """ + # Check system prompt hints + system_lower = system_prompt.lower() + if "boxed" in system_lower or "\\boxed" in system_lower: + return "boxed" + if "" in system_lower or "xml" in system_lower: + return "xml" + + # Check completion content + if "" in completion_text.lower(): + return "xml" + if "\\boxed{" in completion_text: + return "boxed" + + # Default to XML (most common in MedARC environments) + return "xml" + + +def apply_environment_parser(completion_text: str, system_prompt: str = "") -> tuple[str, str]: + """ + Apply environment-specific parsing to extract the answer. + This mirrors what vf.XMLParser.parse_answer() or vf.Parser(extract_boxed_answer) does. + + Returns: + (extracted_text, format_used) + """ + if not completion_text: + return "", "none" + + format_type = detect_answer_format(completion_text, system_prompt) + + if format_type == "boxed": + extracted = extract_boxed_answer(completion_text) + # If boxed extraction returns original text, it means no boxed found + if extracted == completion_text: + # Fallback: try XML + xml_extracted = extract_xml_answer(completion_text) + if xml_extracted: + return xml_extracted, "xml" + return completion_text, "boxed_fallback" + return extracted, "boxed" + + # XML format (default) + extracted = extract_xml_answer(completion_text) + if extracted is not None: + return extracted, "xml" + + # No XML tags found - return original text (some models don't use tags) + return completion_text, "no_tags" + +# MCQ benchmark prefixes (exclude known open-ended/non-MCQ tasks) +# Excluded: medec (error correction), medexqa (explanation-based), medredqa (open-ended), +# medcasereasoning (diagnosis generation), careqa_open (open-ended subset) +MCQ_BENCHMARK_PREFIXES = { + "careqa", # MCQ subset (CareQA_en); careqa_open is excluded below + "head_qa_v2", + "longhealth", + "m_arc", + "med_halt", + "med_mcqa", + "medbullets", + "medcalc_bench", + "medconceptsqa", + "medqa", + "medxpertqa", + "metamedqa", + "mmlu_pro_health", + "pubmedqa", +} + +# Explicit exclusions (open-ended variants that match MCQ prefixes) +MCQ_BENCHMARK_EXCLUSIONS = { + "careqa_open", + "careqa-open", +} + + +def is_mcq_benchmark(name: str) -> bool: + """Check if benchmark name matches MCQ benchmark patterns.""" + name_lower = name.lower().strip() + + # Check explicit exclusions first + if name_lower in MCQ_BENCHMARK_EXCLUSIONS: + return False + for exclusion in MCQ_BENCHMARK_EXCLUSIONS: + if name_lower.startswith(exclusion + "-") or name_lower.startswith(exclusion + "_"): + return False + + # Check if matches MCQ prefixes + return any( + name_lower == prefix or name_lower.startswith(prefix + "-") or name_lower.startswith(prefix + "_") + for prefix in MCQ_BENCHMARK_PREFIXES + ) + + +def parse_mcq_options(prompt_text: str) -> dict[str, str]: + """Extract option texts from MCQ prompt.""" + options = {} + pattern = r'(?:^|\n)\s*[\(\[]?([A-Z])[\)\]]?[\.\):\-]\s+(.+?)(?=\n\s*[\(\[]?[A-Z][\)\]]?[\.\):\-]\s+|$)' + try: + matches = re.findall(pattern, prompt_text, re.DOTALL | re.MULTILINE) + for letter, text in matches: + cleaned = text.strip() + cleaned = re.sub(r'\s*(?:Answer|Question):\s*$', '', cleaned, flags=re.IGNORECASE) + options[letter] = cleaned + except Exception: + pass + return options + + +def extract_model_choice( + completion_text: str, + system_prompt: str = "", + use_logged_parsed: bool = True, + logged_parsed_answer: str | None = None, + logged_parsing_method: str | None = None, +) -> tuple[Optional[str], str, str]: + """ + Extract model's answer choice using the exact same logic as during evaluation: + 1. Apply environment-specific parser (XML/BOXED extraction) + 2. Apply multiple_choice_accuracy parsing strategies + + Args: + completion_text: Raw model completion + system_prompt: System prompt (helps detect format) + use_logged_parsed: If True and logged values available, use them directly + logged_parsed_answer: Pre-parsed answer from results.jsonl info dict + logged_parsing_method: Parsing method from results.jsonl info dict + + Returns: + (choice_letter, parsing_method, env_format) + """ + # If logged parsed answer is available (from new evaluation runs), use it directly + if use_logged_parsed and logged_parsed_answer is not None: + return logged_parsed_answer, logged_parsing_method or "logged", "logged" + + if not completion_text: + return None, "none", "none" + + # Step 1: Apply environment-specific parsing (XML/BOXED extraction) + # This mirrors: parsed = parser.parse_answer(completion) or "" + extracted_text, env_format = apply_environment_parser(completion_text, system_prompt) + + if not extracted_text: + return None, "none", env_format + + # Step 2: Apply multiple_choice_accuracy parsing logic to the extracted text + # This mirrors what happens inside multiple_choice_accuracy() + + # Remove think tags (same as multiple_choice_accuracy) + text = _remove_think_tags(extracted_text) + text_lower = _nfkc_casefold(text) + + # Strategy 1: Direct single letter (response is just the option) + normalized = _norm_letter(text.strip()) + if normalized and len(text.strip()) <= 3: + return normalized, "direct_answer", env_format + + # Strategy 2: Leading option token like "B. answer..." + leading_match = LEADING_OPTION_PATTERN.match(text) + if leading_match: + letter = _norm_letter(leading_match.group(1)) + if letter: + return letter, "anchored_token", env_format # Note: leading returns "anchored_token" per original code + + # Strategy 3: Anchored patterns (same as multiple_choice_accuracy) + anchored_matches = list(ANCHOR_PATTERN.finditer(text_lower)) + if anchored_matches: + last_match = anchored_matches[-1] + if last_match.group("neg") is None: # Not negated + letter = _norm_letter(last_match.group("opt")) + if letter: + return letter, "anchored_token", env_format + + # Strategy 4: Last token in tail + tail = text[-200:] if len(text) > 200 else text + tail_tokens = list(TOKEN_PATTERN.finditer(_nfkc_casefold(tail))) + if tail_tokens: + for token_match in reversed(tail_tokens): + letter = _norm_letter(token_match.group(1)) + if letter and letter.isalpha(): + return letter, "last_token", env_format + + return None, "failed", env_format + + +def analyze_result(result: dict) -> dict[str, Any]: + """Analyze a single result entry.""" + analysis = { + "example_id": result.get("example_id"), + "correct_answer": result.get("answer"), + "model_choice": None, + "is_correct": result.get("reward", 0.0) == 1.0, + "correct_answer_text": None, + "model_choice_text": None, + "all_options": {}, + "parsing_method": "failed", + "env_format": "unknown", + "options_source": "missing", + } + + # Extract system prompt (helps detect XML vs BOXED format) + system_prompt = "" + prompt = result.get("prompt", []) + if isinstance(prompt, list): + for msg in prompt: + if isinstance(msg, dict) and msg.get("role") == "system": + system_prompt = msg.get("content", "") + break + + # Extract user prompt text + prompt_text = "" + if isinstance(prompt, list): + for msg in prompt: + if isinstance(msg, dict) and msg.get("role") == "user": + prompt_text = msg.get("content", "") + break + + # Extract completion text + completion_text = "" + completion = result.get("completion", []) + if isinstance(completion, list): + for msg in completion: + if isinstance(msg, dict) and msg.get("role") == "assistant": + completion_text = msg.get("content", "") + break + + # Get options from info or parse from prompt + info = result.get("info") or {} + all_options = info.get("options") or info.get("all_options") + if not isinstance(all_options, dict) or not all_options: + all_options = parse_mcq_options(prompt_text) + analysis["options_source"] = "prompt" if all_options else "missing" + else: + analysis["options_source"] = "info" + analysis["all_options"] = all_options + + # Get correct answer text + analysis["correct_answer_text"] = info.get("answer_text") or info.get("correct_answer_text") + if not analysis["correct_answer_text"] and analysis["correct_answer"] in all_options: + analysis["correct_answer_text"] = all_options[analysis["correct_answer"]] + + # Check if we have pre-logged parsed answer (from new evaluation runs with info logging) + logged_parsed_answer = info.get("model_parsed_answer") + logged_parsing_method = info.get("parsing_method") + + # Extract model's choice using the full parsing pipeline: + # 1. Environment-specific parsing (XML/BOXED extraction) + # 2. multiple_choice_accuracy parsing strategies + model_choice, method, env_format = extract_model_choice( + completion_text, + system_prompt=system_prompt, + use_logged_parsed=True, + logged_parsed_answer=logged_parsed_answer, + logged_parsing_method=logged_parsing_method, + ) + + analysis["env_format"] = env_format + + # Validate choice against known options + if model_choice and all_options and model_choice not in all_options: + # Check if it's a numeric answer that should be mapped + if model_choice.isdigit() and all(k.isalpha() for k in all_options.keys()): + model_choice = None + method = "invalid_option" + + analysis["model_choice"] = model_choice + analysis["parsing_method"] = method + + # Get model choice text + if model_choice and model_choice in all_options: + analysis["model_choice_text"] = all_options[model_choice] + + return analysis + + +def build_confusion_matrix(results: list[dict]) -> tuple[pd.DataFrame, pd.DataFrame]: + """Build confusion matrix: correct_answer -> model_choice.""" + all_letters = set() + for r in results: + if r.get("correct_answer"): + all_letters.add(r["correct_answer"]) + if r.get("model_choice"): + all_letters.add(r["model_choice"]) + + letters = sorted(all_letters) + matrix = pd.DataFrame(0, index=letters, columns=letters) + matrix.index.name = "Correct Answer" + matrix.columns.name = "Model Choice" + + for r in results: + correct = r.get("correct_answer") + choice = r.get("model_choice") + if correct and choice and correct in letters and choice in letters: + matrix.loc[correct, choice] += 1 + + normalized = matrix.div(matrix.sum(axis=1), axis=0) * 100 + normalized = normalized.fillna(0) + matrix["Total"] = matrix.sum(axis=1) + + return matrix, normalized + + +def compare_rollouts(results_by_rollout: dict[str, list[dict]]) -> pd.DataFrame: + """Compare model choices across different rollouts.""" + rollout_ids = sorted(results_by_rollout.keys()) + if len(rollout_ids) < 2: + return pd.DataFrame() + + # Build lookup maps + rollout_maps = {} + for rid in rollout_ids: + rollout_maps[rid] = {r.get("example_id"): r for r in results_by_rollout[rid] if r.get("example_id") is not None} + + all_ids = set() + for mapping in rollout_maps.values(): + all_ids.update(mapping.keys()) + + comparisons = [] + base_rollout = "base" if "base" in rollout_ids else rollout_ids[0] + + for example_id in sorted(all_ids): + base_result = rollout_maps.get(base_rollout, {}).get(example_id, {}) + comp = { + "example_id": example_id, + "correct_answer": base_result.get("correct_answer"), + "correct_answer_text": base_result.get("correct_answer_text"), + } + + for rid in rollout_ids: + r = rollout_maps[rid].get(example_id) + if r: + comp[f"{rid}_choice"] = r.get("model_choice") + comp[f"{rid}_choice_text"] = r.get("model_choice_text") + else: + comp[f"{rid}_choice"] = None + comp[f"{rid}_choice_text"] = None + + # Analyze variation + choices = [comp.get(f"{rid}_choice") for rid in rollout_ids] + choice_texts = [comp.get(f"{rid}_choice_text") for rid in rollout_ids] + valid_choices = [c for c in choices if c is not None] + valid_texts = [t for t in choice_texts if t is not None] + + comp["has_comparable_data"] = len(valid_choices) >= 2 + + if len(valid_choices) < 2: + comp["has_variation"] = False + comp["semantic_consistency"] = None + comp["variation_type"] = "insufficient_data" + else: + has_letter_variation = len(set(valid_choices)) > 1 + has_semantic_info = len(valid_texts) == len(valid_choices) + + if has_semantic_info: + normalized_texts = [t.lower().strip() if t else None for t in valid_texts] + semantic_consistency = len(set(t for t in normalized_texts if t)) <= 1 + else: + semantic_consistency = None + + comp["has_variation"] = has_letter_variation + comp["semantic_consistency"] = semantic_consistency + + if not has_letter_variation: + comp["variation_type"] = "none" + elif semantic_consistency is True: + comp["variation_type"] = "letter_only" + elif semantic_consistency is False: + comp["variation_type"] = "semantic_change" + else: + comp["variation_type"] = "unknown_semantic" + + comparisons.append(comp) + + return pd.DataFrame(comparisons) + + +def compute_distribution_stats(counts: dict[str, int]) -> dict[str, float] | None: + """Compute entropy and bias metrics for a distribution.""" + total = sum(counts.values()) + if total == 0: + return None + probs = [c / total for c in counts.values()] + k = len(probs) + entropy = -sum(p * math.log2(p) for p in probs if p > 0) + uniform = 1.0 / k if k > 0 else 0.0 + l1_bias = sum(abs(p - uniform) for p in probs) if k > 0 else 0.0 + max_min_bias = (max(probs) - min(probs)) if k > 0 else 0.0 + return {"entropy": entropy, "l1_bias": l1_bias, "max_min_bias": max_min_bias} + + +def aggregate_statistics(results: list[dict], rollout_df: pd.DataFrame | None = None) -> dict[str, Any]: + """Compute aggregate statistics from analyzed results.""" + stats = { + "total_questions": len(results), + "total_correct": sum(1 for r in results if r.get("is_correct")), + "total_incorrect": sum(1 for r in results if not r.get("is_correct")), + } + + stats["accuracy"] = stats["total_correct"] / stats["total_questions"] if stats["total_questions"] > 0 else 0.0 + + # Parsing success rate + parsed = sum(1 for r in results if r.get("model_choice") is not None) + stats["parsing_success_rate"] = parsed / stats["total_questions"] if stats["total_questions"] > 0 else 0.0 + + # Choice distribution + choice_counts = {} + for r in results: + c = r.get("model_choice") + if c: + choice_counts[c] = choice_counts.get(c, 0) + 1 + stats["choice_distribution"] = choice_counts + stats["choice_distribution_stats"] = compute_distribution_stats(choice_counts) + + # Correct answer distribution (benchmark balance) + correct_counts = {} + for r in results: + c = r.get("correct_answer") + if c: + correct_counts[c] = correct_counts.get(c, 0) + 1 + stats["correct_answer_distribution"] = correct_counts + stats["correct_answer_distribution_stats"] = compute_distribution_stats(correct_counts) + + # Positional bias: compare choice vs correct distributions + if choice_counts and correct_counts: + # Compute per-position bias (choice_rate - correct_rate) + all_positions = set(choice_counts.keys()) | set(correct_counts.keys()) + total_choices = sum(choice_counts.values()) + total_correct = sum(correct_counts.values()) + positional_bias = {} + for pos in sorted(all_positions): + choice_rate = choice_counts.get(pos, 0) / total_choices if total_choices > 0 else 0 + correct_rate = correct_counts.get(pos, 0) / total_correct if total_correct > 0 else 0 + positional_bias[pos] = choice_rate - correct_rate + stats["positional_bias"] = positional_bias + + # Parsing method distribution + method_counts = {} + for r in results: + m = r.get("parsing_method") + if m: + method_counts[m] = method_counts.get(m, 0) + 1 + stats["parsing_method_distribution"] = method_counts + + # Environment format distribution (XML vs BOXED) + format_counts = {} + for r in results: + f = r.get("env_format") + if f: + format_counts[f] = format_counts.get(f, 0) + 1 + stats["env_format_distribution"] = format_counts + + # Options coverage + options_present = sum(1 for r in results if r.get("all_options")) + stats["options_coverage_rate"] = options_present / stats["total_questions"] if stats["total_questions"] > 0 else 0.0 + + # Rollout statistics + if rollout_df is not None and not rollout_df.empty: + comparable = rollout_df[rollout_df["has_comparable_data"] == True] if "has_comparable_data" in rollout_df.columns else rollout_df + stats["questions_with_rollouts"] = len(comparable) + stats["rollout_example_count"] = len(rollout_df) + + if "has_variation" in rollout_df.columns and len(comparable) > 0: + variation_count = comparable["has_variation"].sum() + stats["variation_rate"] = variation_count / len(comparable) + else: + stats["variation_rate"] = None + + if "semantic_consistency" in rollout_df.columns: + varied = comparable[comparable["has_variation"] == True] + if not varied.empty: + valid_sem = varied["semantic_consistency"].dropna() + stats["semantic_consistency_rate"] = valid_sem.mean() if len(valid_sem) > 0 else None + else: + stats["semantic_consistency_rate"] = None + else: + stats["semantic_consistency_rate"] = None + + if "variation_type" in rollout_df.columns: + stats["variation_type_distribution"] = rollout_df["variation_type"].value_counts().to_dict() + + return stats + + +def load_results_jsonl(filepath: Path) -> list[dict]: + """Load results from JSONL file.""" + results = [] + try: + with open(filepath) as f: + for line in f: + if line.strip(): + results.append(json.loads(line)) + except Exception as e: + print(f"Error loading {filepath}: {e}") + return results + + +def _extract_model_prefix(model_dir_name: str) -> str: + """Extract model prefix from model directory name like 'gpt-5_2-20251213-183612' -> 'gpt-5_2'.""" + # Remove timestamp suffix (format: -YYYYMMDD-HHMMSS) + import re + match = re.match(r"(.+)-\d{8}-\d{6}$", model_dir_name) + if match: + return match.group(1) + return model_dir_name + + +def _get_model_prefix_variants(model_prefix: str) -> list[str]: + """Generate possible variants of model prefix for flexible matching. + + E.g., 'gpt-oss-20b-med' -> ['gpt-oss-20b-med', 'gpt-oss-20b', 'gpt-oss'] + This handles cases where subdirectories use shorter prefixes. + """ + variants = [model_prefix] + + # Common suffixes that might be in model name but not in subdirectory names + suffix_patterns = ['-med', '-high', '-low', '-small', '-large', '-tiny', '-base', '-instruct'] + + # Try removing known suffixes + for suffix in suffix_patterns: + if model_prefix.endswith(suffix): + variants.append(model_prefix[:-len(suffix)]) + + # Also generate progressively shorter prefixes by splitting on '-' + parts = model_prefix.split('-') + for i in range(len(parts) - 1, 0, -1): + variant = '-'.join(parts[:i]) + if variant not in variants: + variants.append(variant) + + return variants + + +def _extract_benchmark_name(dirname: str, model_prefix: str) -> tuple[str, str]: + """Extract benchmark name and rollout ID from directory name. + + Handles formats: + - {model_prefix}-{benchmark} -> benchmark, "base" + - {model_prefix}-{benchmark}-rollout{seed} -> benchmark, "rollout{seed}" + - {benchmark} -> benchmark, "base" + - {benchmark}-rollout{seed} -> benchmark, "rollout{seed}" + + Also handles prefix variants where subdirectories may use shorter model prefixes + (e.g., model dir 'gpt-oss-20b-med' but subdirs use 'gpt-oss-20b' as prefix). + """ + # Try all prefix variants (longest first) + if model_prefix: + for prefix in _get_model_prefix_variants(model_prefix): + if dirname.startswith(prefix + "-"): + dirname = dirname[len(prefix) + 1:] + break + + # Handle rollout suffix + if "-rollout" in dirname: + parts = dirname.rsplit("-rollout", 1) + benchmark_name = parts[0] + rollout_id = f"rollout{parts[1]}" + else: + benchmark_name = dirname + rollout_id = "base" + + # Handle -- separator (vf-eval format) + if "--" in benchmark_name: + benchmark_name = benchmark_name.split("--")[0] + + return benchmark_name, rollout_id + + +def discover_benchmarks(model_dir: Path) -> dict[str, dict[str, Path]]: + """Discover benchmarks and rollouts in model directory. + + Supports multiple directory structures: + 1. model_dir/{model_prefix}-{benchmark}/results.jsonl (medmarks format) + 2. model_dir/{benchmark}/results.jsonl + 3. model_dir/{benchmark--model--name}/hash_id/results.jsonl (vf-eval structure) + """ + benchmarks = {} + if not model_dir.is_dir(): + return benchmarks + + # Extract model prefix from parent directory name + model_prefix = _extract_model_prefix(model_dir.name) + + for item in model_dir.iterdir(): + if not item.is_dir(): + continue + + # Check if results.jsonl is directly in this directory + results_file = item / "results.jsonl" + if results_file.exists(): + benchmark_name, rollout_id = _extract_benchmark_name(item.name, model_prefix) + + if benchmark_name not in benchmarks: + benchmarks[benchmark_name] = {} + benchmarks[benchmark_name][rollout_id] = results_file + continue + + # Check for vf-eval structure: benchmark--model--name/hash_id/results.jsonl + for subitem in item.iterdir(): + if not subitem.is_dir(): + continue + results_file = subitem / "results.jsonl" + if not results_file.exists(): + continue + + benchmark_name, rollout_id = _extract_benchmark_name(item.name, model_prefix) + # For vf-eval structure, use hash as rollout identifier if base + if rollout_id == "base": + rollout_id = subitem.name + + if benchmark_name not in benchmarks: + benchmarks[benchmark_name] = {} + benchmarks[benchmark_name][rollout_id] = results_file + + return benchmarks + + +def process_benchmark( + benchmark_name: str, + rollout_paths: dict[str, Path], + output_dir: Path, + model_name: str, +) -> dict[str, Any]: + """Process a single benchmark with all its rollouts.""" + print(f"\n Processing {benchmark_name}...") + print(f" Found {len(rollout_paths)} version(s): {', '.join(sorted(rollout_paths.keys()))}") + + analyzed_by_rollout = {} + all_analyzed = [] + + for rollout_id, results_path in sorted(rollout_paths.items()): + raw_results = load_results_jsonl(results_path) + if not raw_results: + continue + + analyzed = [] + for raw in raw_results: + analysis = analyze_result(raw) + analysis["rollout_id"] = rollout_id + analyzed.append(analysis) + all_analyzed.append(analysis) + + analyzed_by_rollout[rollout_id] = analyzed + + if not all_analyzed: + print(f" Warning: No results to analyze") + return {} + + # Save per-example analysis + output_base = output_dir / f"{model_name}_{benchmark_name}" + df = pd.DataFrame(all_analyzed) + df.to_csv(output_base.with_suffix(".csv"), index=False) + + # Confusion matrix + confusion_counts, confusion_norm = build_confusion_matrix(all_analyzed) + confusion_counts.to_csv(output_base.with_name(f"{output_base.name}_confusion.csv")) + + # Cross-rollout comparison + rollout_df = None + if len(analyzed_by_rollout) > 1: + rollout_df = compare_rollouts(analyzed_by_rollout) + if not rollout_df.empty: + rollout_df.to_csv(output_base.with_name(f"{output_base.name}_rollouts.csv"), index=False) + + # Aggregate statistics + stats = aggregate_statistics(all_analyzed, rollout_df) + stats["benchmark_name"] = benchmark_name + + print(f" Total: {stats['total_questions']}, Accuracy: {stats['accuracy']:.1%}, Parsed: {stats['parsing_success_rate']:.1%}") + if stats.get("variation_rate") is not None: + print(f" Variation: {stats['variation_rate']:.1%}, Semantic consistency: {stats.get('semantic_consistency_rate', 0):.1%}") + + return stats + + +def weighted_mean(stats_list: list[dict], key: str, weight_key: str) -> float | None: + """Compute weighted mean.""" + values = [(s.get(key), s.get(weight_key)) for s in stats_list if s.get(key) is not None and s.get(weight_key, 0) > 0] + if not values: + return None + total_weight = sum(w for _, w in values) + return sum(v * w for v, w in values) / total_weight if total_weight > 0 else None + + +def normalize_benchmark_name(benchmark_name: str, model_name: str) -> str: + """Normalize benchmark name by removing model prefix.""" + model_base = re.sub(r'-\d{8}-\d{6}$', '', model_name) + for prefix in [model_name, model_base]: + for sep in ["-", "--"]: + if benchmark_name.startswith(prefix + sep): + return benchmark_name[len(prefix) + len(sep):] + return benchmark_name + + +def compute_win_rates(all_summaries: list[dict]) -> dict[str, dict[str, float | None]]: + """Compute win rates across models per benchmark.""" + bench_map: dict[str, list[tuple[str, float, int]]] = {} + for summary in all_summaries: + model = summary.get("model_name") + per_bench = summary.get("per_benchmark_stats", {}) + for bench, stats in per_bench.items(): + if not model or not bench: + continue + bench_key = normalize_benchmark_name(bench, model) + acc = stats.get("accuracy") + n_q = stats.get("total_questions", 0) + if acc is None or n_q <= 0: + continue + bench_map.setdefault(bench_key, []).append((model, acc, n_q)) + + model_win_rates: dict[str, list[tuple[float, float]]] = {} + for bench, entries in bench_map.items(): + if len(entries) < 2: + continue + weight = math.log(entries[0][2]) if entries[0][2] > 1 else 0.0 + for i, (model_i, acc_i, _) in enumerate(entries): + wins = sum(1.0 for j, (_, acc_j, _) in enumerate(entries) if i != j and acc_i > acc_j) + ties = sum(0.5 for j, (_, acc_j, _) in enumerate(entries) if i != j and acc_i == acc_j) + win_rate = (wins + ties) / (len(entries) - 1) if len(entries) > 1 else 0.0 + model_win_rates.setdefault(model_i, []).append((win_rate, weight)) + + results = {} + for model, rates in model_win_rates.items(): + if not rates: + results[model] = {"mean_win_rate": None, "weighted_mean_win_rate": None} + continue + mean_win = sum(r for r, _ in rates) / len(rates) + weight_sum = sum(w for _, w in rates) + weighted_mean = sum(r * w for r, w in rates) / weight_sum if weight_sum > 0 else None + results[model] = {"mean_win_rate": mean_win, "weighted_mean_win_rate": weighted_mean} + + return results + + +def process_model( + model_dir: Path, + output_dir: Path, + model_name: str | None = None, + only_benchmarks: set[str] | None = None, + merge_existing: bool = False, +) -> dict[str, Any] | None: + """Process all benchmarks for a single model.""" + model_name = model_name or model_dir.name + + print(f"\n{'='*70}") + print(f"Processing model: {model_name}") + print(f"{'='*70}") + + benchmarks = discover_benchmarks(model_dir) + # Filter to MCQ only + benchmarks = {k: v for k, v in benchmarks.items() if is_mcq_benchmark(k)} + if only_benchmarks: + benchmarks = {k: v for k, v in benchmarks.items() if k in only_benchmarks} + + if not benchmarks: + print(f" No MCQ benchmarks found") + return None + + print(f" Found {len(benchmarks)} MCQ benchmark(s)") + + output_dir.mkdir(parents=True, exist_ok=True) + + # Load existing stats if merging + benchmark_stats: dict[str, dict] = {} + summary_path = output_dir / f"{model_name}_summary.json" + if merge_existing and summary_path.exists(): + try: + with open(summary_path) as f: + existing = json.load(f) + benchmark_stats = existing.get("per_benchmark_stats", {}) + except Exception: + pass + + # Process each benchmark + for bench_name, rollout_paths in sorted(benchmarks.items()): + stats = process_benchmark(bench_name, rollout_paths, output_dir, model_name) + if stats: + benchmark_stats[bench_name] = stats + + if not benchmark_stats: + return None + + # Aggregate statistics + total_q = sum(s["total_questions"] for s in benchmark_stats.values()) + total_c = sum(s["total_correct"] for s in benchmark_stats.values()) + + aggregate = { + "total_questions": total_q, + "total_correct": total_c, + "overall_accuracy": total_c / total_q if total_q > 0 else 0.0, + "parsing_success_rate": weighted_mean(list(benchmark_stats.values()), "parsing_success_rate", "total_questions"), + "options_coverage_rate": weighted_mean(list(benchmark_stats.values()), "options_coverage_rate", "total_questions"), + } + + # Rollout aggregates + benchmarks_with_rollouts = [s for s in benchmark_stats.values() if s.get("questions_with_rollouts", 0) > 0] + if benchmarks_with_rollouts: + aggregate["questions_with_rollouts"] = sum(s["questions_with_rollouts"] for s in benchmarks_with_rollouts) + aggregate["variation_rate"] = weighted_mean(benchmarks_with_rollouts, "variation_rate", "questions_with_rollouts") + aggregate["semantic_consistency_rate"] = weighted_mean(benchmarks_with_rollouts, "semantic_consistency_rate", "questions_with_rollouts") + + summary = { + "model_name": model_name, + "total_benchmarks": len(benchmark_stats), + "per_benchmark_stats": benchmark_stats, + "aggregate_stats": aggregate, + } + + # Save summary + with open(summary_path, 'w') as f: + json.dump(summary, f, indent=2) + + # Save benchmark metrics CSV + metrics_rows = [] + for bench, stats in benchmark_stats.items(): + row = { + "benchmark": bench, + "total_questions": stats.get("total_questions"), + "accuracy": stats.get("accuracy"), + "parsing_success_rate": stats.get("parsing_success_rate"), + "variation_rate": stats.get("variation_rate"), + "semantic_consistency_rate": stats.get("semantic_consistency_rate"), + } + metrics_rows.append(row) + pd.DataFrame(metrics_rows).to_csv(output_dir / f"{model_name}_benchmark_metrics.csv", index=False) + + print(f"\n Summary: {len(benchmark_stats)} benchmarks, {total_q} questions, {aggregate['overall_accuracy']:.1%} accuracy") + + return summary + + +def update_all_models_metrics(output_dir: Path) -> None: + """Recompute and save cross-model metrics.""" + summary_files = list(output_dir.glob("*/*_summary.json")) + if not summary_files: + print("No summary files found") + return + + all_summaries = [] + for path in summary_files: + with open(path) as f: + all_summaries.append(json.load(f)) + + win_rates = compute_win_rates(all_summaries) + + rows = [] + for summary in all_summaries: + agg = summary.get("aggregate_stats", {}) + model_name = summary.get("model_name") + row = { + "model": model_name, + "total_benchmarks": summary.get("total_benchmarks"), + "total_questions": agg.get("total_questions"), + "overall_accuracy": agg.get("overall_accuracy"), + "parsing_success_rate": agg.get("parsing_success_rate"), + "variation_rate": agg.get("variation_rate"), + "semantic_consistency_rate": agg.get("semantic_consistency_rate"), + } + if model_name in win_rates: + row.update(win_rates[model_name]) + rows.append(row) + + df = pd.DataFrame(rows) + df.to_csv(output_dir / "all_models_metrics.csv", index=False) + df.to_json(output_dir / "all_models_metrics.json", orient="records", indent=2) + print(f"\nSaved all-model metrics ({len(rows)} models)") + + +def main(): + parser = argparse.ArgumentParser(description="Analyze MCQ benchmark results") + parser.add_argument("--logs-dir", type=str, required=True, help="Directory containing model results") + parser.add_argument("--output-dir", type=str, required=True, help="Output directory for analysis") + parser.add_argument("--model", type=str, help="Specific model to analyze") + parser.add_argument("--all-models", action="store_true", help="Analyze all models in logs-dir") + parser.add_argument("--benchmark", action="append", help="Only analyze specific benchmark(s)") + + args = parser.parse_args() + + logs_dir = Path(args.logs_dir) + output_dir = Path(args.output_dir) + + if not logs_dir.exists(): + print(f"Error: {logs_dir} does not exist") + sys.exit(1) + + only_benchmarks = set(args.benchmark) if args.benchmark else None + + if args.all_models: + model_dirs = [d for d in logs_dir.iterdir() if d.is_dir()] + print(f"Found {len(model_dirs)} model directories") + + for model_dir in sorted(model_dirs): + try: + process_model( + model_dir, + output_dir / model_dir.name, + model_dir.name, + only_benchmarks, + merge_existing=bool(only_benchmarks), + ) + except Exception as e: + print(f"Error processing {model_dir.name}: {e}") + import traceback + traceback.print_exc() + + update_all_models_metrics(output_dir) + else: + model_name = args.model or logs_dir.name + process_model(logs_dir, output_dir, model_name, only_benchmarks, merge_existing=bool(only_benchmarks)) + # Update aggregate metrics after single model run + update_all_models_metrics(output_dir) + + print("\nAnalysis complete!") + + +if __name__ == "__main__": + main() diff --git a/scripts/mcq_visualizations.py b/scripts/mcq_visualizations.py new file mode 100755 index 00000000..23054517 --- /dev/null +++ b/scripts/mcq_visualizations.py @@ -0,0 +1,668 @@ +#!/usr/bin/env python3 +""" +MCQ Analysis Visualizations + +Generates paper-ready visualizations from MCQ analysis outputs. +Uses viridis colormap for consistency with blog style. + +Usage: + python scripts/mcq_visualizations.py \ + --analysis-dir ./analysis_output \ + --output-dir ./figures +""" + +import argparse +import re +import textwrap +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from matplotlib.backends.backend_pdf import PdfPages + +# Publication style settings +plt.style.use('seaborn-v0_8-whitegrid') +sns.set_context("paper", font_scale=1.2) +plt.rcParams['figure.dpi'] = 300 +plt.rcParams['savefig.dpi'] = 300 +plt.rcParams['font.family'] = 'sans-serif' +plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial'] +plt.rcParams['axes.labelsize'] = 12 +plt.rcParams['axes.titlesize'] = 14 +plt.rcParams['xtick.labelsize'] = 10 +plt.rcParams['ytick.labelsize'] = 10 +plt.rcParams['legend.fontsize'] = 10 + +# Color scheme +PERF_CMAP = 'viridis' +DIVERGE_CMAP = 'RdBu_r' + + +def clean_model_name(name: str) -> str: + """Clean model name for display.""" + name = re.sub(r'-\d{8}-\d{6}$', '', name) + return name.replace('_', '-') + + +def extract_model_size(name: str) -> str: + """Extract model size category. + + Categories: + - small: <10B parameters + - medium: 10-30B parameters + - large: >30B parameters or frontier models + """ + name_lower = name.lower() + + # Known frontier/large models without size in name + LARGE_MODELS = { + 'gpt-5', 'gpt_5', 'gpt-oss-120b', 'grok-4', 'sonnet-4', 'opus-4', + 'claude-4', 'gemini-2', 'baichuan-m2', 'minimax-m2', 'intellect3' + } + + # Known medium models without clear size + MEDIUM_MODELS = { + 'gpt-oss-20b', 'magistral-small' + } + + # Check known model patterns first + for pattern in LARGE_MODELS: + if pattern in name_lower: + return 'large' + + for pattern in MEDIUM_MODELS: + if pattern in name_lower: + return 'medium' + + # Try to extract size from "Xb" pattern + match = re.search(r'(\d+\.?\d*)b', name_lower) + if match: + size = float(match.group(1)) + if size < 10: + return 'small' + elif size < 30: + return 'medium' + return 'large' + + # Default to large for unknown models (frontier models often lack size in name) + return 'large' + + +def to_percent(series: pd.Series) -> pd.Series: + """Convert 0-1 values to percentages if needed.""" + if series.dropna().empty: + return series + if series.max() <= 1.5: + return series * 100 + return series + + +def load_model_metrics(analysis_dir: Path) -> pd.DataFrame: + """Load all-models metrics.""" + path = analysis_dir / 'all_models_metrics.csv' + if not path.exists(): + raise FileNotFoundError(f"Missing {path}") + df = pd.read_csv(path) + df['clean_model'] = df['model'].apply(clean_model_name) + for col in ['variation_rate', 'overall_accuracy', 'semantic_consistency_rate', 'mean_win_rate', 'weighted_mean_win_rate']: + if col in df.columns: + df[f'{col}_pct'] = to_percent(df[col]) + return df + + +def load_benchmark_metrics(analysis_dir: Path) -> pd.DataFrame: + """Load per-benchmark metrics from all models.""" + files = list(analysis_dir.glob("**/*_benchmark_metrics.csv")) + if not files: + raise FileNotFoundError(f"No benchmark metrics found in {analysis_dir}") + + dfs = [] + for fpath in files: + df = pd.read_csv(fpath) + match = re.match(r"(.+)_benchmark_metrics$", fpath.stem) + model_name = match.group(1) if match else fpath.parent.name + df["model"] = model_name + df["clean_model"] = clean_model_name(model_name) + for col in ['variation_rate', 'accuracy', 'semantic_consistency_rate']: + if col in df.columns: + df[f'{col}_pct'] = to_percent(df[col]) + dfs.append(df) + + return pd.concat(dfs, ignore_index=True) + + +def create_model_heatmap(df: pd.DataFrame, output_dir: Path): + """Create model performance heatmap.""" + print("Creating model performance heatmap...") + + # Aggregate by model + agg_cols = ['variation_rate_pct', 'overall_accuracy_pct', 'semantic_consistency_rate_pct'] + available = [c for c in agg_cols if c in df.columns] + if not available: + print(" Skipping - no data") + return + + model_stats = df.groupby('clean_model')[available].mean().reset_index() + model_stats = model_stats.sort_values(available[1] if len(available) > 1 else available[0], ascending=False) + + labels = { + 'overall_accuracy_pct': 'Accuracy\n(Higher=Better)', + 'variation_rate_pct': 'Variation Rate\n(Higher=More Sensitive)', + 'semantic_consistency_rate_pct': 'Semantic Consistency\n(Higher=Better)', + } + + heatmap_data = model_stats[available].values + fig, ax = plt.subplots(figsize=(10, max(8, len(model_stats) * 0.3))) + + sns.heatmap( + heatmap_data, annot=True, fmt='.1f', cmap=PERF_CMAP, + cbar_kws={'label': 'Percentage (%)'}, linewidths=1, linecolor='white', + ax=ax, vmin=0, vmax=100, + xticklabels=[labels.get(c, c) for c in available], + yticklabels=model_stats['clean_model'].values + ) + + ax.set_title('Model Performance Summary\n(Yellow = Higher Value)', fontweight='bold', pad=15) + ax.set_ylabel('Model', fontweight='bold') + ax.tick_params(axis='y', labelsize=8) + + plt.tight_layout() + plt.savefig(output_dir / 'model_heatmap.png', bbox_inches='tight') + plt.savefig(output_dir / 'model_heatmap.pdf', bbox_inches='tight') + plt.close() + print(f" Saved with {len(model_stats)} models") + + +def create_model_ranking(df: pd.DataFrame, output_dir: Path): + """Create model ranking by accuracy.""" + print("Creating model ranking plot...") + + if 'overall_accuracy_pct' not in df.columns: + print(" Skipping - no accuracy data") + return + + model_stats = df.groupby('clean_model').agg({ + 'overall_accuracy_pct': 'mean', + 'variation_rate_pct': 'mean' if 'variation_rate_pct' in df.columns else 'first', + }).reset_index() + model_stats = model_stats.sort_values('overall_accuracy_pct', ascending=True) + + fig, ax = plt.subplots(figsize=(12, max(8, len(model_stats) * 0.25))) + + colors = plt.cm.viridis(model_stats['overall_accuracy_pct'] / 100.0) + ax.barh(range(len(model_stats)), model_stats['overall_accuracy_pct'], color=colors, edgecolor='black', linewidth=0.3) + + ax.set_yticks(range(len(model_stats))) + ax.set_yticklabels(model_stats['clean_model'], fontsize=8) + ax.set_xlabel('Accuracy (%)', fontweight='bold') + ax.set_title('Models Ranked by Accuracy', fontweight='bold', pad=15) + + mean_acc = model_stats['overall_accuracy_pct'].mean() + ax.axvline(mean_acc, color='red', linestyle='--', alpha=0.7, label=f'Mean: {mean_acc:.1f}%') + ax.legend(loc='lower right') + ax.grid(True, alpha=0.3, axis='x') + ax.set_xlim(0, 100) + + plt.tight_layout() + plt.savefig(output_dir / 'model_ranking.png', bbox_inches='tight') + plt.savefig(output_dir / 'model_ranking.pdf', bbox_inches='tight') + plt.close() + print(f" Saved with {len(model_stats)} models") + + +def create_variation_ranking(df: pd.DataFrame, output_dir: Path): + """Create model ranking by variation rate.""" + print("Creating variation ranking plot...") + + if 'variation_rate_pct' not in df.columns: + print(" Skipping - no variation data") + return + + model_stats = df.dropna(subset=['variation_rate_pct']).groupby('clean_model').agg({ + 'variation_rate_pct': 'mean', + 'overall_accuracy_pct': 'mean' if 'overall_accuracy_pct' in df.columns else 'first', + }).reset_index() + model_stats = model_stats.sort_values('variation_rate_pct', ascending=True) + + if model_stats.empty: + print(" Skipping - no data after filtering") + return + + fig, ax = plt.subplots(figsize=(12, max(8, len(model_stats) * 0.25))) + + colors = plt.cm.viridis(model_stats['overall_accuracy_pct'] / 100.0) if 'overall_accuracy_pct' in model_stats else 'steelblue' + ax.barh(range(len(model_stats)), model_stats['variation_rate_pct'], color=colors, edgecolor='black', linewidth=0.3) + + ax.set_yticks(range(len(model_stats))) + ax.set_yticklabels(model_stats['clean_model'], fontsize=8) + ax.set_xlabel('Variation Rate (%)', fontweight='bold') + ax.set_title('Models Ranked by Variation Rate\n(Lower = More Robust)', fontweight='bold', pad=15) + + sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(vmin=0, vmax=100)) + sm.set_array([]) + cbar = plt.colorbar(sm, ax=ax, label='Accuracy (%)', pad=0.02) + cbar.ax.tick_params(labelsize=9) + + ax.grid(True, alpha=0.3, axis='x') + ax.set_xlim(0, 100) + + plt.tight_layout() + plt.savefig(output_dir / 'variation_ranking.png', bbox_inches='tight') + plt.savefig(output_dir / 'variation_ranking.pdf', bbox_inches='tight') + plt.close() + print(f" Saved with {len(model_stats)} models") + + +def create_scatter_plot(df: pd.DataFrame, output_dir: Path): + """Create accuracy vs semantic consistency scatter with model labels.""" + print("Creating scatter plot...") + + required = ['overall_accuracy_pct', 'semantic_consistency_rate_pct'] + if not all(c in df.columns for c in required): + print(" Skipping - missing columns") + return + + extra_cols = ['weighted_mean_win_rate_pct'] if 'weighted_mean_win_rate_pct' in df.columns else [] + model_stats = df.dropna(subset=required).groupby('clean_model')[required + extra_cols].mean().reset_index() + model_stats['size'] = model_stats['clean_model'].apply(extract_model_size) + + if model_stats.empty: + print(" Skipping - no data") + return + + fig, ax = plt.subplots(figsize=(16, 14)) + ax.set_facecolor('#fafafa') + + markers = {'small': 'o', 'medium': 's', 'large': '^'} + sizes = {'small': 180, 'medium': 220, 'large': 280} + + color_col = 'weighted_mean_win_rate_pct' if 'weighted_mean_win_rate_pct' in model_stats.columns else 'overall_accuracy_pct' + vmin, vmax = model_stats[color_col].min() - 2, model_stats[color_col].max() + 2 + + for size_cat, marker in markers.items(): + mask = model_stats['size'] == size_cat + if mask.sum() == 0: + continue + subset = model_stats[mask] + ax.scatter( + subset['semantic_consistency_rate_pct'], + subset['overall_accuracy_pct'], + c=subset[color_col], + cmap=PERF_CMAP, + s=sizes[size_cat], + alpha=0.85, + marker=marker, + edgecolors='black', + linewidth=1.2, + label=size_cat.title(), + vmin=vmin, + vmax=vmax, + zorder=3 + ) + + # Add model name labels with smart positioning + from adjustText import adjust_text + texts = [] + for _, row in model_stats.iterrows(): + # Shorten long model names + label = row['clean_model'] + if len(label) > 20: + label = label[:18] + '...' + texts.append(ax.annotate( + label, + (row['semantic_consistency_rate_pct'], row['overall_accuracy_pct']), + fontsize=7, + alpha=0.9, + zorder=4 + )) + + # Try to use adjustText if available, otherwise use basic offset + try: + adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle='-', color='gray', alpha=0.5, lw=0.5), + expand_points=(1.5, 1.5), force_text=(0.5, 0.5)) + except: + # Fallback: simple offset + for text in texts: + text.set_position((text.get_position()[0] + 0.5, text.get_position()[1] + 0.5)) + + sm = plt.cm.ScalarMappable(cmap=PERF_CMAP, norm=plt.Normalize(vmin=vmin, vmax=vmax)) + sm.set_array([]) + cbar = plt.colorbar(sm, ax=ax, shrink=0.5, pad=0.02) + cbar.set_label('Win Rate (%)' if 'win_rate' in color_col else 'Accuracy (%)', fontweight='bold', fontsize=11) + + ax.set_xlabel('Semantic Consistency (%) — Higher = Better', fontweight='bold', fontsize=12) + ax.set_ylabel('Accuracy (%) — Higher = Better', fontweight='bold', fontsize=12) + ax.set_title('Model Performance: Accuracy vs Semantic Consistency\n(Each point is a model, colored by win rate)', + fontweight='bold', pad=20, fontsize=14) + + # Clean up legend - only show unique entries + handles, labels = ax.get_legend_handles_labels() + by_label = dict(zip(labels, handles)) + ax.legend(by_label.values(), ['Small (<10B)', 'Medium (10-30B)', 'Large (>30B)'], + loc='lower left', framealpha=0.9, fontsize=10) + + ax.grid(True, alpha=0.4, linestyle='--', zorder=1) + ax.set_axisbelow(True) + + # Add quadrant labels + x_mid = model_stats['semantic_consistency_rate_pct'].median() + y_mid = model_stats['overall_accuracy_pct'].median() + ax.axvline(x_mid, color='gray', linestyle=':', alpha=0.5, zorder=2) + ax.axhline(y_mid, color='gray', linestyle=':', alpha=0.5, zorder=2) + + plt.tight_layout() + plt.savefig(output_dir / 'scatter_accuracy_consistency.png', bbox_inches='tight', dpi=300) + plt.savefig(output_dir / 'scatter_accuracy_consistency.pdf', bbox_inches='tight') + plt.close() + print(f" Saved with {len(model_stats)} models") + + +def create_violin_plots(df_bench: pd.DataFrame, output_dir: Path): + """Create violin distribution plots.""" + print("Creating violin plots...") + + if 'variation_rate_pct' not in df_bench.columns: + print(" Skipping - no variation data") + return + + df = df_bench.dropna(subset=['variation_rate_pct']) + if df.empty: + print(" Skipping - no data") + return + + model_order = df.groupby('clean_model')['variation_rate_pct'].median().sort_values().index + height = max(10, len(model_order) * 0.3) + + # Variation rate violin + fig, ax = plt.subplots(figsize=(12, height)) + sns.violinplot( + data=df, y='clean_model', x='variation_rate_pct', ax=ax, + palette='viridis', inner='box', orient='h', order=model_order, + hue='clean_model', legend=False + ) + ax.set_xlabel('Variation Rate (%) Across Benchmarks', fontweight='bold') + ax.set_ylabel('Model', fontweight='bold') + ax.set_title('Variation Rate Distribution\n(Higher = More Answer-Order Sensitive)', fontweight='bold', pad=15) + ax.set_xlim(0, 100) + ax.tick_params(axis='y', labelsize=8) + ax.grid(True, alpha=0.3, axis='x') + plt.tight_layout() + plt.savefig(output_dir / 'violin_variation.png', bbox_inches='tight') + plt.savefig(output_dir / 'violin_variation.pdf', bbox_inches='tight') + plt.close() + + # Accuracy violin + if 'accuracy_pct' in df.columns: + fig, ax = plt.subplots(figsize=(12, height)) + sns.violinplot( + data=df, y='clean_model', x='accuracy_pct', ax=ax, + palette='viridis', inner='box', orient='h', order=model_order, + hue='clean_model', legend=False + ) + ax.set_xlabel('Accuracy (%) Across Benchmarks', fontweight='bold') + ax.set_ylabel('Model', fontweight='bold') + ax.set_title('Accuracy Distribution', fontweight='bold', pad=15) + ax.set_xlim(0, 100) + ax.tick_params(axis='y', labelsize=8) + ax.grid(True, alpha=0.3, axis='x') + plt.tight_layout() + plt.savefig(output_dir / 'violin_accuracy.png', bbox_inches='tight') + plt.savefig(output_dir / 'violin_accuracy.pdf', bbox_inches='tight') + plt.close() + + print(f" Saved violin plots for {len(model_order)} models") + + +def create_benchmark_heatmap(df_bench: pd.DataFrame, output_dir: Path): + """Create benchmark performance heatmap.""" + print("Creating benchmark heatmap...") + + available = [c for c in ['variation_rate_pct', 'accuracy_pct', 'semantic_consistency_rate_pct'] if c in df_bench.columns] + if not available: + print(" Skipping - no data") + return + + bench_stats = df_bench.groupby('benchmark')[available].mean().reset_index() + bench_stats = bench_stats.sort_values(available[0], ascending=True).head(30) + bench_stats['clean_bench'] = bench_stats['benchmark'].str.replace('_', ' ').str.replace('-', ' ').str.title() + + labels = { + 'variation_rate_pct': 'Variation\n(Higher=Sensitive)', + 'accuracy_pct': 'Accuracy\n(Higher=Better)', + 'semantic_consistency_rate_pct': 'Consistency\n(Higher=Better)', + } + + heatmap_data = bench_stats[available].values + fig, ax = plt.subplots(figsize=(10, max(8, len(bench_stats) * 0.35))) + + sns.heatmap( + heatmap_data, annot=True, fmt='.1f', cmap=PERF_CMAP, + cbar_kws={'label': 'Percentage (%)'}, linewidths=1, linecolor='white', + ax=ax, vmin=0, vmax=100, + xticklabels=[labels.get(c, c) for c in available], + yticklabels=bench_stats['clean_bench'].values + ) + + ax.set_title('Benchmark Performance Summary\n(Averaged Across Models)', fontweight='bold', pad=15) + ax.set_ylabel('Benchmark', fontweight='bold') + ax.tick_params(axis='y', labelsize=9) + + plt.tight_layout() + plt.savefig(output_dir / 'benchmark_heatmap.png', bbox_inches='tight') + plt.savefig(output_dir / 'benchmark_heatmap.pdf', bbox_inches='tight') + plt.close() + print(f" Saved with {len(bench_stats)} benchmarks") + + +def create_correlation_heatmap(df: pd.DataFrame, output_dir: Path): + """Create metric correlation heatmap.""" + print("Creating correlation heatmap...") + + metrics = [c for c in ['weighted_mean_win_rate_pct', 'overall_accuracy_pct', 'semantic_consistency_rate_pct', 'variation_rate_pct'] if c in df.columns] + if len(metrics) < 2: + print(" Skipping - not enough metrics") + return + + corr = df[metrics].corr() + + labels = { + 'weighted_mean_win_rate_pct': 'Win Rate', + 'overall_accuracy_pct': 'Accuracy', + 'semantic_consistency_rate_pct': 'Consistency', + 'variation_rate_pct': 'Variation', + } + + fig, ax = plt.subplots(figsize=(8, 7)) + sns.heatmap( + corr, annot=True, fmt='.3f', cmap=DIVERGE_CMAP, + center=0, vmin=-1, vmax=1, + cbar_kws={'label': 'Correlation'}, + linewidths=2, linecolor='white', ax=ax, square=True, + annot_kws={'fontsize': 12, 'fontweight': 'bold'} + ) + + ax.set_xticklabels([labels.get(c, c) for c in metrics], rotation=45, ha='right') + ax.set_yticklabels([labels.get(c, c) for c in metrics], rotation=0) + ax.set_title('Metric Correlations\n(Red = Negative, Blue = Positive)', fontweight='bold', pad=15) + + plt.tight_layout() + plt.savefig(output_dir / 'correlation_heatmap.png', bbox_inches='tight') + plt.savefig(output_dir / 'correlation_heatmap.pdf', bbox_inches='tight') + plt.close() + print(" Saved correlation heatmap") + + +def create_summary_table(df: pd.DataFrame, output_dir: Path): + """Create summary statistics visualization.""" + print("Creating summary table...") + + stats = { + 'Total Models': df['clean_model'].nunique(), + 'Avg Accuracy (%)': df['overall_accuracy_pct'].mean() if 'overall_accuracy_pct' in df.columns else None, + 'Avg Variation (%)': df['variation_rate_pct'].mean() if 'variation_rate_pct' in df.columns else None, + 'Avg Consistency (%)': df['semantic_consistency_rate_pct'].mean() if 'semantic_consistency_rate_pct' in df.columns else None, + } + + fig = plt.figure(figsize=(14, 10)) + + # Summary box + ax1 = fig.add_subplot(2, 2, 1) + ax1.axis('off') + text = "ANALYSIS SUMMARY\n\n" + for k, v in stats.items(): + if v is not None: + text += f"{k}: {v:.1f}\n" if isinstance(v, float) else f"{k}: {v}\n" + ax1.text(0.5, 0.5, text, ha='center', va='center', fontsize=14, + bbox=dict(boxstyle='round', facecolor='#FDE724', alpha=0.7, pad=1.5)) + + # Distribution histograms + if 'variation_rate_pct' in df.columns: + ax2 = fig.add_subplot(2, 2, 2) + df['variation_rate_pct'].hist(bins=20, ax=ax2, color='steelblue', edgecolor='black', alpha=0.7) + ax2.axvline(df['variation_rate_pct'].mean(), color='red', linestyle='--', linewidth=2) + ax2.set_xlabel('Variation Rate (%)', fontweight='bold') + ax2.set_ylabel('Count', fontweight='bold') + ax2.set_title('Variation Rate Distribution', fontweight='bold') + + if 'overall_accuracy_pct' in df.columns: + ax3 = fig.add_subplot(2, 2, 3) + df['overall_accuracy_pct'].hist(bins=20, ax=ax3, color='forestgreen', edgecolor='black', alpha=0.7) + ax3.axvline(df['overall_accuracy_pct'].mean(), color='red', linestyle='--', linewidth=2) + ax3.set_xlabel('Accuracy (%)', fontweight='bold') + ax3.set_ylabel('Count', fontweight='bold') + ax3.set_title('Accuracy Distribution', fontweight='bold') + + # Key findings + ax4 = fig.add_subplot(2, 2, 4) + ax4.axis('off') + findings = f""" +KEY FINDINGS + +- {stats.get('Avg Variation (%)', 0):.0f}% average variation rate + indicates answer-order sensitivity + +- {stats.get('Avg Consistency (%)', 0):.0f}% semantic consistency + suggests content-based selection + +- Evaluation protocols should use + multiple random seeds for reliability +""" + ax4.text(0.5, 0.5, findings, ha='center', va='center', fontsize=12, + bbox=dict(boxstyle='round', facecolor='#B8DE6F', alpha=0.8, pad=1.5), + family='monospace', linespacing=1.6) + + plt.suptitle('MedMarks MCQ Analysis Summary', fontsize=16, fontweight='bold', y=0.98) + plt.tight_layout() + plt.savefig(output_dir / 'summary.png', bbox_inches='tight') + plt.savefig(output_dir / 'summary.pdf', bbox_inches='tight') + plt.close() + print(" Saved summary visualization") + + +def create_pdf_report(output_dir: Path, df: pd.DataFrame): + """Create combined PDF report.""" + print("Creating PDF report...") + + figures = [ + 'summary.png', + 'model_heatmap.png', + 'model_ranking.png', + 'variation_ranking.png', + 'scatter_accuracy_consistency.png', + 'benchmark_heatmap.png', + 'correlation_heatmap.png', + ] + + pdf_path = output_dir / 'mcq_analysis_report.pdf' + + with PdfPages(pdf_path) as pdf: + # Title page + fig = plt.figure(figsize=(8.5, 11)) + plt.axis('off') + plt.text(0.5, 0.6, 'MedMarks MCQ Analysis Report', ha='center', fontsize=24, fontweight='bold') + plt.text(0.5, 0.5, f'Models: {df["clean_model"].nunique()}', ha='center', fontsize=14) + pdf.savefig(fig, bbox_inches='tight') + plt.close() + + # Figures + for fname in figures: + fpath = output_dir / fname + if not fpath.exists(): + continue + fig = plt.figure(figsize=(11, 8.5)) + plt.axis('off') + img = plt.imread(fpath) + plt.imshow(img) + plt.title(fname.replace('.png', '').replace('_', ' ').title(), fontsize=12) + pdf.savefig(fig, bbox_inches='tight') + plt.close() + + print(f" Saved to {pdf_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Generate MCQ analysis visualizations") + parser.add_argument("--analysis-dir", type=str, required=True, help="Directory with analysis outputs") + parser.add_argument("--output-dir", type=str, required=True, help="Output directory for figures") + + args = parser.parse_args() + + analysis_dir = Path(args.analysis_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + print("="*60) + print("MCQ Analysis Visualizations") + print("="*60) + + # Load data + print("\nLoading data...") + try: + df_models = load_model_metrics(analysis_dir) + print(f" Loaded {df_models['clean_model'].nunique()} models") + except FileNotFoundError as e: + print(f" {e}") + df_models = pd.DataFrame() + + try: + df_bench = load_benchmark_metrics(analysis_dir) + print(f" Loaded {df_bench['benchmark'].nunique()} benchmarks") + except FileNotFoundError as e: + print(f" {e}") + df_bench = pd.DataFrame() + + if df_models.empty and df_bench.empty: + print("\nNo data to visualize!") + return + + # Generate visualizations + print("\nGenerating visualizations...") + + if not df_models.empty: + create_model_heatmap(df_models, output_dir) + create_model_ranking(df_models, output_dir) + create_variation_ranking(df_models, output_dir) + create_scatter_plot(df_models, output_dir) + create_correlation_heatmap(df_models, output_dir) + create_summary_table(df_models, output_dir) + + if not df_bench.empty: + create_violin_plots(df_bench, output_dir) + create_benchmark_heatmap(df_bench, output_dir) + + # PDF report + if not df_models.empty: + create_pdf_report(output_dir, df_models) + + print("\n" + "="*60) + print("Visualizations complete!") + print("="*60) + print(f"\nSaved to: {output_dir}") + for f in sorted(output_dir.glob("*.png")): + print(f" - {f.name}") + + +if __name__ == "__main__": + main()