diff --git a/environments/medexqa/README.md b/environments/medexqa/README.md new file mode 100644 index 00000000..79115a63 --- /dev/null +++ b/environments/medexqa/README.md @@ -0,0 +1,249 @@ +# medexqa-env- by mnishant2 + +### Overview +- **Environment ID**: `medexqa` +- **Short description**: Medical QA with multiple-choice questions and explanations across five underrepresented medical specialties +- **Tags**: medical, clinical, single-turn, multiple-choice, explanations, evaluation + +### Datasets +- **Primary dataset(s)**: MedExQA +- **Source links**: [Paper](https://arxiv.org/abs/2406.06331), [HuggingFace Dataset](https://huggingface.co/datasets/bluesky333/MedExQA), [GitHub](https://github.com/knowlab/MedExQA) +- **Split sizes**: + + | Specialty | Dev | Test | Total | + | --------------------------- | --- | ---- | ----- | + | Biomedical Engineering | 4 | 144 | 148 | + | Clinical Laboratory Science | 9 | 368 | 377 | + | Clinical Psychology | 3 | 108 | 111 | + | Occupational Therapy | 5 | 189 | 194 | + | Speech Language Pathology | 4 | 131 | 135 | + | **Total** | **25** | **940** | **965** | + +### Task +- **Type**: single-turn +- **Prompting**: Uses the authors' instruction embedded in the user message; options A/B/C/D are included. + ``` + The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question. Your answer should be paired with an explanation why you chose that answer. + ``` +- **Answer extraction [authors' logic](https://github.com/knowlab/MedExQA/blob/9a5b34af103b0c8ba0c00906e278f6572249fafa/evaluate_pipe_MedExQA.py)** : + - Canonical letter extraction using a sequence of regex patterns (e.g., explicit "Answer is A:", leading letter, etc.) + - If no explicit letter is found, fuzzy matching (thefuzz) maps the generated text to the closest option and returns the corresponding letter +- **Parser**: `Parser` or `ThinkParser` with `extract_fn=extract_boxed_answer` supported for think-mode; MCQ scoring uses the authors' extraction logic above. +- Run Evaluation per specialty or on multiple specialties +- Use lexical metrics('rougeL', 'bleu', 'bertscore', 'meteor') or use an LLM-as-a-judge for explanation evaluation +- **Rubric overview**: + - MCQ accuracy: 0 or 100 per example + - Explanation score: 0–100 per example (lexical metrics average); 0 if the answer is wrong + - Combined score: weighted average of MCQ and explanation (`mcq_weight`, `explanation_weight`) +- **Model Download**: + In the first run it will download `wordnet`, `NLTK` and `sciBERT` models for running the lexical metrics + +### Quickstart + +- Run MCQ-only (no explanation scoring): +```bash +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"use_explanations": false}' +``` + +- Run with explanation scoring (lexical metrics): +```bash +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"use_explanations": true}' +``` + +- Use LLM-as-judge for explanations (instead of lexical metrics): +```bash +export JUDGE_API_KEY=sk-... +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"use_explanations": true, "use_judge": true, "judge_model": "gpt-4o-mini"}' +``` + +- Configure sampling and rollouts: +```bash +uv run vf-eval medexqa \ + -m gpt-4.1-mini \ + -n -1 -r 3 -t 1024 -T 0.7 \ + -a '{"use_think": false, "use_explanations": true, "mcq_weight": 0.5, "explanation_weight": 0.5}' +``` + +### Environment Arguments + +| Arg | Type | Default | Description | +| ---------------------- | ---------------------- | -------------- | ----------- | +| `specialty` | list[str] \/ str \| None | `None` | Select one or more specialties. Codes: `BE`, `CLS`, `CP`, `OT`, `SLP`. `None`\/`ALL` loads all. | +| `use_think` | bool | `False` | Use `ThinkParser` to support `...` blocks. | +| `use_explanations` | bool | `True` | Whether to compute explanation scores. | +| `explanation_metrics` | list[str] \/ str \| None | `None` | Lexical metrics to use: any of `rougeL`, `bleu`, `meteor`, `bertscore`. `None`\/`"all"` averages all four. | +| `mcq_weight` | float | `0.5` | Weight for MCQ accuracy in the combined score. | +| `explanation_weight` | float | `0.5` | Weight for explanation in the combined score. | +| `use_judge` | bool | `False` | Use LLM-as-judge for explanations instead of lexical metrics. | +| `judge_mode` | str \| None | `None` | Judge mode: `"g-eval"` or `"factscore"`. Required when `use_judge=True`. | +| `judge_model` | str | `gpt-4o-mini` | Judge model name. | +| `judge_base_url` | str \| None | `None` | Judge API base URL. | +| `judge_api_key` | str \| None | `None` | Judge API key (falls back to `JUDGE_API_KEY` or `OPENAI_API_KEY`). | +| `use_coverage` | bool | `False` | For FactScore only: enable coverage calculation (measures recall). Default is support-only (precision). Increases API calls from ~6-8 to ~12-15 per example. | +| `seed` | int \| None | `None` | When multiple specialties are selected, shuffles the combined eval set with this seed. | + +### Metrics + +- **Answer accuracy (per example)**: 0 or 100. Uses authors' regex+fuzzy logic to extract a letter. +- **Explanation score (per example)**: 0–100. If the answer is wrong, the explanation score is 0. + - Lexical metrics supported: `rougeL`, `bleu`, `meteor`, `bertscore` (w/ SciBERT `allenai/scibert_scivocab_uncased`). + - Selection via `explanation_metrics` (list or `'all'`/`None` to average all four). +- **Combined score**: `mcq_weight * accuracy + explanation_weight * explanation`. + +Optional LLM-as-judge for explanations: +- Set `use_explanations=true` and `use_judge=true` to replace lexical metrics with judge scoring (0–100 after scaling). +- Must specify `judge_mode` when using LLM-as-judge. Two modes are supported: + +#### Judge Modes + +**G-Eval Mode** (`judge_mode="g-eval"`): +- Uses Chain-of-Thought evaluation with 7 structured steps to assess explanation quality. +- Evaluates medical accuracy, correct option justification, distractor analysis, reference alignment, reasoning clarity, and completeness. +- Outputs structured JSON with detailed step-by-step analysis and a final score (0-100). + +**FactScore Mode** (`judge_mode="factscore"`): +- Decomposes explanations into atomic medical claims (5-7 key claims) using specialty-specific few-shot examples. +- Verifies each claim against reference explanations using 3-level support scoring (FULLY_SUPPORTED=1.0, PARTIALLY_SUPPORTED=0.5, NOT_SUPPORTED=0.0). +- Optionally computes coverage to measure if model covers key concepts from references (disabled by default; enable with `use_coverage=True` for comprehensive precision+recall evaluation). + +Usage examples: +```bash +# G-Eval mode +export JUDGE_API_KEY=sk-... +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"use_explanations": true, "use_judge": true, "judge_mode": "g-eval", "judge_model": "gpt-4o-mini"}' + +# FactScore mode (support-only, fast) +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"use_explanations": true, "use_judge": true, "judge_mode": "factscore", "judge_model": "gpt-4o-mini"}' + +# FactScore mode with coverage (support+coverage, comprehensive) +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"use_explanations": true, "use_judge": true, "judge_mode": "factscore", "judge_model": "gpt-4o-mini", "use_coverage": true}' +``` + +#### Judge Rescoring Tool + +The `judge_rescore.py` script allows you to re-evaluate previously generated model completions with LLM judges and save detailed results to CSV files. This is useful for experimenting with different judge configurations without re-running model inference. + +**Usage:** +```bash +# Rescore with FactScore (fast mode: support-only, ~6-8 API calls per example) +export OPENAI_API_KEY=sk-... +uv run python environments/medexqa/tools/judge_rescore.py \ + --judge factscore \ + --base https://openrouter.ai/api/v1 \ + --model openai/gpt-4o-mini \ + --key_var OPENAI_API_KEY \ + --input_glob 'environments/medexqa/outputs/evals/**/results.jsonl' \ + --out_csv_prefix 'environments/medexqa/outputs/judge_scores/medexqa_' \ + --sleep_ms 2000 \ + --max_retries 6 \ + --max_tokens 512 \ + --verbose + +# Rescore with FactScore + coverage (comprehensive: support+coverage, ~12-15 API calls per example) +uv run python environments/medexqa/tools/judge_rescore.py \ + --judge factscore \ + --use_coverage \ + --model openai/gpt-4o-mini \ + ... # other args same as above + +# Rescore with G-Eval (1 API call per example) +uv run python environments/medexqa/tools/judge_rescore.py \ + --judge geval \ + --model openai/gpt-4o-mini \ + --max_tokens 768 \ + ... # other args same as above + +# Rescore with both judges +uv run python environments/medexqa/tools/judge_rescore.py \ + --judge both \ + --model openai/gpt-4o-mini \ + ... # other args same as above +``` + +**Output:** Creates CSV files with detailed judge outputs: +- `medexqa_factscore.csv`: Contains extracted claims, support/coverage rates, and final scores +- `medexqa_geval.csv`: Contains structured JSON evaluation details and scores + +### Specialty Selection and Macro Average + +- Single specialty by code: +```bash +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"specialty": "CLS"}' +``` + +- Multiple specialties: +```bash +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"specialty": ["CLS", "CP"], "seed": 42}' +``` + +- All specialties: +```bash +uv run vf-eval medexqa -m gpt-4.1-mini -a '{"specialty": "ALL"}' +``` + +## IMPORTANT: Macro-average accuracy (as reported in the paper): +- Run each specialty separately and average the per-run average answer accuracies; or +- Run multiple specialties with `-s` to save results. Each saved example includes its `specialty` in `info`, along with the `per-example answer_accuracy_reward`. Use the saved JSONL to compute per-specialty accuracies and then take the unweighted mean across specialties. + +### Testing Instructions + +#### 1. Environment Setup +```bash +# Navigate to repository root +cd /data/storage_hpc_nishant/med-lm-envs + +# Sync uv environment +uv sync +``` + +#### 2. Quick Validation Test (MCQ-only) +```bash +uv run vf-eval medexqa -m gpt-4.1-mini -n 5 -a '{"use_explanations": false}' +``` + +#### 3. Full Evaluation with Save +```bash +export OPENAI_API_KEY=sk-... +uv run vf-eval medexqa -m gpt-4.1-mini -n -1 -s -a '{"specialty": "ALL", "use_explanations": true}' +``` + +#### 4. LLM-as-Judge for Explanations +```bash +export JUDGE_API_KEY=sk-... +uv run vf-eval medexqa -m gpt-4.1-mini -n -1 -s -a '{"use_explanations": true, "use_judge": true, "mcq_weight": 0.5, "explanation_weight": 0.5}' +``` + +#### 5. With Think Tags +```bash +uv run vf-eval medexqa -m gpt-4.1-mini -n -1 -a '{"use_think": true}' +``` + +#### 6. Example Run with openrouter +```bash +export OPENROUTER_API_KEY=.... +uv run vf-eval medexqa -m openai/gpt-oss-20b:free -b https://openrouter.ai/api/v1 -k OPENAI_API_KEY -n 10 -r 1 -c 1 -a '{"use_explanations": true, "explanation_metrics": "all", "specialty": ["BE", "OT"]}' -s +``` +output +```bash +Rewards: +reward: avg - 59.416, std - 19.928 +r1: [67.79, 65.809, 64.158, 66.619, 69.124, 0.0, 66.957, 66.327, 66.87, 60.503] +answer_accuracy_reward: avg - 90.000, std - 30.000 +r1: [100.0, 100.0, 100.0, 100.0, 100.0, 0.0, 100.0, 100.0, 100.0, 100.0] +explanation_reward: avg - 28.832, std - 10.577 +r1: [35.58, 31.618, 28.316, 33.239, 38.249, 0.0, 33.915, 32.653, 33.741, 21.006] +``` +### Citation + +```bibtex +@article{kim2024medexqa, + title={MedExQA: Medical Question Answering Benchmark with Multiple Explanations}, + author={Kim, Yunsoo and Wu, Jinge and Abdulle, Yusuf and Wu, Honghan}, + journal={arXiv preprint arXiv:2406.06331}, + year={2024} +} +``` +### Authors +This environment has been put together by: + +Nishant Mishra - ([mnishant2](https://github.com/mnishant2)) \ No newline at end of file diff --git a/environments/medexqa/medexqa/__init__.py b/environments/medexqa/medexqa/__init__.py new file mode 100644 index 00000000..f01e6615 --- /dev/null +++ b/environments/medexqa/medexqa/__init__.py @@ -0,0 +1,5 @@ +"""MedExQA environment package.""" + +from medexqa.main import load_environment + +__all__ = ["load_environment"] diff --git a/environments/medexqa/medexqa/factscore_judge/__init__.py b/environments/medexqa/medexqa/factscore_judge/__init__.py new file mode 100644 index 00000000..1149e4e5 --- /dev/null +++ b/environments/medexqa/medexqa/factscore_judge/__init__.py @@ -0,0 +1,7 @@ +"""FactScore judge for MedExQA explanations.""" + +from .atomic_facts_judge import create_factscore_judge_rubric, explanation_factscore_reward +from .atomic_facts_generator import AtomicFactGenerator + +__all__ = ["create_factscore_judge_rubric", "explanation_factscore_reward", "AtomicFactGenerator"] + diff --git a/environments/medexqa/medexqa/factscore_judge/atomic_facts_generator.py b/environments/medexqa/medexqa/factscore_judge/atomic_facts_generator.py new file mode 100644 index 00000000..b82d1e4d --- /dev/null +++ b/environments/medexqa/medexqa/factscore_judge/atomic_facts_generator.py @@ -0,0 +1,215 @@ +import json +from typing import List + +from openai import AsyncOpenAI + + +class AtomicFactGenerator: + """ + MedExQA-specific atomic facts generator. + + Extracts concise, checkable medical claims from an MCQA explanation that + support the chosen option and, when useful, refute key distractors. + Returns a Python list of strings (facts), not raw model text. + """ + + def __init__(self, async_openai_client: AsyncOpenAI | None, model_name: str = "gpt-4o-mini") -> None: + self.client = async_openai_client + self.model_name = model_name + + async def run(self, explanation_text: str, state: dict = None) -> List[str]: + """ + Extract atomic facts from an MCQA explanation. + + Args: + explanation_text: The explanation text to extract claims from + state: Optional state dict for token tracking + """ + explanation = (explanation_text or "").strip() + if not explanation: + return [] + + primary = await self._extract_json_claims(explanation, state=state) + if primary: + return primary + + fallback = await self._extract_json_claims(explanation, fallback=True, state=state) + return fallback or [] + + async def _extract_json_claims(self, explanation: str, fallback: bool = False, state: dict = None) -> List[str]: + if self.client is None: + return [] + + if not fallback: + prompt = ( + "You are a medical expert evaluating MCQA (multiple-choice question) explanations.\n" + "Extract atomic, checkable medical claims that: (1) justify why the correct option is right, " + "(2) when applicable, explain why key distractors are wrong, (3) preserve medical terminology.\n\n" + "Rules:\n" + "- Output a strict JSON array of strings ONLY (no extra text).\n" + "- Extract 5-7 MOST IMPORTANT claims (prioritize key medical concepts).\n" + "- Each claim ≤ 30 words; no duplicates; no vague statements.\n" + "- Preserve technical terms and abbreviations (e.g., 'DEXA', 'PTFE', 'AAC').\n" + "- If no checkable medical content, return [].\n\n" + "Few-shot examples (imitate format exactly):\n\n" + "# Biomedical Engineering Example 1:\n" + "Explanation: Membrane oxygenators require materials with high gas permeability for O2 and CO2 exchange. " + "Silicone rubber, polypropylene, and Teflon are highly permeable polymers. Ceramic membranes are dense, " + "brittle, have poor gas permeability, and can cause hemolysis.\n" + "Claims JSON: [\n" + " \"Membrane oxygenators require high gas permeability for O2 and CO2 exchange.\",\n" + " \"Silicone rubber has excellent gas permeability and biocompatibility.\",\n" + " \"Polypropylene provides high gas transfer and is durable.\",\n" + " \"Teflon (PTFE) is chemically inert with good blood contact properties.\",\n" + " \"Ceramic membranes have poor gas permeability compared to polymers.\",\n" + " \"Ceramic membranes are brittle and can cause hemolysis.\"\n" + "]\n\n" + "# Biomedical Engineering Example 2:\n" + "Explanation: Thermographic cameras detect infrared radiation emitted by objects due to temperature. " + "All objects above absolute zero emit infrared radiation. X-rays and UV are higher-energy and not used for thermal imaging. " + "Microwaves are used for radar, not temperature scanning.\n" + "Claims JSON: [\n" + " \"Thermographic cameras detect infrared radiation from objects.\",\n" + " \"All objects above absolute zero emit infrared radiation.\",\n" + " \"Infrared is ideal for measuring surface temperatures.\",\n" + " \"X-rays are too high-energy for conventional thermal imaging.\",\n" + " \"Microwaves are used for radar applications, not thermal cameras.\"\n" + "]\n\n" + "# Clinical Laboratory Science Example 1:\n" + "Explanation: Hemoglobin A1c measures average blood glucose over 2-3 months by detecting glycated hemoglobin. " + "Fasting glucose only reflects current levels. Random glucose varies throughout the day. Oral glucose tolerance test is diagnostic but not for monitoring.\n" + "Claims JSON: [\n" + " \"Hemoglobin A1c measures average blood glucose over 2-3 months.\",\n" + " \"A1c detects glycated hemoglobin formed by glucose binding.\",\n" + " \"Fasting glucose only reflects current blood glucose levels.\",\n" + " \"Random glucose varies throughout the day and is unreliable for averages.\",\n" + " \"OGTT is diagnostic but not suitable for long-term monitoring.\"\n" + "]\n\n" + "# Clinical Laboratory Science Example 2:\n" + "Explanation: Gram staining differentiates bacteria by cell wall structure. Gram-positive bacteria have thick peptidoglycan walls " + "that retain crystal violet stain. Gram-negative bacteria have thin peptidoglycan and outer membranes, appearing pink after counterstaining.\n" + "Claims JSON: [\n" + " \"Gram staining differentiates bacteria by cell wall structure.\",\n" + " \"Gram-positive bacteria have thick peptidoglycan cell walls.\",\n" + " \"Thick peptidoglycan retains crystal violet stain in Gram-positive bacteria.\",\n" + " \"Gram-negative bacteria have thin peptidoglycan and outer membranes.\",\n" + " \"Gram-negative bacteria appear pink after safranin counterstaining.\"\n" + "]\n\n" + "# Clinical Psychology Example 1:\n" + "Explanation: Cognitive-behavioral therapy (CBT) is first-line for generalized anxiety disorder, with strong evidence for efficacy. " + "Psychodynamic therapy lacks robust evidence for GAD. Exposure therapy is specific to phobias. Supportive therapy alone is insufficient for GAD.\n" + "Claims JSON: [\n" + " \"CBT is first-line treatment for generalized anxiety disorder.\",\n" + " \"CBT has strong evidence for efficacy in treating GAD.\",\n" + " \"Psychodynamic therapy lacks robust evidence for GAD treatment.\",\n" + " \"Exposure therapy is specific to phobias, not GAD.\",\n" + " \"Supportive therapy alone is insufficient for GAD management.\"\n" + "]\n\n" + "# Clinical Psychology Example 2:\n" + "Explanation: The PHQ-9 is a validated 9-item screening tool for major depressive disorder with scores 0-27. " + "Scores ≥10 indicate moderate depression requiring clinical evaluation. It assesses DSM-5 criteria for MDD.\n" + "Claims JSON: [\n" + " \"PHQ-9 is a validated screening tool for major depressive disorder.\",\n" + " \"PHQ-9 contains 9 items with total scores ranging 0-27.\",\n" + " \"Scores ≥10 indicate moderate depression needing evaluation.\",\n" + " \"PHQ-9 assesses DSM-5 diagnostic criteria for MDD.\"\n" + "]\n\n" + "# Occupational Therapy Example 1:\n" + "Explanation: The Barthel Index measures independence in activities of daily living (ADL) across 10 domains. " + "Scores range 0-100, with higher scores indicating greater independence. It's reliable for tracking functional recovery post-stroke.\n" + "Claims JSON: [\n" + " \"Barthel Index measures independence in activities of daily living.\",\n" + " \"The index assesses 10 functional domains.\",\n" + " \"Scores range from 0 (dependent) to 100 (independent).\",\n" + " \"Higher Barthel Index scores indicate greater functional independence.\",\n" + " \"Barthel Index is reliable for tracking post-stroke recovery.\"\n" + "]\n\n" + "# Occupational Therapy Example 2:\n" + "Explanation: Adaptive utensils with built-up handles improve grip for patients with arthritis by reducing required pinch force. " + "Weighted utensils help tremor patients. Angled utensils assist those with limited wrist mobility. Standard utensils lack these modifications.\n" + "Claims JSON: [\n" + " \"Built-up handle utensils improve grip for arthritis patients.\",\n" + " \"Built-up handles reduce required pinch force during eating.\",\n" + " \"Weighted utensils help stabilize tremors during eating.\",\n" + " \"Angled utensils assist patients with limited wrist mobility.\",\n" + " \"Standard utensils lack these adaptive modifications.\"\n" + "]\n\n" + "# Speech Pathology Example 1:\n" + "Explanation: Videofluoroscopic swallow study (VFSS) is the gold standard for dysphagia evaluation, visualizing all swallowing phases. " + "It detects aspiration, penetration, and pharyngeal residue. Clinical swallow exam cannot visualize aspiration. Endoscopy misses oral phase.\n" + "Claims JSON: [\n" + " \"VFSS is the gold standard for dysphagia evaluation.\",\n" + " \"VFSS visualizes all phases of swallowing in real-time.\",\n" + " \"VFSS can detect aspiration, penetration, and pharyngeal residue.\",\n" + " \"Clinical swallow examination cannot visualize aspiration.\",\n" + " \"Fiberoptic endoscopic evaluation misses the oral phase of swallowing.\"\n" + "]\n\n" + "# Speech Pathology Example 2:\n" + "Explanation: The Peabody Picture Vocabulary Test (PPVT) assesses receptive vocabulary in children and adults. " + "It requires pointing to pictures, not verbal responses, making it suitable for nonverbal individuals. Expressive language tests require speech production.\n" + "Claims JSON: [\n" + " \"PPVT assesses receptive vocabulary in children and adults.\",\n" + " \"PPVT requires pointing to pictures, not verbal responses.\",\n" + " \"PPVT is suitable for assessing nonverbal individuals.\",\n" + " \"Expressive language tests require speech production.\"\n" + "]\n\n" + "Now extract claims for the explanation below.\n\n" + f"Explanation:\n{explanation}\n\n" + "Claims JSON:" + ) + else: + prompt = ( + "Extract atomic, checkable medical claims from this MCQA explanation.\n" + "Return ONLY a JSON array of 4–10 strings; each ≤ 30 words. If none, return [].\n\n" + f"Explanation:\n{explanation}\n\n" + "Claims JSON:" + ) + + try: + resp = await self.client.chat.completions.create( + model=self.model_name, + messages=[{"role": "user", "content": prompt}], + temperature=0.0, + max_tokens=512, + ) + text = (resp.choices[0].message.content or "").strip() + return _parse_json_list(text) + except Exception: + return [] + + +def _parse_json_list(text: str) -> List[str]: + try: + data = json.loads(text) + if isinstance(data, list): + out = [] + for x in data: + s = (str(x) or "").strip() + if s: + out.append(s) + # keep unique order + seen = set() + uniq = [] + for s in out: + if s not in seen: + uniq.append(s) + seen.add(s) + return uniq + return [] + except Exception: + # fallback: find bracketed content + try: + start = text.find("[") + end = text.rfind("]") + if start != -1 and end != -1 and end > start: + data = json.loads(text[start : end + 1]) + if isinstance(data, list): + return [str(x).strip() for x in data if str(x).strip()] + except Exception: + pass + return [] + + + + + diff --git a/environments/medexqa/medexqa/factscore_judge/atomic_facts_judge.py b/environments/medexqa/medexqa/factscore_judge/atomic_facts_judge.py new file mode 100644 index 00000000..6125e742 --- /dev/null +++ b/environments/medexqa/medexqa/factscore_judge/atomic_facts_judge.py @@ -0,0 +1,246 @@ +""" +FactScore-style judge for MedExQA explanations (reference-only, no external retrieval). + +Two-step process: +1) Extract atomic medical claims from the model's explanation. +2) Verify each claim against available references: question, correct option text, exp0, exp1. + +Returns support rate in [0, 1], scaled to [0, 100] for reward. +""" + +import json +import re +import verifiers as vf +from .atomic_facts_generator import AtomicFactGenerator + + +JUDGE_TEMPLATE = """You are a medical knowledge verification expert. Evaluate if the Passage supports the Claim. + +PASSAGE: +{response} + +CLAIM TO VERIFY: +{answer} + +INSTRUCTIONS: +1. Check if the claim is FULLY supported by the passage with explicit evidence +2. Check if the claim is PARTIALLY supported (implied/inferable but not explicit) +3. Check if the claim is NOT supported (no evidence or contradicts passage) + +Respond with EXACTLY ONE of: +- "FULLY_SUPPORTED" - explicit evidence exists in passage +- "PARTIALLY_SUPPORTED" - implied/inferable from passage +- "NOT_SUPPORTED" - no evidence or contradicts passage + +Your response:""".strip() + + +def extract_support_level(text: str) -> tuple[float, bool]: + """ + Extract support level from LLM judge response. + + Returns: + (score, valid): score is 0.0, 0.5, or 1.0; valid indicates if parsing succeeded + """ + cleaned_text = (text or "").strip().upper() + + # Check for 3-level responses + if "FULLY_SUPPORTED" in cleaned_text or "FULLY SUPPORTED" in cleaned_text: + return (1.0, True) + if "PARTIALLY_SUPPORTED" in cleaned_text or "PARTIALLY SUPPORTED" in cleaned_text: + return (0.5, True) + if "NOT_SUPPORTED" in cleaned_text or "NOT SUPPORTED" in cleaned_text: + return (0.0, True) + + # Fallback to old binary format for backwards compatibility + cleaned_lower = cleaned_text.lower() + has_true = "true" in cleaned_lower + has_false = "false" in cleaned_lower + if has_true and not has_false: + return (1.0, True) + if has_false and not has_true: + return (0.0, True) + + # Ambiguous response + return (0.0, False) + + +async def explanation_factscore_reward( + judge, + prompt, + completion, + answer, + state, + **kwargs, +) -> float: + # parse explanation text + if isinstance(completion, list) and completion: + explanation = completion[-1].get("content", "") or "" + else: + explanation = str(completion) + + info = kwargs.get("info", {}) or {} + options = {k: info.get(k, "") for k in ["A", "B", "C", "D"]} + question = info.get("question", "") + exp0 = info.get("exp0", "") + exp1 = info.get("exp1", "") + correct_letter = (answer or "").strip().upper() + correct_option_text = options.get(correct_letter, "") + + # Gate explanation to zero if predicted MCQ answer is wrong + # Parse answer first (extracts from \boxed{} in think mode, returns raw text in normal mode) + parser = kwargs.get("parser") + if parser: + parsed = parser.parse_answer(completion) or "" + else: + parsed = explanation + + from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy + + is_correct = multiple_choice_accuracy( + llm_answer=parsed, + answer_letter=correct_letter, + answer_text=correct_option_text, + accept_answer_text=True, + strip_tex=False, + ) + + if not is_correct: + return 0.0 + + # Build references block + refs = ( + f"Question: {question}\n" + f"Correct option ({correct_letter}): {correct_option_text}\n" + f"Reference Explanation 1: {exp0}\n" + f"Reference Explanation 2: {exp1}" + ) + + # Initialize generator (reuse medredqa style) + llm_client = kwargs.get("judge_client") + llm_model = kwargs.get("judge_model", "gpt-4o-mini") + generator = AtomicFactGenerator(llm_client, model_name=llm_model) + + # Extract atomic claims from model explanation + try: + if llm_client is None: + # No client available - cannot extract claims + return 0.0 + claims = await generator.run(explanation, state=state) + except Exception as e: + # Log extraction error for debugging + import sys + print(f"Warning: Atomic facts extraction failed: {e}", file=sys.stderr) + claims = [] + if not claims: + return 0.0 + + # Step 2a: verify each model claim against references (support_rate) + # One call per claim like MedRedQA approach + support_score = 0.0 + total = 0 + + for claim in claims: + total += 1 + # Call judge like medredqa does: judge(prompt, completion, answer, state, **kwargs) + # prompt is not used in template, completion becomes {response}, answer becomes {answer} + judge_response = await judge(prompt, refs, str(claim), state, **kwargs) + score, ok = extract_support_level(judge_response) + if ok: + support_score += score + + support_rate = (support_score / total) if total > 0 else 0.0 + + # Step 2b: Coverage rate - DISABLED by default for speed + # This measures recall: does the model explanation cover key reference concepts? + # Enable with use_coverage=True in kwargs for balanced precision+recall evaluation + use_coverage = kwargs.get("use_coverage", False) + coverage_rate = 0.0 + + if use_coverage and llm_client is not None: + # Extract claims from both reference explanations + all_ref_claims: list[str] = [] + + # Extract from reference 1 + if (exp0 or "").strip(): + try: + ref0_claims = await generator.run(exp0, state=state) + all_ref_claims.extend(ref0_claims) + except Exception: + pass + + # Extract from reference 2 + if (exp1 or "").strip(): + try: + ref1_claims = await generator.run(exp1, state=state) + all_ref_claims.extend(ref1_claims) + except Exception: + pass + + # Remove duplicates while preserving order + seen = set() + unique_ref_claims = [] + for claim in all_ref_claims: + if claim not in seen: + unique_ref_claims.append(claim) + seen.add(claim) + + # Verify each reference claim against model explanation + coverage_score = 0.0 + coverage_total = 0 + + for ref_claim in unique_ref_claims: + coverage_total += 1 + # Check if model explanation supports this reference claim + # Call judge: passage=explanation, claim=ref_claim + cov_response = await judge(prompt, explanation, str(ref_claim), state, **kwargs) + score, ok = extract_support_level(cov_response) + if ok: + coverage_score += score + + coverage_rate = (coverage_score / coverage_total) if coverage_total > 0 else 0.0 + + # Combine support and coverage (if enabled) + if use_coverage: + # Use weighted combination when coverage is enabled + w_support = float(kwargs.get("support_weight", 0.5)) + w_coverage = float(kwargs.get("coverage_weight", 0.5)) + denom = w_support + w_coverage if (w_support + w_coverage) > 0 else 1.0 + final = (w_support * support_rate + w_coverage * coverage_rate) / denom + else: + # Coverage disabled: use support_rate only + final = support_rate + + # Optionally stash structured details for external loggers (if passed in kwargs) + # Caller can access via state or judge logs; for rescore tool we return these via logs reconstruction + state = state or {} + try: + state["factscore_details"] = { + "support_rate": float(support_rate), + "coverage_rate": float(coverage_rate), + } + except Exception: + pass + + return float(final * 100.0) + + +def create_factscore_judge_rubric( + parser: vf.Parser, + judge_client, + judge_model: str = "gpt-4o-mini", + use_coverage: bool = False, + explanation_weight: float = 1.0, +) -> vf.JudgeRubric: + # Pass judge_prompt like medredqa does - uses standard {response} and {answer} placeholders + rubric = vf.JudgeRubric( + judge_client=judge_client, + judge_model=judge_model, + judge_prompt=JUDGE_TEMPLATE, + parser=parser, + use_coverage=use_coverage, # Pass through to reward function via kwargs + ) + rubric.add_reward_func(explanation_factscore_reward, weight=explanation_weight) + return rubric + + diff --git a/environments/medexqa/medexqa/geval_judge/__init__.py b/environments/medexqa/medexqa/geval_judge/__init__.py new file mode 100644 index 00000000..faa0103a --- /dev/null +++ b/environments/medexqa/medexqa/geval_judge/__init__.py @@ -0,0 +1,6 @@ +"""G-Eval judge for MedExQA explanations.""" + +from .geval_judge import create_geval_judge_rubric, explanation_geval_reward + +__all__ = ["create_geval_judge_rubric", "explanation_geval_reward"] + diff --git a/environments/medexqa/medexqa/geval_judge/geval_judge.py b/environments/medexqa/medexqa/geval_judge/geval_judge.py new file mode 100644 index 00000000..d83e90eb --- /dev/null +++ b/environments/medexqa/medexqa/geval_judge/geval_judge.py @@ -0,0 +1,214 @@ +import re +import json +import verifiers as vf + + +GEVAL_CRITERIA = """You are evaluating a medical MCQA explanation for quality and correctness. + +Assess the explanation across these dimensions: + +1. MEDICAL ACCURACY: Does the explanation contain factually correct medical information that aligns with the correct option and reference explanations? Are there any medical errors or contradictions? + +2. CORRECT OPTION JUSTIFICATION: Does the explanation clearly explain WHY the correct answer is medically appropriate using valid clinical/scientific reasoning? + +3. DISTRACTOR ANALYSIS (when applicable): Does the explanation explain why incorrect options are wrong? Note: Not all explanations need this, but it enhances quality when present. + +4. REFERENCE ALIGNMENT: Do the explanation's core medical claims align with the key concepts in both reference explanations? + +5. REASONING CLARITY: Is the medical reasoning easy to follow with logical flow from evidence to conclusion? + +6. COMPLETENESS: Does the explanation cover the essential medical concepts without major omissions?""" + + +GEVAL_EVALUATION_STEPS = [ + "Extract all medical claims from the actual explanation and list them explicitly", + "Compare each claim against the correct option text and both reference explanations. Mark claims as: ALIGNED (matches references), CONTRADICTS (conflicts with references), or NEW_INFO (additional but not contradictory)", + "Identify if the explanation justifies WHY the correct option is right (not just states it is correct)", + "Check for any major medical errors, inaccuracies, or unsupported claims that could mislead", + "Assess whether distractor refutation is present and accurate (if applicable to this question)", + "Evaluate overall reasoning clarity, logical flow, and completeness", + "Synthesize findings into a score using this rubric: 0.0-0.2 (major errors/irrelevant/contradicts references), 0.2-0.4 (significant gaps/multiple minor errors), 0.4-0.6 (acceptable but incomplete/some inaccuracies), 0.6-0.8 (good quality with minor issues), 0.8-1.0 (excellent: comprehensive, accurate, well-reasoned)" +] + + +GEVAL_PROMPT_TEMPLATE = """You are a strict medical explanation evaluator following a structured evaluation process. + +CRITERIA: +{criteria} + +EVALUATION STEPS (follow these in order): +{evaluation_steps} + +OUTPUT FORMAT: +Respond with a JSON object containing your step-by-step analysis and final score. Use this exact structure: +{{ + "step1_claims_extracted": ["claim1", "claim2", ...], + "step2_alignment_analysis": {{ + "aligned_claims": [...], + "contradicting_claims": [...], + "new_info_claims": [...] + }}, + "step3_correct_option_justified": true/false, + "step4_medical_errors_found": ["error description"] or [], + "step5_distractor_refutation": "present_and_accurate" / "present_but_weak" / "absent" / "not_applicable", + "step6_reasoning_assessment": "clear" / "somewhat_clear" / "confusing", + "step7_final_score": 0.XX, + "score_justification": "Brief 1-2 sentence explanation of the score" +}} + +QUESTION CONTEXT: +Question: {question} +Options: +{options} +Correct Answer: {correct_answer} + +REFERENCE EXPLANATIONS: +Reference 1: {ref_exp1} +Reference 2: {ref_exp2} + +MODEL EXPLANATION TO EVALUATE: +{model_explanation} + +Provide your evaluation as JSON:""" + + +def _extract_score_from_json(text: str) -> tuple[float, dict]: + """ + Extract score from JSON response. + + Returns: + (score, parsed_dict): score is 0.0-1.0; parsed_dict contains full evaluation + """ + try: + # Try to parse as JSON first + # Find JSON object in response (may have extra text before/after) + json_match = re.search(r'\{.*\}', text, re.DOTALL) + if json_match: + json_str = json_match.group(0) + data = json.loads(json_str) + + # Extract score from step7_final_score or final_score + score = float(data.get("step7_final_score", data.get("final_score", 0.0))) + score = max(0.0, min(1.0, score)) + return score, data + except Exception: + pass + + # Fallback: try old "final_score:" pattern + try: + m = re.search(r"final_score\s*:\s*(\d+\.\d+|\d+)", text, flags=re.IGNORECASE) + if m: + score = float(m.group(1)) + return max(0.0, min(1.0, score)), {} + except Exception: + pass + + # Last fallback: extract any number + try: + m = re.search(r"(\d+\.\d+|\d+)", text.strip()) + if m: + val = float(m.group(1)) + return max(0.0, min(1.0, val)), {} + except Exception: + pass + + return 0.0, {} + + +async def explanation_geval_reward( + judge, + prompt, + completion, + answer, + state, + **kwargs, +) -> float: + # Extract the last assistant message content as the explanation text + if isinstance(completion, list) and completion: + completion_text = completion[-1].get("content", "") or "" + else: + completion_text = str(completion) + + info = kwargs.get("info", {}) or {} + options = {k: info.get(k, "") for k in ["A", "B", "C", "D"]} + question = info.get("question", "") + exp0 = info.get("exp0", "") + exp1 = info.get("exp1", "") + correct_letter = (answer or "").strip().upper() + + # Gate explanation to zero if predicted MCQ answer is wrong + # Parse answer first (extracts from \boxed{} in think mode, returns raw text in normal mode) + parser = kwargs.get("parser") + if parser: + parsed = parser.parse_answer(completion) or "" + else: + parsed = completion_text + + from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy + + correct_option_text = options.get(correct_letter, "") + is_correct = multiple_choice_accuracy( + llm_answer=parsed, + answer_letter=correct_letter, + answer_text=correct_option_text, + accept_answer_text=True, + strip_tex=False, + ) + + if not is_correct: + return 0.0 + + # Format options string + opts_str = "\n".join(f"{k}. {options.get(k, '')}" for k in ["A", "B", "C", "D"]) + + # Format evaluation steps for prompt + steps_formatted = "\n".join([f"{i+1}. {step}" for i, step in enumerate(GEVAL_EVALUATION_STEPS)]) + + # Build the full prompt using the new template + full_prompt = GEVAL_PROMPT_TEMPLATE.format( + criteria=GEVAL_CRITERIA, + evaluation_steps=steps_formatted, + question=question, + options=opts_str, + correct_answer=f"{correct_letter} ({options.get(correct_letter, '')})", + ref_exp1=exp0, + ref_exp2=exp1, + model_explanation=completion_text + ) + + # Call judge with structured prompt requesting JSON + judge_response = await judge([ + {"role": "system", "content": "You are a strict, deterministic medical evaluator. Follow the evaluation steps carefully and output valid JSON only."}, + {"role": "user", "content": full_prompt} + ], "", "", state, **kwargs) + + # Parse JSON response and extract score + txt = str(judge_response) + score, eval_details = _extract_score_from_json(txt) + + # Optionally store evaluation details in state for debugging/logging + if state is not None and eval_details: + try: + state["geval_details"] = eval_details + except Exception: + pass + + return float(score * 100.0) + + +def create_geval_judge_rubric( + parser: vf.Parser, + judge_client, + judge_model: str = "gpt-4o-mini", + explanation_weight: float = 1.0, +) -> vf.JudgeRubric: + rubric = vf.JudgeRubric( + judge_client=judge_client, + judge_model=judge_model, + judge_prompt="{question}", # not used directly; reward builds full prompt + parser=parser, + ) + rubric.add_reward_func(explanation_geval_reward, weight=explanation_weight) + return rubric + + diff --git a/environments/medexqa/medexqa/main.py b/environments/medexqa/medexqa/main.py new file mode 100644 index 00000000..45440b7d --- /dev/null +++ b/environments/medexqa/medexqa/main.py @@ -0,0 +1,351 @@ +import os +import re + +import verifiers as vf +from datasets import Dataset, concatenate_datasets +import pandas as pd +import evaluate +from openai import AsyncOpenAI + +from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy +from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice + + +# MedExQA specialties +SPECIALTIES = [ + "biomedical_engineer", + "clinical_laboratory_scientist", + "clinical_psychologist", + "occupational_therapist", + "speech_pathologist", +] + +# author prompt directly taken from https://github.com/knowlab/MedExQA/blob/9a5b34af103b0c8ba0c00906e278f6572249fafa/evaluate_pipe_MedExQA.py#L32 +def _build_question_str(question: str, options: dict[str, str], use_think: bool = False) -> str: + """Build user prompt with authors' instruction embedded (as in their script). + + The instruction lives in the user message; the system prompt remains empty in + normal mode. In think mode, system prompt instructs use of and tags. + """ + instruction = ( + "The following is a multiple-choice question. Please choose the most suitable one " + "among A, B, C and D as the answer to this question. Your answer should be paired " + "with an explanation why you chose that answer.\n\n" + ) + opts = "\n".join(f"{k}. {v}" for k, v in options.items()) + return f"{instruction}{question}\n{opts}\nAnswer:" + + +def _to_vf_format(ds: Dataset, shuffle_answers: bool, shuffle_seed: int | None) -> Dataset: + """Normalize raw rows into the fields expected by SingleTurnEnv. + + Produces rows of the form: + - question: string containing authors' instruction, question, and options + - answer: gold letter (A/B/C/D) - shuffled if shuffle_answers=True + - info: original fields including exp0/exp1 and specialty, plus shuffled options + """ + def _format_row(row: dict) -> dict: + question = row.get("question", "") or "" + + # Build options dict from A, B, C, D columns + opts = { + "A": row.get("A", ""), + "B": row.get("B", ""), + "C": row.get("C", ""), + "D": row.get("D", ""), + } + + # Get answer letter + answer_letter = (row.get("answer") or "").strip().upper() + if answer_letter not in ("A", "B", "C", "D"): + return None + + # Shuffle options if requested + if shuffle_answers: + opts, answer_letter, _ = randomize_multiple_choice( + options=opts, + answer_choice=answer_letter, + seed=shuffle_seed, + row_id=question, # Use question text for deterministic per-row shuffling + ) + + question_str = _build_question_str(question, opts) + + # Keep original data in info + info = dict(row) + + # Update info with shuffled values + if shuffle_answers: + info["A"] = opts["A"] + info["B"] = opts["B"] + info["C"] = opts["C"] + info["D"] = opts["D"] + info["answer"] = answer_letter + + return { + "question": question_str, + "answer": answer_letter, + "info": info, + } + + # Disable cache when shuffling to ensure fresh randomization + load_from_cache_file = not shuffle_answers + return ds.map( + _format_row, + remove_columns=ds.column_names, + load_from_cache_file=load_from_cache_file + ).filter(lambda row: row is not None, load_from_cache_file=load_from_cache_file) + + +def load_environment( + use_think: bool = False, + use_explanations: bool = True, + mcq_weight: float = 0.5, + explanation_weight: float = 0.5, + specialty: list[str] | str | None = None, # list of short codes or full names; None/"ALL" => all + explanation_metrics: list[str] | str | None = None, # None/"all" => average of all four + # MCQ shuffling + shuffle_answers: bool = False, + shuffle_seed: int | None = 1618, + # Optional judge settings + use_judge: bool = False, + judge_mode: str | None = None, # "g-eval" | "factscore" + judge_model: str = "gpt-4o-mini", + judge_base_url: str | None = None, + judge_api_key: str | None = None, + use_coverage: bool = False, # For FactScore: enable coverage calculation (slower but comprehensive) + **kwargs +) -> vf.Environment: + """ + Single-turn MedExQA environment using HuggingFace `bluesky333/MedExQA` dataset + + Key behaviors: + - User prompt embeds the authors' instruction and the options (authors' format). + - System prompt: empty (normal) or THINK_BOXED (think mode). + - Specialty selection: accepts list or string; loads requested specialties (None/ALL => all). + - MCQ accuracy: authors' regex+fuzzy extraction; returns 0 or 100. + - Explanation score: lexical metrics (ROUGE-L, BLEU, METEOR, BERTScore) averaged 0–100; 0 if answer wrong. + - Optional judge mode: explanation scored by JudgeRubric (0–100). + """ + + # Load specialties (one or more) + # Note: MedExQA only has dev and test splits, no train split + # Load TSV files directly since HF dataset has column name issues + + # Resolve allowed specialties up-front and only load those files + code_map = { + "BE": "biomedical_engineer", + "CLS": "clinical_laboratory_scientist", + "CP": "clinical_psychologist", + "OT": "occupational_therapist", + "SLP": "speech_pathologist", + "ALL": "all", + } + allowed_names: set[str] + if specialty is None or (isinstance(specialty, str) and (specialty.upper() in ("ALL", ""))): + allowed_names = set(SPECIALTIES) + elif isinstance(specialty, str): + allowed_names = {code_map.get(specialty.upper(), specialty)} + else: + tmp = set() + for s in specialty: + name = code_map.get((s or "").upper(), s) + if name and name != "all": + tmp.add(name) + allowed_names = tmp if tmp else set(SPECIALTIES) + macro_active = len(allowed_names) > 1 + + # Load all requested specialties + test_datasets = [] + for sp_name in SPECIALTIES: + if sp_name not in allowed_names: + continue + try: + url = f"https://huggingface.co/datasets/bluesky333/MedExQA/resolve/main/test/{sp_name}_test.tsv" + df = pd.read_csv( + url, + sep='\t', + header=None, + names=["question", "A", "B", "C", "D", "exp0", "exp1", "answer"] + ) + df['specialty'] = sp_name + ds_part = Dataset.from_pandas(df, preserve_index=False) + test_datasets.append(ds_part) + except Exception as e: + print(f"Warning: Could not load {sp_name}: {e}") + continue + + # Concatenate and format for verifiers - no training dataset available + test_combined = concatenate_datasets(test_datasets) if test_datasets else None + test_ds = _to_vf_format(test_combined, shuffle_answers, shuffle_seed) if test_combined else None + + # Shuffle examples if multiple specialties were selected + if macro_active and test_ds is not None: + try: + test_ds = test_ds.shuffle(seed=int(kwargs.get("seed", 0))) + except Exception: + pass + + # Setup system prompt and parser - standardized with medredqa approach + # - Normal mode: No system prompt, parser returns raw text + # - Think mode: XML system prompt, XMLParser extracts from tags + if use_think: + # Like medredqa: think in tags, answer+explanation in tags + system_prompt = ( + "Think step-by-step inside ... tags. " + "Then, inside ... tags, provide your final answer choice (A, B, C, or D) " + "followed by an explanation of why you chose that answer." + ) + parser = vf.XMLParser(fields=["think", "answer"], answer_field="answer") + else: + # Normal mode: no system prompt, parser returns raw text for multiple_choice_accuracy + system_prompt = "" + parser = vf.Parser() + + # (shuffling handled above when multiple specialties) + + # Lexical Metrics selection; pass individually or None/'all'/'overall' => average of all four + base_metrics = ["rougeL", "bleu", "meteor", "bertscore"] + if explanation_metrics is None: + selected_metrics = base_metrics + else: + if isinstance(explanation_metrics, str) and explanation_metrics.lower() in ("all", "overall"): + selected_metrics = base_metrics + elif isinstance(explanation_metrics, list) and any(str(m).lower() in ("all", "overall") for m in explanation_metrics): + selected_metrics = base_metrics + else: + selected_metrics = explanation_metrics + + def compute_metric_score(metric_name: str, prediction: str, refs: list[str]) -> float: + try: + name = metric_name.lower() + if name in ("rouge", "rougel"): + rouge = evaluate.load("rouge") + res = rouge.compute(predictions=[prediction], references=[refs]) + return float(res.get("rougeL", 0.0)) * 100.0 + if name == "bleu": + bleu = evaluate.load("bleu") + res = bleu.compute(predictions=[prediction], references=[refs]) + sc = float(res.get("bleu", 0.0)) + return sc * 100.0 if sc <= 1.0 else sc + if name == "meteor": + meteor = evaluate.load("meteor") + res = meteor.compute(predictions=[prediction], references=[refs]) + sc = float(res.get("meteor", 0.0)) + return sc * 100.0 if sc <= 1.0 else sc + if name == "bertscore": + bscore = evaluate.load("bertscore") + res = bscore.compute( + predictions=[prediction], + references=[refs], + model_type="allenai/scibert_scivocab_uncased", + lang="en", + rescale_with_baseline=False, + ) + f1_list = res.get("f1", []) + return (float(f1_list[0]) * 100.0) if f1_list else 0.0 + return 0.0 + except Exception: + return 0.0 + + def compute_expl_score(pred: str, exp0: str, exp1: str) -> float: + refs = [exp0 or "", exp1 or ""] + metric_vals = [compute_metric_score(m, pred, refs) for m in selected_metrics] + metric_vals = [v for v in metric_vals if v is not None] + if not metric_vals: + return 0.0 + # always average across selected metrics + return (sum(metric_vals) / len(metric_vals)) + + # Note: No per-example macro scaling. + + def _get_completion_text(completion_obj) -> str: + if isinstance(completion_obj, list) and completion_obj: + return completion_obj[-1].get("content", "") or "" + return completion_obj if isinstance(completion_obj, str) else str(completion_obj) + + def answer_accuracy_reward(parser, completion, answer, **kwargs) -> float: + # Parse answer first (extracts from \boxed{} in think mode, returns raw text in normal mode) + parsed = parser.parse_answer(completion) or "" + info = kwargs.get("info", {}) or {} + + # Get answer_text for fallback matching + options = {"A": info.get("A", ""), "B": info.get("B", ""), "C": info.get("C", ""), "D": info.get("D", "")} + answer_text = options.get(answer, "") + + is_correct = multiple_choice_accuracy( + llm_answer=parsed, + answer_letter=answer, + answer_text=answer_text, + accept_answer_text=True, + strip_tex=False, # MedExQA doesn't use LaTeX + ) + return 100.0 if is_correct else 0.0 + + def explanation_reward(parser, completion, answer, **kwargs) -> float: + # Parse answer first (extracts from \boxed{} in think mode, returns raw text in normal mode) + parsed = parser.parse_answer(completion) or "" + info = kwargs.get("info", {}) or {} + + # Get answer_text for fallback matching + options = {"A": info.get("A", ""), "B": info.get("B", ""), "C": info.get("C", ""), "D": info.get("D", "")} + answer_text = options.get(answer, "") + + # Check if answer is correct using multiple_choice_accuracy + is_correct = multiple_choice_accuracy( + llm_answer=parsed, + answer_letter=answer, + answer_text=answer_text, + accept_answer_text=True, + strip_tex=False, + ) + + if not is_correct: + return 0.0 + else: + # For lexical metrics, use the raw completion text (not parsed) + completion_text = _get_completion_text(completion) + return compute_expl_score(completion_text, info.get("exp0", ""), info.get("exp1", "")) + + # Optional: Use LLM-as-judge for explanation instead of lexical metrics + if use_explanations and use_judge: + api_key = judge_api_key if judge_api_key else os.getenv("JUDGE_API_KEY") or os.getenv("OPENAI_API_KEY") + judge_client = AsyncOpenAI(base_url=judge_base_url, api_key=api_key) if api_key else None + if judge_mode is None: + raise ValueError("use_judge=True requires judge_mode to be one of {'g-eval','factscore'}") + if judge_mode not in ("g-eval", "factscore"): + raise ValueError("judge_mode must be 'g-eval' or 'factscore'") + + if judge_mode == "g-eval": + from medexqa.geval_judge.geval_judge import create_geval_judge_rubric + judge_rubric = create_geval_judge_rubric(parser=parser, judge_client=judge_client, judge_model=judge_model, explanation_weight=explanation_weight) + # Combine answer accuracy with the judge-based explanation score + judge_rubric.add_reward_func(answer_accuracy_reward, weight=mcq_weight) + rubric = judge_rubric + elif judge_mode == "factscore": + from medexqa.factscore_judge.atomic_facts_judge import create_factscore_judge_rubric + judge_rubric = create_factscore_judge_rubric(parser=parser, judge_client=judge_client, judge_model=judge_model, use_coverage=use_coverage, explanation_weight=explanation_weight) + judge_rubric.add_reward_func(answer_accuracy_reward, weight=mcq_weight) + rubric = judge_rubric + else: + # Lexical metrics for explanations (or MCQ-only if use_explanations=False) + if use_explanations: + rubric = vf.Rubric( + funcs=[answer_accuracy_reward, explanation_reward], + weights=[mcq_weight, explanation_weight], + parser=parser + ) + else: + # MCQ-only mode + rubric = vf.Rubric(funcs=[answer_accuracy_reward], weights=[1.0], parser=parser) + + env = vf.SingleTurnEnv( + dataset=None, # No training split available + eval_dataset=test_ds, + system_prompt=system_prompt, + parser=parser, + rubric=rubric, + **kwargs + ) + + return env diff --git a/environments/medexqa/pyproject.toml b/environments/medexqa/pyproject.toml new file mode 100644 index 00000000..ffa5a7d4 --- /dev/null +++ b/environments/medexqa/pyproject.toml @@ -0,0 +1,33 @@ +[project] +name = "medexqa" +version = "0.1.0" +description = "MedExQA Evaluation - Medical QA with Multiple Explanations" +readme = "README.md" +requires-python = ">=3.11" +authors = [ + { name = "Nishant Mishra", email = "mnishant2@gmail.com" }, +] +dependencies = [ + "datasets>=4.0.0", + "verifiers>=0.1.2.post0", + "pandas>=2.0.0", + "evaluate>=0.4.0", + "rouge-score>=0.1.2", + "sacrebleu>=2.4.0", + "bert-score>=0.3.13", + "openai>=1.0.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["medexqa/**"] +packages = ["medexqa"] + +[tool.prime.environment] +# lets Prime/vf-eval know where the loader lives in a flat repo +loader = "medexqa:load_environment" +display_name = "MedExQA" +visibility = "PUBLIC" diff --git a/environments/medexqa/tools/judge_rescore.py b/environments/medexqa/tools/judge_rescore.py new file mode 100644 index 00000000..7f5ada8a --- /dev/null +++ b/environments/medexqa/tools/judge_rescore.py @@ -0,0 +1,350 @@ +import argparse +import asyncio +import csv +import glob +import json +import os +import re +from typing import Any, Dict, List, Tuple + +from openai import AsyncOpenAI + +# Reuse existing judge implementations from the environment +from environments.medexqa.geval_judge.geval_judge import ( + explanation_geval_reward as geval_reward, +) +from environments.medexqa.factscore_judge.atomic_facts_judge import ( + explanation_factscore_reward as factscore_reward, +) + + +def _extract_numeric(text: str) -> float: + m = re.search(r"(\d+\.\d+|\d+)", (text or "").strip()) + if not m: + return 0.0 + try: + val = float(m.group(1)) + return max(0.0, min(1.0, val)) + except Exception: + return 0.0 + + +def _read_results(paths: List[str]) -> List[Tuple[str, Dict[str, Any]]]: + rows: List[Tuple[str, Dict[str, Any]]] = [] + for p in paths: + try: + with open(p, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + rec = json.loads(line) + rows.append((p, rec)) + except Exception as e: + print(f"Warning: failed to read {p}: {e}") + return rows + + +class JudgeRecorder: + def __init__(self, client: AsyncOpenAI, model: str, sleep_ms: int = 500, max_retries: int = 5, verbose: bool = True, max_tokens: int = 384): + self.client = client + self.model = model + self.sleep_ms = max(0, int(sleep_ms)) + self.max_retries = max(1, int(max_retries)) + self.verbose = verbose + self.max_tokens = max_tokens + self.logs: List[Dict[str, str]] = [] + + async def __call__(self, messages, *_args, **_kwargs) -> str: + # messages is a list of {role, content} + content = messages[-1].get("content", "") if messages else "" + attempt = 0 + delay = self.sleep_ms / 1000.0 + while True: + try: + if self.verbose: + print(f"[judge] calling model={self.model}, tokens<=256") + resp = await self.client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0, + max_tokens=self.max_tokens, + ) + out = resp.choices[0].message.content or "" + self.logs.append({"prompt": content, "response": out}) + # throttle between calls + if self.sleep_ms > 0: + await asyncio.sleep(self.sleep_ms / 1000.0) + return out + except Exception as e: + attempt += 1 + msg = str(e) + if self.verbose: + print(f"[judge] error on attempt {attempt}: {msg}") + # retry on rate limit or transient errors + if attempt < self.max_retries: + # exponential backoff with floor at configured delay + backoff = delay * (2 ** (attempt - 1)) + await asyncio.sleep(backoff) + continue + # record failure + self.logs.append({"prompt": content, "response": f": {msg}"}) + return "" + + +async def judge_geval( + client: AsyncOpenAI, + model: str, + rec: Dict[str, Any], + *, + sleep_ms: int = 500, + max_retries: int = 5, + verbose: bool = True, + max_tokens: int = 384, +) -> Tuple[float, str, str]: + info = rec.get("info", {}) or {} + question = info.get("question", "") + options = {k: info.get(k, "") for k in ["A", "B", "C", "D"]} + exp0 = info.get("exp0", "") + exp1 = info.get("exp1", "") + answer = rec.get("answer", "") + completion_msgs = rec.get("completion", []) + + jr = JudgeRecorder(client, model, sleep_ms=sleep_ms, max_retries=max_retries, verbose=verbose, max_tokens=max_tokens) + score = await geval_reward(jr, None, completion_msgs, answer, state={}, info=info, judge_client=client, judge_model=model) + # Last log entry contains the overall prompt/response + judge_output = jr.logs[-1]["response"] if jr.logs else "" + refs = ( + f"Question: {question}\n" + f"Correct answer: {answer} ({options.get(answer,'')})\n" + f"Ref1: {exp0}\n" + f"Ref2: {exp1}" + ) + return float(score), judge_output, refs + + +async def judge_factscore( + client: AsyncOpenAI, + model: str, + rec: Dict[str, Any], + *, + sleep_ms: int = 500, + max_retries: int = 5, + verbose: bool = True, + max_tokens: int = 384, + use_coverage: bool = False, +) -> Tuple[float, str, str]: + info = rec.get("info", {}) or {} + question = info.get("question", "") + options = {k: info.get(k, "") for k in ["A", "B", "C", "D"]} + exp0 = info.get("exp0", "") + exp1 = info.get("exp1", "") + answer = rec.get("answer", "") + completion_msgs = rec.get("completion", []) + + jr = JudgeRecorder(client, model, sleep_ms=sleep_ms, max_retries=max_retries, verbose=verbose, max_tokens=max_tokens) + score = await factscore_reward(jr, None, completion_msgs, answer, state={}, info=info, judge_client=client, judge_model=model, use_coverage=use_coverage) + + # Parse logs to reconstruct claim labels and track extraction outcomes + labels: List[Tuple[str, str]] = [] # (claim, label) where passage=references (support) + coverage_labels: List[Tuple[str, str]] = [] # (ref_claim, label) where passage=explanation (coverage) + extraction_responses: List[str] = [] + for entry in jr.logs: + prompt = entry.get("prompt", "") or "" + response = entry.get("response", "") or "" + + # Look for new format: "CLAIM TO VERIFY:\n" + m = re.search(r"CLAIM TO VERIFY:\s*\n(.+?)(?:\n\nINSTRUCTIONS:)", prompt, flags=re.DOTALL) + if m: + claim = m.group(1).strip() + # Heuristic: prompts containing "PASSAGE (Combined References):" belong to support + # Prompts with "MODEL EXPLANATION:" belong to coverage + if "PASSAGE (Combined References):" in prompt: + labels.append((claim, response.strip().upper())) + elif "MODEL EXPLANATION:" in prompt: + coverage_labels.append((claim, response.strip().upper())) + + # Fallback: old format "Fact:\n" + if not m: + m_old = re.search(r"Fact:\s*\n(.+)$", prompt, flags=re.DOTALL) + if m_old: + claim = m_old.group(1).strip() + if "Question:" in prompt: + labels.append((claim, response.strip().upper())) + else: + coverage_labels.append((claim, response.strip().upper())) + + if "Claims JSON:" in prompt: + extraction_responses.append(response) + + refs = ( + f"Question: {question}\n" + f"Correct: ({answer}) {options.get(answer,'')}\n" + f"Ref1: {exp0}\n" + f"Ref2: {exp1}" + ) + # Derive error tag for extraction phase + err_tag = "" + if extraction_responses: + last_extraction = extraction_responses[-1] + try: + parsed = json.loads(last_extraction) + if isinstance(parsed, list) and len(parsed) == 0: + err_tag = "empty_extraction" + except Exception: + err_tag = "extraction_error" + elif not labels and not coverage_labels: + err_tag = "empty_extraction" + + # Compute support/coverage rates from labels (handle 3-level format) + def _rate(pairs: List[Tuple[str, str]]) -> float: + if not pairs: + return 0.0 + total_score = 0.0 + for _, lbl in pairs: + lbl_clean = (lbl or "").strip().upper() + if "FULLY_SUPPORTED" in lbl_clean or "FULLY SUPPORTED" in lbl_clean: + total_score += 1.0 + elif "PARTIALLY_SUPPORTED" in lbl_clean or "PARTIALLY SUPPORTED" in lbl_clean: + total_score += 0.5 + elif lbl_clean.startswith("TRUE"): # Fallback for old format + total_score += 1.0 + # NOT_SUPPORTED or FALSE = 0.0 + return float(total_score) / float(len(pairs)) + + support_rate = _rate(labels) + coverage_rate = _rate(coverage_labels) + + details = json.dumps({ + "claims": labels, + "coverage_labels": coverage_labels, + "support_rate": support_rate, + "coverage_rate": coverage_rate, + }, ensure_ascii=False) + if err_tag: + refs = refs + f"\n: {err_tag}" + return float(score), details, refs + + +async def main(): + ap = argparse.ArgumentParser(description="Re-score saved MedExQA completions with LLM judges.") + ap.add_argument("--base", default="https://openrouter.ai/api/v1", help="Judge API base URL") + ap.add_argument("--model", default="openai/gpt-oss-20b:free", help="Judge model id") + ap.add_argument("--key_var", default="OPENAI_API_KEY", help="Env var name holding the API key") + ap.add_argument("--input_glob", default="environments/medexqa/outputs/evals/**/results.jsonl", help="Glob to results.jsonl files") + ap.add_argument("--out_csv_prefix", default="environments/medexqa/outputs/judge_scores/medexqa_", help="Output CSV prefix (will append judge name)") + ap.add_argument("--sleep_ms", type=int, default=500, help="Sleep/throttle between judge calls (ms)") + ap.add_argument("--max_retries", type=int, default=5, help="Max retries on judge call errors") + ap.add_argument("--max_tokens", type=int, default=384, help="Max tokens per judge response") + ap.add_argument("--verbose", action="store_true", help="Verbose logging") + ap.add_argument("--judge", choices=["geval", "factscore", "both"], default="both", help="Which judge(s) to run") + ap.add_argument("--use_coverage", action="store_true", help="Enable coverage calculation for FactScore (slower but more comprehensive)") + args = ap.parse_args() + + api_key = os.getenv(args.key_var) + if not api_key: + raise SystemExit(f"Missing API key in env var {args.key_var}") + + client = AsyncOpenAI(base_url=args.base, api_key=api_key) + + # Discover saved runs + paths = sorted(glob.glob(args.input_glob, recursive=True)) + if args.verbose: + print(f"Scanning {len(paths)} results.jsonl files...") + rows = _read_results(paths) + if not rows: + print("No results found to re-score.") + return + + os.makedirs(os.path.dirname(args.out_csv_prefix), exist_ok=True) + + # Prepare CSV writers conditionally + gwriter = None + fwriter = None + geval_path = args.out_csv_prefix + "geval.csv" + fact_path = args.out_csv_prefix + "factscore.csv" + if args.judge in ("geval", "both"): + gf = open(geval_path, "w", newline="") + gwriter = csv.writer(gf) + gwriter.writerow(["run_file", "specialty", "question", "A", "B", "C", "D", "answer", "completion", "judge_model_output", "judge_score", "references", "error"]) + if args.judge in ("factscore", "both"): + ff = open(fact_path, "w", newline="") + fwriter = csv.writer(ff) + fwriter.writerow(["run_file", "specialty", "question", "A", "B", "C", "D", "answer", "completion", "claims_labels_json", "support_rate", "coverage_labels_json", "coverage_rate", "final_score", "references", "error"]) + + # Process sequentially to keep it simple + for idx, (run_file, rec) in enumerate(rows, start=1): + info = rec.get("info", {}) or {} + spec = info.get("specialty", "") + question = info.get("question", "") + A = info.get("A", "") + B = info.get("B", "") + C = info.get("C", "") + D = info.get("D", "") + answer = rec.get("answer", "") + completion_msgs = rec.get("completion", []) + completion_text = completion_msgs[-1].get("content", "") if completion_msgs else "" + + if args.verbose: + print(f"[{idx}/{len(rows)}] {run_file} | spec={spec} | len(prompt)={len(question)} | len(completion)={len(completion_text)}") + + # G-Eval + if args.judge in ("geval", "both") and gwriter is not None: + if args.verbose: + print(" -> G-Eval judging...") + g_score, g_out, g_refs = await judge_geval( + client, + args.model, + rec, + sleep_ms=args.sleep_ms, + max_retries=args.max_retries, + verbose=args.verbose, + max_tokens=args.max_tokens, + ) + # detect errors in logs + g_err = "" + if g_out.strip().startswith("") or g_out.strip() == "": + g_err = "empty_or_error" + gwriter.writerow([run_file, spec, question, A, B, C, D, answer, completion_text, g_out, f"{g_score:.3f}", g_refs, g_err]) + + # FactScore + if args.judge in ("factscore", "both") and fwriter is not None: + if args.verbose: + print(" -> FactScore judging...") + f_score, f_details, f_refs = await judge_factscore( + client, + args.model, + rec, + sleep_ms=args.sleep_ms, + max_retries=args.max_retries, + verbose=args.verbose, + max_tokens=args.max_tokens, + use_coverage=args.use_coverage, + ) + f_err = "" + support_rate = "" + coverage_rate = "" + coverage_labels_json = "{}" + try: + dd = json.loads(f_details) + support_rate = f"{float(dd.get('support_rate', 0.0)):.3f}" + coverage_rate = f"{float(dd.get('coverage_rate', 0.0)):.3f}" + coverage_labels_json = json.dumps(dd.get("coverage_labels", []), ensure_ascii=False) + except Exception: + pass + if f_details.strip() == "" or f_details.strip() == "{}": + f_err = "empty_or_error" + fwriter.writerow([run_file, spec, question, A, B, C, D, answer, completion_text, f_details, support_rate, coverage_labels_json, coverage_rate, f"{f_score:.3f}", f_refs, f_err]) + + if gwriter is not None: + gf.close() + print(f"Wrote: {geval_path}") + if fwriter is not None: + ff.close() + print(f"Wrote: {fact_path}") + + +if __name__ == "__main__": + asyncio.run(main()) + +