Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion environments/aci_bench/aci_bench/aci_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datasets.utils.logging import disable_progress_bar
from aci_bench.judge_prompts import JUDGE_DIMENSIONS, JUDGE_OUTPUT_JSON, JUDGE_TEMPLATE
from medarc_verifiers.parsers import JSONParser
from medarc_verifiers.parsers.xml_parser import XMLParser
from medarc_verifiers.prompts import XML_SYSTEM_PROMPT, AnswerFormat
from medarc_verifiers.judging import MultiJudge, MultiJudgeRubric
from medarc_verifiers.rewards import normalize_helm_reward
Expand Down Expand Up @@ -83,7 +84,7 @@ def load_environment(
if answer_format == AnswerFormat.XML:
system_prompt = system_prompt or XML_SYSTEM_PROMPT
parser_fields = ["answer"]
parser = vf.XMLParser(fields=parser_fields, answer_field="answer")
parser = XMLParser(fields=parser_fields, answer_field="answer")
elif answer_format == AnswerFormat.BOXED:
system_prompt = system_prompt or BOXED_SYSTEM_PROMPT
parser = vf.Parser(extract_fn=extract_boxed_answer)
Expand Down
3 changes: 2 additions & 1 deletion environments/head_qa/head_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import verifiers as vf
from datasets import load_dataset
from medarc_verifiers.parsers.xml_parser import XMLParser
from medarc_verifiers.prompts import THINK_XML_SYSTEM_PROMPT, XML_SYSTEM_PROMPT, AnswerFormat
from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy
from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice
Expand Down Expand Up @@ -135,7 +136,7 @@ def _map_example(example: dict[str, Any]) -> dict[str, Any] | None:
if answer_format == AnswerFormat.XML:
system_prompt = system_prompt or (THINK_XML_SYSTEM_PROMPT if use_think else XML_SYSTEM_PROMPT)
parser_fields = ["think", "answer"] if use_think else ["answer"]
parser = vf.XMLParser(fields=parser_fields, answer_field="answer")
parser = XMLParser(fields=parser_fields, answer_field="answer")
elif answer_format == AnswerFormat.BOXED:
system_prompt = system_prompt or (THINK_BOXED_SYSTEM_PROMPT if use_think else BOXED_SYSTEM_PROMPT)
parser = vf.ThinkParser(extract_boxed_answer) if use_think else vf.Parser(extract_boxed_answer)
Expand Down
3 changes: 2 additions & 1 deletion environments/med_mcqa/med_mcqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import verifiers as vf
from datasets import load_dataset
from datasets.utils.logging import disable_progress_bar
from medarc_verifiers.parsers.xml_parser import XMLParser
from medarc_verifiers.prompts import THINK_XML_SYSTEM_PROMPT, XML_SYSTEM_PROMPT, AnswerFormat
from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy
from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice
Expand Down Expand Up @@ -125,7 +126,7 @@ def _map_example(example: dict[str, Any]) -> dict[str, Any] | None:
if answer_format == AnswerFormat.XML:
system_prompt = system_prompt or (THINK_XML_SYSTEM_PROMPT if use_think else XML_SYSTEM_PROMPT)
parser_fields = ["think", "answer"] if use_think else ["answer"]
parser = vf.XMLParser(fields=parser_fields, answer_field="answer")
parser = XMLParser(fields=parser_fields, answer_field="answer")
elif answer_format == AnswerFormat.BOXED:
system_prompt = system_prompt or (THINK_BOXED_SYSTEM_PROMPT if use_think else BOXED_SYSTEM_PROMPT)
parser = vf.ThinkParser(extract_boxed_answer) if use_think else vf.Parser(extract_boxed_answer)
Expand Down
3 changes: 2 additions & 1 deletion environments/medbullets/medbullets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import verifiers as vf
from datasets import Dataset, load_dataset
from datasets.utils.logging import disable_progress_bar
from medarc_verifiers.parsers.xml_parser import XMLParser
from medarc_verifiers.prompts import THINK_XML_SYSTEM_PROMPT, XML_SYSTEM_PROMPT, AnswerFormat
from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy
from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice
Expand Down Expand Up @@ -131,7 +132,7 @@ def load_environment(
if answer_format == AnswerFormat.XML:
system_prompt = THINK_XML_SYSTEM_PROMPT if use_think else XML_SYSTEM_PROMPT
parser_fields = ["think", "answer"] if use_think else ["answer"]
parser = vf.XMLParser(fields=parser_fields, answer_field="answer")
parser = XMLParser(fields=parser_fields, answer_field="answer")
elif answer_format == AnswerFormat.BOXED:
parser = (
vf.ThinkParser(extract_fn=extract_boxed_answer) if use_think else vf.Parser(extract_fn=extract_boxed_answer)
Expand Down
6 changes: 5 additions & 1 deletion environments/medcasereasoning/medcasereasoning.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import verifiers as vf
from datasets import load_dataset
from datasets.utils.logging import disable_progress_bar
Expand Down Expand Up @@ -113,7 +115,9 @@ async def medical_diagnosis_reward_func(
)

# Parse judge response
judge_response_clean = judge_response.strip().lower()
judge_response_clean = re.sub(
r"<think>.*?</think>", "", judge_response, flags=re.DOTALL | re.IGNORECASE
).strip().lower()
else:
judge_response_clean = "no"
judge_response = "no answer"
Expand Down
3 changes: 2 additions & 1 deletion environments/medconceptsqa/medconceptsqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import verifiers as vf
from datasets import Dataset, load_dataset
from datasets.utils.logging import disable_progress_bar
from medarc_verifiers.parsers.xml_parser import XMLParser
from medarc_verifiers.prompts import THINK_XML_SYSTEM_PROMPT, XML_SYSTEM_PROMPT, AnswerFormat
from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy
from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice
Expand Down Expand Up @@ -183,7 +184,7 @@ def _map(row: dict, idx: int | None = None) -> dict:
if answer_format == AnswerFormat.XML:
system_prompt = THINK_XML_SYSTEM_PROMPT if use_think else XML_SYSTEM_PROMPT
parser_fields = ["think", "answer"] if use_think else ["answer"]
parser = vf.XMLParser(fields=parser_fields, answer_field="answer")
parser = XMLParser(fields=parser_fields, answer_field="answer")
elif answer_format == AnswerFormat.BOXED:
system_prompt = THINK_BOXED_SYSTEM_PROMPT if use_think else BOXED_SYSTEM_PROMPT
parser = vf.ThinkParser(extract_boxed_answer) if use_think else vf.Parser(extract_boxed_answer)
Expand Down
3 changes: 2 additions & 1 deletion environments/medqa/medqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import verifiers as vf
from datasets import load_dataset
from datasets.utils.logging import disable_progress_bar
from medarc_verifiers.parsers.xml_parser import XMLParser
from medarc_verifiers.prompts import THINK_XML_SYSTEM_PROMPT, XML_SYSTEM_PROMPT, AnswerFormat
from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy
from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice
Expand Down Expand Up @@ -79,7 +80,7 @@ def _map(ex, idx=None):
if answer_format == AnswerFormat.XML:
system_prompt = system_prompt or (THINK_XML_SYSTEM_PROMPT if use_think else XML_SYSTEM_PROMPT)
parser_fields = ["think", "answer"] if use_think else ["answer"]
parser = vf.XMLParser(fields=parser_fields, answer_field="answer")
parser = XMLParser(fields=parser_fields, answer_field="answer")
elif answer_format == AnswerFormat.BOXED:
parser = vf.ThinkParser(extract_boxed_answer) if use_think else vf.Parser(extract_boxed_answer)
system_prompt = system_prompt or (THINK_BOXED_SYSTEM_PROMPT if use_think else BOXED_SYSTEM_PROMPT)
Expand Down
5 changes: 3 additions & 2 deletions environments/medredqa/medredqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datasets import load_dataset
from datasets.utils.logging import disable_progress_bar
from factscore_judge.atomic_facts_judge import create_atomic_facts_judge_rubric
from medarc_verifiers.parsers.xml_parser import XMLParser
from openai import AsyncOpenAI

disable_progress_bar() # suppress datasets mapping progress bar
Expand Down Expand Up @@ -82,9 +83,9 @@ def load_environment(

# Create parser - XMLParser with both think and answer fields if use_think is True
parser = (
vf.XMLParser(fields=["think", "answer"], answer_field="answer")
XMLParser(fields=["think", "answer"], answer_field="answer")
if use_think
else vf.XMLParser(fields=["answer"], answer_field="answer")
else XMLParser(fields=["answer"], answer_field="answer")
)

# Create JudgeRubric using the helper function from judge.py
Expand Down
3 changes: 2 additions & 1 deletion environments/medxpertqa/medxpertqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import verifiers as vf
from datasets import load_dataset
from datasets.utils.logging import disable_progress_bar
from medarc_verifiers.parsers.xml_parser import XMLParser
from medarc_verifiers.prompts import AnswerFormat
from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy
from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice
Expand Down Expand Up @@ -105,7 +106,7 @@ def _map(example: dict) -> dict:

if answer_format == AnswerFormat.XML:
parser_fields = ["think", "answer"] if use_think else ["answer"]
parser = vf.XMLParser(fields=parser_fields, answer_field="answer")
parser = XMLParser(fields=parser_fields, answer_field="answer")
elif answer_format == AnswerFormat.BOXED:
parser = vf.ThinkParser(extract_boxed_answer) if use_think else vf.Parser(extract_boxed_answer)
else:
Expand Down
3 changes: 2 additions & 1 deletion environments/pubmedqa/pubmedqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import verifiers as vf
from datasets import load_dataset
from datasets.utils.logging import disable_progress_bar
from medarc_verifiers.parsers.xml_parser import XMLParser
from medarc_verifiers.prompts import THINK_XML_SYSTEM_PROMPT, XML_SYSTEM_PROMPT, AnswerFormat
from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy
from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice
Expand Down Expand Up @@ -147,7 +148,7 @@ def load_environment(

if answer_format == AnswerFormat.XML:
parser_fields = ["think", "answer"] if use_think else ["answer"]
parser = vf.XMLParser(fields=parser_fields, answer_field="answer")
parser = XMLParser(fields=parser_fields, answer_field="answer")
system_prompt = THINK_XML_SYSTEM_PROMPT if use_think else XML_SYSTEM_PROMPT
elif answer_format == AnswerFormat.BOXED:
parser = (
Expand Down
25 changes: 25 additions & 0 deletions environments/rubric_train/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[project]
name = "rubric-train"
description = "Training env: rubric-judged open-ended medical prompts for RL reward"
tags = ["medical", "rubric", "single-turn", "llm-judge", "train"]
version = "0.1.0"
requires-python = ">=3.11"
dependencies = [
"medarc-verifiers>=0.1.0",
"verifiers>=0.1.4",
"datasets>=4.1.1",
"openai>=2.1.0",
"healthbench",
]

[dependency-groups]
dev = [
"ruff>=0.13.3",
]

[tool.prime.environment]
loader = "rubric_train:load_environment"

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
140 changes: 140 additions & 0 deletions environments/rubric_train/rubric_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""RL training env: HealthBench-style rubric-judged prompts.

Wraps the existing `healthbench` env's judge-per-criterion reward path so it
can be driven from prime-rl as a training env with a `dataset=` attribute
(not just eval). Each rollout's reward = sum(points for criteria met) /
sum(positive_points_possible), bounded to [0, 1].

Train split comes from HealthBench's `consensus` (3,671 examples) or `hard`
(1,000 examples) datasets. Eval split comes from HealthBench `all` test set.
Choosing a smaller eval slice keeps per-step validation cheap.
"""

from __future__ import annotations

import asyncio
from pathlib import Path
from typing import Any

import verifiers as vf
from datasets import load_dataset
from datasets.utils.logging import disable_progress_bar

# Reuse healthbench's judge-plumbing + scoring exactly — single source of truth.
from healthbench import ( # type: ignore[import-not-found]
HEALTHBENCH_CONSENSUS_CRITERIA_LOOKUP,
_format_prompt_to_judge,
_judge_single_criterion,
_process_healthbench_dataset,
)
from medarc_verifiers.judging import MultiJudge, MultiJudgeRubric
from medarc_verifiers.types import Messages
from verifiers.types import Info, State

disable_progress_bar()


TRAIN_MAPPING = {
"consensus": ("neuralleap/healthbench-consensus", "train"),
"hard": ("neuralleap/healthbench-hard", "train"),
}
EVAL_DATASET = ("neuralleap/healthbench-regular", "test")


def load_environment(
judge_model: str | list[str] = "openai/gpt-oss-20b",
judge_base_url: str | list[str] | None = None,
judge_api_key: str | list[str] | None = None,
train_split: str = "consensus",
num_train_examples: int | None = None,
num_eval_examples: int = 100,
judge_timeout: int | None = 300,
max_parallel_judges: int | None = 3,
**kwargs: Any,
) -> vf.Environment:
"""Return a SingleTurnEnv with HealthBench-style rubric reward.

The reward for a rollout is `sum(points for criteria met)` divided by
`sum(p for p in points_possible if p > 0)`. Negative points stay negative
when met, which preserves the penalty structure of the original rubric.

`judge_model` defaults to gpt-oss-20b (matches the scale-sweep judge on
n-4:18001). Override via `judge_base_url` for a different endpoint.
"""
if train_split not in TRAIN_MAPPING:
raise ValueError(f"Unknown train_split={train_split}; expected one of {list(TRAIN_MAPPING)}")

train_repo, train_split_name = TRAIN_MAPPING[train_split]
train_ds = load_dataset(train_repo, split=train_split_name)
if num_train_examples is not None:
train_ds = train_ds.select(range(min(num_train_examples, len(train_ds))))
train_ds = train_ds.map(lambda ex: {"info": _process_healthbench_dataset(ex)})

eval_repo, eval_split_name = EVAL_DATASET
eval_ds = load_dataset(eval_repo, split=eval_split_name)
if num_eval_examples:
eval_ds = eval_ds.select(range(min(num_eval_examples, len(eval_ds))))
eval_ds = eval_ds.map(lambda ex: {"info": _process_healthbench_dataset(ex)})

multi_judge = MultiJudge.from_env_args(
judge_model=judge_model,
judge_base_url=judge_base_url,
judge_api_key=judge_api_key,
judge_prompt="{question}",
judge_timeout=judge_timeout,
)
rubric = MultiJudgeRubric(multi_judge)

criteria_parallel = max_parallel_judges if max_parallel_judges is not None else 3

async def reward_rubric(prompt: Messages, completion: Messages, info: Info, state: State) -> float:
"""Per-rollout reward = fraction of positive-point rubric criteria met.

Mirrors healthbench.reward_healthbench but normalizes by positive-point
total so the scalar lives in [0, 1] for RL stability.
"""
if isinstance(completion, list) and completion:
raw_completion = completion[-1].get("content", "")
else:
raw_completion = str(completion)

conversation = _format_prompt_to_judge(prompt, raw_completion)
criteria = info.get("criteria", []) or []
points_list = info.get("points_list", []) or []
if not points_list:
return 0.0

positive_total = sum(pt for pt in points_list if pt > 0)
if positive_total <= 0:
return 0.0

semaphore = asyncio.Semaphore(criteria_parallel)
tasks = [
_judge_single_criterion(
idx=idx,
criterion=criterion,
points_possible=pts,
conversation=conversation,
rubric=rubric,
semaphore=semaphore,
state=state,
)
for idx, (criterion, pts) in enumerate(zip(criteria, points_list))
]
judgments = await asyncio.gather(*tasks)

earned = 0.0
for j in judgments:
if isinstance(j, dict) and j.get("criteria_met"):
earned += float(j.get("points_possible", 0))
# Clamp to [0, 1] — negative-point criteria can push earned slightly below 0.
return max(0.0, min(1.0, earned / positive_total))

rubric.add_reward_func(reward_rubric, weight=1.0)

return vf.SingleTurnEnv(
dataset=train_ds,
eval_dataset=eval_ds,
system_prompt="",
rubric=rubric,
)
17 changes: 17 additions & 0 deletions environments/sctpublic/sctpublic/sctpublic.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,28 @@ def parse_prompt(row):
return df


_THINK_BLOCK_RE = re.compile(r"<think>.*?</think>", re.DOTALL | re.IGNORECASE)


class SCTParser(vf.Parser):
"""Parser for SCT numeric ratings (-2 to +2)."""

@staticmethod
def _strip_think_blocks(text: str) -> str:
"""Remove <think>...</think> blocks so we only parse the final answer.

Also handles the case where a model emits an unpaired </think>
(no opening tag) -- we keep only the text after the last </think>.
"""
if "<think>" in text.lower():
text = _THINK_BLOCK_RE.sub("", text)
if "</think>" in text.lower():
text = text.split("</think>")[-1]
return text.strip()

def parse_answer(self, completion: Any) -> str | None:
response = getattr(completion, "content", str(completion))
response = self._strip_think_blocks(response)
try:
rating = response.split("Rating: ")[1][:5]
pattern = r"\+2|\+1|0|-1|-2"
Expand Down
19 changes: 19 additions & 0 deletions medarc_verifiers/parsers/xml_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,25 @@ def parse(self, completion: Messages | str, strip: bool = True, last: bool = Fal
return parsed
return None

def parse_answer(self, completion: Messages | str) -> str | None:
"""Extract the last answer from a completion.

Overrides upstream to always use ``last=True`` when parsing chat messages.
Qwen 3.5 reasoning models echo ``<answer>`` tags inside their thinking
blocks, so we must take the **last** match to get the real answer.
"""
if isinstance(completion, str):
parsed = self.parse(completion, last=True)
if parsed and hasattr(parsed, self.answer_field) and getattr(parsed, self.answer_field) is not None:
return getattr(parsed, self.answer_field)
else:
for msg in reversed(self.get_assistant_messages(completion)):
content = str(msg.get("content", ""))
parsed = self.parse(content, last=True)
if parsed and hasattr(parsed, self.answer_field) and getattr(parsed, self.answer_field) is not None:
return getattr(parsed, self.answer_field)
return None

def _has_any_field(self, parsed: Any) -> bool:
for _, alternatives in self._fields:
for alt in alternatives:
Expand Down