From f82f3d4737b99cdf5f4b91a4fe0e85eb8d0293de Mon Sep 17 00:00:00 2001 From: mnishant2 Date: Sun, 12 Oct 2025 23:12:55 +0200 Subject: [PATCH 1/9] MedEXQA: env for qa and explanation eval --- environments/medexqa/README.md | 135 ++++++++++++++ environments/medexqa/medexqa.py | 275 ++++++++++++++++++++++++++++ environments/medexqa/pyproject.toml | 25 +++ 3 files changed, 435 insertions(+) create mode 100644 environments/medexqa/README.md create mode 100644 environments/medexqa/medexqa.py create mode 100644 environments/medexqa/pyproject.toml diff --git a/environments/medexqa/README.md b/environments/medexqa/README.md new file mode 100644 index 00000000..48156cef --- /dev/null +++ b/environments/medexqa/README.md @@ -0,0 +1,135 @@ +# medexqa-env- by mnishant2 + +### Overview +- **Environment ID**: `medexqa` +- **Short description**: Medical QA with multiple-choice questions and explanations across five underrepresented medical specialties +- **Tags**: medical, clinical, single-turn, multiple-choice, explanations, train, evaluation + +### Datasets +- **Primary dataset(s)**: MedExQA +- **Source links**: [Paper](https://arxiv.org/abs/2406.06331), [HuggingFace Dataset](https://huggingface.co/datasets/bluesky333/MedExQA), [GitHub](https://github.com/knowlab/MedExQA) +- **Split sizes**: + + | Specialty | Dev | Test | Total | + | --------------------------- | --- | ---- | ----- | + | Biomedical Engineering | 4 | 144 | 148 | + | Clinical Laboratory Science | 9 | 368 | 377 | + | Clinical Psychology | 3 | 108 | 111 | + | Occupational Therapy | 5 | 189 | 194 | + | Speech Language Pathology | 4 | 131 | 135 | + | **Total** | **25** | **940** | **965** | + +### Task +- **Type**: single-turn +- **System Prompt**: Uses the authors' prompt from their evaluation code: + ``` + "The following is a multiple-choice question. Please choose the most suitable one + among A, B, C and D as the answer to this question. Your answer should be paired + with an explanation why you chose that answer." + ``` +- **Parser**: `Parser` or `ThinkParser`, with `extract_fn=extract_boxed_answer` for strict letter-in-\boxed{}-format parsing +- **Rubric overview**: + - MCQ-only mode: Binary scoring based on correctly boxed letter choice + - Full evaluation mode: Weighted combination of MCQ accuracy + explanation quality (using LLM-as-judge) + +### Quickstart + +Run MCQ-only evaluation (default): + +```bash +uv run vf-eval medexqa -m gpt-4.1-mini +``` + +Run with explanation evaluation: + +```bash +export JUDGE_API_KEY=sk-... +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"use_explanations": true}' +``` + +Configure model and sampling: + +```bash +uv run vf-eval medexqa \ + -m gpt-4.1-mini \ + -n -1 -r 3 -t 1024 -T 0.7 \ + -a '{"use_think": false, "use_explanations": false}' +``` + +### Environment Arguments + +| Arg | Type | Default | Description | +| -------------------- | ----- | ------------- | ---------------------------------------------------------------------------------- | +| `use_think` | bool | `False` | Whether to check for `...` formatting with `ThinkParser` | +| `use_explanations` | bool | `False` | Whether to evaluate explanation quality using LLM-as-judge | +| `mcq_weight` | float | `0.5` | Weight for MCQ accuracy (only used when `use_explanations=True`) | +| `explanation_weight` | float | `0.5` | Weight for explanation quality (only used when `use_explanations=True`) | +| `judge_model` | str | `gpt-4o-mini` | Model to use for judging explanations | +| `judge_base_url` | str | `None` | Base URL for judge model API | +| `judge_api_key` | str | `None` | API key for judge (falls back to `JUDGE_API_KEY` or `OPENAI_API_KEY` env vars) | + +### Metrics + +**MCQ-Only Mode** (`use_explanations=False`): + +| Metric | Weight | Meaning | +| ------ | ------ | ------- | +| `correct_answer_reward_func` | 1.0 | 1.0 if parsed letter is correct, else 0.0 | +| `parser.get_format_reward_func()` | 0.0 | Optional format adherence (not counted) | + +**Full Evaluation Mode** (`use_explanations=True`): + +| Metric | Weight (default) | Meaning | +| ------ | ---------------- | ------- | +| `correct_answer_reward_func` | 0.5 | 1.0 if parsed letter is correct, else 0.0 | +| `explanation_quality_reward` | 0.5 | 0.0-1.0 score from LLM judge comparing model's explanation to two reference explanations | + +**Explanation Judge Criteria:** +- Medical accuracy +- Relevance to the question +- Clarity and completeness +- Proper use of medical concepts + +### Testing Instructions + +#### 1. Environment Setup +```bash +# Navigate to repository root +cd /data/storage_hpc_nishant/med-lm-envs + +# Sync uv environment +uv sync +``` + +#### 2. Quick Validation Test (MCQ-only) +```bash +uv run vf-eval medexqa -m gpt-4.1-mini -n 5 +``` + +#### 3. Full MCQ Evaluation +```bash +export OPENAI_API_KEY=sk-... +uv run vf-eval medexqa -m gpt-4.1-mini -n -1 -s +``` + +#### 4. With Explanation Evaluation +```bash +export JUDGE_API_KEY=sk-... +uv run vf-eval medexqa -m gpt-4.1-mini -n -1 -a '{"use_explanations": true}' -s +``` + +#### 5. With Think Tags +```bash +uv run vf-eval medexqa -m gpt-4.1-mini -n -1 -a '{"use_think": true}' +``` + +### Citation + +```bibtex +@article{kim2024medexqa, + title={MedExQA: Medical Question Answering Benchmark with Multiple Explanations}, + author={Kim, Yunsoo and Wu, Jinge and Abdulle, Yusuf and Wu, Honghan}, + journal={arXiv preprint arXiv:2406.06331}, + year={2024} +} +``` diff --git a/environments/medexqa/medexqa.py b/environments/medexqa/medexqa.py new file mode 100644 index 00000000..ba17361c --- /dev/null +++ b/environments/medexqa/medexqa.py @@ -0,0 +1,275 @@ +import os +import re + +import verifiers as vf +from datasets import Dataset, concatenate_datasets +from openai import AsyncOpenAI +from verifiers.utils.data_utils import BOXED_SYSTEM_PROMPT, THINK_BOXED_SYSTEM_PROMPT, extract_boxed_answer +import pandas as pd + + +# MedExQA specialties +SPECIALTIES = [ + "biomedical_engineer", + "clinical_laboratory_scientist", + "clinical_psychologist", + "occupational_therapist", + "speech_pathologist", +] + + + +def _build_question_str(question: str, options: dict[str, str]) -> str: + """Format question with answer choices, following authors' format with boxed instruction.""" + # Instruction adapted from authors' code https://github.com/knowlab/MedExQA/blob/9a5b34af103b0c8ba0c00906e278f6572249fafa/evaluate_pipe_MedExQA.py#L32 + instruction = ( + "The following is a multiple-choice question. Please choose the most suitable one " + "among A, B, C and D as the answer to this question. " + 'Put your answer in \\boxed{X} format where X is the letter choice. ' + "Your answer should be paired with an explanation why you chose that answer.\n\n" + ) + opts = "\n".join(f"{k}. {v}" for k, v in options.items()) + return f"{instruction}{question}\n{opts}\nAnswer:" + + +def _to_vf_format(ds: Dataset) -> Dataset: + """ + Shape each row for SingleTurnEnv's defaults: + - 'question': formatted question string with options + - 'answer': gold letter (A/B/C/D) + - 'info': keep all original fields including explanations + """ + def _format_row(row: dict) -> dict: + question = row.get("question", "") or "" + + # Build options dict from A, B, C, D columns + opts = { + "A": row.get("A", ""), + "B": row.get("B", ""), + "C": row.get("C", ""), + "D": row.get("D", ""), + } + + # Get answer letter + answer_letter = (row.get("answer") or "").strip().upper() + if answer_letter not in ("A", "B", "C", "D"): + return None + + question_str = _build_question_str(question, opts) + + # Keep original data in info + info = dict(row) + + return { + "question": question_str, + "answer": answer_letter, + "info": info, + } + + return ds.map(_format_row, remove_columns=ds.column_names).filter(lambda row: row is not None) + + +def load_environment( + use_think: bool = False, + use_explanations: bool = False, + mcq_weight: float = 0.5, + explanation_weight: float = 0.5, + judge_model: str = "gpt-4o-mini", + judge_base_url: str | None = None, + judge_api_key: str | None = None, + **kwargs +) -> vf.Environment: + """ + Single-turn MedExQA environment using HuggingFace `bluesky333/MedExQA` dataset + + Each example is normalized to the fields expected by `vf.SingleTurnEnv`: + { + "question": "", # string used as the user prompt + "answer": "", # top-level gold letter + "info": { ...original example fields... } # full source row including exp0, exp1 + } + + - Loads all 5 medical specialties (biomedical engineering, clinical lab science, + clinical psychology, occupational therapy, speech language pathology) + - No training split (dataset does not provide one) + - Test split used as evaluation data (940 total examples) + + - Parser extracts \\boxed{A|B|C|D} from completions + + - Reward looks for exact match between parsed letter and answer letter + - Optional: Explanation quality evaluation using LLM-as-judge + """ + + # Load all specialties and concatenate + # Note: MedExQA only has dev and test splits, no train split + # Load TSV files directly since HF dataset has column name issues + test_datasets = [] + + for specialty in SPECIALTIES: + try: + # Download and load TSV file directly + url = f"https://huggingface.co/datasets/bluesky333/MedExQA/resolve/main/test/{specialty}_test.tsv" + + # Load TSV with pandas (no headers in file) + df = pd.read_csv( + url, + sep='\t', + header=None, + names=["question", "A", "B", "C", "D", "exp0", "exp1", "answer"] + ) + + # Add specialty column + df['specialty'] = specialty + + # Convert to HF dataset + test_ds = Dataset.from_pandas(df, preserve_index=False) + test_datasets.append(test_ds) + except Exception as e: + print(f"Warning: Could not load {specialty}: {e}") + continue + + # Concatenate all specialties + test_combined = concatenate_datasets(test_datasets) if test_datasets else None + + # Format for verifiers - no training dataset available + test_ds = _to_vf_format(test_combined) if test_combined else None + + # Setup system prompt - use standard boxed prompts since instruction is in question + # Like M-ARC, we put the instruction in the question itself, so use standard prompts + system_prompt = THINK_BOXED_SYSTEM_PROMPT if use_think else BOXED_SYSTEM_PROMPT + + # Parser for extracting \\boxed{} answers + parser = ( + vf.ThinkParser(extract_fn=extract_boxed_answer) if use_think + else vf.Parser(extract_fn=extract_boxed_answer) + ) + + def correct_answer_reward_func(parser, completion, answer, **kwargs) -> float: + """Reward function for MCQ accuracy.""" + response = parser.parse_answer(completion) or "" + return 1.0 if response == answer else 0.0 + + # Create rubric based on whether we're evaluating explanations + if use_explanations: + # Setup judge for explanation evaluation + api_key = judge_api_key if judge_api_key else os.getenv("JUDGE_API_KEY") + if not api_key: + api_key = os.getenv("OPENAI_API_KEY") + + judge_client = AsyncOpenAI(base_url=judge_base_url, api_key=api_key) + + # We construct the judge prompt directly below when calling the judge + + # Important: the JudgeRubric formats only with {question}, {answer}, {response}. + # To include reference explanations exp0/exp1, we fully format the prompt + # ourselves and pass it as {question}. Hence, set rubric prompt to "{question}". + judge_rubric = vf.JudgeRubric( + judge_client=judge_client, + judge_model=judge_model, + judge_prompt="{question}", + ) + + async def combined_reward( + judge, prompt, completion, answer, state, **kwargs + ) -> float: + """Combined reward: MCQ accuracy + explanation quality.""" + # 1. Calculate MCQ accuracy + + mcq_score = correct_answer_reward_func(parser, completion, answer) + + # 2. Calculate explanation quality (strictly after the boxed answer) + completion_text = completion if isinstance(completion, str) else str(completion) + boxed_pattern = r"\\boxed\{[A-D]\}" + match = re.search(boxed_pattern, completion_text) + + if match: + explanation = completion_text[match.end():].strip() + else: + explanation = completion_text.strip() + # If the explanation is too short, set the score to 0.0 + if len(explanation.split()) < 10: + explanation_score = 0.0 + else: + info = kwargs.get("info", {}) + if not info: + return (mcq_weight * mcq_score) # no info, skip explanation + exp0 = info.get("exp0", "") + exp1 = info.get("exp1", "") + if not exp0 or not exp1: + return (mcq_weight * mcq_score) # missing refs + + question = info.get("question", "") + opts = { + "A": info.get("A", ""), + "B": info.get("B", ""), + "C": info.get("C", ""), + "D": info.get("D", ""), + } + opts_str = "\n".join(f"{k}. {v}" for k, v in opts.items()) + formatted_question = f"{question}\n{opts_str}" + + # Build judge prompt directly to avoid brace-escaping issues + full_prompt = ( + "You are evaluating the quality of a medical explanation.\n\n" + "**Question:**\n" + formatted_question + "\n\n" + "**Correct Answer:** " + str(answer) + "\n\n" + "**Reference Explanation 1:**\n" + str(exp0) + "\n\n" + "**Reference Explanation 2:**\n" + str(exp1) + "\n\n" + "**Model's Response:**\n" + explanation + "\n\n" + "Evaluate whether the model's explanation is medically accurate, relevant, and demonstrates understanding of the medical concepts. The explanation should justify why the answer is correct.\n\n" + "Compare the model's explanation quality to the reference explanations. Consider:\n" + "- Medical accuracy\n" + "- Relevance to the question\n" + "- Clarity and completeness\n" + "- Proper use of medical concepts\n\n" + "Respond with a score from 0.0 to 1.0:\n" + "- 1.0 = Excellent (as good as or better than references)\n" + "- 0.75 = Good (mostly correct with minor issues)\n" + "- 0.5 = Acceptable (partially correct)\n" + "- 0.25 = Poor (significant errors)\n" + "- 0.0 = Wrong or irrelevant\n\n" + "Respond with ONLY a number between 0.0 and 1.0." + ) + + judge_response = await judge_rubric.judge( + [{"role": "user", "content": full_prompt}], + "", # completion (unused) + "", # answer (unused) + state, + **kwargs, + ) + + try: + score_str = str(judge_response).strip() + number_match = re.search(r"(\d+\.?\d*)", score_str) + if number_match: + explanation_score = float(number_match.group(1)) + explanation_score = max(0.0, min(1.0, explanation_score)) + else: + explanation_score = 0.0 + except (ValueError, AttributeError): + explanation_score = 0.0 + + # Return weighted combination + return (mcq_weight * mcq_score) + (explanation_weight * explanation_score) + + # Add combined reward function + judge_rubric.add_reward_func(combined_reward, weight=1.0) + + rubric = judge_rubric + else: + # MCQ-only evaluation + rubric = vf.Rubric( + funcs=[correct_answer_reward_func], + weights=[1.0], + parser=parser, + ) + + return vf.SingleTurnEnv( + dataset=None, # No training split available + eval_dataset=test_ds, + system_prompt=system_prompt, + parser=parser, + rubric=rubric, + **kwargs + ) diff --git a/environments/medexqa/pyproject.toml b/environments/medexqa/pyproject.toml new file mode 100644 index 00000000..3f06da96 --- /dev/null +++ b/environments/medexqa/pyproject.toml @@ -0,0 +1,25 @@ +[project] +name = "medexqa" +version = "0.1.0" +description = "MedExQA Evaluation - Medical QA with Multiple Explanations" +readme = "README.md" +requires-python = ">=3.11" +dependencies = [ + "datasets>=4.0.0", + "verifiers>=0.1.2.post0", + "openai>=1.0.0", + "pandas>=2.0.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["medexqa.py"] + +[tool.prime.environment] +# lets Prime/vf-eval know where the loader lives in a flat repo +loader = "medexqa:load_environment" +display_name = "MedExQA" +visibility = "PUBLIC" From 86921ab5005de11871307d76cb7e6416968ffe72 Mon Sep 17 00:00:00 2001 From: mnishant2 Date: Wed, 15 Oct 2025 21:21:18 +0200 Subject: [PATCH 2/9] =?UTF-8?q?MedExQA:=20switch=20to=20lexical=20metrics?= =?UTF-8?q?=20+=20optional=20judge;=20authors=E2=80=99=20prompt;=20special?= =?UTF-8?q?ty=20codes;=20separate=20MCQ/expl=20metrics;=20metric=20scaling?= =?UTF-8?q?=200-100;=20deps=20update?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- environments/medexqa/medexqa.py | 293 ++++++++++++++++++---------- environments/medexqa/pyproject.toml | 6 +- 2 files changed, 195 insertions(+), 104 deletions(-) diff --git a/environments/medexqa/medexqa.py b/environments/medexqa/medexqa.py index ba17361c..586b985f 100644 --- a/environments/medexqa/medexqa.py +++ b/environments/medexqa/medexqa.py @@ -3,9 +3,11 @@ import verifiers as vf from datasets import Dataset, concatenate_datasets -from openai import AsyncOpenAI -from verifiers.utils.data_utils import BOXED_SYSTEM_PROMPT, THINK_BOXED_SYSTEM_PROMPT, extract_boxed_answer +from verifiers.utils.data_utils import THINK_BOXED_SYSTEM_PROMPT, extract_boxed_answer import pandas as pd +import evaluate +from thefuzz import process +from openai import AsyncOpenAI # MedExQA specialties @@ -19,17 +21,17 @@ +AUTHOR_SYSTEM_PROMPT = ( + "The following is a multiple-choice question. Please choose the most suitable one " + "among A, B, C and D as the answer to this question. " + "Your answer should be paired with an explanation why you chose that answer." +) + + def _build_question_str(question: str, options: dict[str, str]) -> str: - """Format question with answer choices, following authors' format with boxed instruction.""" - # Instruction adapted from authors' code https://github.com/knowlab/MedExQA/blob/9a5b34af103b0c8ba0c00906e278f6572249fafa/evaluate_pipe_MedExQA.py#L32 - instruction = ( - "The following is a multiple-choice question. Please choose the most suitable one " - "among A, B, C and D as the answer to this question. " - 'Put your answer in \\boxed{X} format where X is the letter choice. ' - "Your answer should be paired with an explanation why you chose that answer.\n\n" - ) + """Format question with answer choices only; instruction is provided via system prompt.""" opts = "\n".join(f"{k}. {v}" for k, v in options.items()) - return f"{instruction}{question}\n{opts}\nAnswer:" + return f"{question}\n{opts}\nAnswer:" def _to_vf_format(ds: Dataset) -> Dataset: @@ -74,6 +76,12 @@ def load_environment( use_explanations: bool = False, mcq_weight: float = 0.5, explanation_weight: float = 0.5, + specialty: str = "all", + explanation_metrics: list[str] | None = None, + metrics_aggregation: str = "average", + macroaverage: bool = False, + # Optional judge settings + use_judge: bool = False, judge_model: str = "gpt-4o-mini", judge_base_url: str | None = None, judge_api_key: str | None = None, @@ -134,9 +142,8 @@ def load_environment( # Format for verifiers - no training dataset available test_ds = _to_vf_format(test_combined) if test_combined else None - # Setup system prompt - use standard boxed prompts since instruction is in question - # Like M-ARC, we put the instruction in the question itself, so use standard prompts - system_prompt = THINK_BOXED_SYSTEM_PROMPT if use_think else BOXED_SYSTEM_PROMPT + # Setup system prompt - use authors' instruction in system; prepend think prompt if requested + system_prompt = f"{THINK_BOXED_SYSTEM_PROMPT}\n{AUTHOR_SYSTEM_PROMPT}" if use_think else AUTHOR_SYSTEM_PROMPT # Parser for extracting \\boxed{} answers parser = ( @@ -149,121 +156,201 @@ def correct_answer_reward_func(parser, completion, answer, **kwargs) -> float: response = parser.parse_answer(completion) or "" return 1.0 if response == answer else 0.0 - # Create rubric based on whether we're evaluating explanations - if use_explanations: - # Setup judge for explanation evaluation - api_key = judge_api_key if judge_api_key else os.getenv("JUDGE_API_KEY") - if not api_key: - api_key = os.getenv("OPENAI_API_KEY") - - judge_client = AsyncOpenAI(base_url=judge_base_url, api_key=api_key) - - # We construct the judge prompt directly below when calling the judge - - # Important: the JudgeRubric formats only with {question}, {answer}, {response}. - # To include reference explanations exp0/exp1, we fully format the prompt - # ourselves and pass it as {question}. Hence, set rubric prompt to "{question}". + # Optional specialty filter (short codes supported) + if specialty and test_ds is not None: + code_map = { + "BE": "biomedical_engineer", + "CLS": "clinical_laboratory_scientist", + "CP": "clinical_psychologist", + "OT": "occupational_therapist", + "SLP": "speech_pathologist", + "ALL": "all", + } + spec_upper = (specialty or "all").upper() + resolved = code_map.get(spec_upper, specialty) + if resolved != "all": + test_ds = test_ds.filter(lambda row: (row.get("info") or {}).get("specialty") == resolved) + + # Helpers (authors' answer extraction logic) + def process_before_extraction(gen: str, choice_dict: dict[str, str]) -> str: + for key, val in sorted(choice_dict.items(), key=lambda x: len(x[1] or ""), reverse=True): + pattern = re.compile(re.escape((val or "").rstrip(".")), re.IGNORECASE) + gen = pattern.sub(key, gen) + return gen + + def extract_choice(gen: str, choice_list: list[str]) -> str: + res = re.search(r"(?:(?:[Cc]hoose)|(?:(?:[Aa]nswer|[Cc]hoice)(?![^ABCD]{0,20}?(?:n't|not))[^ABCD]{0,10}?\b(?:|is|:|be))\b)[^ABCD]{0,20}?\b(A|B|C|D)\b", gen) + if res is None: + res = re.search(r"\b(A|B|C|D)\b(?![^ABCD]{0,8}?(?:n't|not)[^ABCD]{0,5}?(?:correct|right))[^ABCD]{0,10}?\b(?:correct|right)\b", gen) + if res is None: + res = re.search(r"^(A|B|C|D)(?:\.|,|:|$)", gen) + if res is None: + res = re.search(r"(? str: + gen = process_before_extraction(completion_text or "", options) + pred = extract_choice(gen, [options.get(c, "") for c in ["A", "B", "C", "D"]]) + return (pred or "").upper() + + # Metrics selection; 'all'/'overall' => average of all four + base_metrics = ["rougeL", "bleu", "meteor", "bertscore"] + if explanation_metrics is None: + selected_metrics = base_metrics + else: + if isinstance(explanation_metrics, str) and explanation_metrics.lower() in ("all", "overall"): + selected_metrics = base_metrics + elif isinstance(explanation_metrics, list) and any(str(m).lower() in ("all", "overall") for m in explanation_metrics): + selected_metrics = base_metrics + else: + selected_metrics = explanation_metrics + + def compute_metric_score(metric_name: str, prediction: str, refs: list[str]) -> float: + try: + name = metric_name.lower() + if name in ("rouge", "rougel"): + rouge = evaluate.load("rouge") + res = rouge.compute(predictions=[prediction], references=[refs]) + return float(res.get("rougeL", 0.0)) * 100.0 + if name == "bleu": + bleu = evaluate.load("bleu") + res = bleu.compute(predictions=[prediction], references=[refs]) + sc = float(res.get("bleu", 0.0)) + return sc * 100.0 if sc <= 1.0 else sc + if name == "meteor": + meteor = evaluate.load("meteor") + res = meteor.compute(predictions=[prediction], references=[refs]) + sc = float(res.get("meteor", 0.0)) + return sc * 100.0 if sc <= 1.0 else sc + if name == "bertscore": + bscore = evaluate.load("bertscore") + res = bscore.compute( + predictions=[prediction], + references=[refs], + model_type="allenai/scibert_scivocab_uncased", + lang="en", + rescale_with_baseline=False, + ) + f1_list = res.get("f1", []) + return (float(f1_list[0]) * 100.0) if f1_list else 0.0 + return 0.0 + except Exception: + return 0.0 + + def compute_expl_score(pred: str, exp0: str, exp1: str) -> float: + refs = [exp0 or "", exp1 or ""] + metric_vals = [compute_metric_score(m, pred, refs) for m in selected_metrics] + metric_vals = [v for v in metric_vals if v is not None] + if not metric_vals: + return 0.0 + # always average across selected metrics + return (sum(metric_vals) / len(metric_vals)) + + # Precompute specialty counts for macroaverage weighting (if requested) + spec_counts: dict[str, int] = {} + total_examples = 0 + if test_ds is not None: + for row in test_ds: + info_row = row.get("info") or {} + spec = info_row.get("specialty") or "unknown" + spec_counts[spec] = spec_counts.get(spec, 0) + 1 + total_examples += 1 + num_specs = len(spec_counts) if spec_counts else 1 + + def _macro_scale(spec: str) -> float: + if not macroaverage: + return 1.0 + if spec_counts and total_examples and num_specs: + n_k = spec_counts.get(spec, 1) + return float(total_examples) / float(num_specs * n_k) + return 1.0 + + def answer_accuracy_reward(parser, completion, answer, **kwargs) -> float: + completion_text = completion if isinstance(completion, str) else str(completion) + info = kwargs.get("info", {}) or {} + options = {"A": info.get("A", ""), "B": info.get("B", ""), "C": info.get("C", ""), "D": info.get("D", "")} + gold = (answer or "").strip().upper() + pred_letter = extract_answer_letter(completion_text, options) + base = 1.0 if pred_letter == gold else 0.0 + spec = (info.get("specialty") or "unknown") + return base * _macro_scale(spec) + + def explanation_reward(parser, completion, answer, **kwargs) -> float: + completion_text = completion if isinstance(completion, str) else str(completion) + info = kwargs.get("info", {}) or {} + options = {"A": info.get("A", ""), "B": info.get("B", ""), "C": info.get("C", ""), "D": info.get("D", "")} + gold = (answer or "").strip().upper() + pred_letter = extract_answer_letter(completion_text, options) + if pred_letter != gold: + base = 0.0 + else: + base = compute_expl_score(completion_text, info.get("exp0", ""), info.get("exp1", "")) + spec = (info.get("specialty") or "unknown") + return base * _macro_scale(spec) + + # Optional: Use LLM-as-judge for explanation instead of lexical metrics + if use_explanations and use_judge: + api_key = judge_api_key if judge_api_key else os.getenv("JUDGE_API_KEY") or os.getenv("OPENAI_API_KEY") + judge_client = AsyncOpenAI(base_url=judge_base_url, api_key=api_key) if api_key else None judge_rubric = vf.JudgeRubric( judge_client=judge_client, judge_model=judge_model, judge_prompt="{question}", ) - async def combined_reward( - judge, prompt, completion, answer, state, **kwargs - ) -> float: - """Combined reward: MCQ accuracy + explanation quality.""" - # 1. Calculate MCQ accuracy - - mcq_score = correct_answer_reward_func(parser, completion, answer) - - # 2. Calculate explanation quality (strictly after the boxed answer) + async def explanation_judge_reward(judge, prompt, completion, answer, state, **kwargs) -> float: completion_text = completion if isinstance(completion, str) else str(completion) - boxed_pattern = r"\\boxed\{[A-D]\}" - match = re.search(boxed_pattern, completion_text) - - if match: - explanation = completion_text[match.end():].strip() + info = kwargs.get("info", {}) or {} + options = {"A": info.get("A", ""), "B": info.get("B", ""), "C": info.get("C", ""), "D": info.get("D", "")} + gold = (answer or "").strip().upper() + pred_letter = extract_answer_letter(completion_text, options) + if pred_letter != gold: + base = 0.0 else: - explanation = completion_text.strip() - # If the explanation is too short, set the score to 0.0 - if len(explanation.split()) < 10: - explanation_score = 0.0 - else: - info = kwargs.get("info", {}) - if not info: - return (mcq_weight * mcq_score) # no info, skip explanation - exp0 = info.get("exp0", "") - exp1 = info.get("exp1", "") - if not exp0 or not exp1: - return (mcq_weight * mcq_score) # missing refs - + # Build judge prompt question = info.get("question", "") - opts = { - "A": info.get("A", ""), - "B": info.get("B", ""), - "C": info.get("C", ""), - "D": info.get("D", ""), - } - opts_str = "\n".join(f"{k}. {v}" for k, v in opts.items()) + opts_str = "\n".join(f"{k}. {options.get(k, '')}" for k in ["A","B","C","D"]) formatted_question = f"{question}\n{opts_str}" - - # Build judge prompt directly to avoid brace-escaping issues + exp0 = info.get("exp0", "") + exp1 = info.get("exp1", "") full_prompt = ( "You are evaluating the quality of a medical explanation.\n\n" "**Question:**\n" + formatted_question + "\n\n" - "**Correct Answer:** " + str(answer) + "\n\n" + "**Correct Answer:** " + str(gold) + "\n\n" "**Reference Explanation 1:**\n" + str(exp0) + "\n\n" "**Reference Explanation 2:**\n" + str(exp1) + "\n\n" - "**Model's Response:**\n" + explanation + "\n\n" - "Evaluate whether the model's explanation is medically accurate, relevant, and demonstrates understanding of the medical concepts. The explanation should justify why the answer is correct.\n\n" - "Compare the model's explanation quality to the reference explanations. Consider:\n" - "- Medical accuracy\n" - "- Relevance to the question\n" - "- Clarity and completeness\n" - "- Proper use of medical concepts\n\n" - "Respond with a score from 0.0 to 1.0:\n" - "- 1.0 = Excellent (as good as or better than references)\n" - "- 0.75 = Good (mostly correct with minor issues)\n" - "- 0.5 = Acceptable (partially correct)\n" - "- 0.25 = Poor (significant errors)\n" - "- 0.0 = Wrong or irrelevant\n\n" + "**Model's Response:**\n" + completion_text + "\n\n" "Respond with ONLY a number between 0.0 and 1.0." ) - judge_response = await judge_rubric.judge( [{"role": "user", "content": full_prompt}], - "", # completion (unused) - "", # answer (unused) + "", + "", state, **kwargs, ) - try: score_str = str(judge_response).strip() - number_match = re.search(r"(\d+\.?\d*)", score_str) - if number_match: - explanation_score = float(number_match.group(1)) - explanation_score = max(0.0, min(1.0, explanation_score)) - else: - explanation_score = 0.0 - except (ValueError, AttributeError): - explanation_score = 0.0 - - # Return weighted combination - return (mcq_weight * mcq_score) + (explanation_weight * explanation_score) - - # Add combined reward function - judge_rubric.add_reward_func(combined_reward, weight=1.0) - + import re as _re + m = _re.search(r"(\d+\.?\d*)", score_str) + s = float(m.group(1)) if m else 0.0 + except Exception: + s = 0.0 + base = max(0.0, min(1.0, s)) * 100.0 + spec = (info.get("specialty") or "unknown") + return base * _macro_scale(spec) + + # Use JudgeRubric with two metrics: answer accuracy (sync), explanation judge (async) + judge_rubric.add_reward_func(answer_accuracy_reward, weight=0.0) + judge_rubric.add_reward_func(explanation_judge_reward, weight=0.0) rubric = judge_rubric else: - # MCQ-only evaluation - rubric = vf.Rubric( - funcs=[correct_answer_reward_func], - weights=[1.0], - parser=parser, - ) + # Keep metrics separate (no combined reward) + rubric = vf.Rubric(funcs=[answer_accuracy_reward, explanation_reward], weights=[0.0, 0.0], parser=parser) return vf.SingleTurnEnv( dataset=None, # No training split available diff --git a/environments/medexqa/pyproject.toml b/environments/medexqa/pyproject.toml index 3f06da96..24d44f14 100644 --- a/environments/medexqa/pyproject.toml +++ b/environments/medexqa/pyproject.toml @@ -7,8 +7,12 @@ requires-python = ">=3.11" dependencies = [ "datasets>=4.0.0", "verifiers>=0.1.2.post0", - "openai>=1.0.0", "pandas>=2.0.0", + "evaluate>=0.4.0", + "rouge-score>=0.1.2", + "sacrebleu>=2.4.0", + "bert-score>=0.3.13", + "thefuzz>=0.22.1", ] [build-system] From 2106e6efc934ecab0576c8980086d1e6fc4fc0e6 Mon Sep 17 00:00:00 2001 From: mnishant2 Date: Thu, 16 Oct 2025 01:49:21 +0200 Subject: [PATCH 3/9] added specialty wise eval, lexical substring match metrics, optional LLM-as-a-judge --- environments/medexqa/README.md | 132 +++++++++++------- environments/medexqa/medexqa.py | 232 +++++++++++++++++--------------- 2 files changed, 207 insertions(+), 157 deletions(-) diff --git a/environments/medexqa/README.md b/environments/medexqa/README.md index 48156cef..5ad8f8ad 100644 --- a/environments/medexqa/README.md +++ b/environments/medexqa/README.md @@ -3,7 +3,7 @@ ### Overview - **Environment ID**: `medexqa` - **Short description**: Medical QA with multiple-choice questions and explanations across five underrepresented medical specialties -- **Tags**: medical, clinical, single-turn, multiple-choice, explanations, train, evaluation +- **Tags**: medical, clinical, single-turn, multiple-choice, explanations, evaluation ### Datasets - **Primary dataset(s)**: MedExQA @@ -21,74 +21,97 @@ ### Task - **Type**: single-turn -- **System Prompt**: Uses the authors' prompt from their evaluation code: +- **Prompting**: Uses the authors' instruction embedded in the user message; options A/B/C/D are included. ``` - "The following is a multiple-choice question. Please choose the most suitable one - among A, B, C and D as the answer to this question. Your answer should be paired - with an explanation why you chose that answer." + The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question. Your answer should be paired with an explanation why you chose that answer. ``` -- **Parser**: `Parser` or `ThinkParser`, with `extract_fn=extract_boxed_answer` for strict letter-in-\boxed{}-format parsing +- **Answer extraction [authors' logic](https://github.com/knowlab/MedExQA/blob/9a5b34af103b0c8ba0c00906e278f6572249fafa/evaluate_pipe_MedExQA.py)** : + - Canonical letter extraction using a sequence of regex patterns (e.g., explicit "Answer is A:", leading letter, etc.) + - If no explicit letter is found, fuzzy matching (thefuzz) maps the generated text to the closest option and returns the corresponding letter +- **Parser**: `Parser` or `ThinkParser` with `extract_fn=extract_boxed_answer` supported for think-mode; MCQ scoring uses the authors' extraction logic above. +- Run Evaluation per specialty or on multiple specialties +- Use lexical metrics('rougeL', 'bleu', 'bertscore', 'meteor') or use an LLM-as-a-judge for explanation evaluation - **Rubric overview**: - - MCQ-only mode: Binary scoring based on correctly boxed letter choice - - Full evaluation mode: Weighted combination of MCQ accuracy + explanation quality (using LLM-as-judge) + - MCQ accuracy: 0 or 100 per example + - Explanation score: 0–100 per example (lexical metrics average); 0 if the answer is wrong + - Combined score: weighted average of MCQ and explanation (`mcq_weight`, `explanation_weight`) +- **Model Download**: + In the first run it will download `wordnet`, `NLTK` and `sciBERT` models for running the lexical metrics ### Quickstart -Run MCQ-only evaluation (default): - +- Run MCQ-only (no explanation scoring): ```bash -uv run vf-eval medexqa -m gpt-4.1-mini +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"use_explanations": false}' ``` -Run with explanation evaluation: - +- Run with explanation scoring (lexical metrics): ```bash -export JUDGE_API_KEY=sk-... uv run vf-eval medexqa -m gpt-4.1-mini -a '{"use_explanations": true}' ``` -Configure model and sampling: +- Use LLM-as-judge for explanations (instead of lexical metrics): +```bash +export JUDGE_API_KEY=sk-... +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"use_explanations": true, "use_judge": true, "judge_model": "gpt-4o-mini"}' +``` +- Configure sampling and rollouts: ```bash uv run vf-eval medexqa \ - -m gpt-4.1-mini \ - -n -1 -r 3 -t 1024 -T 0.7 \ - -a '{"use_think": false, "use_explanations": false}' + -m gpt-4.1-mini \ + -n -1 -r 3 -t 1024 -T 0.7 \ + -a '{"use_think": false, "use_explanations": true, "mcq_weight": 0.5, "explanation_weight": 0.5}' ``` ### Environment Arguments -| Arg | Type | Default | Description | -| -------------------- | ----- | ------------- | ---------------------------------------------------------------------------------- | -| `use_think` | bool | `False` | Whether to check for `...` formatting with `ThinkParser` | -| `use_explanations` | bool | `False` | Whether to evaluate explanation quality using LLM-as-judge | -| `mcq_weight` | float | `0.5` | Weight for MCQ accuracy (only used when `use_explanations=True`) | -| `explanation_weight` | float | `0.5` | Weight for explanation quality (only used when `use_explanations=True`) | -| `judge_model` | str | `gpt-4o-mini` | Model to use for judging explanations | -| `judge_base_url` | str | `None` | Base URL for judge model API | -| `judge_api_key` | str | `None` | API key for judge (falls back to `JUDGE_API_KEY` or `OPENAI_API_KEY` env vars) | +| Arg | Type | Default | Description | +| ---------------------- | ---------------------- | -------------- | ----------- | +| `specialty` | list[str] \/ str \| None | `None` | Select one or more specialties. Codes: `BE`, `CLS`, `CP`, `OT`, `SLP`. `None`\/`ALL` loads all. | +| `use_think` | bool | `False` | Use `ThinkParser` to support `...` blocks. | +| `use_explanations` | bool | `True` | Whether to compute explanation scores. | +| `explanation_metrics` | list[str] \/ str \| None | `None` | Lexical metrics to use: any of `rougeL`, `bleu`, `meteor`, `bertscore`. `None`\/`"all"` averages all four. | +| `mcq_weight` | float | `0.5` | Weight for MCQ accuracy in the combined score. | +| `explanation_weight` | float | `0.5` | Weight for explanation in the combined score. | +| `use_judge` | bool | `False` | Use LLM-as-judge for explanations instead of lexical metrics. | +| `judge_model` | str | `gpt-4o-mini` | Judge model name. | +| `judge_base_url` | str \| None | `None` | Judge API base URL. | +| `judge_api_key` | str \| None | `None` | Judge API key (falls back to `JUDGE_API_KEY` or `OPENAI_API_KEY`). | +| `seed` | int \| None | `None` | When multiple specialties are selected, shuffles the combined eval set with this seed. | ### Metrics -**MCQ-Only Mode** (`use_explanations=False`): +- **Answer accuracy (per example)**: 0 or 100. Uses authors' regex+fuzzy logic to extract a letter. +- **Explanation score (per example)**: 0–100. If the answer is wrong, the explanation score is 0. + - Lexical metrics supported: `rougeL`, `bleu`, `meteor`, `bertscore` (w/ SciBERT `allenai/scibert_scivocab_uncased`). + - Selection via `explanation_metrics` (list or `'all'`/`None` to average all four). +- **Combined score**: `mcq_weight * accuracy + explanation_weight * explanation`. + +Optional LLM-as-judge for explanations: +- Set `use_explanations=true` and `use_judge=true` to replace lexical metrics with judge scoring (0–100 after scaling). +- Criteria include medical accuracy, relevance, clarity, completeness, and use of medical concepts. 0 if the answer from string matching is wrong. -| Metric | Weight | Meaning | -| ------ | ------ | ------- | -| `correct_answer_reward_func` | 1.0 | 1.0 if parsed letter is correct, else 0.0 | -| `parser.get_format_reward_func()` | 0.0 | Optional format adherence (not counted) | +### Specialty Selection and Macro Average + +- Single specialty by code: +```bash +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"specialty": "CLS"}' +``` -**Full Evaluation Mode** (`use_explanations=True`): +- Multiple specialties: +```bash +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"specialty": ["CLS", "CP"], "seed": 42}' +``` -| Metric | Weight (default) | Meaning | -| ------ | ---------------- | ------- | -| `correct_answer_reward_func` | 0.5 | 1.0 if parsed letter is correct, else 0.0 | -| `explanation_quality_reward` | 0.5 | 0.0-1.0 score from LLM judge comparing model's explanation to two reference explanations | +- All specialties: +```bash +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"specialty": "ALL"}' +``` -**Explanation Judge Criteria:** -- Medical accuracy -- Relevance to the question -- Clarity and completeness -- Proper use of medical concepts +## IMPORTANT: Macro-average accuracy (as reported in the paper): +- Run each specialty separately and average the per-run average answer accuracies; or +- Run multiple specialties with `-s` to save results. Each saved example includes its `specialty` in `info`, along with the `per-example answer_accuracy_reward`. Use the saved JSONL to compute per-specialty accuracies and then take the unweighted mean across specialties. ### Testing Instructions @@ -103,19 +126,19 @@ uv sync #### 2. Quick Validation Test (MCQ-only) ```bash -uv run vf-eval medexqa -m gpt-4.1-mini -n 5 +uv run vf-eval medexqa -m gpt-4.1-mini -n 5 -a '{"use_explanations": false}' ``` -#### 3. Full MCQ Evaluation +#### 3. Full Evaluation with Save ```bash export OPENAI_API_KEY=sk-... -uv run vf-eval medexqa -m gpt-4.1-mini -n -1 -s +uv run vf-eval medexqa -m gpt-4.1-mini -n -1 -s -a '{"specialty": "ALL", "use_explanations": true}' ``` -#### 4. With Explanation Evaluation +#### 4. LLM-as-Judge for Explanations ```bash export JUDGE_API_KEY=sk-... -uv run vf-eval medexqa -m gpt-4.1-mini -n -1 -a '{"use_explanations": true}' -s +uv run vf-eval medexqa -m gpt-4.1-mini -n -1 -s -a '{"use_explanations": true, "use_judge": true, "mcq_weight": 0.5, "explanation_weight": 0.5}' ``` #### 5. With Think Tags @@ -123,6 +146,21 @@ uv run vf-eval medexqa -m gpt-4.1-mini -n -1 -a '{"use_explanations": true}' -s uv run vf-eval medexqa -m gpt-4.1-mini -n -1 -a '{"use_think": true}' ``` +#### 6. Example Run with openrouter +```bash +export OPENROUTER_API_KEY=.... +uv run vf-eval medexqa -m openai/gpt-oss-20b:free -b https://openrouter.ai/api/v1 -k OPENAI_API_KEY -n 10 -r 1 -c 1 -a '{"use_explanations": true, "explanation_metrics": "all", "specialty": ["BE", "OT"]}' -s +``` +output +```bash +Rewards: +reward: avg - 59.416, std - 19.928 +r1: [67.79, 65.809, 64.158, 66.619, 69.124, 0.0, 66.957, 66.327, 66.87, 60.503] +answer_accuracy_reward: avg - 90.000, std - 30.000 +r1: [100.0, 100.0, 100.0, 100.0, 100.0, 0.0, 100.0, 100.0, 100.0, 100.0] +explanation_reward: avg - 28.832, std - 10.577 +r1: [35.58, 31.618, 28.316, 33.239, 38.249, 0.0, 33.915, 32.653, 33.741, 21.006] +``` ### Citation ```bibtex diff --git a/environments/medexqa/medexqa.py b/environments/medexqa/medexqa.py index 586b985f..0b37607f 100644 --- a/environments/medexqa/medexqa.py +++ b/environments/medexqa/medexqa.py @@ -19,27 +19,29 @@ "speech_pathologist", ] - - -AUTHOR_SYSTEM_PROMPT = ( - "The following is a multiple-choice question. Please choose the most suitable one " - "among A, B, C and D as the answer to this question. " - "Your answer should be paired with an explanation why you chose that answer." -) - - +# author prompt directly taken from https://github.com/knowlab/MedExQA/blob/9a5b34af103b0c8ba0c00906e278f6572249fafa/evaluate_pipe_MedExQA.py#L32 def _build_question_str(question: str, options: dict[str, str]) -> str: - """Format question with answer choices only; instruction is provided via system prompt.""" + """Build user prompt with authors' instruction embedded (as in their script). + + The instruction lives in the user message; the system prompt remains empty in + normal mode, and only adds THINK_BOXED in think-mode. + """ + instruction = ( + "The following is a multiple-choice question. Please choose the most suitable one " + "among A, B, C and D as the answer to this question. Your answer should be paired " + "with an explanation why you chose that answer.\n\n" + ) opts = "\n".join(f"{k}. {v}" for k, v in options.items()) - return f"{question}\n{opts}\nAnswer:" + return f"{instruction}{question}\n{opts}\nAnswer:" def _to_vf_format(ds: Dataset) -> Dataset: - """ - Shape each row for SingleTurnEnv's defaults: - - 'question': formatted question string with options - - 'answer': gold letter (A/B/C/D) - - 'info': keep all original fields including explanations + """Normalize raw rows into the fields expected by SingleTurnEnv. + + Produces rows of the form: + - question: string containing authors' instruction, question, and options + - answer: gold letter (A/B/C/D) + - info: original fields including exp0/exp1 and specialty """ def _format_row(row: dict) -> dict: question = row.get("question", "") or "" @@ -73,13 +75,11 @@ def _format_row(row: dict) -> dict: def load_environment( use_think: bool = False, - use_explanations: bool = False, + use_explanations: bool = True, mcq_weight: float = 0.5, explanation_weight: float = 0.5, - specialty: str = "all", - explanation_metrics: list[str] | None = None, - metrics_aggregation: str = "average", - macroaverage: bool = False, + specialty: list[str] | str | None = None, # list of short codes or full names; None/"ALL" => all + explanation_metrics: list[str] | str | None = None, # None/"all" => average of all four # Optional judge settings use_judge: bool = False, judge_model: str = "gpt-4o-mini", @@ -90,60 +90,75 @@ def load_environment( """ Single-turn MedExQA environment using HuggingFace `bluesky333/MedExQA` dataset - Each example is normalized to the fields expected by `vf.SingleTurnEnv`: - { - "question": "", # string used as the user prompt - "answer": "", # top-level gold letter - "info": { ...original example fields... } # full source row including exp0, exp1 - } - - - Loads all 5 medical specialties (biomedical engineering, clinical lab science, - clinical psychology, occupational therapy, speech language pathology) - - No training split (dataset does not provide one) - - Test split used as evaluation data (940 total examples) - - - Parser extracts \\boxed{A|B|C|D} from completions - - - Reward looks for exact match between parsed letter and answer letter - - Optional: Explanation quality evaluation using LLM-as-judge + Key behaviors: + - User prompt embeds the authors' instruction and the options (authors' format). + - System prompt: empty (normal) or THINK_BOXED (think mode). + - Specialty selection: accepts list or string; loads requested specialties (None/ALL => all). + - MCQ accuracy: authors' regex+fuzzy extraction; returns 0 or 100. + - Explanation score: lexical metrics (ROUGE-L, BLEU, METEOR, BERTScore) averaged 0–100; 0 if answer wrong. + - Optional judge mode: explanation scored by JudgeRubric (0–100). """ - # Load all specialties and concatenate + # Load specialties (one or more) # Note: MedExQA only has dev and test splits, no train split # Load TSV files directly since HF dataset has column name issues - test_datasets = [] - for specialty in SPECIALTIES: + # Resolve allowed specialties up-front and only load those files + code_map = { + "BE": "biomedical_engineer", + "CLS": "clinical_laboratory_scientist", + "CP": "clinical_psychologist", + "OT": "occupational_therapist", + "SLP": "speech_pathologist", + "ALL": "all", + } + allowed_names: set[str] + if specialty is None or (isinstance(specialty, str) and (specialty.upper() in ("ALL", ""))): + allowed_names = set(SPECIALTIES) + elif isinstance(specialty, str): + allowed_names = {code_map.get(specialty.upper(), specialty)} + else: + tmp = set() + for s in specialty: + name = code_map.get((s or "").upper(), s) + if name and name != "all": + tmp.add(name) + allowed_names = tmp if tmp else set(SPECIALTIES) + macro_active = len(allowed_names) > 1 + + # Load all requested specialties + test_datasets = [] + for sp_name in SPECIALTIES: + if sp_name not in allowed_names: + continue try: - # Download and load TSV file directly - url = f"https://huggingface.co/datasets/bluesky333/MedExQA/resolve/main/test/{specialty}_test.tsv" - - # Load TSV with pandas (no headers in file) + url = f"https://huggingface.co/datasets/bluesky333/MedExQA/resolve/main/test/{sp_name}_test.tsv" df = pd.read_csv( url, sep='\t', header=None, names=["question", "A", "B", "C", "D", "exp0", "exp1", "answer"] ) - - # Add specialty column - df['specialty'] = specialty - - # Convert to HF dataset - test_ds = Dataset.from_pandas(df, preserve_index=False) - test_datasets.append(test_ds) + df['specialty'] = sp_name + ds_part = Dataset.from_pandas(df, preserve_index=False) + test_datasets.append(ds_part) except Exception as e: - print(f"Warning: Could not load {specialty}: {e}") + print(f"Warning: Could not load {sp_name}: {e}") continue - # Concatenate all specialties + # Concatenate and format for verifiers - no training dataset available test_combined = concatenate_datasets(test_datasets) if test_datasets else None - - # Format for verifiers - no training dataset available test_ds = _to_vf_format(test_combined) if test_combined else None - # Setup system prompt - use authors' instruction in system; prepend think prompt if requested - system_prompt = f"{THINK_BOXED_SYSTEM_PROMPT}\n{AUTHOR_SYSTEM_PROMPT}" if use_think else AUTHOR_SYSTEM_PROMPT + # Shuffle examples if multiple specialties were selected + if macro_active and test_ds is not None: + try: + test_ds = test_ds.shuffle(seed=int(kwargs.get("seed", 0))) + except Exception: + pass + + # Setup system prompt - empty for normal; use think-boxed for think mode + system_prompt = THINK_BOXED_SYSTEM_PROMPT if use_think else "" # Parser for extracting \\boxed{} answers parser = ( @@ -156,20 +171,7 @@ def correct_answer_reward_func(parser, completion, answer, **kwargs) -> float: response = parser.parse_answer(completion) or "" return 1.0 if response == answer else 0.0 - # Optional specialty filter (short codes supported) - if specialty and test_ds is not None: - code_map = { - "BE": "biomedical_engineer", - "CLS": "clinical_laboratory_scientist", - "CP": "clinical_psychologist", - "OT": "occupational_therapist", - "SLP": "speech_pathologist", - "ALL": "all", - } - spec_upper = (specialty or "all").upper() - resolved = code_map.get(spec_upper, specialty) - if resolved != "all": - test_ds = test_ds.filter(lambda row: (row.get("info") or {}).get("specialty") == resolved) + # (shuffling handled above when multiple specialties) # Helpers (authors' answer extraction logic) def process_before_extraction(gen: str, choice_dict: dict[str, str]) -> str: @@ -197,7 +199,7 @@ def extract_answer_letter(completion_text: str, options: dict[str, str]) -> str: pred = extract_choice(gen, [options.get(c, "") for c in ["A", "B", "C", "D"]]) return (pred or "").upper() - # Metrics selection; 'all'/'overall' => average of all four + # Lexical Metrics selection; pass individually or None/'all'/'overall' => average of all four base_metrics = ["rougeL", "bleu", "meteor", "bertscore"] if explanation_metrics is None: selected_metrics = base_metrics @@ -250,37 +252,35 @@ def compute_expl_score(pred: str, exp0: str, exp1: str) -> float: # always average across selected metrics return (sum(metric_vals) / len(metric_vals)) - # Precompute specialty counts for macroaverage weighting (if requested) - spec_counts: dict[str, int] = {} - total_examples = 0 - if test_ds is not None: - for row in test_ds: - info_row = row.get("info") or {} - spec = info_row.get("specialty") or "unknown" - spec_counts[spec] = spec_counts.get(spec, 0) + 1 - total_examples += 1 - num_specs = len(spec_counts) if spec_counts else 1 - - def _macro_scale(spec: str) -> float: - if not macroaverage: - return 1.0 - if spec_counts and total_examples and num_specs: - n_k = spec_counts.get(spec, 1) - return float(total_examples) / float(num_specs * n_k) - return 1.0 + # Note: No per-example macro scaling. + + def _get_completion_text(completion_obj) -> str: + if isinstance(completion_obj, list) and completion_obj: + return completion_obj[-1].get("content", "") or "" + return completion_obj if isinstance(completion_obj, str) else str(completion_obj) def answer_accuracy_reward(parser, completion, answer, **kwargs) -> float: - completion_text = completion if isinstance(completion, str) else str(completion) + completion_text = _get_completion_text(completion) info = kwargs.get("info", {}) or {} + state = kwargs.get("state") options = {"A": info.get("A", ""), "B": info.get("B", ""), "C": info.get("C", ""), "D": info.get("D", "")} gold = (answer or "").strip().upper() pred_letter = extract_answer_letter(completion_text, options) - base = 1.0 if pred_letter == gold else 0.0 - spec = (info.get("specialty") or "unknown") - return base * _macro_scale(spec) + base = 100.0 if pred_letter == gold else 0.0 + # Persist per-specialty counters into state so runs saved with -s can be summarized post-hoc + if isinstance(state, dict): + spec = (info.get("specialty") or "unknown") + counters = state.get("specialty_counters") or {} + curr = counters.get(spec) or {"correct": 0, "total": 0} + curr["total"] = int(curr.get("total", 0)) + 1 + if pred_letter == gold: + curr["correct"] = int(curr.get("correct", 0)) + 1 + counters[spec] = curr + state["specialty_counters"] = counters + return base def explanation_reward(parser, completion, answer, **kwargs) -> float: - completion_text = completion if isinstance(completion, str) else str(completion) + completion_text = _get_completion_text(completion) info = kwargs.get("info", {}) or {} options = {"A": info.get("A", ""), "B": info.get("B", ""), "C": info.get("C", ""), "D": info.get("D", "")} gold = (answer or "").strip().upper() @@ -289,8 +289,7 @@ def explanation_reward(parser, completion, answer, **kwargs) -> float: base = 0.0 else: base = compute_expl_score(completion_text, info.get("exp0", ""), info.get("exp1", "")) - spec = (info.get("specialty") or "unknown") - return base * _macro_scale(spec) + return base # Optional: Use LLM-as-judge for explanation instead of lexical metrics if use_explanations and use_judge: @@ -303,7 +302,7 @@ def explanation_reward(parser, completion, answer, **kwargs) -> float: ) async def explanation_judge_reward(judge, prompt, completion, answer, state, **kwargs) -> float: - completion_text = completion if isinstance(completion, str) else str(completion) + completion_text = _get_completion_text(completion) info = kwargs.get("info", {}) or {} options = {"A": info.get("A", ""), "B": info.get("B", ""), "C": info.get("C", ""), "D": info.get("D", "")} gold = (answer or "").strip().upper() @@ -320,10 +319,22 @@ async def explanation_judge_reward(judge, prompt, completion, answer, state, **k full_prompt = ( "You are evaluating the quality of a medical explanation.\n\n" "**Question:**\n" + formatted_question + "\n\n" - "**Correct Answer:** " + str(gold) + "\n\n" + "**Correct Answer:** " + str(answer) + "\n\n" "**Reference Explanation 1:**\n" + str(exp0) + "\n\n" "**Reference Explanation 2:**\n" + str(exp1) + "\n\n" "**Model's Response:**\n" + completion_text + "\n\n" + "Evaluate whether the model's explanation is medically accurate, relevant, and demonstrates understanding of the medical concepts. The explanation should justify why the answer is correct.\n\n" + "Compare the model's explanation quality to the reference explanations. Consider:\n" + "- Medical accuracy\n" + "- Relevance to the question\n" + "- Clarity and completeness\n" + "- Proper use of medical concepts\n\n" + "Respond with a score from 0.0 to 1.0:\n" + "- 1.0 = Excellent (as good as or better than references)\n" + "- 0.75 = Good (mostly correct with minor issues)\n" + "- 0.5 = Acceptable (partially correct)\n" + "- 0.25 = Poor (significant errors)\n" + "- 0.0 = Wrong or irrelevant\n\n" "Respond with ONLY a number between 0.0 and 1.0." ) judge_response = await judge_rubric.judge( @@ -336,23 +347,22 @@ async def explanation_judge_reward(judge, prompt, completion, answer, state, **k try: score_str = str(judge_response).strip() import re as _re - m = _re.search(r"(\d+\.?\d*)", score_str) - s = float(m.group(1)) if m else 0.0 + number_match = _re.search(r"(\d+\.?\d*)", score_str) + explanation_score = float(number_match.group(1)) if number_match else 0.0 except Exception: - s = 0.0 - base = max(0.0, min(1.0, s)) * 100.0 - spec = (info.get("specialty") or "unknown") - return base * _macro_scale(spec) + explanation_score = 0.0 + base = max(0.0, min(1.0, explanation_score)) * 100.0 + return base # Use JudgeRubric with two metrics: answer accuracy (sync), explanation judge (async) - judge_rubric.add_reward_func(answer_accuracy_reward, weight=0.0) - judge_rubric.add_reward_func(explanation_judge_reward, weight=0.0) + judge_rubric.add_reward_func(answer_accuracy_reward, weight=mcq_weight) + judge_rubric.add_reward_func(explanation_judge_reward, weight=explanation_weight) rubric = judge_rubric else: - # Keep metrics separate (no combined reward) - rubric = vf.Rubric(funcs=[answer_accuracy_reward, explanation_reward], weights=[0.0, 0.0], parser=parser) + # Keep metrics separate (and a combine drewad with tunable weights) + rubric = vf.Rubric(funcs=[answer_accuracy_reward, explanation_reward], weights=[mcq_weight, explanation_weight], parser=parser) - return vf.SingleTurnEnv( + env = vf.SingleTurnEnv( dataset=None, # No training split available eval_dataset=test_ds, system_prompt=system_prompt, @@ -360,3 +370,5 @@ async def explanation_judge_reward(judge, prompt, completion, answer, state, **k rubric=rubric, **kwargs ) + + return env From f5375564b7574de095e9a6862eba7bf96845b6e5 Mon Sep 17 00:00:00 2001 From: mnishant2 Date: Thu, 16 Oct 2025 10:58:03 +0200 Subject: [PATCH 4/9] cleaned up code --- environments/medexqa/medexqa.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/environments/medexqa/medexqa.py b/environments/medexqa/medexqa.py index 0b37607f..69d72a93 100644 --- a/environments/medexqa/medexqa.py +++ b/environments/medexqa/medexqa.py @@ -262,21 +262,10 @@ def _get_completion_text(completion_obj) -> str: def answer_accuracy_reward(parser, completion, answer, **kwargs) -> float: completion_text = _get_completion_text(completion) info = kwargs.get("info", {}) or {} - state = kwargs.get("state") options = {"A": info.get("A", ""), "B": info.get("B", ""), "C": info.get("C", ""), "D": info.get("D", "")} gold = (answer or "").strip().upper() pred_letter = extract_answer_letter(completion_text, options) base = 100.0 if pred_letter == gold else 0.0 - # Persist per-specialty counters into state so runs saved with -s can be summarized post-hoc - if isinstance(state, dict): - spec = (info.get("specialty") or "unknown") - counters = state.get("specialty_counters") or {} - curr = counters.get(spec) or {"correct": 0, "total": 0} - curr["total"] = int(curr.get("total", 0)) + 1 - if pred_letter == gold: - curr["correct"] = int(curr.get("correct", 0)) + 1 - counters[spec] = curr - state["specialty_counters"] = counters return base def explanation_reward(parser, completion, answer, **kwargs) -> float: From 1bf188ed0bbf028208b4b0cd90b8c4c5985eea30 Mon Sep 17 00:00:00 2001 From: mnishant2 Date: Mon, 3 Nov 2025 17:52:48 +0100 Subject: [PATCH 5/9] Add FactScore and G-Eval judges for explanation evaluation --- environments/medexqa/README.md | 74 +++- .../factscore_judge/atomic_facts_generator.py | 211 +++++++++++ .../factscore_judge/atomic_facts_judge.py | 259 +++++++++++++ .../medexqa/geval_judge/geval_judge.py | 200 ++++++++++ environments/medexqa/medexqa.py | 83 +---- environments/medexqa/tools/judge_rescore.py | 350 ++++++++++++++++++ 6 files changed, 1112 insertions(+), 65 deletions(-) create mode 100644 environments/medexqa/factscore_judge/atomic_facts_generator.py create mode 100644 environments/medexqa/factscore_judge/atomic_facts_judge.py create mode 100644 environments/medexqa/geval_judge/geval_judge.py create mode 100644 environments/medexqa/tools/judge_rescore.py diff --git a/environments/medexqa/README.md b/environments/medexqa/README.md index 5ad8f8ad..802dbfe2 100644 --- a/environments/medexqa/README.md +++ b/environments/medexqa/README.md @@ -75,9 +75,11 @@ uv run vf-eval medexqa \ | `mcq_weight` | float | `0.5` | Weight for MCQ accuracy in the combined score. | | `explanation_weight` | float | `0.5` | Weight for explanation in the combined score. | | `use_judge` | bool | `False` | Use LLM-as-judge for explanations instead of lexical metrics. | +| `judge_mode` | str \| None | `None` | Judge mode: `"g-eval"` or `"factscore"`. Required when `use_judge=True`. | | `judge_model` | str | `gpt-4o-mini` | Judge model name. | | `judge_base_url` | str \| None | `None` | Judge API base URL. | | `judge_api_key` | str \| None | `None` | Judge API key (falls back to `JUDGE_API_KEY` or `OPENAI_API_KEY`). | +| `use_coverage` | bool | `False` | For FactScore only: enable coverage calculation (measures recall). Default is support-only (precision). Increases API calls from ~6-8 to ~12-15 per example. | | `seed` | int \| None | `None` | When multiple specialties are selected, shuffles the combined eval set with this seed. | ### Metrics @@ -90,7 +92,77 @@ uv run vf-eval medexqa \ Optional LLM-as-judge for explanations: - Set `use_explanations=true` and `use_judge=true` to replace lexical metrics with judge scoring (0–100 after scaling). -- Criteria include medical accuracy, relevance, clarity, completeness, and use of medical concepts. 0 if the answer from string matching is wrong. +- Must specify `judge_mode` when using LLM-as-judge. Two modes are supported: + +#### Judge Modes + +**G-Eval Mode** (`judge_mode="g-eval"`): +- Uses Chain-of-Thought evaluation with 7 structured steps to assess explanation quality. +- Evaluates medical accuracy, correct option justification, distractor analysis, reference alignment, reasoning clarity, and completeness. +- Outputs structured JSON with detailed step-by-step analysis and a final score (0-100). + +**FactScore Mode** (`judge_mode="factscore"`): +- Decomposes explanations into atomic medical claims (5-7 key claims) using specialty-specific few-shot examples. +- Verifies each claim against reference explanations using 3-level support scoring (FULLY_SUPPORTED=1.0, PARTIALLY_SUPPORTED=0.5, NOT_SUPPORTED=0.0). +- Optionally computes coverage to measure if model covers key concepts from references (disabled by default; enable with `use_coverage=True` for comprehensive precision+recall evaluation). + +Usage examples: +```bash +# G-Eval mode +export JUDGE_API_KEY=sk-... +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"use_explanations": true, "use_judge": true, "judge_mode": "g-eval", "judge_model": "gpt-4o-mini"}' + +# FactScore mode (support-only, fast) +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"use_explanations": true, "use_judge": true, "judge_mode": "factscore", "judge_model": "gpt-4o-mini"}' + +# FactScore mode with coverage (support+coverage, comprehensive) +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"use_explanations": true, "use_judge": true, "judge_mode": "factscore", "judge_model": "gpt-4o-mini", "use_coverage": true}' +``` + +#### Judge Rescoring Tool + +The `judge_rescore.py` script allows you to re-evaluate previously generated model completions with LLM judges and save detailed results to CSV files. This is useful for experimenting with different judge configurations without re-running model inference. + +**Usage:** +```bash +# Rescore with FactScore (fast mode: support-only, ~6-8 API calls per example) +export OPENAI_API_KEY=sk-... +uv run python environments/medexqa/tools/judge_rescore.py \ + --judge factscore \ + --base https://openrouter.ai/api/v1 \ + --model openai/gpt-4o-mini \ + --key_var OPENAI_API_KEY \ + --input_glob 'environments/medexqa/outputs/evals/**/results.jsonl' \ + --out_csv_prefix 'environments/medexqa/outputs/judge_scores/medexqa_' \ + --sleep_ms 2000 \ + --max_retries 6 \ + --max_tokens 512 \ + --verbose + +# Rescore with FactScore + coverage (comprehensive: support+coverage, ~12-15 API calls per example) +uv run python environments/medexqa/tools/judge_rescore.py \ + --judge factscore \ + --use_coverage \ + --model openai/gpt-4o-mini \ + ... # other args same as above + +# Rescore with G-Eval (1 API call per example) +uv run python environments/medexqa/tools/judge_rescore.py \ + --judge geval \ + --model openai/gpt-4o-mini \ + --max_tokens 768 \ + ... # other args same as above + +# Rescore with both judges +uv run python environments/medexqa/tools/judge_rescore.py \ + --judge both \ + --model openai/gpt-4o-mini \ + ... # other args same as above +``` + +**Output:** Creates CSV files with detailed judge outputs: +- `medexqa_factscore.csv`: Contains extracted claims, support/coverage rates, and final scores +- `medexqa_geval.csv`: Contains structured JSON evaluation details and scores ### Specialty Selection and Macro Average diff --git a/environments/medexqa/factscore_judge/atomic_facts_generator.py b/environments/medexqa/factscore_judge/atomic_facts_generator.py new file mode 100644 index 00000000..1b11d987 --- /dev/null +++ b/environments/medexqa/factscore_judge/atomic_facts_generator.py @@ -0,0 +1,211 @@ +import json +from typing import List + +from openai import AsyncOpenAI + + +class AtomicFactGenerator: + """ + MedExQA-specific atomic facts generator. + + Extracts concise, checkable medical claims from an MCQA explanation that + support the chosen option and, when useful, refute key distractors. + Returns a Python list of strings (facts), not raw model text. + """ + + def __init__(self, async_openai_client: AsyncOpenAI | None, model_name: str = "gpt-4o-mini") -> None: + self.client = async_openai_client + self.model_name = model_name + + async def run(self, explanation_text: str) -> List[str]: + """ + Extract atomic facts from an MCQA explanation. + """ + explanation = (explanation_text or "").strip() + if not explanation: + return [] + + primary = await self._extract_json_claims(explanation) + if primary: + return primary + + fallback = await self._extract_json_claims(explanation, fallback=True) + return fallback or [] + + async def _extract_json_claims(self, explanation: str, fallback: bool = False) -> List[str]: + if self.client is None: + return [] + + if not fallback: + prompt = ( + "You are a medical expert evaluating MCQA (multiple-choice question) explanations.\n" + "Extract atomic, checkable medical claims that: (1) justify why the correct option is right, " + "(2) when applicable, explain why key distractors are wrong, (3) preserve medical terminology.\n\n" + "Rules:\n" + "- Output a strict JSON array of strings ONLY (no extra text).\n" + "- Extract 5-7 MOST IMPORTANT claims (prioritize key medical concepts).\n" + "- Each claim ≤ 30 words; no duplicates; no vague statements.\n" + "- Preserve technical terms and abbreviations (e.g., 'DEXA', 'PTFE', 'AAC').\n" + "- If no checkable medical content, return [].\n\n" + "Few-shot examples (imitate format exactly):\n\n" + "# Biomedical Engineering Example 1:\n" + "Explanation: Membrane oxygenators require materials with high gas permeability for O2 and CO2 exchange. " + "Silicone rubber, polypropylene, and Teflon are highly permeable polymers. Ceramic membranes are dense, " + "brittle, have poor gas permeability, and can cause hemolysis.\n" + "Claims JSON: [\n" + " \"Membrane oxygenators require high gas permeability for O2 and CO2 exchange.\",\n" + " \"Silicone rubber has excellent gas permeability and biocompatibility.\",\n" + " \"Polypropylene provides high gas transfer and is durable.\",\n" + " \"Teflon (PTFE) is chemically inert with good blood contact properties.\",\n" + " \"Ceramic membranes have poor gas permeability compared to polymers.\",\n" + " \"Ceramic membranes are brittle and can cause hemolysis.\"\n" + "]\n\n" + "# Biomedical Engineering Example 2:\n" + "Explanation: Thermographic cameras detect infrared radiation emitted by objects due to temperature. " + "All objects above absolute zero emit infrared radiation. X-rays and UV are higher-energy and not used for thermal imaging. " + "Microwaves are used for radar, not temperature scanning.\n" + "Claims JSON: [\n" + " \"Thermographic cameras detect infrared radiation from objects.\",\n" + " \"All objects above absolute zero emit infrared radiation.\",\n" + " \"Infrared is ideal for measuring surface temperatures.\",\n" + " \"X-rays are too high-energy for conventional thermal imaging.\",\n" + " \"Microwaves are used for radar applications, not thermal cameras.\"\n" + "]\n\n" + "# Clinical Laboratory Science Example 1:\n" + "Explanation: Hemoglobin A1c measures average blood glucose over 2-3 months by detecting glycated hemoglobin. " + "Fasting glucose only reflects current levels. Random glucose varies throughout the day. Oral glucose tolerance test is diagnostic but not for monitoring.\n" + "Claims JSON: [\n" + " \"Hemoglobin A1c measures average blood glucose over 2-3 months.\",\n" + " \"A1c detects glycated hemoglobin formed by glucose binding.\",\n" + " \"Fasting glucose only reflects current blood glucose levels.\",\n" + " \"Random glucose varies throughout the day and is unreliable for averages.\",\n" + " \"OGTT is diagnostic but not suitable for long-term monitoring.\"\n" + "]\n\n" + "# Clinical Laboratory Science Example 2:\n" + "Explanation: Gram staining differentiates bacteria by cell wall structure. Gram-positive bacteria have thick peptidoglycan walls " + "that retain crystal violet stain. Gram-negative bacteria have thin peptidoglycan and outer membranes, appearing pink after counterstaining.\n" + "Claims JSON: [\n" + " \"Gram staining differentiates bacteria by cell wall structure.\",\n" + " \"Gram-positive bacteria have thick peptidoglycan cell walls.\",\n" + " \"Thick peptidoglycan retains crystal violet stain in Gram-positive bacteria.\",\n" + " \"Gram-negative bacteria have thin peptidoglycan and outer membranes.\",\n" + " \"Gram-negative bacteria appear pink after safranin counterstaining.\"\n" + "]\n\n" + "# Clinical Psychology Example 1:\n" + "Explanation: Cognitive-behavioral therapy (CBT) is first-line for generalized anxiety disorder, with strong evidence for efficacy. " + "Psychodynamic therapy lacks robust evidence for GAD. Exposure therapy is specific to phobias. Supportive therapy alone is insufficient for GAD.\n" + "Claims JSON: [\n" + " \"CBT is first-line treatment for generalized anxiety disorder.\",\n" + " \"CBT has strong evidence for efficacy in treating GAD.\",\n" + " \"Psychodynamic therapy lacks robust evidence for GAD treatment.\",\n" + " \"Exposure therapy is specific to phobias, not GAD.\",\n" + " \"Supportive therapy alone is insufficient for GAD management.\"\n" + "]\n\n" + "# Clinical Psychology Example 2:\n" + "Explanation: The PHQ-9 is a validated 9-item screening tool for major depressive disorder with scores 0-27. " + "Scores ≥10 indicate moderate depression requiring clinical evaluation. It assesses DSM-5 criteria for MDD.\n" + "Claims JSON: [\n" + " \"PHQ-9 is a validated screening tool for major depressive disorder.\",\n" + " \"PHQ-9 contains 9 items with total scores ranging 0-27.\",\n" + " \"Scores ≥10 indicate moderate depression needing evaluation.\",\n" + " \"PHQ-9 assesses DSM-5 diagnostic criteria for MDD.\"\n" + "]\n\n" + "# Occupational Therapy Example 1:\n" + "Explanation: The Barthel Index measures independence in activities of daily living (ADL) across 10 domains. " + "Scores range 0-100, with higher scores indicating greater independence. It's reliable for tracking functional recovery post-stroke.\n" + "Claims JSON: [\n" + " \"Barthel Index measures independence in activities of daily living.\",\n" + " \"The index assesses 10 functional domains.\",\n" + " \"Scores range from 0 (dependent) to 100 (independent).\",\n" + " \"Higher Barthel Index scores indicate greater functional independence.\",\n" + " \"Barthel Index is reliable for tracking post-stroke recovery.\"\n" + "]\n\n" + "# Occupational Therapy Example 2:\n" + "Explanation: Adaptive utensils with built-up handles improve grip for patients with arthritis by reducing required pinch force. " + "Weighted utensils help tremor patients. Angled utensils assist those with limited wrist mobility. Standard utensils lack these modifications.\n" + "Claims JSON: [\n" + " \"Built-up handle utensils improve grip for arthritis patients.\",\n" + " \"Built-up handles reduce required pinch force during eating.\",\n" + " \"Weighted utensils help stabilize tremors during eating.\",\n" + " \"Angled utensils assist patients with limited wrist mobility.\",\n" + " \"Standard utensils lack these adaptive modifications.\"\n" + "]\n\n" + "# Speech Pathology Example 1:\n" + "Explanation: Videofluoroscopic swallow study (VFSS) is the gold standard for dysphagia evaluation, visualizing all swallowing phases. " + "It detects aspiration, penetration, and pharyngeal residue. Clinical swallow exam cannot visualize aspiration. Endoscopy misses oral phase.\n" + "Claims JSON: [\n" + " \"VFSS is the gold standard for dysphagia evaluation.\",\n" + " \"VFSS visualizes all phases of swallowing in real-time.\",\n" + " \"VFSS can detect aspiration, penetration, and pharyngeal residue.\",\n" + " \"Clinical swallow examination cannot visualize aspiration.\",\n" + " \"Fiberoptic endoscopic evaluation misses the oral phase of swallowing.\"\n" + "]\n\n" + "# Speech Pathology Example 2:\n" + "Explanation: The Peabody Picture Vocabulary Test (PPVT) assesses receptive vocabulary in children and adults. " + "It requires pointing to pictures, not verbal responses, making it suitable for nonverbal individuals. Expressive language tests require speech production.\n" + "Claims JSON: [\n" + " \"PPVT assesses receptive vocabulary in children and adults.\",\n" + " \"PPVT requires pointing to pictures, not verbal responses.\",\n" + " \"PPVT is suitable for assessing nonverbal individuals.\",\n" + " \"Expressive language tests require speech production.\"\n" + "]\n\n" + "Now extract claims for the explanation below.\n\n" + f"Explanation:\n{explanation}\n\n" + "Claims JSON:" + ) + else: + prompt = ( + "Extract atomic, checkable medical claims from this MCQA explanation.\n" + "Return ONLY a JSON array of 4–10 strings; each ≤ 30 words. If none, return [].\n\n" + f"Explanation:\n{explanation}\n\n" + "Claims JSON:" + ) + + try: + resp = await self.client.chat.completions.create( + model=self.model_name, + messages=[{"role": "user", "content": prompt}], + temperature=0.0, + max_tokens=512, + ) + text = (resp.choices[0].message.content or "").strip() + return _parse_json_list(text) + except Exception: + return [] + + +def _parse_json_list(text: str) -> List[str]: + try: + data = json.loads(text) + if isinstance(data, list): + out = [] + for x in data: + s = (str(x) or "").strip() + if s: + out.append(s) + # keep unique order + seen = set() + uniq = [] + for s in out: + if s not in seen: + uniq.append(s) + seen.add(s) + return uniq + return [] + except Exception: + # fallback: find bracketed content + try: + start = text.find("[") + end = text.rfind("]") + if start != -1 and end != -1 and end > start: + data = json.loads(text[start : end + 1]) + if isinstance(data, list): + return [str(x).strip() for x in data if str(x).strip()] + except Exception: + pass + return [] + + + + + diff --git a/environments/medexqa/factscore_judge/atomic_facts_judge.py b/environments/medexqa/factscore_judge/atomic_facts_judge.py new file mode 100644 index 00000000..e56f4e69 --- /dev/null +++ b/environments/medexqa/factscore_judge/atomic_facts_judge.py @@ -0,0 +1,259 @@ +""" +FactScore-style judge for MedExQA explanations (reference-only, no external retrieval). + +Two-step process: +1) Extract atomic medical claims from the model's explanation. +2) Verify each claim against available references: question, correct option text, exp0, exp1. + +Returns support rate in [0, 1], scaled to [0, 100] for reward. +""" + +import json +import re +import verifiers as vf +from .atomic_facts_generator import AtomicFactGenerator + + +JUDGE_TEMPLATE = """You are a medical knowledge verification expert. Evaluate if the Passage supports the Claim. + +CONTEXT: +Question: {question} +Correct Answer: ({correct_letter}) {correct_option} + +Reference Explanations: +- Reference 1: {exp0} +- Reference 2: {exp1} + +PASSAGE (Combined References): +{response} + +CLAIM TO VERIFY: +{answer} + +INSTRUCTIONS: +1. Check if the claim is FULLY supported by the passage with explicit evidence +2. Check if the claim is PARTIALLY supported (implied/inferable but not explicit) +3. Check if the claim is NOT supported (no evidence or contradicts passage) + +Respond with EXACTLY ONE of: +- "FULLY_SUPPORTED" - explicit evidence exists in passage +- "PARTIALLY_SUPPORTED" - implied/inferable from passage +- "NOT_SUPPORTED" - no evidence or contradicts passage + +Your response:""".strip() + + +def extract_support_level(text: str) -> tuple[float, bool]: + """ + Extract support level from LLM judge response. + + Returns: + (score, valid): score is 0.0, 0.5, or 1.0; valid indicates if parsing succeeded + """ + cleaned_text = (text or "").strip().upper() + + # Check for 3-level responses + if "FULLY_SUPPORTED" in cleaned_text or "FULLY SUPPORTED" in cleaned_text: + return (1.0, True) + if "PARTIALLY_SUPPORTED" in cleaned_text or "PARTIALLY SUPPORTED" in cleaned_text: + return (0.5, True) + if "NOT_SUPPORTED" in cleaned_text or "NOT SUPPORTED" in cleaned_text: + return (0.0, True) + + # Fallback to old binary format for backwards compatibility + cleaned_lower = cleaned_text.lower() + has_true = "true" in cleaned_lower + has_false = "false" in cleaned_lower + if has_true and not has_false: + return (1.0, True) + if has_false and not has_true: + return (0.0, True) + + # Ambiguous response + return (0.0, False) + + +async def explanation_factscore_reward( + judge, + prompt, + completion, + answer, + state, + **kwargs, +) -> float: + # parse explanation text + if isinstance(completion, list) and completion: + explanation = completion[-1].get("content", "") or "" + else: + explanation = str(completion) + + info = kwargs.get("info", {}) or {} + options = {k: info.get(k, "") for k in ["A", "B", "C", "D"]} + question = info.get("question", "") + exp0 = info.get("exp0", "") + exp1 = info.get("exp1", "") + correct_letter = (answer or "").strip().upper() + correct_option_text = options.get(correct_letter, "") + + # Gate explanation to zero if predicted MCQ answer is wrong + try: + m = re.search(r"(? 0 else 0.0 + + # Step 2b: Coverage rate - DISABLED by default for speed + # This measures recall: does the model explanation cover key reference concepts? + # Enable with use_coverage=True in kwargs for balanced precision+recall evaluation + use_coverage = kwargs.get("use_coverage", False) + coverage_rate = 0.0 + + if use_coverage and llm_client is not None: + # Extract claims from both reference explanations + all_ref_claims: list[str] = [] + + # Extract from reference 1 + if (exp0 or "").strip(): + try: + ref0_claims = await generator.run(exp0) + all_ref_claims.extend(ref0_claims) + except Exception: + pass + + # Extract from reference 2 + if (exp1 or "").strip(): + try: + ref1_claims = await generator.run(exp1) + all_ref_claims.extend(ref1_claims) + except Exception: + pass + + # Remove duplicates while preserving order + seen = set() + unique_ref_claims = [] + for claim in all_ref_claims: + if claim not in seen: + unique_ref_claims.append(claim) + seen.add(claim) + + # Verify each reference claim against model explanation + coverage_score = 0.0 + coverage_total = 0 + + for ref_claim in unique_ref_claims: + coverage_total += 1 + # Check if model explanation supports this reference claim + prompt_msg = JUDGE_TEMPLATE.format( + response=explanation, # Passage = model explanation + answer=str(ref_claim), # Claim = reference claim + question=question, + correct_letter=correct_letter, + correct_option=correct_option_text, + exp0=exp0, + exp1=exp1 + ) + cov_response = await judge([ + {"role": "user", "content": prompt_msg} + ], "", "", state, **kwargs) + score, ok = extract_support_level(cov_response) + if ok: + coverage_score += score + + coverage_rate = (coverage_score / coverage_total) if coverage_total > 0 else 0.0 + + # Combine support and coverage (if enabled) + if use_coverage: + # Use weighted combination when coverage is enabled + w_support = float(kwargs.get("support_weight", 0.5)) + w_coverage = float(kwargs.get("coverage_weight", 0.5)) + denom = w_support + w_coverage if (w_support + w_coverage) > 0 else 1.0 + final = (w_support * support_rate + w_coverage * coverage_rate) / denom + else: + # Coverage disabled: use support_rate only + final = support_rate + + # Optionally stash structured details for external loggers (if passed in kwargs) + # Caller can access via state or judge logs; for rescore tool we return these via logs reconstruction + state = state or {} + try: + state["factscore_details"] = { + "support_rate": float(support_rate), + "coverage_rate": float(coverage_rate), + } + except Exception: + pass + + return float(final * 100.0) + + +def create_factscore_judge_rubric( + parser: vf.Parser, + judge_client, + judge_model: str = "gpt-4o-mini", + use_coverage: bool = False, +) -> vf.JudgeRubric: + rubric = vf.JudgeRubric( + judge_client=judge_client, + judge_model=judge_model, + judge_prompt=JUDGE_TEMPLATE, + parser=parser, + use_coverage=use_coverage, # Pass through to reward function via kwargs + ) + rubric.add_reward_func(explanation_factscore_reward, weight=1.0) + return rubric + + diff --git a/environments/medexqa/geval_judge/geval_judge.py b/environments/medexqa/geval_judge/geval_judge.py new file mode 100644 index 00000000..7df46956 --- /dev/null +++ b/environments/medexqa/geval_judge/geval_judge.py @@ -0,0 +1,200 @@ +import re +import json +import verifiers as vf + + +GEVAL_CRITERIA = """You are evaluating a medical MCQA explanation for quality and correctness. + +Assess the explanation across these dimensions: + +1. MEDICAL ACCURACY: Does the explanation contain factually correct medical information that aligns with the correct option and reference explanations? Are there any medical errors or contradictions? + +2. CORRECT OPTION JUSTIFICATION: Does the explanation clearly explain WHY the correct answer is medically appropriate using valid clinical/scientific reasoning? + +3. DISTRACTOR ANALYSIS (when applicable): Does the explanation explain why incorrect options are wrong? Note: Not all explanations need this, but it enhances quality when present. + +4. REFERENCE ALIGNMENT: Do the explanation's core medical claims align with the key concepts in both reference explanations? + +5. REASONING CLARITY: Is the medical reasoning easy to follow with logical flow from evidence to conclusion? + +6. COMPLETENESS: Does the explanation cover the essential medical concepts without major omissions?""" + + +GEVAL_EVALUATION_STEPS = [ + "Extract all medical claims from the actual explanation and list them explicitly", + "Compare each claim against the correct option text and both reference explanations. Mark claims as: ALIGNED (matches references), CONTRADICTS (conflicts with references), or NEW_INFO (additional but not contradictory)", + "Identify if the explanation justifies WHY the correct option is right (not just states it is correct)", + "Check for any major medical errors, inaccuracies, or unsupported claims that could mislead", + "Assess whether distractor refutation is present and accurate (if applicable to this question)", + "Evaluate overall reasoning clarity, logical flow, and completeness", + "Synthesize findings into a score using this rubric: 0.0-0.2 (major errors/irrelevant/contradicts references), 0.2-0.4 (significant gaps/multiple minor errors), 0.4-0.6 (acceptable but incomplete/some inaccuracies), 0.6-0.8 (good quality with minor issues), 0.8-1.0 (excellent: comprehensive, accurate, well-reasoned)" +] + + +GEVAL_PROMPT_TEMPLATE = """You are a strict medical explanation evaluator following a structured evaluation process. + +CRITERIA: +{criteria} + +EVALUATION STEPS (follow these in order): +{evaluation_steps} + +OUTPUT FORMAT: +Respond with a JSON object containing your step-by-step analysis and final score. Use this exact structure: +{{ + "step1_claims_extracted": ["claim1", "claim2", ...], + "step2_alignment_analysis": {{ + "aligned_claims": [...], + "contradicting_claims": [...], + "new_info_claims": [...] + }}, + "step3_correct_option_justified": true/false, + "step4_medical_errors_found": ["error description"] or [], + "step5_distractor_refutation": "present_and_accurate" / "present_but_weak" / "absent" / "not_applicable", + "step6_reasoning_assessment": "clear" / "somewhat_clear" / "confusing", + "step7_final_score": 0.XX, + "score_justification": "Brief 1-2 sentence explanation of the score" +}} + +QUESTION CONTEXT: +Question: {question} +Options: +{options} +Correct Answer: {correct_answer} + +REFERENCE EXPLANATIONS: +Reference 1: {ref_exp1} +Reference 2: {ref_exp2} + +MODEL EXPLANATION TO EVALUATE: +{model_explanation} + +Provide your evaluation as JSON:""" + + +def _extract_score_from_json(text: str) -> tuple[float, dict]: + """ + Extract score from JSON response. + + Returns: + (score, parsed_dict): score is 0.0-1.0; parsed_dict contains full evaluation + """ + try: + # Try to parse as JSON first + # Find JSON object in response (may have extra text before/after) + json_match = re.search(r'\{.*\}', text, re.DOTALL) + if json_match: + json_str = json_match.group(0) + data = json.loads(json_str) + + # Extract score from step7_final_score or final_score + score = float(data.get("step7_final_score", data.get("final_score", 0.0))) + score = max(0.0, min(1.0, score)) + return score, data + except Exception: + pass + + # Fallback: try old "final_score:" pattern + try: + m = re.search(r"final_score\s*:\s*(\d+\.\d+|\d+)", text, flags=re.IGNORECASE) + if m: + score = float(m.group(1)) + return max(0.0, min(1.0, score)), {} + except Exception: + pass + + # Last fallback: extract any number + try: + m = re.search(r"(\d+\.\d+|\d+)", text.strip()) + if m: + val = float(m.group(1)) + return max(0.0, min(1.0, val)), {} + except Exception: + pass + + return 0.0, {} + + +async def explanation_geval_reward( + judge, + prompt, + completion, + answer, + state, + **kwargs, +) -> float: + # Extract the last assistant message content as the explanation text + if isinstance(completion, list) and completion: + completion_text = completion[-1].get("content", "") or "" + else: + completion_text = str(completion) + + info = kwargs.get("info", {}) or {} + options = {k: info.get(k, "") for k in ["A", "B", "C", "D"]} + question = info.get("question", "") + exp0 = info.get("exp0", "") + exp1 = info.get("exp1", "") + correct_letter = (answer or "").strip().upper() + + # Gate explanation to zero if predicted MCQ answer is wrong + try: + m = re.search(r"(? vf.JudgeRubric: + rubric = vf.JudgeRubric( + judge_client=judge_client, + judge_model=judge_model, + judge_prompt="{question}", # not used directly; reward builds full prompt + parser=parser, + ) + rubric.add_reward_func(explanation_geval_reward, weight=1.0) + return rubric + + diff --git a/environments/medexqa/medexqa.py b/environments/medexqa/medexqa.py index 69d72a93..c7d07ef3 100644 --- a/environments/medexqa/medexqa.py +++ b/environments/medexqa/medexqa.py @@ -82,9 +82,11 @@ def load_environment( explanation_metrics: list[str] | str | None = None, # None/"all" => average of all four # Optional judge settings use_judge: bool = False, + judge_mode: str | None = None, # "g-eval" | "factscore" judge_model: str = "gpt-4o-mini", judge_base_url: str | None = None, judge_api_key: str | None = None, + use_coverage: bool = False, # For FactScore: enable coverage calculation (slower but comprehensive) **kwargs ) -> vf.Environment: """ @@ -284,71 +286,24 @@ def explanation_reward(parser, completion, answer, **kwargs) -> float: if use_explanations and use_judge: api_key = judge_api_key if judge_api_key else os.getenv("JUDGE_API_KEY") or os.getenv("OPENAI_API_KEY") judge_client = AsyncOpenAI(base_url=judge_base_url, api_key=api_key) if api_key else None - judge_rubric = vf.JudgeRubric( - judge_client=judge_client, - judge_model=judge_model, - judge_prompt="{question}", - ) - - async def explanation_judge_reward(judge, prompt, completion, answer, state, **kwargs) -> float: - completion_text = _get_completion_text(completion) - info = kwargs.get("info", {}) or {} - options = {"A": info.get("A", ""), "B": info.get("B", ""), "C": info.get("C", ""), "D": info.get("D", "")} - gold = (answer or "").strip().upper() - pred_letter = extract_answer_letter(completion_text, options) - if pred_letter != gold: - base = 0.0 - else: - # Build judge prompt - question = info.get("question", "") - opts_str = "\n".join(f"{k}. {options.get(k, '')}" for k in ["A","B","C","D"]) - formatted_question = f"{question}\n{opts_str}" - exp0 = info.get("exp0", "") - exp1 = info.get("exp1", "") - full_prompt = ( - "You are evaluating the quality of a medical explanation.\n\n" - "**Question:**\n" + formatted_question + "\n\n" - "**Correct Answer:** " + str(answer) + "\n\n" - "**Reference Explanation 1:**\n" + str(exp0) + "\n\n" - "**Reference Explanation 2:**\n" + str(exp1) + "\n\n" - "**Model's Response:**\n" + completion_text + "\n\n" - "Evaluate whether the model's explanation is medically accurate, relevant, and demonstrates understanding of the medical concepts. The explanation should justify why the answer is correct.\n\n" - "Compare the model's explanation quality to the reference explanations. Consider:\n" - "- Medical accuracy\n" - "- Relevance to the question\n" - "- Clarity and completeness\n" - "- Proper use of medical concepts\n\n" - "Respond with a score from 0.0 to 1.0:\n" - "- 1.0 = Excellent (as good as or better than references)\n" - "- 0.75 = Good (mostly correct with minor issues)\n" - "- 0.5 = Acceptable (partially correct)\n" - "- 0.25 = Poor (significant errors)\n" - "- 0.0 = Wrong or irrelevant\n\n" - "Respond with ONLY a number between 0.0 and 1.0." - ) - judge_response = await judge_rubric.judge( - [{"role": "user", "content": full_prompt}], - "", - "", - state, - **kwargs, - ) - try: - score_str = str(judge_response).strip() - import re as _re - number_match = _re.search(r"(\d+\.?\d*)", score_str) - explanation_score = float(number_match.group(1)) if number_match else 0.0 - except Exception: - explanation_score = 0.0 - base = max(0.0, min(1.0, explanation_score)) * 100.0 - return base - - # Use JudgeRubric with two metrics: answer accuracy (sync), explanation judge (async) - judge_rubric.add_reward_func(answer_accuracy_reward, weight=mcq_weight) - judge_rubric.add_reward_func(explanation_judge_reward, weight=explanation_weight) - rubric = judge_rubric + if judge_mode is None: + raise ValueError("use_judge=True requires judge_mode to be one of {'g-eval','factscore'}") + if judge_mode not in ("g-eval", "factscore"): + raise ValueError("judge_mode must be 'g-eval' or 'factscore'") + + if judge_mode == "g-eval": + from .geval_judge.geval_judge import create_geval_judge_rubric + judge_rubric = create_geval_judge_rubric(parser=parser, judge_client=judge_client, judge_model=judge_model) + # Combine answer accuracy with the judge-based explanation score + judge_rubric.add_reward_func(answer_accuracy_reward, weight=mcq_weight) + rubric = judge_rubric + elif judge_mode == "factscore": + from .factscore_judge.atomic_facts_judge import create_factscore_judge_rubric + judge_rubric = create_factscore_judge_rubric(parser=parser, judge_client=judge_client, judge_model=judge_model, use_coverage=use_coverage) + judge_rubric.add_reward_func(answer_accuracy_reward, weight=mcq_weight) + rubric = judge_rubric else: - # Keep metrics separate (and a combine drewad with tunable weights) + # Keep metrics separate (and a combined reward with tunable weights) rubric = vf.Rubric(funcs=[answer_accuracy_reward, explanation_reward], weights=[mcq_weight, explanation_weight], parser=parser) env = vf.SingleTurnEnv( diff --git a/environments/medexqa/tools/judge_rescore.py b/environments/medexqa/tools/judge_rescore.py new file mode 100644 index 00000000..7f5ada8a --- /dev/null +++ b/environments/medexqa/tools/judge_rescore.py @@ -0,0 +1,350 @@ +import argparse +import asyncio +import csv +import glob +import json +import os +import re +from typing import Any, Dict, List, Tuple + +from openai import AsyncOpenAI + +# Reuse existing judge implementations from the environment +from environments.medexqa.geval_judge.geval_judge import ( + explanation_geval_reward as geval_reward, +) +from environments.medexqa.factscore_judge.atomic_facts_judge import ( + explanation_factscore_reward as factscore_reward, +) + + +def _extract_numeric(text: str) -> float: + m = re.search(r"(\d+\.\d+|\d+)", (text or "").strip()) + if not m: + return 0.0 + try: + val = float(m.group(1)) + return max(0.0, min(1.0, val)) + except Exception: + return 0.0 + + +def _read_results(paths: List[str]) -> List[Tuple[str, Dict[str, Any]]]: + rows: List[Tuple[str, Dict[str, Any]]] = [] + for p in paths: + try: + with open(p, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + rec = json.loads(line) + rows.append((p, rec)) + except Exception as e: + print(f"Warning: failed to read {p}: {e}") + return rows + + +class JudgeRecorder: + def __init__(self, client: AsyncOpenAI, model: str, sleep_ms: int = 500, max_retries: int = 5, verbose: bool = True, max_tokens: int = 384): + self.client = client + self.model = model + self.sleep_ms = max(0, int(sleep_ms)) + self.max_retries = max(1, int(max_retries)) + self.verbose = verbose + self.max_tokens = max_tokens + self.logs: List[Dict[str, str]] = [] + + async def __call__(self, messages, *_args, **_kwargs) -> str: + # messages is a list of {role, content} + content = messages[-1].get("content", "") if messages else "" + attempt = 0 + delay = self.sleep_ms / 1000.0 + while True: + try: + if self.verbose: + print(f"[judge] calling model={self.model}, tokens<=256") + resp = await self.client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0, + max_tokens=self.max_tokens, + ) + out = resp.choices[0].message.content or "" + self.logs.append({"prompt": content, "response": out}) + # throttle between calls + if self.sleep_ms > 0: + await asyncio.sleep(self.sleep_ms / 1000.0) + return out + except Exception as e: + attempt += 1 + msg = str(e) + if self.verbose: + print(f"[judge] error on attempt {attempt}: {msg}") + # retry on rate limit or transient errors + if attempt < self.max_retries: + # exponential backoff with floor at configured delay + backoff = delay * (2 ** (attempt - 1)) + await asyncio.sleep(backoff) + continue + # record failure + self.logs.append({"prompt": content, "response": f": {msg}"}) + return "" + + +async def judge_geval( + client: AsyncOpenAI, + model: str, + rec: Dict[str, Any], + *, + sleep_ms: int = 500, + max_retries: int = 5, + verbose: bool = True, + max_tokens: int = 384, +) -> Tuple[float, str, str]: + info = rec.get("info", {}) or {} + question = info.get("question", "") + options = {k: info.get(k, "") for k in ["A", "B", "C", "D"]} + exp0 = info.get("exp0", "") + exp1 = info.get("exp1", "") + answer = rec.get("answer", "") + completion_msgs = rec.get("completion", []) + + jr = JudgeRecorder(client, model, sleep_ms=sleep_ms, max_retries=max_retries, verbose=verbose, max_tokens=max_tokens) + score = await geval_reward(jr, None, completion_msgs, answer, state={}, info=info, judge_client=client, judge_model=model) + # Last log entry contains the overall prompt/response + judge_output = jr.logs[-1]["response"] if jr.logs else "" + refs = ( + f"Question: {question}\n" + f"Correct answer: {answer} ({options.get(answer,'')})\n" + f"Ref1: {exp0}\n" + f"Ref2: {exp1}" + ) + return float(score), judge_output, refs + + +async def judge_factscore( + client: AsyncOpenAI, + model: str, + rec: Dict[str, Any], + *, + sleep_ms: int = 500, + max_retries: int = 5, + verbose: bool = True, + max_tokens: int = 384, + use_coverage: bool = False, +) -> Tuple[float, str, str]: + info = rec.get("info", {}) or {} + question = info.get("question", "") + options = {k: info.get(k, "") for k in ["A", "B", "C", "D"]} + exp0 = info.get("exp0", "") + exp1 = info.get("exp1", "") + answer = rec.get("answer", "") + completion_msgs = rec.get("completion", []) + + jr = JudgeRecorder(client, model, sleep_ms=sleep_ms, max_retries=max_retries, verbose=verbose, max_tokens=max_tokens) + score = await factscore_reward(jr, None, completion_msgs, answer, state={}, info=info, judge_client=client, judge_model=model, use_coverage=use_coverage) + + # Parse logs to reconstruct claim labels and track extraction outcomes + labels: List[Tuple[str, str]] = [] # (claim, label) where passage=references (support) + coverage_labels: List[Tuple[str, str]] = [] # (ref_claim, label) where passage=explanation (coverage) + extraction_responses: List[str] = [] + for entry in jr.logs: + prompt = entry.get("prompt", "") or "" + response = entry.get("response", "") or "" + + # Look for new format: "CLAIM TO VERIFY:\n" + m = re.search(r"CLAIM TO VERIFY:\s*\n(.+?)(?:\n\nINSTRUCTIONS:)", prompt, flags=re.DOTALL) + if m: + claim = m.group(1).strip() + # Heuristic: prompts containing "PASSAGE (Combined References):" belong to support + # Prompts with "MODEL EXPLANATION:" belong to coverage + if "PASSAGE (Combined References):" in prompt: + labels.append((claim, response.strip().upper())) + elif "MODEL EXPLANATION:" in prompt: + coverage_labels.append((claim, response.strip().upper())) + + # Fallback: old format "Fact:\n" + if not m: + m_old = re.search(r"Fact:\s*\n(.+)$", prompt, flags=re.DOTALL) + if m_old: + claim = m_old.group(1).strip() + if "Question:" in prompt: + labels.append((claim, response.strip().upper())) + else: + coverage_labels.append((claim, response.strip().upper())) + + if "Claims JSON:" in prompt: + extraction_responses.append(response) + + refs = ( + f"Question: {question}\n" + f"Correct: ({answer}) {options.get(answer,'')}\n" + f"Ref1: {exp0}\n" + f"Ref2: {exp1}" + ) + # Derive error tag for extraction phase + err_tag = "" + if extraction_responses: + last_extraction = extraction_responses[-1] + try: + parsed = json.loads(last_extraction) + if isinstance(parsed, list) and len(parsed) == 0: + err_tag = "empty_extraction" + except Exception: + err_tag = "extraction_error" + elif not labels and not coverage_labels: + err_tag = "empty_extraction" + + # Compute support/coverage rates from labels (handle 3-level format) + def _rate(pairs: List[Tuple[str, str]]) -> float: + if not pairs: + return 0.0 + total_score = 0.0 + for _, lbl in pairs: + lbl_clean = (lbl or "").strip().upper() + if "FULLY_SUPPORTED" in lbl_clean or "FULLY SUPPORTED" in lbl_clean: + total_score += 1.0 + elif "PARTIALLY_SUPPORTED" in lbl_clean or "PARTIALLY SUPPORTED" in lbl_clean: + total_score += 0.5 + elif lbl_clean.startswith("TRUE"): # Fallback for old format + total_score += 1.0 + # NOT_SUPPORTED or FALSE = 0.0 + return float(total_score) / float(len(pairs)) + + support_rate = _rate(labels) + coverage_rate = _rate(coverage_labels) + + details = json.dumps({ + "claims": labels, + "coverage_labels": coverage_labels, + "support_rate": support_rate, + "coverage_rate": coverage_rate, + }, ensure_ascii=False) + if err_tag: + refs = refs + f"\n: {err_tag}" + return float(score), details, refs + + +async def main(): + ap = argparse.ArgumentParser(description="Re-score saved MedExQA completions with LLM judges.") + ap.add_argument("--base", default="https://openrouter.ai/api/v1", help="Judge API base URL") + ap.add_argument("--model", default="openai/gpt-oss-20b:free", help="Judge model id") + ap.add_argument("--key_var", default="OPENAI_API_KEY", help="Env var name holding the API key") + ap.add_argument("--input_glob", default="environments/medexqa/outputs/evals/**/results.jsonl", help="Glob to results.jsonl files") + ap.add_argument("--out_csv_prefix", default="environments/medexqa/outputs/judge_scores/medexqa_", help="Output CSV prefix (will append judge name)") + ap.add_argument("--sleep_ms", type=int, default=500, help="Sleep/throttle between judge calls (ms)") + ap.add_argument("--max_retries", type=int, default=5, help="Max retries on judge call errors") + ap.add_argument("--max_tokens", type=int, default=384, help="Max tokens per judge response") + ap.add_argument("--verbose", action="store_true", help="Verbose logging") + ap.add_argument("--judge", choices=["geval", "factscore", "both"], default="both", help="Which judge(s) to run") + ap.add_argument("--use_coverage", action="store_true", help="Enable coverage calculation for FactScore (slower but more comprehensive)") + args = ap.parse_args() + + api_key = os.getenv(args.key_var) + if not api_key: + raise SystemExit(f"Missing API key in env var {args.key_var}") + + client = AsyncOpenAI(base_url=args.base, api_key=api_key) + + # Discover saved runs + paths = sorted(glob.glob(args.input_glob, recursive=True)) + if args.verbose: + print(f"Scanning {len(paths)} results.jsonl files...") + rows = _read_results(paths) + if not rows: + print("No results found to re-score.") + return + + os.makedirs(os.path.dirname(args.out_csv_prefix), exist_ok=True) + + # Prepare CSV writers conditionally + gwriter = None + fwriter = None + geval_path = args.out_csv_prefix + "geval.csv" + fact_path = args.out_csv_prefix + "factscore.csv" + if args.judge in ("geval", "both"): + gf = open(geval_path, "w", newline="") + gwriter = csv.writer(gf) + gwriter.writerow(["run_file", "specialty", "question", "A", "B", "C", "D", "answer", "completion", "judge_model_output", "judge_score", "references", "error"]) + if args.judge in ("factscore", "both"): + ff = open(fact_path, "w", newline="") + fwriter = csv.writer(ff) + fwriter.writerow(["run_file", "specialty", "question", "A", "B", "C", "D", "answer", "completion", "claims_labels_json", "support_rate", "coverage_labels_json", "coverage_rate", "final_score", "references", "error"]) + + # Process sequentially to keep it simple + for idx, (run_file, rec) in enumerate(rows, start=1): + info = rec.get("info", {}) or {} + spec = info.get("specialty", "") + question = info.get("question", "") + A = info.get("A", "") + B = info.get("B", "") + C = info.get("C", "") + D = info.get("D", "") + answer = rec.get("answer", "") + completion_msgs = rec.get("completion", []) + completion_text = completion_msgs[-1].get("content", "") if completion_msgs else "" + + if args.verbose: + print(f"[{idx}/{len(rows)}] {run_file} | spec={spec} | len(prompt)={len(question)} | len(completion)={len(completion_text)}") + + # G-Eval + if args.judge in ("geval", "both") and gwriter is not None: + if args.verbose: + print(" -> G-Eval judging...") + g_score, g_out, g_refs = await judge_geval( + client, + args.model, + rec, + sleep_ms=args.sleep_ms, + max_retries=args.max_retries, + verbose=args.verbose, + max_tokens=args.max_tokens, + ) + # detect errors in logs + g_err = "" + if g_out.strip().startswith("") or g_out.strip() == "": + g_err = "empty_or_error" + gwriter.writerow([run_file, spec, question, A, B, C, D, answer, completion_text, g_out, f"{g_score:.3f}", g_refs, g_err]) + + # FactScore + if args.judge in ("factscore", "both") and fwriter is not None: + if args.verbose: + print(" -> FactScore judging...") + f_score, f_details, f_refs = await judge_factscore( + client, + args.model, + rec, + sleep_ms=args.sleep_ms, + max_retries=args.max_retries, + verbose=args.verbose, + max_tokens=args.max_tokens, + use_coverage=args.use_coverage, + ) + f_err = "" + support_rate = "" + coverage_rate = "" + coverage_labels_json = "{}" + try: + dd = json.loads(f_details) + support_rate = f"{float(dd.get('support_rate', 0.0)):.3f}" + coverage_rate = f"{float(dd.get('coverage_rate', 0.0)):.3f}" + coverage_labels_json = json.dumps(dd.get("coverage_labels", []), ensure_ascii=False) + except Exception: + pass + if f_details.strip() == "" or f_details.strip() == "{}": + f_err = "empty_or_error" + fwriter.writerow([run_file, spec, question, A, B, C, D, answer, completion_text, f_details, support_rate, coverage_labels_json, coverage_rate, f"{f_score:.3f}", f_refs, f_err]) + + if gwriter is not None: + gf.close() + print(f"Wrote: {geval_path}") + if fwriter is not None: + ff.close() + print(f"Wrote: {fact_path}") + + +if __name__ == "__main__": + asyncio.run(main()) + + From a415f6c50fdd48c303f430b0a81f3b950b8eddd9 Mon Sep 17 00:00:00 2001 From: mnishant2 Date: Tue, 4 Nov 2025 01:01:02 +0100 Subject: [PATCH 6/9] Add MCQ shuffle and standardized answer extraction to MedExQA --- environments/medexqa/medexqa.py | 116 +++++++++++++++++----------- environments/medexqa/pyproject.toml | 2 +- 2 files changed, 73 insertions(+), 45 deletions(-) diff --git a/environments/medexqa/medexqa.py b/environments/medexqa/medexqa.py index c7d07ef3..ad2ae0b9 100644 --- a/environments/medexqa/medexqa.py +++ b/environments/medexqa/medexqa.py @@ -6,9 +6,11 @@ from verifiers.utils.data_utils import THINK_BOXED_SYSTEM_PROMPT, extract_boxed_answer import pandas as pd import evaluate -from thefuzz import process from openai import AsyncOpenAI +from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy +from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice + # MedExQA specialties SPECIALTIES = [ @@ -35,13 +37,13 @@ def _build_question_str(question: str, options: dict[str, str]) -> str: return f"{instruction}{question}\n{opts}\nAnswer:" -def _to_vf_format(ds: Dataset) -> Dataset: +def _to_vf_format(ds: Dataset, shuffle_answers: bool, shuffle_seed: int | None) -> Dataset: """Normalize raw rows into the fields expected by SingleTurnEnv. Produces rows of the form: - question: string containing authors' instruction, question, and options - - answer: gold letter (A/B/C/D) - - info: original fields including exp0/exp1 and specialty + - answer: gold letter (A/B/C/D) - shuffled if shuffle_answers=True + - info: original fields including exp0/exp1 and specialty, plus shuffled options """ def _format_row(row: dict) -> dict: question = row.get("question", "") or "" @@ -59,18 +61,41 @@ def _format_row(row: dict) -> dict: if answer_letter not in ("A", "B", "C", "D"): return None + # Shuffle options if requested + if shuffle_answers: + opts, answer_letter, _ = randomize_multiple_choice( + options=opts, + answer_choice=answer_letter, + seed=shuffle_seed, + row_id=question, # Use question text for deterministic per-row shuffling + ) + question_str = _build_question_str(question, opts) # Keep original data in info info = dict(row) + # Update info with shuffled values + if shuffle_answers: + info["A"] = opts["A"] + info["B"] = opts["B"] + info["C"] = opts["C"] + info["D"] = opts["D"] + info["answer"] = answer_letter + return { "question": question_str, "answer": answer_letter, "info": info, } - return ds.map(_format_row, remove_columns=ds.column_names).filter(lambda row: row is not None) + # Disable cache when shuffling to ensure fresh randomization + load_from_cache_file = not shuffle_answers + return ds.map( + _format_row, + remove_columns=ds.column_names, + load_from_cache_file=load_from_cache_file + ).filter(lambda row: row is not None, load_from_cache_file=load_from_cache_file) def load_environment( @@ -80,6 +105,9 @@ def load_environment( explanation_weight: float = 0.5, specialty: list[str] | str | None = None, # list of short codes or full names; None/"ALL" => all explanation_metrics: list[str] | str | None = None, # None/"all" => average of all four + # MCQ shuffling + shuffle_answers: bool = False, + shuffle_seed: int | None = 1618, # Optional judge settings use_judge: bool = False, judge_mode: str | None = None, # "g-eval" | "factscore" @@ -150,7 +178,7 @@ def load_environment( # Concatenate and format for verifiers - no training dataset available test_combined = concatenate_datasets(test_datasets) if test_datasets else None - test_ds = _to_vf_format(test_combined) if test_combined else None + test_ds = _to_vf_format(test_combined, shuffle_answers, shuffle_seed) if test_combined else None # Shuffle examples if multiple specialties were selected if macro_active and test_ds is not None: @@ -175,32 +203,6 @@ def correct_answer_reward_func(parser, completion, answer, **kwargs) -> float: # (shuffling handled above when multiple specialties) - # Helpers (authors' answer extraction logic) - def process_before_extraction(gen: str, choice_dict: dict[str, str]) -> str: - for key, val in sorted(choice_dict.items(), key=lambda x: len(x[1] or ""), reverse=True): - pattern = re.compile(re.escape((val or "").rstrip(".")), re.IGNORECASE) - gen = pattern.sub(key, gen) - return gen - - def extract_choice(gen: str, choice_list: list[str]) -> str: - res = re.search(r"(?:(?:[Cc]hoose)|(?:(?:[Aa]nswer|[Cc]hoice)(?![^ABCD]{0,20}?(?:n't|not))[^ABCD]{0,10}?\b(?:|is|:|be))\b)[^ABCD]{0,20}?\b(A|B|C|D)\b", gen) - if res is None: - res = re.search(r"\b(A|B|C|D)\b(?![^ABCD]{0,8}?(?:n't|not)[^ABCD]{0,5}?(?:correct|right))[^ABCD]{0,10}?\b(?:correct|right)\b", gen) - if res is None: - res = re.search(r"^(A|B|C|D)(?:\.|,|:|$)", gen) - if res is None: - res = re.search(r"(? str: - gen = process_before_extraction(completion_text or "", options) - pred = extract_choice(gen, [options.get(c, "") for c in ["A", "B", "C", "D"]]) - return (pred or "").upper() - # Lexical Metrics selection; pass individually or None/'all'/'overall' => average of all four base_metrics = ["rougeL", "bleu", "meteor", "bertscore"] if explanation_metrics is None: @@ -264,23 +266,41 @@ def _get_completion_text(completion_obj) -> str: def answer_accuracy_reward(parser, completion, answer, **kwargs) -> float: completion_text = _get_completion_text(completion) info = kwargs.get("info", {}) or {} + + # Get answer_text for fallback matching options = {"A": info.get("A", ""), "B": info.get("B", ""), "C": info.get("C", ""), "D": info.get("D", "")} - gold = (answer or "").strip().upper() - pred_letter = extract_answer_letter(completion_text, options) - base = 100.0 if pred_letter == gold else 0.0 - return base + answer_text = options.get(answer, "") + + is_correct = multiple_choice_accuracy( + llm_answer=completion_text, + answer_letter=answer, + answer_text=answer_text, + accept_answer_text=True, + strip_tex=False, # MedExQA doesn't use LaTeX + ) + return 100.0 if is_correct else 0.0 def explanation_reward(parser, completion, answer, **kwargs) -> float: completion_text = _get_completion_text(completion) info = kwargs.get("info", {}) or {} + + # Get answer_text for fallback matching options = {"A": info.get("A", ""), "B": info.get("B", ""), "C": info.get("C", ""), "D": info.get("D", "")} - gold = (answer or "").strip().upper() - pred_letter = extract_answer_letter(completion_text, options) - if pred_letter != gold: - base = 0.0 + answer_text = options.get(answer, "") + + # Check if answer is correct using multiple_choice_accuracy + is_correct = multiple_choice_accuracy( + llm_answer=completion_text, + answer_letter=answer, + answer_text=answer_text, + accept_answer_text=True, + strip_tex=False, + ) + + if not is_correct: + return 0.0 else: - base = compute_expl_score(completion_text, info.get("exp0", ""), info.get("exp1", "")) - return base + return compute_expl_score(completion_text, info.get("exp0", ""), info.get("exp1", "")) # Optional: Use LLM-as-judge for explanation instead of lexical metrics if use_explanations and use_judge: @@ -303,8 +323,16 @@ def explanation_reward(parser, completion, answer, **kwargs) -> float: judge_rubric.add_reward_func(answer_accuracy_reward, weight=mcq_weight) rubric = judge_rubric else: - # Keep metrics separate (and a combined reward with tunable weights) - rubric = vf.Rubric(funcs=[answer_accuracy_reward, explanation_reward], weights=[mcq_weight, explanation_weight], parser=parser) + # Lexical metrics for explanations (or MCQ-only if use_explanations=False) + if use_explanations: + rubric = vf.Rubric( + funcs=[answer_accuracy_reward, explanation_reward], + weights=[mcq_weight, explanation_weight], + parser=parser + ) + else: + # MCQ-only mode + rubric = vf.Rubric(funcs=[answer_accuracy_reward], weights=[1.0], parser=parser) env = vf.SingleTurnEnv( dataset=None, # No training split available diff --git a/environments/medexqa/pyproject.toml b/environments/medexqa/pyproject.toml index 24d44f14..41a28464 100644 --- a/environments/medexqa/pyproject.toml +++ b/environments/medexqa/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "rouge-score>=0.1.2", "sacrebleu>=2.4.0", "bert-score>=0.3.13", - "thefuzz>=0.22.1", + "openai>=1.0.0", ] [build-system] From 6a17754c821be7d7cb1ca685b37d4ea9e3289e99 Mon Sep 17 00:00:00 2001 From: mnishant2 Date: Wed, 5 Nov 2025 15:54:16 +0100 Subject: [PATCH 7/9] author details --- environments/medexqa/README.md | 4 ++++ environments/medexqa/pyproject.toml | 3 +++ 2 files changed, 7 insertions(+) diff --git a/environments/medexqa/README.md b/environments/medexqa/README.md index 802dbfe2..79115a63 100644 --- a/environments/medexqa/README.md +++ b/environments/medexqa/README.md @@ -243,3 +243,7 @@ r1: [35.58, 31.618, 28.316, 33.239, 38.249, 0.0, 33.915, 32.653, 33.741, 21.006] year={2024} } ``` +### Authors +This environment has been put together by: + +Nishant Mishra - ([mnishant2](https://github.com/mnishant2)) \ No newline at end of file diff --git a/environments/medexqa/pyproject.toml b/environments/medexqa/pyproject.toml index 41a28464..ab9bae1b 100644 --- a/environments/medexqa/pyproject.toml +++ b/environments/medexqa/pyproject.toml @@ -4,6 +4,9 @@ version = "0.1.0" description = "MedExQA Evaluation - Medical QA with Multiple Explanations" readme = "README.md" requires-python = ">=3.11" +authors = [ + { name = "Nishant Mishra", email = "mnishant2@gmail.com" }, +] dependencies = [ "datasets>=4.0.0", "verifiers>=0.1.2.post0", From 9e27e9802f65ac408ecb1823deb69e57a92e1bfe Mon Sep 17 00:00:00 2001 From: mnishant2 Date: Sat, 8 Nov 2025 02:44:45 +0100 Subject: [PATCH 8/9] Update medexqa judges: fix template to use standard placeholders, use judge() for token tracking compatibility --- .../medexqa/factscore_judge/__init__.py | 7 +++ .../factscore_judge/atomic_facts_generator.py | 12 +++-- .../factscore_judge/atomic_facts_judge.py | 46 ++++--------------- environments/medexqa/geval_judge/__init__.py | 6 +++ environments/medexqa/medexqa.py | 4 +- 5 files changed, 33 insertions(+), 42 deletions(-) create mode 100644 environments/medexqa/factscore_judge/__init__.py create mode 100644 environments/medexqa/geval_judge/__init__.py diff --git a/environments/medexqa/factscore_judge/__init__.py b/environments/medexqa/factscore_judge/__init__.py new file mode 100644 index 00000000..1149e4e5 --- /dev/null +++ b/environments/medexqa/factscore_judge/__init__.py @@ -0,0 +1,7 @@ +"""FactScore judge for MedExQA explanations.""" + +from .atomic_facts_judge import create_factscore_judge_rubric, explanation_factscore_reward +from .atomic_facts_generator import AtomicFactGenerator + +__all__ = ["create_factscore_judge_rubric", "explanation_factscore_reward", "AtomicFactGenerator"] + diff --git a/environments/medexqa/factscore_judge/atomic_facts_generator.py b/environments/medexqa/factscore_judge/atomic_facts_generator.py index 1b11d987..b82d1e4d 100644 --- a/environments/medexqa/factscore_judge/atomic_facts_generator.py +++ b/environments/medexqa/factscore_judge/atomic_facts_generator.py @@ -17,22 +17,26 @@ def __init__(self, async_openai_client: AsyncOpenAI | None, model_name: str = "g self.client = async_openai_client self.model_name = model_name - async def run(self, explanation_text: str) -> List[str]: + async def run(self, explanation_text: str, state: dict = None) -> List[str]: """ Extract atomic facts from an MCQA explanation. + + Args: + explanation_text: The explanation text to extract claims from + state: Optional state dict for token tracking """ explanation = (explanation_text or "").strip() if not explanation: return [] - primary = await self._extract_json_claims(explanation) + primary = await self._extract_json_claims(explanation, state=state) if primary: return primary - fallback = await self._extract_json_claims(explanation, fallback=True) + fallback = await self._extract_json_claims(explanation, fallback=True, state=state) return fallback or [] - async def _extract_json_claims(self, explanation: str, fallback: bool = False) -> List[str]: + async def _extract_json_claims(self, explanation: str, fallback: bool = False, state: dict = None) -> List[str]: if self.client is None: return [] diff --git a/environments/medexqa/factscore_judge/atomic_facts_judge.py b/environments/medexqa/factscore_judge/atomic_facts_judge.py index e56f4e69..94a4798a 100644 --- a/environments/medexqa/factscore_judge/atomic_facts_judge.py +++ b/environments/medexqa/factscore_judge/atomic_facts_judge.py @@ -16,15 +16,7 @@ JUDGE_TEMPLATE = """You are a medical knowledge verification expert. Evaluate if the Passage supports the Claim. -CONTEXT: -Question: {question} -Correct Answer: ({correct_letter}) {correct_option} - -Reference Explanations: -- Reference 1: {exp0} -- Reference 2: {exp1} - -PASSAGE (Combined References): +PASSAGE: {response} CLAIM TO VERIFY: @@ -122,7 +114,7 @@ async def explanation_factscore_reward( if llm_client is None: # No client available - cannot extract claims return 0.0 - claims = await generator.run(explanation) + claims = await generator.run(explanation, state=state) except Exception as e: # Log extraction error for debugging import sys @@ -138,18 +130,9 @@ async def explanation_factscore_reward( for claim in claims: total += 1 - prompt_msg = JUDGE_TEMPLATE.format( - response=refs, - answer=str(claim), - question=question, - correct_letter=correct_letter, - correct_option=correct_option_text, - exp0=exp0, - exp1=exp1 - ) - judge_response = await judge([ - {"role": "user", "content": prompt_msg} - ], "", "", state, **kwargs) + # Call judge like medredqa does: judge(prompt, completion, answer, state, **kwargs) + # prompt is not used in template, completion becomes {response}, answer becomes {answer} + judge_response = await judge(prompt, refs, str(claim), state, **kwargs) score, ok = extract_support_level(judge_response) if ok: support_score += score @@ -169,7 +152,7 @@ async def explanation_factscore_reward( # Extract from reference 1 if (exp0 or "").strip(): try: - ref0_claims = await generator.run(exp0) + ref0_claims = await generator.run(exp0, state=state) all_ref_claims.extend(ref0_claims) except Exception: pass @@ -177,7 +160,7 @@ async def explanation_factscore_reward( # Extract from reference 2 if (exp1 or "").strip(): try: - ref1_claims = await generator.run(exp1) + ref1_claims = await generator.run(exp1, state=state) all_ref_claims.extend(ref1_claims) except Exception: pass @@ -197,18 +180,8 @@ async def explanation_factscore_reward( for ref_claim in unique_ref_claims: coverage_total += 1 # Check if model explanation supports this reference claim - prompt_msg = JUDGE_TEMPLATE.format( - response=explanation, # Passage = model explanation - answer=str(ref_claim), # Claim = reference claim - question=question, - correct_letter=correct_letter, - correct_option=correct_option_text, - exp0=exp0, - exp1=exp1 - ) - cov_response = await judge([ - {"role": "user", "content": prompt_msg} - ], "", "", state, **kwargs) + # Call judge: passage=explanation, claim=ref_claim + cov_response = await judge(prompt, explanation, str(ref_claim), state, **kwargs) score, ok = extract_support_level(cov_response) if ok: coverage_score += score @@ -246,6 +219,7 @@ def create_factscore_judge_rubric( judge_model: str = "gpt-4o-mini", use_coverage: bool = False, ) -> vf.JudgeRubric: + # Pass judge_prompt like medredqa does - uses standard {response} and {answer} placeholders rubric = vf.JudgeRubric( judge_client=judge_client, judge_model=judge_model, diff --git a/environments/medexqa/geval_judge/__init__.py b/environments/medexqa/geval_judge/__init__.py new file mode 100644 index 00000000..faa0103a --- /dev/null +++ b/environments/medexqa/geval_judge/__init__.py @@ -0,0 +1,6 @@ +"""G-Eval judge for MedExQA explanations.""" + +from .geval_judge import create_geval_judge_rubric, explanation_geval_reward + +__all__ = ["create_geval_judge_rubric", "explanation_geval_reward"] + diff --git a/environments/medexqa/medexqa.py b/environments/medexqa/medexqa.py index ad2ae0b9..0ae3b92c 100644 --- a/environments/medexqa/medexqa.py +++ b/environments/medexqa/medexqa.py @@ -312,13 +312,13 @@ def explanation_reward(parser, completion, answer, **kwargs) -> float: raise ValueError("judge_mode must be 'g-eval' or 'factscore'") if judge_mode == "g-eval": - from .geval_judge.geval_judge import create_geval_judge_rubric + from environments.medexqa.geval_judge.geval_judge import create_geval_judge_rubric judge_rubric = create_geval_judge_rubric(parser=parser, judge_client=judge_client, judge_model=judge_model) # Combine answer accuracy with the judge-based explanation score judge_rubric.add_reward_func(answer_accuracy_reward, weight=mcq_weight) rubric = judge_rubric elif judge_mode == "factscore": - from .factscore_judge.atomic_facts_judge import create_factscore_judge_rubric + from environments.medexqa.factscore_judge.atomic_facts_judge import create_factscore_judge_rubric judge_rubric = create_factscore_judge_rubric(parser=parser, judge_client=judge_client, judge_model=judge_model, use_coverage=use_coverage) judge_rubric.add_reward_func(answer_accuracy_reward, weight=mcq_weight) rubric = judge_rubric From a70feb3c4485777b9b786b698337e4436f618ac5 Mon Sep 17 00:00:00 2001 From: mnishant2 Date: Sat, 8 Nov 2025 14:34:14 +0100 Subject: [PATCH 9/9] Standardize parser logic and implement namespace package for medexqa --- environments/medexqa/medexqa/__init__.py | 5 ++ .../{ => medexqa}/factscore_judge/__init__.py | 0 .../factscore_judge/atomic_facts_generator.py | 0 .../factscore_judge/atomic_facts_judge.py | 27 +++++++--- .../{ => medexqa}/geval_judge/__init__.py | 0 .../{ => medexqa}/geval_judge/geval_judge.py | 28 +++++++--- .../medexqa/{medexqa.py => medexqa/main.py} | 53 ++++++++++--------- environments/medexqa/pyproject.toml | 3 +- 8 files changed, 77 insertions(+), 39 deletions(-) create mode 100644 environments/medexqa/medexqa/__init__.py rename environments/medexqa/{ => medexqa}/factscore_judge/__init__.py (100%) rename environments/medexqa/{ => medexqa}/factscore_judge/atomic_facts_generator.py (100%) rename environments/medexqa/{ => medexqa}/factscore_judge/atomic_facts_judge.py (92%) rename environments/medexqa/{ => medexqa}/geval_judge/__init__.py (100%) rename environments/medexqa/{ => medexqa}/geval_judge/geval_judge.py (90%) rename environments/medexqa/{medexqa.py => medexqa/main.py} (87%) diff --git a/environments/medexqa/medexqa/__init__.py b/environments/medexqa/medexqa/__init__.py new file mode 100644 index 00000000..f01e6615 --- /dev/null +++ b/environments/medexqa/medexqa/__init__.py @@ -0,0 +1,5 @@ +"""MedExQA environment package.""" + +from medexqa.main import load_environment + +__all__ = ["load_environment"] diff --git a/environments/medexqa/factscore_judge/__init__.py b/environments/medexqa/medexqa/factscore_judge/__init__.py similarity index 100% rename from environments/medexqa/factscore_judge/__init__.py rename to environments/medexqa/medexqa/factscore_judge/__init__.py diff --git a/environments/medexqa/factscore_judge/atomic_facts_generator.py b/environments/medexqa/medexqa/factscore_judge/atomic_facts_generator.py similarity index 100% rename from environments/medexqa/factscore_judge/atomic_facts_generator.py rename to environments/medexqa/medexqa/factscore_judge/atomic_facts_generator.py diff --git a/environments/medexqa/factscore_judge/atomic_facts_judge.py b/environments/medexqa/medexqa/factscore_judge/atomic_facts_judge.py similarity index 92% rename from environments/medexqa/factscore_judge/atomic_facts_judge.py rename to environments/medexqa/medexqa/factscore_judge/atomic_facts_judge.py index 94a4798a..6125e742 100644 --- a/environments/medexqa/factscore_judge/atomic_facts_judge.py +++ b/environments/medexqa/medexqa/factscore_judge/atomic_facts_judge.py @@ -88,12 +88,24 @@ async def explanation_factscore_reward( correct_option_text = options.get(correct_letter, "") # Gate explanation to zero if predicted MCQ answer is wrong - try: - m = re.search(r"(? vf.JudgeRubric: # Pass judge_prompt like medredqa does - uses standard {response} and {answer} placeholders rubric = vf.JudgeRubric( @@ -227,7 +240,7 @@ def create_factscore_judge_rubric( parser=parser, use_coverage=use_coverage, # Pass through to reward function via kwargs ) - rubric.add_reward_func(explanation_factscore_reward, weight=1.0) + rubric.add_reward_func(explanation_factscore_reward, weight=explanation_weight) return rubric diff --git a/environments/medexqa/geval_judge/__init__.py b/environments/medexqa/medexqa/geval_judge/__init__.py similarity index 100% rename from environments/medexqa/geval_judge/__init__.py rename to environments/medexqa/medexqa/geval_judge/__init__.py diff --git a/environments/medexqa/geval_judge/geval_judge.py b/environments/medexqa/medexqa/geval_judge/geval_judge.py similarity index 90% rename from environments/medexqa/geval_judge/geval_judge.py rename to environments/medexqa/medexqa/geval_judge/geval_judge.py index 7df46956..d83e90eb 100644 --- a/environments/medexqa/geval_judge/geval_judge.py +++ b/environments/medexqa/medexqa/geval_judge/geval_judge.py @@ -137,12 +137,25 @@ async def explanation_geval_reward( correct_letter = (answer or "").strip().upper() # Gate explanation to zero if predicted MCQ answer is wrong - try: - m = re.search(r"(? vf.JudgeRubric: rubric = vf.JudgeRubric( judge_client=judge_client, @@ -194,7 +208,7 @@ def create_geval_judge_rubric( judge_prompt="{question}", # not used directly; reward builds full prompt parser=parser, ) - rubric.add_reward_func(explanation_geval_reward, weight=1.0) + rubric.add_reward_func(explanation_geval_reward, weight=explanation_weight) return rubric diff --git a/environments/medexqa/medexqa.py b/environments/medexqa/medexqa/main.py similarity index 87% rename from environments/medexqa/medexqa.py rename to environments/medexqa/medexqa/main.py index 0ae3b92c..45440b7d 100644 --- a/environments/medexqa/medexqa.py +++ b/environments/medexqa/medexqa/main.py @@ -3,7 +3,6 @@ import verifiers as vf from datasets import Dataset, concatenate_datasets -from verifiers.utils.data_utils import THINK_BOXED_SYSTEM_PROMPT, extract_boxed_answer import pandas as pd import evaluate from openai import AsyncOpenAI @@ -22,11 +21,11 @@ ] # author prompt directly taken from https://github.com/knowlab/MedExQA/blob/9a5b34af103b0c8ba0c00906e278f6572249fafa/evaluate_pipe_MedExQA.py#L32 -def _build_question_str(question: str, options: dict[str, str]) -> str: +def _build_question_str(question: str, options: dict[str, str], use_think: bool = False) -> str: """Build user prompt with authors' instruction embedded (as in their script). The instruction lives in the user message; the system prompt remains empty in - normal mode, and only adds THINK_BOXED in think-mode. + normal mode. In think mode, system prompt instructs use of and tags. """ instruction = ( "The following is a multiple-choice question. Please choose the most suitable one " @@ -187,19 +186,21 @@ def load_environment( except Exception: pass - # Setup system prompt - empty for normal; use think-boxed for think mode - system_prompt = THINK_BOXED_SYSTEM_PROMPT if use_think else "" - - # Parser for extracting \\boxed{} answers - parser = ( - vf.ThinkParser(extract_fn=extract_boxed_answer) if use_think - else vf.Parser(extract_fn=extract_boxed_answer) - ) - - def correct_answer_reward_func(parser, completion, answer, **kwargs) -> float: - """Reward function for MCQ accuracy.""" - response = parser.parse_answer(completion) or "" - return 1.0 if response == answer else 0.0 + # Setup system prompt and parser - standardized with medredqa approach + # - Normal mode: No system prompt, parser returns raw text + # - Think mode: XML system prompt, XMLParser extracts from tags + if use_think: + # Like medredqa: think in tags, answer+explanation in tags + system_prompt = ( + "Think step-by-step inside ... tags. " + "Then, inside ... tags, provide your final answer choice (A, B, C, or D) " + "followed by an explanation of why you chose that answer." + ) + parser = vf.XMLParser(fields=["think", "answer"], answer_field="answer") + else: + # Normal mode: no system prompt, parser returns raw text for multiple_choice_accuracy + system_prompt = "" + parser = vf.Parser() # (shuffling handled above when multiple specialties) @@ -264,7 +265,8 @@ def _get_completion_text(completion_obj) -> str: return completion_obj if isinstance(completion_obj, str) else str(completion_obj) def answer_accuracy_reward(parser, completion, answer, **kwargs) -> float: - completion_text = _get_completion_text(completion) + # Parse answer first (extracts from \boxed{} in think mode, returns raw text in normal mode) + parsed = parser.parse_answer(completion) or "" info = kwargs.get("info", {}) or {} # Get answer_text for fallback matching @@ -272,7 +274,7 @@ def answer_accuracy_reward(parser, completion, answer, **kwargs) -> float: answer_text = options.get(answer, "") is_correct = multiple_choice_accuracy( - llm_answer=completion_text, + llm_answer=parsed, answer_letter=answer, answer_text=answer_text, accept_answer_text=True, @@ -281,7 +283,8 @@ def answer_accuracy_reward(parser, completion, answer, **kwargs) -> float: return 100.0 if is_correct else 0.0 def explanation_reward(parser, completion, answer, **kwargs) -> float: - completion_text = _get_completion_text(completion) + # Parse answer first (extracts from \boxed{} in think mode, returns raw text in normal mode) + parsed = parser.parse_answer(completion) or "" info = kwargs.get("info", {}) or {} # Get answer_text for fallback matching @@ -290,7 +293,7 @@ def explanation_reward(parser, completion, answer, **kwargs) -> float: # Check if answer is correct using multiple_choice_accuracy is_correct = multiple_choice_accuracy( - llm_answer=completion_text, + llm_answer=parsed, answer_letter=answer, answer_text=answer_text, accept_answer_text=True, @@ -300,6 +303,8 @@ def explanation_reward(parser, completion, answer, **kwargs) -> float: if not is_correct: return 0.0 else: + # For lexical metrics, use the raw completion text (not parsed) + completion_text = _get_completion_text(completion) return compute_expl_score(completion_text, info.get("exp0", ""), info.get("exp1", "")) # Optional: Use LLM-as-judge for explanation instead of lexical metrics @@ -312,14 +317,14 @@ def explanation_reward(parser, completion, answer, **kwargs) -> float: raise ValueError("judge_mode must be 'g-eval' or 'factscore'") if judge_mode == "g-eval": - from environments.medexqa.geval_judge.geval_judge import create_geval_judge_rubric - judge_rubric = create_geval_judge_rubric(parser=parser, judge_client=judge_client, judge_model=judge_model) + from medexqa.geval_judge.geval_judge import create_geval_judge_rubric + judge_rubric = create_geval_judge_rubric(parser=parser, judge_client=judge_client, judge_model=judge_model, explanation_weight=explanation_weight) # Combine answer accuracy with the judge-based explanation score judge_rubric.add_reward_func(answer_accuracy_reward, weight=mcq_weight) rubric = judge_rubric elif judge_mode == "factscore": - from environments.medexqa.factscore_judge.atomic_facts_judge import create_factscore_judge_rubric - judge_rubric = create_factscore_judge_rubric(parser=parser, judge_client=judge_client, judge_model=judge_model, use_coverage=use_coverage) + from medexqa.factscore_judge.atomic_facts_judge import create_factscore_judge_rubric + judge_rubric = create_factscore_judge_rubric(parser=parser, judge_client=judge_client, judge_model=judge_model, use_coverage=use_coverage, explanation_weight=explanation_weight) judge_rubric.add_reward_func(answer_accuracy_reward, weight=mcq_weight) rubric = judge_rubric else: diff --git a/environments/medexqa/pyproject.toml b/environments/medexqa/pyproject.toml index ab9bae1b..ffa5a7d4 100644 --- a/environments/medexqa/pyproject.toml +++ b/environments/medexqa/pyproject.toml @@ -23,7 +23,8 @@ requires = ["hatchling"] build-backend = "hatchling.build" [tool.hatch.build] -include = ["medexqa.py"] +include = ["medexqa/**"] +packages = ["medexqa"] [tool.prime.environment] # lets Prime/vf-eval know where the loader lives in a flat repo