From d1d164df9288051c031620fae79e7e64e7644cf1 Mon Sep 17 00:00:00 2001 From: Anthony Wang Date: Mon, 13 Apr 2026 19:52:24 +0000 Subject: [PATCH 1/2] Patch thinking-model grader bugs Temporary fix for three grader bugs that misscored thinking models (Qwen 3.5/Qwen3-4B-Thinking) which produce blocks: - Switch 10 envs from vf.XMLParser to custom XMLParser (last=True) - Strip blocks in SCT parser before extracting "Rating:" - Strip blocks in medcasereasoning judge before yes/no check --- environments/aci_bench/aci_bench/aci_bench.py | 3 ++- environments/head_qa/head_qa.py | 3 ++- environments/med_mcqa/med_mcqa.py | 3 ++- environments/medbullets/medbullets.py | 3 ++- .../medcasereasoning/medcasereasoning.py | 6 +++++- environments/medconceptsqa/medconceptsqa.py | 3 ++- environments/medqa/medqa.py | 3 ++- environments/medredqa/medredqa.py | 5 +++-- environments/medxpertqa/medxpertqa.py | 3 ++- environments/pubmedqa/pubmedqa.py | 3 ++- environments/sctpublic/sctpublic/sctpublic.py | 17 +++++++++++++++++ medarc_verifiers/parsers/xml_parser.py | 19 +++++++++++++++++++ 12 files changed, 60 insertions(+), 11 deletions(-) diff --git a/environments/aci_bench/aci_bench/aci_bench.py b/environments/aci_bench/aci_bench/aci_bench.py index 4f7346db..57e3b6ad 100644 --- a/environments/aci_bench/aci_bench/aci_bench.py +++ b/environments/aci_bench/aci_bench/aci_bench.py @@ -5,6 +5,7 @@ from datasets.utils.logging import disable_progress_bar from aci_bench.judge_prompts import JUDGE_DIMENSIONS, JUDGE_OUTPUT_JSON, JUDGE_TEMPLATE from medarc_verifiers.parsers import JSONParser +from medarc_verifiers.parsers.xml_parser import XMLParser from medarc_verifiers.prompts import XML_SYSTEM_PROMPT, AnswerFormat from medarc_verifiers.judging import MultiJudge, MultiJudgeRubric from medarc_verifiers.rewards import normalize_helm_reward @@ -83,7 +84,7 @@ def load_environment( if answer_format == AnswerFormat.XML: system_prompt = system_prompt or XML_SYSTEM_PROMPT parser_fields = ["answer"] - parser = vf.XMLParser(fields=parser_fields, answer_field="answer") + parser = XMLParser(fields=parser_fields, answer_field="answer") elif answer_format == AnswerFormat.BOXED: system_prompt = system_prompt or BOXED_SYSTEM_PROMPT parser = vf.Parser(extract_fn=extract_boxed_answer) diff --git a/environments/head_qa/head_qa.py b/environments/head_qa/head_qa.py index e17b7937..471450da 100644 --- a/environments/head_qa/head_qa.py +++ b/environments/head_qa/head_qa.py @@ -29,6 +29,7 @@ import verifiers as vf from datasets import load_dataset +from medarc_verifiers.parsers.xml_parser import XMLParser from medarc_verifiers.prompts import THINK_XML_SYSTEM_PROMPT, XML_SYSTEM_PROMPT, AnswerFormat from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice @@ -135,7 +136,7 @@ def _map_example(example: dict[str, Any]) -> dict[str, Any] | None: if answer_format == AnswerFormat.XML: system_prompt = system_prompt or (THINK_XML_SYSTEM_PROMPT if use_think else XML_SYSTEM_PROMPT) parser_fields = ["think", "answer"] if use_think else ["answer"] - parser = vf.XMLParser(fields=parser_fields, answer_field="answer") + parser = XMLParser(fields=parser_fields, answer_field="answer") elif answer_format == AnswerFormat.BOXED: system_prompt = system_prompt or (THINK_BOXED_SYSTEM_PROMPT if use_think else BOXED_SYSTEM_PROMPT) parser = vf.ThinkParser(extract_boxed_answer) if use_think else vf.Parser(extract_boxed_answer) diff --git a/environments/med_mcqa/med_mcqa.py b/environments/med_mcqa/med_mcqa.py index 28392a3c..a7e642bc 100644 --- a/environments/med_mcqa/med_mcqa.py +++ b/environments/med_mcqa/med_mcqa.py @@ -23,6 +23,7 @@ import verifiers as vf from datasets import load_dataset from datasets.utils.logging import disable_progress_bar +from medarc_verifiers.parsers.xml_parser import XMLParser from medarc_verifiers.prompts import THINK_XML_SYSTEM_PROMPT, XML_SYSTEM_PROMPT, AnswerFormat from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice @@ -125,7 +126,7 @@ def _map_example(example: dict[str, Any]) -> dict[str, Any] | None: if answer_format == AnswerFormat.XML: system_prompt = system_prompt or (THINK_XML_SYSTEM_PROMPT if use_think else XML_SYSTEM_PROMPT) parser_fields = ["think", "answer"] if use_think else ["answer"] - parser = vf.XMLParser(fields=parser_fields, answer_field="answer") + parser = XMLParser(fields=parser_fields, answer_field="answer") elif answer_format == AnswerFormat.BOXED: system_prompt = system_prompt or (THINK_BOXED_SYSTEM_PROMPT if use_think else BOXED_SYSTEM_PROMPT) parser = vf.ThinkParser(extract_boxed_answer) if use_think else vf.Parser(extract_boxed_answer) diff --git a/environments/medbullets/medbullets.py b/environments/medbullets/medbullets.py index f0d3431e..7ffa5867 100644 --- a/environments/medbullets/medbullets.py +++ b/environments/medbullets/medbullets.py @@ -1,6 +1,7 @@ import verifiers as vf from datasets import Dataset, load_dataset from datasets.utils.logging import disable_progress_bar +from medarc_verifiers.parsers.xml_parser import XMLParser from medarc_verifiers.prompts import THINK_XML_SYSTEM_PROMPT, XML_SYSTEM_PROMPT, AnswerFormat from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice @@ -131,7 +132,7 @@ def load_environment( if answer_format == AnswerFormat.XML: system_prompt = THINK_XML_SYSTEM_PROMPT if use_think else XML_SYSTEM_PROMPT parser_fields = ["think", "answer"] if use_think else ["answer"] - parser = vf.XMLParser(fields=parser_fields, answer_field="answer") + parser = XMLParser(fields=parser_fields, answer_field="answer") elif answer_format == AnswerFormat.BOXED: parser = ( vf.ThinkParser(extract_fn=extract_boxed_answer) if use_think else vf.Parser(extract_fn=extract_boxed_answer) diff --git a/environments/medcasereasoning/medcasereasoning.py b/environments/medcasereasoning/medcasereasoning.py index 87afc915..9d81a70d 100644 --- a/environments/medcasereasoning/medcasereasoning.py +++ b/environments/medcasereasoning/medcasereasoning.py @@ -1,3 +1,5 @@ +import re + import verifiers as vf from datasets import load_dataset from datasets.utils.logging import disable_progress_bar @@ -113,7 +115,9 @@ async def medical_diagnosis_reward_func( ) # Parse judge response - judge_response_clean = judge_response.strip().lower() + judge_response_clean = re.sub( + r".*?", "", judge_response, flags=re.DOTALL | re.IGNORECASE + ).strip().lower() else: judge_response_clean = "no" judge_response = "no answer" diff --git a/environments/medconceptsqa/medconceptsqa.py b/environments/medconceptsqa/medconceptsqa.py index 0b8e443d..45d7e6bd 100644 --- a/environments/medconceptsqa/medconceptsqa.py +++ b/environments/medconceptsqa/medconceptsqa.py @@ -4,6 +4,7 @@ import verifiers as vf from datasets import Dataset, load_dataset from datasets.utils.logging import disable_progress_bar +from medarc_verifiers.parsers.xml_parser import XMLParser from medarc_verifiers.prompts import THINK_XML_SYSTEM_PROMPT, XML_SYSTEM_PROMPT, AnswerFormat from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice @@ -183,7 +184,7 @@ def _map(row: dict, idx: int | None = None) -> dict: if answer_format == AnswerFormat.XML: system_prompt = THINK_XML_SYSTEM_PROMPT if use_think else XML_SYSTEM_PROMPT parser_fields = ["think", "answer"] if use_think else ["answer"] - parser = vf.XMLParser(fields=parser_fields, answer_field="answer") + parser = XMLParser(fields=parser_fields, answer_field="answer") elif answer_format == AnswerFormat.BOXED: system_prompt = THINK_BOXED_SYSTEM_PROMPT if use_think else BOXED_SYSTEM_PROMPT parser = vf.ThinkParser(extract_boxed_answer) if use_think else vf.Parser(extract_boxed_answer) diff --git a/environments/medqa/medqa.py b/environments/medqa/medqa.py index af68fce1..f4497c5d 100644 --- a/environments/medqa/medqa.py +++ b/environments/medqa/medqa.py @@ -3,6 +3,7 @@ import verifiers as vf from datasets import load_dataset from datasets.utils.logging import disable_progress_bar +from medarc_verifiers.parsers.xml_parser import XMLParser from medarc_verifiers.prompts import THINK_XML_SYSTEM_PROMPT, XML_SYSTEM_PROMPT, AnswerFormat from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice @@ -79,7 +80,7 @@ def _map(ex, idx=None): if answer_format == AnswerFormat.XML: system_prompt = system_prompt or (THINK_XML_SYSTEM_PROMPT if use_think else XML_SYSTEM_PROMPT) parser_fields = ["think", "answer"] if use_think else ["answer"] - parser = vf.XMLParser(fields=parser_fields, answer_field="answer") + parser = XMLParser(fields=parser_fields, answer_field="answer") elif answer_format == AnswerFormat.BOXED: parser = vf.ThinkParser(extract_boxed_answer) if use_think else vf.Parser(extract_boxed_answer) system_prompt = system_prompt or (THINK_BOXED_SYSTEM_PROMPT if use_think else BOXED_SYSTEM_PROMPT) diff --git a/environments/medredqa/medredqa.py b/environments/medredqa/medredqa.py index 50810dab..3d3f82ae 100644 --- a/environments/medredqa/medredqa.py +++ b/environments/medredqa/medredqa.py @@ -5,6 +5,7 @@ from datasets import load_dataset from datasets.utils.logging import disable_progress_bar from factscore_judge.atomic_facts_judge import create_atomic_facts_judge_rubric +from medarc_verifiers.parsers.xml_parser import XMLParser from openai import AsyncOpenAI disable_progress_bar() # suppress datasets mapping progress bar @@ -82,9 +83,9 @@ def load_environment( # Create parser - XMLParser with both think and answer fields if use_think is True parser = ( - vf.XMLParser(fields=["think", "answer"], answer_field="answer") + XMLParser(fields=["think", "answer"], answer_field="answer") if use_think - else vf.XMLParser(fields=["answer"], answer_field="answer") + else XMLParser(fields=["answer"], answer_field="answer") ) # Create JudgeRubric using the helper function from judge.py diff --git a/environments/medxpertqa/medxpertqa.py b/environments/medxpertqa/medxpertqa.py index 2ef3334e..f8cca9a8 100644 --- a/environments/medxpertqa/medxpertqa.py +++ b/environments/medxpertqa/medxpertqa.py @@ -2,6 +2,7 @@ import verifiers as vf from datasets import load_dataset from datasets.utils.logging import disable_progress_bar +from medarc_verifiers.parsers.xml_parser import XMLParser from medarc_verifiers.prompts import AnswerFormat from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice @@ -105,7 +106,7 @@ def _map(example: dict) -> dict: if answer_format == AnswerFormat.XML: parser_fields = ["think", "answer"] if use_think else ["answer"] - parser = vf.XMLParser(fields=parser_fields, answer_field="answer") + parser = XMLParser(fields=parser_fields, answer_field="answer") elif answer_format == AnswerFormat.BOXED: parser = vf.ThinkParser(extract_boxed_answer) if use_think else vf.Parser(extract_boxed_answer) else: diff --git a/environments/pubmedqa/pubmedqa.py b/environments/pubmedqa/pubmedqa.py index dc733550..aa03fdf8 100644 --- a/environments/pubmedqa/pubmedqa.py +++ b/environments/pubmedqa/pubmedqa.py @@ -4,6 +4,7 @@ import verifiers as vf from datasets import load_dataset from datasets.utils.logging import disable_progress_bar +from medarc_verifiers.parsers.xml_parser import XMLParser from medarc_verifiers.prompts import THINK_XML_SYSTEM_PROMPT, XML_SYSTEM_PROMPT, AnswerFormat from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice @@ -147,7 +148,7 @@ def load_environment( if answer_format == AnswerFormat.XML: parser_fields = ["think", "answer"] if use_think else ["answer"] - parser = vf.XMLParser(fields=parser_fields, answer_field="answer") + parser = XMLParser(fields=parser_fields, answer_field="answer") system_prompt = THINK_XML_SYSTEM_PROMPT if use_think else XML_SYSTEM_PROMPT elif answer_format == AnswerFormat.BOXED: parser = ( diff --git a/environments/sctpublic/sctpublic/sctpublic.py b/environments/sctpublic/sctpublic/sctpublic.py index 3829cead..b6916398 100644 --- a/environments/sctpublic/sctpublic/sctpublic.py +++ b/environments/sctpublic/sctpublic/sctpublic.py @@ -93,11 +93,28 @@ def parse_prompt(row): return df +_THINK_BLOCK_RE = re.compile(r".*?", re.DOTALL | re.IGNORECASE) + + class SCTParser(vf.Parser): """Parser for SCT numeric ratings (-2 to +2).""" + @staticmethod + def _strip_think_blocks(text: str) -> str: + """Remove ... blocks so we only parse the final answer. + + Also handles the case where a model emits an unpaired + (no opening tag) -- we keep only the text after the last . + """ + if "" in text.lower(): + text = _THINK_BLOCK_RE.sub("", text) + if "" in text.lower(): + text = text.split("")[-1] + return text.strip() + def parse_answer(self, completion: Any) -> str | None: response = getattr(completion, "content", str(completion)) + response = self._strip_think_blocks(response) try: rating = response.split("Rating: ")[1][:5] pattern = r"\+2|\+1|0|-1|-2" diff --git a/medarc_verifiers/parsers/xml_parser.py b/medarc_verifiers/parsers/xml_parser.py index 6a1176fc..49606d78 100644 --- a/medarc_verifiers/parsers/xml_parser.py +++ b/medarc_verifiers/parsers/xml_parser.py @@ -61,6 +61,25 @@ def parse(self, completion: Messages | str, strip: bool = True, last: bool = Fal return parsed return None + def parse_answer(self, completion: Messages | str) -> str | None: + """Extract the last answer from a completion. + + Overrides upstream to always use ``last=True`` when parsing chat messages. + Qwen 3.5 reasoning models echo ```` tags inside their thinking + blocks, so we must take the **last** match to get the real answer. + """ + if isinstance(completion, str): + parsed = self.parse(completion, last=True) + if parsed and hasattr(parsed, self.answer_field) and getattr(parsed, self.answer_field) is not None: + return getattr(parsed, self.answer_field) + else: + for msg in reversed(self.get_assistant_messages(completion)): + content = str(msg.get("content", "")) + parsed = self.parse(content, last=True) + if parsed and hasattr(parsed, self.answer_field) and getattr(parsed, self.answer_field) is not None: + return getattr(parsed, self.answer_field) + return None + def _has_any_field(self, parsed: Any) -> bool: for _, alternatives in self._fields: for alt in alternatives: From 7bcbc6a27c3abe3326aaaceb6354c1cd1313535c Mon Sep 17 00:00:00 2001 From: Anthony Wang Date: Wed, 22 Apr 2026 19:25:20 +0000 Subject: [PATCH 2/2] Add rubric_train environment for RL training with HealthBench-style rewards --- environments/rubric_train/pyproject.toml | 25 ++++ environments/rubric_train/rubric_train.py | 140 ++++++++++++++++++++++ 2 files changed, 165 insertions(+) create mode 100644 environments/rubric_train/pyproject.toml create mode 100644 environments/rubric_train/rubric_train.py diff --git a/environments/rubric_train/pyproject.toml b/environments/rubric_train/pyproject.toml new file mode 100644 index 00000000..19bb9476 --- /dev/null +++ b/environments/rubric_train/pyproject.toml @@ -0,0 +1,25 @@ +[project] +name = "rubric-train" +description = "Training env: rubric-judged open-ended medical prompts for RL reward" +tags = ["medical", "rubric", "single-turn", "llm-judge", "train"] +version = "0.1.0" +requires-python = ">=3.11" +dependencies = [ + "medarc-verifiers>=0.1.0", + "verifiers>=0.1.4", + "datasets>=4.1.1", + "openai>=2.1.0", + "healthbench", +] + +[dependency-groups] +dev = [ + "ruff>=0.13.3", +] + +[tool.prime.environment] +loader = "rubric_train:load_environment" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" diff --git a/environments/rubric_train/rubric_train.py b/environments/rubric_train/rubric_train.py new file mode 100644 index 00000000..c48ddc4d --- /dev/null +++ b/environments/rubric_train/rubric_train.py @@ -0,0 +1,140 @@ +"""RL training env: HealthBench-style rubric-judged prompts. + +Wraps the existing `healthbench` env's judge-per-criterion reward path so it +can be driven from prime-rl as a training env with a `dataset=` attribute +(not just eval). Each rollout's reward = sum(points for criteria met) / +sum(positive_points_possible), bounded to [0, 1]. + +Train split comes from HealthBench's `consensus` (3,671 examples) or `hard` +(1,000 examples) datasets. Eval split comes from HealthBench `all` test set. +Choosing a smaller eval slice keeps per-step validation cheap. +""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import Any + +import verifiers as vf +from datasets import load_dataset +from datasets.utils.logging import disable_progress_bar + +# Reuse healthbench's judge-plumbing + scoring exactly — single source of truth. +from healthbench import ( # type: ignore[import-not-found] + HEALTHBENCH_CONSENSUS_CRITERIA_LOOKUP, + _format_prompt_to_judge, + _judge_single_criterion, + _process_healthbench_dataset, +) +from medarc_verifiers.judging import MultiJudge, MultiJudgeRubric +from medarc_verifiers.types import Messages +from verifiers.types import Info, State + +disable_progress_bar() + + +TRAIN_MAPPING = { + "consensus": ("neuralleap/healthbench-consensus", "train"), + "hard": ("neuralleap/healthbench-hard", "train"), +} +EVAL_DATASET = ("neuralleap/healthbench-regular", "test") + + +def load_environment( + judge_model: str | list[str] = "openai/gpt-oss-20b", + judge_base_url: str | list[str] | None = None, + judge_api_key: str | list[str] | None = None, + train_split: str = "consensus", + num_train_examples: int | None = None, + num_eval_examples: int = 100, + judge_timeout: int | None = 300, + max_parallel_judges: int | None = 3, + **kwargs: Any, +) -> vf.Environment: + """Return a SingleTurnEnv with HealthBench-style rubric reward. + + The reward for a rollout is `sum(points for criteria met)` divided by + `sum(p for p in points_possible if p > 0)`. Negative points stay negative + when met, which preserves the penalty structure of the original rubric. + + `judge_model` defaults to gpt-oss-20b (matches the scale-sweep judge on + n-4:18001). Override via `judge_base_url` for a different endpoint. + """ + if train_split not in TRAIN_MAPPING: + raise ValueError(f"Unknown train_split={train_split}; expected one of {list(TRAIN_MAPPING)}") + + train_repo, train_split_name = TRAIN_MAPPING[train_split] + train_ds = load_dataset(train_repo, split=train_split_name) + if num_train_examples is not None: + train_ds = train_ds.select(range(min(num_train_examples, len(train_ds)))) + train_ds = train_ds.map(lambda ex: {"info": _process_healthbench_dataset(ex)}) + + eval_repo, eval_split_name = EVAL_DATASET + eval_ds = load_dataset(eval_repo, split=eval_split_name) + if num_eval_examples: + eval_ds = eval_ds.select(range(min(num_eval_examples, len(eval_ds)))) + eval_ds = eval_ds.map(lambda ex: {"info": _process_healthbench_dataset(ex)}) + + multi_judge = MultiJudge.from_env_args( + judge_model=judge_model, + judge_base_url=judge_base_url, + judge_api_key=judge_api_key, + judge_prompt="{question}", + judge_timeout=judge_timeout, + ) + rubric = MultiJudgeRubric(multi_judge) + + criteria_parallel = max_parallel_judges if max_parallel_judges is not None else 3 + + async def reward_rubric(prompt: Messages, completion: Messages, info: Info, state: State) -> float: + """Per-rollout reward = fraction of positive-point rubric criteria met. + + Mirrors healthbench.reward_healthbench but normalizes by positive-point + total so the scalar lives in [0, 1] for RL stability. + """ + if isinstance(completion, list) and completion: + raw_completion = completion[-1].get("content", "") + else: + raw_completion = str(completion) + + conversation = _format_prompt_to_judge(prompt, raw_completion) + criteria = info.get("criteria", []) or [] + points_list = info.get("points_list", []) or [] + if not points_list: + return 0.0 + + positive_total = sum(pt for pt in points_list if pt > 0) + if positive_total <= 0: + return 0.0 + + semaphore = asyncio.Semaphore(criteria_parallel) + tasks = [ + _judge_single_criterion( + idx=idx, + criterion=criterion, + points_possible=pts, + conversation=conversation, + rubric=rubric, + semaphore=semaphore, + state=state, + ) + for idx, (criterion, pts) in enumerate(zip(criteria, points_list)) + ] + judgments = await asyncio.gather(*tasks) + + earned = 0.0 + for j in judgments: + if isinstance(j, dict) and j.get("criteria_met"): + earned += float(j.get("points_possible", 0)) + # Clamp to [0, 1] — negative-point criteria can push earned slightly below 0. + return max(0.0, min(1.0, earned / positive_total)) + + rubric.add_reward_func(reward_rubric, weight=1.0) + + return vf.SingleTurnEnv( + dataset=train_ds, + eval_dataset=eval_ds, + system_prompt="", + rubric=rubric, + )