Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from omegaconf import DictConfig
from vllm.transformers_utils.tokenizer import get_tokenizer

from forge.reward.hf_rm import HFRewardModel
from forge.reward.ensemble import EnsembleReward

@dataclass
class Episode:
Expand Down Expand Up @@ -137,11 +139,24 @@ def simple_grpo_loss(
).mean()
return loss


@dataclass
class RewardActor(ForgeActor):

reward_functions: list[Callable]
# use_model = True
use_model: bool = False
reward_functions: list[Callable] = None
rm_specs: list[dict] = None
ensemble_reduce: str = "mean"

def __post_init__(self):
if self.use_model:
rms = [HFRewardModel(**spec) for spec in self.rm_specs]
self.reward_fn = EnsembleReward(rms, reduce=self.ensemble_reduce)
self.reward_functions = [self.reward_fn] if self.use_model else []
else:
self.reward_functions = [
MathReward(),
ThinkingReward(),
]

@endpoint
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
Expand Down Expand Up @@ -309,6 +324,11 @@ async def main(cfg: DictConfig):
metric_logging_cfg = cfg.get("metric_logging", {})
mlogger = await get_or_create_metric_logger(process_name="Controller")
await mlogger.init_backends.call_one(metric_logging_cfg)

reward_cfg = cfg.get("reward", {})
use_model = bool(reward_cfg.get("use_model", False))
rm_specs = reward_cfg.get("rm_specs", [])
ensemble_reduce = reward_cfg.get("ensemble_reduce", "mean")

# ---- Setup services ---- #

Expand All @@ -332,7 +352,9 @@ async def main(cfg: DictConfig):
ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),
ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
RewardActor.options(**cfg.services.reward_actor).as_service(
reward_functions=[MathReward(), ThinkingReward()]
use_model=use_model,
rm_specs=rm_specs,
ensemble_reduce=ensemble_reduce,
),
)

Expand Down
16 changes: 15 additions & 1 deletion apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,20 @@ ref_model:
initial_load_path: hf://${model}
initial_load_in_hf: true

reward:
use_model: true # true -> use HF RM ensemble; false -> use builtin rewards
ensemble_reduce: mean # mean | median | max | vote
builtin: ["math", "thinking"] # used if use_rm: false
rm_specs:
- model_id: Skywork/Skywork-Reward-V2-Qwen3-0.6B
# device: cuda:1
# torch_dtype: torch.bfloat16
max_length: 4096
- model_id: Qwen/Qwen2.5-1.5B # use a REAL RM checkpoint here
# device: cuda:1
# torch_dtype: torch.bfloat16
max_length: 4096

# All resource allocations
services:
policy:
Expand All @@ -130,7 +144,7 @@ services:
procs: 1
num_replicas: 1
mesh_name: reward_actor
with_gpus: false
with_gpus: true

actors:
dataset:
Expand Down
2 changes: 1 addition & 1 deletion scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -310,4 +310,4 @@ EOF
log_info " conda deactivate && conda activate $CONDA_DEFAULT_ENV"
}

main "$@"
main "$@"
52 changes: 52 additions & 0 deletions src/forge/reward/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations
from typing import Protocol, List, Literal, Optional
import torch

class RewardFn(Protocol):
def __call__(
self,
prompts: List[str],
responses: List[str],
targets: Optional[List[str]] = None,
) -> torch.Tensor: ...

class EnsembleReward:
"""
Wraps multiple reward functions and reduces their scores.
Assumes each fn returns a 1D tensor [batch].
"""
def __init__(
self,
fns: List[RewardFn],
reduce: Literal["mean", "median", "max", "vote"] = "mean",
eps: float = 1e-5,
):
self.fns = fns
self.reduce = reduce
self.eps = eps

@torch.inference_mode()
def __call__(self, prompts, responses, targets=None) -> torch.Tensor:
scores = []
for fn in self.fns:
s = fn(prompts, responses, targets)
if not isinstance(s, torch.Tensor):
s = torch.as_tensor(s, dtype=torch.float32)
scores.append(s.float().cpu()) # keep device-agnostic; trainer can move later

stacked = torch.stack(scores, dim=0) # [n_models, batch]

if self.reduce == "mean":
# print("mean score is ", stacked.mean(0))
return stacked.mean(0)
if self.reduce == "median":
# print("median score is ", stacked.median(0).values)
return stacked.median(0).values
if self.reduce == "max":
# print("max score is ", stacked.max(0).values)
return stacked.max(0).values
if self.reduce == "vote":
# print("vote score is ", (stacked > 0.0).float().mean(0))
# Interpret >0 as "good"; vote => fraction of positives in [0,1]
return (stacked > 0.0).float().mean(0)
raise ValueError(f"Unknown reduce: {self.reduce}")
69 changes: 69 additions & 0 deletions src/forge/reward/hf_rm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations
from typing import List, Optional
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from forge.reward.rm_models import (
GRMModel, SkyworksModel, URMModel, QRMModel, GPMModel,
GRMLlama32Model, OffsetBiasModel, GRMGemmaModel, ArmorRMModel,
QwenPRMModel, Qwen72BModel, EurusPRMStage1Model, EurusPRMStage2Model,
INFORMModel, SkyworksGemmaModel, QRMGemmaModel, LDLRewardGemmaModel,
InternLM2RewardModel, InternLM2Reward7BModel, DecisionTreeRewardModel8B,
DecisionTreeRewardModel27B, Qwen72BPRMModel
)


class HFRewardModel:
"""
Minimal RM wrapper. Returns a scalar reward per sample.
- If logits dim=1, uses that as score.
- If logits dim=2, uses the last logit as "good" score.
"""
def __init__(
self,
model_id: str,
device: str = "cuda",
torch_dtype: torch.dtype = torch.bfloat16,
max_length: int = 4096,
template: str = "{prompt}\n\n{response}",
):
self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, padding_side="right", truncation_side="left")
self.model = AutoModelForSequenceClassification.from_pretrained(model_id, torch_dtype=torch_dtype)
self.device = device
self.model.to(self.device).eval()
self.max_length = max_length
self.template = template
if self.tokenizer.pad_token_id is None:
if self.tokenizer.eos_token_id is not None:
self.tokenizer.pad_token = self.tokenizer.eos_token
else:
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self._needs_resize = True
else:
self._needs_resize = False
self.model.config.pad_token_id = self.tokenizer.pad_token_id

@torch.inference_mode()
def __call__(self, prompts: List[str], responses: List[str], targets: Optional[List[str]] = None) -> torch.Tensor:
inputs = self.tokenizer(
prompts,
responses,
truncation=True,
max_length=4096,
padding=True,
return_tensors="pt"
).to(self.device)

out = self.model(**inputs)
logits = out.logits
if logits.shape[-1] == 1:
scores = torch.sigmoid(logits).item()
else:
if logits.shape[-1] == 2:
if logits[0][0] > logits[0][1]:
scores = 0.0
else:
scores = 1.0
else:
scores = logits[..., -1] # assume last logit corresponds to "positive/good"
return scores
Loading