From c7884855193d64b8de29454e7b6f8346935be44e Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Fri, 13 Feb 2026 17:31:53 +0800 Subject: [PATCH 1/7] Add ClipB example --- examples/entropy/README.md | 29 ++++ examples/entropy/clipb.yaml | 100 ++++++++++++ examples/entropy/clipb_trainer.patch | 11 ++ trinity/algorithm/__init__.py | 1 + trinity/algorithm/advantage_fn/__init__.py | 1 + .../algorithm/advantage_fn/clipb_advantage.py | 152 ++++++++++++++++++ trinity/algorithm/algorithm.py | 22 +++ trinity/common/verl_config.py | 1 + 8 files changed, 317 insertions(+) create mode 100644 examples/entropy/README.md create mode 100644 examples/entropy/clipb.yaml create mode 100644 examples/entropy/clipb_trainer.patch create mode 100644 trinity/algorithm/advantage_fn/clipb_advantage.py diff --git a/examples/entropy/README.md b/examples/entropy/README.md new file mode 100644 index 0000000000..e38144ccd4 --- /dev/null +++ b/examples/entropy/README.md @@ -0,0 +1,29 @@ +# Entropy dynamics of RL training + +This example shows the two algorithms **Clip_B** and **Clip_V** from the work [On the Entropy Dynamics in Reinforcement Fine-Tuning of Large Language Models](https://arxiv.org/pdf/2602.03392). + +## Data Preparation + +We utilize the [DAPO-Math-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset as our training set. We exclude 500 questions from the training set to form the validation set (denoted by dapo-validation-500). +The training set is filtered out samples from the training set with excessively high (≥ 15/16) or low (≤ 1/16) pass rates, as evaluated by Qwen2.5-7B-Instruct. + +## Clip_B Experiment + +1. Apply the patch to keep entropy information in the trainer batch: + +```bash +cd /path/to/Trinity-RFT +git apply examples/entropy/clipb_trainer.patch +``` + +2. Update the dataset paths in the config file [`clipb.yaml`](clipb.yaml) to point to your local data. + +3. Run the experiment: + +```bash +trinity run examples/entropy/clipb.yaml +``` + +## Clip_V Implementation + +Coming soon. diff --git a/examples/entropy/clipb.yaml b/examples/entropy/clipb.yaml new file mode 100644 index 0000000000..d78edda47e --- /dev/null +++ b/examples/entropy/clipb.yaml @@ -0,0 +1,100 @@ +project: math_dapo +name: clipb_example +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} + max_prompt_tokens: 1024 + max_response_tokens: 7168 +algorithm: + algorithm_type: grpo_verl + advantage_fn: clipb + advantage_fn_args: + mu: 2.5 + repeat_times: 16 + kl_loss_fn_args: + kl_coef: 0.0 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 20 + batch_size: 64 + explorer_input: + taskset: + name: dapo_235 + storage_type: file + path: ${oc.env:TRINITY_TASKSET_PATH} # processed DAPO-Math-17k + format: + prompt_key: 'question' + response_key: 'ground_truth' + rollout_args: + temperature: 1.0 + logprobs: 20 + eval_tasksets: + - name: dapo-validation-500 + storage_type: file + path: '/path/to/dapo-validation' # validation samples from DAPO-Math-17k + split: 'test' + repeat_times: 32 + format: + prompt_key: 'question' + response_key: 'ground_truth' + rollout_args: + temperature: 0.7 + - name: amc23 + storage_type: file + path: math-ai/amc23 # Path to the AMC23 dataset + split: 'test' + repeat_times: 32 + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 0.7 + - name: aime24 + storage_type: file + path: HuggingFaceH4/aime_2024 # Path to the AIME2024 dataset + split: 'train' + repeat_times: 32 + format: + prompt_key: 'problem' + response_key: 'answer' + rollout_args: + temperature: 0.7 + - name : aime25 + storage_type: file + path: math-ai/aime25 # Path to the AIME2025 dataset + split: 'test' + repeat_times: 32 + format: + prompt_key: 'problem' + response_key: 'answer' + rollout_args: + temperature: 0.7 + default_workflow_type: 'async_math_workflow' + default_reward_fn_type: 'math_boxed_reward' + trainer_input: + experience_buffer: + name: math_buffer + storage_type: queue + max_read_timeout: 7200 +explorer: + eval_interval: 20 + eval_on_startup: true + runner_per_model: 8 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + seed: 42 +trainer: + trainer_type: 'verl' + save_interval: 200 + trainer_config: + algorithm: + rollout_correction: + bypass_mode: false +synchronizer: + sync_method: 'nccl' + sync_interval: 1 + sync_timeout: 3200 diff --git a/examples/entropy/clipb_trainer.patch b/examples/entropy/clipb_trainer.patch new file mode 100644 index 0000000000..03ca08b29c --- /dev/null +++ b/examples/entropy/clipb_trainer.patch @@ -0,0 +1,11 @@ +--- a/trinity/trainer/verl_trainer.py ++++ b/trinity/trainer/verl_trainer.py +@@ -501,7 +501,8 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): + } + metrics.update(old_log_prob_metrics) +- old_log_prob.batch.pop("entropys") ++ # Keep entropys in batch so advantage_fn (e.g. Clip_B) can use it ++ # old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py index 52bb605bcd..e693684cb9 100644 --- a/trinity/algorithm/__init__.py +++ b/trinity/algorithm/__init__.py @@ -29,6 +29,7 @@ "multi_step_grpo": "trinity.algorithm.algorithm.MultiStepGRPOAlgorithm", "on_policy_distill": "trinity.algorithm.algorithm.OnPolicyDistillAlgorithm", "jsd": "trinity.algorithm.algorithm.JSDAlgorithm", + "grpo_verl": "trinity.algorithm.algorithm.GRPOverlAlgorithm", }, ) diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index 239862ba58..7f59211dbb 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -19,6 +19,7 @@ "rec": "trinity.algorithm.advantage_fn.rec_advantage.RECGroupedAdvantage", "on_policy_distill": "trinity.algorithm.advantage_fn.on_policy_distill_advantage.OnPolicyDistillAdvantage", "jsd": "trinity.algorithm.advantage_fn.jsd_advantage.JSDAdvantage", + "clipb": "trinity.algorithm.advantage_fn.clipb_advantage.ClipBAdvantageFn", }, ) diff --git a/trinity/algorithm/advantage_fn/clipb_advantage.py b/trinity/algorithm/advantage_fn/clipb_advantage.py new file mode 100644 index 0000000000..62898ada1a --- /dev/null +++ b/trinity/algorithm/advantage_fn/clipb_advantage.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +"""Advantage computation for Clip_B +Ref: https://arxiv.org/pdf/2602.03392""" + +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, Tuple + +import torch + +if TYPE_CHECKING: + from verl import DataProto + +from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn + + +class ClipBAdvantageFn(AdvantageFn): + """Clip_B advantage: keep all positive-advantage tokens, + one-side clip negative-advantage tokens by entropy signal.""" + + def __init__( + self, + epsilon: float = 1e-6, + mu: float = 2.5, + ) -> None: + self.epsilon = epsilon + self.mu = mu + + def __call__( + self, + exps: "DataProto", + **kwargs, + ) -> Tuple["DataProto", Dict]: + """ + Compute advantage for Clip_B. + exps should contain the following fields: + - token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + - response_mask: `(torch.Tensor)` + shape: (bs, response_length) + - uid: `(torch.Tensor)` + shape: (bs,) + - rollout_log_probs: `(torch.Tensor)` + shape: (bs, response_length) + - entropys: `(torch.Tensor)` + shape: (bs, response_length) + Returns: + exps: DataProto with advantages and returns added + metrics: Dict with clipping metrics + """ + token_level_rewards = exps.batch["token_level_rewards"] + response_mask = exps.batch["response_mask"] + index = exps.non_tensor_batch["uid"] + + response_length = token_level_rewards.shape[-1] + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0, dtype=scores.dtype, device=scores.device) + id2std[idx] = torch.tensor(1.0, dtype=scores.dtype, device=scores.device) + elif len(id2score[idx]) > 1: + group_scores = torch.stack(id2score[idx]).to( + dtype=scores.dtype, device=scores.device + ) + id2mean[idx] = torch.mean(group_scores) + id2std[idx] = torch.std(group_scores) + else: + raise ValueError(f"no score in prompt index: {idx}") + + for i in range(bsz): + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + self.epsilon) + scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask + + exps.batch["advantages"] = scores + exps.batch["returns"] = scores.clone() + + # --- BEGIN: token filtering logic --- + # Use recomputed logprobs & entropy from current model (not rollout) + LP = exps.batch["rollout_log_probs"] # [B, T], recomputed logprobs + H = exps.batch["entropys"] # [B, T], recomputed entropy + M = response_mask # [B, T], mask of valid tokens + p = LP.exp() # [B, T], probability of valid tokens + S = p * (H + LP) # [B, T], indicator + + # Detach for constructing clip mask (no gradient needed) + xS = S.detach().to(torch.float32) # [B, T] + m = M.to(torch.float32) # [B, T] + + # Masked global mean & variance (population variance, denominator = n) + n = m.sum().clamp_min(1.0) + ES = (xS * m).sum() / n # scalar + varS = ((xS - ES) ** 2 * m).sum() / n # scalar + stdS = varS.sqrt() # scalar + + # Centered signal + z = xS - ES # [B, T] + + # if stdS is too small, keep all tokens; otherwise + # keep all positive-advantage tokens; one-side clip negative-advantage tokens + if stdS.item() < 1e-12: + keep = torch.ones_like(M, dtype=M.dtype) # all kept + else: + A = exps.batch["advantages"].detach().to(torch.float32) # [B, T] + pos_mask = A > 0 + neg_mask = A < 0 + + keep_pos = torch.ones_like(pos_mask, dtype=torch.bool) # positive: all kept + keep_neg = z >= -(self.mu * stdS) # negative: lower-side clip + keep_zero = torch.ones_like(pos_mask, dtype=torch.bool) # zero: all kept + + keep_bool = torch.where(pos_mask, keep_pos, torch.where(neg_mask, keep_neg, keep_zero)) + keep = keep_bool.to(M.dtype) + + M_clipped = M * keep + exps.batch["response_mask"] = M_clipped + # --- END: token filtering logic --- + + # Monitoring metrics + total_tokens = m.sum().clamp_min(1.0) + frac_clipped = 1.0 - (M_clipped.to(torch.float32).sum() / total_tokens).item() + + A = exps.batch["advantages"].detach().to(torch.float32) + pos_mask = (A > 0).to(M.dtype) + neg_mask = (A < 0).to(M.dtype) + total_pos = (M * pos_mask).to(torch.float32).sum().clamp_min(1.0) + total_neg = (M * neg_mask).to(torch.float32).sum().clamp_min(1.0) + frac_clipped_pos = 1.0 - ((M_clipped * pos_mask).to(torch.float32).sum() / total_pos).item() + frac_clipped_neg = 1.0 - ((M_clipped * neg_mask).to(torch.float32).sum() / total_neg).item() + + metrics = { + "frac_clipped": frac_clipped, + "frac_clipped_pos": frac_clipped_pos, + "frac_clipped_neg": frac_clipped_neg, + "ES": ES.item(), + "varS": varS.item(), + } + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "epsilon": 1e-6, + "mu": 2.5, + } diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 5352bba449..35b7453222 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -540,3 +540,25 @@ def default_config(cls) -> Dict: "kl_loss_fn": "none", "entropy_loss_fn": "none", } + + +class GRPOverlAlgorithm(AlgorithmType): + """GRPO algorithm, but advantage computation is done in trainer.""" + + use_critic: bool = False + use_reference: bool = True + compute_advantage_in_trainer: bool = True + can_balance_batch: bool = True + schema: str = "experience" + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 2, + "advantage_fn": "grpo", + "sample_strategy": "default", + "policy_loss_fn": "ppo", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", + } diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 1aaa4a3c4a..689b3231b4 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -175,6 +175,7 @@ class Actor: router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig) # do not set loss_agg_mode: str = "token-mean" + loss_scale_factor: Optional[float] = None clip_ratio: float = 0.2 clip_ratio_low: Optional[float] = None clip_ratio_high: Optional[float] = None From b8d67e3a1a93ffec83404c80cd5db88d1fc74a97 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Tue, 24 Feb 2026 12:27:27 +0800 Subject: [PATCH 2/7] add clipv --- examples/entropy/README.md | 19 +- examples/entropy/clipv.yaml | 100 ++++ examples/entropy/clipv_dp_actor.py | 453 ++++++++++++++++++ examples/entropy/clipv_trainer.patch | 124 +++++ trinity/algorithm/advantage_fn/__init__.py | 1 + .../algorithm/advantage_fn/clipv_advantage.py | 153 ++++++ 6 files changed, 848 insertions(+), 2 deletions(-) create mode 100644 examples/entropy/clipv.yaml create mode 100644 examples/entropy/clipv_dp_actor.py create mode 100644 examples/entropy/clipv_trainer.patch create mode 100644 trinity/algorithm/advantage_fn/clipv_advantage.py diff --git a/examples/entropy/README.md b/examples/entropy/README.md index e38144ccd4..941cd3657d 100644 --- a/examples/entropy/README.md +++ b/examples/entropy/README.md @@ -2,6 +2,8 @@ This example shows the two algorithms **Clip_B** and **Clip_V** from the work [On the Entropy Dynamics in Reinforcement Fine-Tuning of Large Language Models](https://arxiv.org/pdf/2602.03392). +NOTE: This example is only tested on verl==0.7.0. + ## Data Preparation We utilize the [DAPO-Math-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset as our training set. We exclude 500 questions from the training set to form the validation set (denoted by dapo-validation-500). @@ -16,7 +18,7 @@ cd /path/to/Trinity-RFT git apply examples/entropy/clipb_trainer.patch ``` -2. Update the dataset paths in the config file [`clipb.yaml`](clipb.yaml) to point to your local data. +2. Update the dataset paths and other configurations in the file [`clipb.yaml`](clipb.yaml) to point to your local data. 3. Run the experiment: @@ -26,4 +28,17 @@ trinity run examples/entropy/clipb.yaml ## Clip_V Implementation -Coming soon. +1. Apply the patch to keep entropy information in the trainer batch: + +```bash +cd /path/to/Trinity-RFT +git apply examples/entropy/clipv_trainer.patch +``` + +2. Update the dataset paths and other configurations in the file [`clipv.yaml`](clipv.yaml) to point to your local data. + +3. Run the experiment: + +```bash +trinity run examples/entropy/clipv.yaml +``` diff --git a/examples/entropy/clipv.yaml b/examples/entropy/clipv.yaml new file mode 100644 index 0000000000..72deb72848 --- /dev/null +++ b/examples/entropy/clipv.yaml @@ -0,0 +1,100 @@ +project: math_dapo +name: clipv_example +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} + max_prompt_tokens: 1024 + max_response_tokens: 7168 +algorithm: + algorithm_type: grpo_verl + advantage_fn: clipv + advantage_fn_args: + mu: 8.5 + repeat_times: 8 + kl_loss_fn_args: + kl_coef: 0.0 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 20 + batch_size: 64 + explorer_input: + taskset: + name: dapo_235 + storage_type: file + path: ${oc.env:TRINITY_TASKSET_PATH} # processed DAPO-Math-17k + format: + prompt_key: 'question' + response_key: 'ground_truth' + rollout_args: + temperature: 1.0 + logprobs: 20 + eval_tasksets: + - name: dapo-validation-500 + storage_type: file + path: '/path/to/dapo-validation' # validation samples from DAPO-Math-17k + split: 'test' + repeat_times: 32 + format: + prompt_key: 'question' + response_key: 'ground_truth' + rollout_args: + temperature: 0.7 + - name: amc23 + storage_type: file + path: math-ai/amc23 # Path to the AMC23 dataset + split: 'test' + repeat_times: 32 + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 0.7 + - name: aime24 + storage_type: file + path: HuggingFaceH4/aime_2024 # Path to the AIME2024 dataset + split: 'train' + repeat_times: 32 + format: + prompt_key: 'problem' + response_key: 'answer' + rollout_args: + temperature: 0.7 + - name : aime25 + storage_type: file + path: math-ai/aime25 # Path to the AIME2025 dataset + split: 'test' + repeat_times: 32 + format: + prompt_key: 'problem' + response_key: 'answer' + rollout_args: + temperature: 0.7 + default_workflow_type: 'async_math_workflow' + default_reward_fn_type: 'math_boxed_reward' + trainer_input: + experience_buffer: + name: math_buffer + storage_type: queue + max_read_timeout: 7200 +explorer: + eval_interval: 20 + eval_on_startup: true + runner_per_model: 8 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + seed: 42 +trainer: + trainer_type: 'verl' + save_interval: 100 + trainer_config: + algorithm: + rollout_correction: + bypass_mode: false +synchronizer: + sync_method: 'nccl' + sync_interval: 1 + sync_timeout: 3600 diff --git a/examples/entropy/clipv_dp_actor.py b/examples/entropy/clipv_dp_actor.py new file mode 100644 index 0000000000..cf48ee63db --- /dev/null +++ b/examples/entropy/clipv_dp_actor.py @@ -0,0 +1,453 @@ +""" +Single Process Actor. +Modified from https://github.com/volcengine/verl/blob/v0.7.0/verl/workers/actor/dp_actor.py +Note: This patch only works for verl==0.7.0 +""" + +import logging +import os +from typing import cast + +import torch +import torch.nn.functional as F +import verl.utils.torch_functional as verl_F +from verl import DataProto +from verl.utils.attention_utils import ( + index_first_axis, + pad_input, + rearrange, + unpad_input, +) +from verl.utils.debug import GPUMemoryLogger +from verl.utils.device import get_device_id +from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch +from verl.utils.torch_functional import logprobs_from_logits +from verl.utils.ulysses import ( + gather_outputs_and_unpad, + ulysses_pad, + ulysses_pad_and_slice_inputs, +) + +from trinity.trainer.verl.dp_actor import DataParallelPPOActor as OriginalDPActor + +__all__ = ["DataParallelPPOActor"] + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def compute_N_from_logits( + logits: torch.Tensor, entropy: torch.Tensor | None = None +) -> torch.Tensor: + """ + logits: [..., V], better to use float32 for stability + entropy: [...], if None, compute H = -sum p log p from logits + return: [...], the same shape as logits without the last dimension + """ + z = logits # use float32 for stability + logp = F.log_softmax(z, dim=-1) # [..., V] + p = logp.exp() # [..., V] + + p2 = p * p # [..., V] + sum_p2 = p2.sum(dim=-1) # [...] + term2 = (p2 * logp).sum(dim=-1) # sum p_i^2 log p_i + + if entropy is None: + entropy = -(p * logp).sum(dim=-1) # H = -sum p log p + N = entropy * sum_p2 + term2 # [... ] + return N + + +class DataParallelPPOActor(OriginalDPActor): + @GPUMemoryLogger(role="dp actor", logger=logger) + def compute_log_prob( + self, data: DataProto, calculate_entropy: bool = False, calculate_nec: bool = False + ) -> ( + tuple[torch.Tensor, torch.Tensor | None] + | tuple[torch.Tensor, torch.Tensor | None, torch.Tensor] + ): + """Compute the log probability of the responses given input_ids, attention_mask and position_ids + + Args: + data (DataProto): a DataProto containing keys + + ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the + concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. + + ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``responses``: tensor of shape [batch_size, response_length]. torch.int64. + + Returns: + - when calculate_nec is False: (log_probs, entropys) + - when calculate_nec is True: (log_probs, entropys, necs) + """ + # set to eval + self.actor_module.eval() + + micro_batch_size = data.meta_info["micro_batch_size"] + temperature = data.meta_info[ + "temperature" + ] # temperature must be in the data.meta_info to avoid silent error + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] + non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] + + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + if use_dynamic_bsz: + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size + micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) + else: + micro_batches = data.split(micro_batch_size) + + log_probs_lst = [] + entropy_lst = [] + # !!! Patch starts !!! + nec_lst = [] + # !!! Patch ends !!! + for micro_batch in micro_batches: + micro_batch = micro_batch.to(get_device_id()) + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} + with torch.no_grad(): + # !!! Patch starts !!! + outputs = self._forward_micro_batch( + model_inputs, + temperature=temperature, + calculate_entropy=calculate_entropy, + calculate_nec=calculate_nec, + ) + # !!! Patch ends !!! + # !!! Patch starts !!! + if calculate_nec: + entropy, log_probs, nec = cast( + tuple[torch.Tensor, torch.Tensor, torch.Tensor], outputs + ) + nec_lst.append(nec) + else: + entropy, log_probs = cast(tuple[torch.Tensor, torch.Tensor], outputs) + # !!! Patch ends !!! + log_probs_lst.append(log_probs) + if calculate_entropy: + entropy_lst.append(entropy) + + log_probs = torch.concat(log_probs_lst, dim=0) + entropys = None + if calculate_entropy: + entropys = torch.concat(entropy_lst, dim=0) + # !!! Patch starts !!! + necs = None + if calculate_nec: + necs = torch.concat(nec_lst, dim=0) + # !!! Patch ends !!! + + if use_dynamic_bsz: + log_probs = restore_dynamic_batch(log_probs, batch_idx_list) + if calculate_entropy: + entropys = restore_dynamic_batch(entropys, batch_idx_list) + # !!! Patch starts !!! + if calculate_nec: + necs = restore_dynamic_batch(necs, batch_idx_list) + # !!! Patch ends !!! + + # !!! Patch starts !!! + if calculate_nec: + return log_probs, entropys, necs + # !!! Patch ends !!! + return log_probs, entropys + + def _forward_micro_batch( # type: ignore # noqa: C901 + self, micro_batch, temperature, calculate_entropy: bool = False, calculate_nec: bool = False + ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Returns: + entropy: # (bs, response_len) + log_probs: # (bs, response_len) + """ + response_length = micro_batch["responses"].size(-1) + multi_modal_inputs = {} + if "multi_modal_inputs" in micro_batch.keys(): + from verl.utils.model import extract_multi_modal_inputs + + multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"]) + + with torch.autocast(device_type=self.device_name, dtype=self.param_dtype): + input_ids = micro_batch["input_ids"] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] + entropy = None + # !!! Patch starts !!! + nec = None + # !!! Patch ends !!! + if position_ids.dim() == 3: # qwen2vl mrope + position_ids = position_ids.transpose(0, 1) # (bsz, 4, seqlen) -> (4, bsz, seqlen) + + if self.use_remove_padding: + input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + if position_ids.dim() == 3: + position_ids_rmpad = ( + index_first_axis( + rearrange(position_ids, "c b s ... -> (b s) c ..."), indices + ) + .transpose(0, 1) + .unsqueeze(1) + ) # (4, bsz, seqlen) -> (4, 1, bsz * seqlen) + else: + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) + + is_mask_all_zero = attention_mask.sum() == 0 + if is_mask_all_zero: + input_ids_rmpad = torch.zeros( + (1, self.ulysses_sequence_parallel_size), + device=input_ids.device, + dtype=input_ids.dtype, + ) + if position_ids.dim() == 3: + position_ids_rmpad = torch.zeros( + (position_ids.shape[0], 1, self.ulysses_sequence_parallel_size), + device=position_ids.device, + dtype=position_ids.dtype, + ) + else: + position_ids_rmpad = torch.zeros( + (1, self.ulysses_sequence_parallel_size), + device=position_ids.device, + dtype=position_ids.dtype, + ) + + if "image_bound" in multi_modal_inputs: + from verl.utils.dataset.vision_utils import ( + process_multi_modal_inputs_for_minicpmo, + ) + + multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( + input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs + ) + + # for compute the log_prob + input_ids_rmpad_rolled = torch.roll( + input_ids_rmpad, shifts=-1, dims=1 + ) # (1, total_nnz) + + # pad and slice the inputs if sp > 1 + if self.use_ulysses_sp: + is_vlm_model = hasattr( + getattr(self.actor_module, "module", self.actor_module).config, + "vision_config", + ) + if is_vlm_model: + # vlm model's inputs will be sliced after embedding + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) + else: + ( + input_ids_rmpad, + position_ids_rmpad, + pad_size, + ) = ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, + position_ids_rmpad=None, + sp_size=self.ulysses_sequence_parallel_size, + ) + + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze( + 0 + ) # ((total_nnz / sp) + pad) + + # only pass input_ids and position_ids to enable flash_attn_varlen + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = self.actor_module( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if self.use_fused_kernels: + log_probs = output.log_probs.squeeze(0) # (total_nnz,) + entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,) + # !!! Patch starts !!! + if calculate_nec: + raise RuntimeError( + "calculate_nec=True is not supported with fused kernels in _forward_micro_batch" + ) + # !!! Patch ends !!! + + else: + logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) + logits_rmpad.div_(temperature) + + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + inplace_backward = True + if calculate_entropy: + inplace_backward = False + log_probs = logprobs_from_logits( + logits=logits_rmpad, + labels=input_ids_rmpad_rolled, + inplace_backward=inplace_backward, + ) + + # compute entropy + if calculate_entropy: + if not self.config.entropy_checkpointing: + entropy_rmpad = self.compute_entropy_from_logits( + logits_rmpad + ) # ((total_nnz / sp) + pad) + else: + entropy_rmpad = torch.utils.checkpoint.checkpoint( + self.compute_entropy_from_logits, logits_rmpad + ) + # !!! Patch starts !!! + if calculate_nec: + H_for_N = entropy_rmpad.to(torch.float32) if calculate_entropy else None + N_rmpad = compute_N_from_logits(logits_rmpad, entropy=H_for_N) + # !!! Patch ends !!! + + # gather log_prob if sp > 1 + if self.use_ulysses_sp: + # gather and unpad for the ulysses sp + log_probs = gather_outputs_and_unpad( + log_probs, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + if calculate_entropy: + entropy_rmpad = gather_outputs_and_unpad( + entropy_rmpad, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + # !!! Patch starts !!! + if calculate_nec: + N_rmpad = gather_outputs_and_unpad( + N_rmpad, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + # !!! Patch ends !!! + + if is_mask_all_zero: + log_probs = log_probs[:0] + if calculate_entropy: + entropy_rmpad = entropy_rmpad[:0] + # !!! Patch starts !!! + if calculate_nec: + N_rmpad = N_rmpad[:0] + # !!! Patch ends !!! + + # pad back to (bsz, seqlen) + if calculate_entropy: + full_entropy = pad_input( + hidden_states=entropy_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + # !!! Patch starts !!! + if calculate_nec: + full_N = pad_input( + hidden_states=N_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + # !!! Patch ends !!! + full_log_probs = pad_input( + hidden_states=log_probs.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + + # only return response part: + if calculate_entropy: + entropy = full_entropy.squeeze(-1)[ + :, -response_length - 1 : -1 + ] # (bsz, response_length) + # !!! Patch starts !!! + if calculate_nec: + nec = full_N.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) + # !!! Patch ends !!! + log_probs = full_log_probs.squeeze(-1)[ + :, -response_length - 1 : -1 + ] # (bsz, response_length) + + else: # not using rmpad and no ulysses sp + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = self.actor_module( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if self.use_fused_kernels: + log_probs = output.log_probs[:, -response_length - 1 : -1] + entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length) + # !!! Patch starts !!! + if calculate_nec: + raise RuntimeError( + "calculate_nec=True is not supported with fused kernels in _forward_micro_batch" + ) + # !!! Patch ends !!! + + else: + logits = output.logits + + logits.div_(temperature) + logits = logits[ + :, -response_length - 1 : -1, : + ] # (bsz, response_length, vocab_size) + log_probs = logprobs_from_logits(logits, micro_batch["responses"]) + if calculate_entropy: + if not self.config.entropy_checkpointing: + entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) + else: + entropy = torch.utils.checkpoint.checkpoint( + verl_F.entropy_from_logits, logits + ) + # !!! Patch starts !!! + if calculate_nec: + H_for_N = entropy.to(torch.float32) if calculate_entropy else None + nec = compute_N_from_logits(logits, entropy=H_for_N) + # !!! Patch ends !!! + + # !!! Patch starts !!! + if calculate_nec: + return entropy, log_probs, nec + # !!! Patch ends !!! + return entropy, log_probs diff --git a/examples/entropy/clipv_trainer.patch b/examples/entropy/clipv_trainer.patch new file mode 100644 index 0000000000..160172b414 --- /dev/null +++ b/examples/entropy/clipv_trainer.patch @@ -0,0 +1,124 @@ +--- a/trinity/trainer/verl/fsdp_workers.py ++++ b/trinity/trainer/verl/fsdp_workers.py +@@ -610,7 +610,8 @@ + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): +- from trinity.trainer.verl.dp_actor import DataParallelPPOActor ++ from examples.entropy.clipv_dp_actor import DataParallelPPOActor +@@ -903,10 +904,16 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension): + # perform recompute log_prob + with self.ulysses_sharding_manager: + with adapter_ctx: +- output, entropys = self.actor.compute_log_prob( +- data=data, calculate_entropy=not is_lora ++ # !!! Patch starts !!! ++ output, entropys, necs = self.actor.compute_log_prob( ++ data=data, calculate_entropy=not is_lora, calculate_nec=True + ) ++ # !!! Patch ends !!! + tensors = {"ref_log_prob": output} if is_lora else {"old_log_probs": output} ++ # !!! Patch starts !!! ++ if necs is not None: ++ tensors["necs"] = necs ++ # !!! Patch ends !!! + if not is_lora: + tensors["entropys"] = entropys + output = DataProto.from_dict( + +--- a/trinity/trainer/verl_trainer.py ++++ b/trinity/trainer/verl_trainer.py +@@ -24,13 +24,15 @@ from verl.trainer.ppo.ray_trainer import ( + Role, + create_colocated_worker_cls, + ) ++from verl import DataProto ++from verl.utils import tensordict_utils as tu ++from verl.workers.utils.padding import left_right_2_no_padding, no_padding_2_padding + from verl.utils import hf_processor, hf_tokenizer + from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path + from verl.utils.debug import marked_timer + from verl.utils.fs import copy_local_path_from_hdfs + from verl.utils.metric import reduce_metrics + from verl.workers.config import FSDPEngineConfig +- + from trinity.algorithm import ADVANTAGE_FN, ALGORITHM_TYPE, KL_FN + from trinity.algorithm.utils import prefix_metrics + from trinity.common.config import Config +@@ -433,6 +435,34 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): + self.config.actor_rollout_ref.actor.optim.total_training_steps = self.total_training_steps + self.config.critic.optim.total_training_steps = self.total_training_steps + ++ def _compute_old_log_prob(self, batch: DataProto): ++ if self.use_legacy_worker_impl == "disable": ++ # TODO: remove step 1, 2, 4 after we make the whole training tensordict and padding free ++ # step 1: convert dataproto to tensordict. ++ batch_td = batch.to_tensordict() ++ # step 2: convert from padding to nopadding ++ batch_td = left_right_2_no_padding(batch_td) ++ # step 3: add meta info ++ tu.assign_non_tensor(batch_td, calculate_entropy=True, compute_loss=False) ++ output = self.actor_rollout_wg.compute_log_prob(batch_td) ++ # gather output ++ entropy = tu.get(output, "entropy") ++ log_probs = tu.get(output, "log_probs") ++ old_log_prob_mfu = tu.get(output, "metrics")["mfu"] ++ necs = tu.get(output, "necs") ++ # step 4. No padding to padding ++ entropy = no_padding_2_padding(entropy, batch_td) ++ log_probs = no_padding_2_padding(log_probs, batch_td) ++ necs = no_padding_2_padding(necs, batch_td) ++ # step 5: rebuild a tensordict and convert to dataproto ++ old_log_prob = tu.get_tensordict({"old_log_probs": log_probs.float(), "entropys": entropy.float(), "necs": necs.float()}) ++ old_log_prob = DataProto.from_tensordict(old_log_prob) ++ else: ++ old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) ++ old_log_prob_mfu = 0 ++ return old_log_prob, old_log_prob_mfu ++ ++ + async def save_state_dict(self): # checkpoint sync + actor_local_path = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.global_steps}", "actor" +@@ -488,6 +518,10 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): + with marked_timer("old_log_prob", timing_raw, color="blue"): + old_log_prob, old_log_prob_mfu = self._compute_old_log_prob(batch) + entropys = old_log_prob.batch["entropys"] ++ # !!! Patch starts !!! ++ new_log_probs = old_log_prob.batch["old_log_probs"] ++ necs = old_log_prob.batch["necs"] ++ # !!! Patch ends !!! + response_masks = batch.batch["response_mask"] + actor_config = self.config.actor_rollout_ref.actor + entropy_agg = agg_loss( +@@ -501,13 +535,19 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): + "perf/mfu/actor_infer": old_log_prob_mfu, + } + metrics.update(old_log_prob_metrics) +- old_log_prob.batch.pop("entropys") ++ # old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + from verl.utils.debug.metrics import calculate_debug_metrics + + metrics.update(calculate_debug_metrics(batch)) ++ ++ # !!! Patch starts !!! ++ batch.batch["new_log_probs"] = new_log_probs ++ batch.batch["new_entropys"] = entropys ++ batch.batch["necs"] = necs ++ # !!! Patch ends !!! + + if self.algorithm.use_reference: # ref_logprob may not be used + # compute reference log_prob +@@ -526,7 +566,8 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): + batch, kl_metrics = self.kl_fn.apply_kl_penalty_to_reward(batch) + metrics.update(prefix_metrics(kl_metrics, prefix="critic")) + # compute advantages, executed on the driver process +- batch, _ = self.advantage_fn(batch) ++ batch, adv_metrics = self.advantage_fn(batch) ++ metrics.update(prefix_metrics(adv_metrics, prefix="clipv")) + else: + # skip token_level_scores for sft/dpo + if "token_level_scores" in batch.batch.keys(): diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index 7f59211dbb..fc23a412df 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -20,6 +20,7 @@ "on_policy_distill": "trinity.algorithm.advantage_fn.on_policy_distill_advantage.OnPolicyDistillAdvantage", "jsd": "trinity.algorithm.advantage_fn.jsd_advantage.JSDAdvantage", "clipb": "trinity.algorithm.advantage_fn.clipb_advantage.ClipBAdvantageFn", + "clipv": "trinity.algorithm.advantage_fn.clipv_advantage.ClipVAdvantageFn", }, ) diff --git a/trinity/algorithm/advantage_fn/clipv_advantage.py b/trinity/algorithm/advantage_fn/clipv_advantage.py new file mode 100644 index 0000000000..0a126c839e --- /dev/null +++ b/trinity/algorithm/advantage_fn/clipv_advantage.py @@ -0,0 +1,153 @@ +"""GRPO advantage computation with Clip_V token filtering. +""" + +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, Tuple + +import torch + +if TYPE_CHECKING: + from verl import DataProto + +from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn + + +class ClipVAdvantageFn(AdvantageFn): + """Clip_V advantage: one-side clip only negative-advantage tokens, + and cap the global clipped-token ratio.""" + + def __init__( + self, + epsilon: float = 1e-6, + mu: float = 2.0, + max_frac: float = 1e-4, + ) -> None: + self.epsilon = epsilon + self.mu = mu + self.max_frac = max_frac + + def __call__( + self, + exps: "DataProto", + **kwargs, + ) -> Tuple["DataProto", Dict]: + token_level_rewards = exps.batch["token_level_rewards"] + response_mask = exps.batch["response_mask"] + index = exps.non_tensor_batch["uid"] + + new_log_probs = exps.batch["old_log_probs"] + new_entropys = exps.batch["entropys"] + necs = exps.batch["necs"] + + response_length = token_level_rewards.shape[-1] + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + + for idx, grouped_scores in id2score.items(): + if len(grouped_scores) == 1: + id2mean[idx] = torch.tensor(0.0, dtype=scores.dtype, device=scores.device) + id2std[idx] = torch.tensor(1.0, dtype=scores.dtype, device=scores.device) + elif len(grouped_scores) > 1: + group_scores = torch.stack(grouped_scores).to( + dtype=scores.dtype, device=scores.device + ) + id2mean[idx] = torch.mean(group_scores) + id2std[idx] = torch.std(group_scores) + else: + raise ValueError(f"no score in prompt index: {idx}") + + for i in range(bsz): + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + self.epsilon) + scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask + + exps.batch["advantages"] = scores + exps.batch["returns"] = scores.clone() + + LP = new_log_probs + H = new_entropys + N = necs + M = response_mask + p = LP.exp() + S = p * (H + LP) + + xD = (N - S).detach().to(torch.float32) + A = exps.batch["advantages"].detach().to(torch.float32) + m = M.to(torch.float32) + + n = m.sum().clamp_min(1.0) + mean_d = (xD * m).sum() / n + var_d = ((xD - mean_d) ** 2 * m).sum() / n + std_d = var_d.sqrt() + + if std_d.item() < 1e-12: + keep = torch.ones_like(M, dtype=M.dtype) + else: + pos_mask = A > 0 + neg_mask = A < 0 + + keep_neg = xD <= (self.mu * std_d) + keep_bool = torch.where(pos_mask, torch.ones_like(pos_mask), keep_neg) + + total_tokens = m.sum().clamp_min(1.0) + clipped_mask = (M > 0) & (~keep_bool) + frac_clipped = (clipped_mask.to(torch.float32).sum() / total_tokens).item() + + if frac_clipped <= self.max_frac: + keep = keep_bool.to(M.dtype) + else: + max_clipped_tokens = max(int(self.max_frac * total_tokens.item()), 1) + neg_to_clip = neg_mask & (M > 0) & (~keep_neg) + neg_to_clip_count = int(neg_to_clip.to(torch.int32).sum().item()) + + if neg_to_clip_count <= max_clipped_tokens: + keep = keep_bool.to(M.dtype) + else: + # Keep only top-K most violating negative-advantage tokens clipped. + candidate_scores = xD.masked_fill(~neg_to_clip, float("-inf")).view(-1) + k = min(max_clipped_tokens, neg_to_clip_count) + _, indices = torch.topk(candidate_scores, k, largest=True) + + limited_clip_mask = torch.zeros_like(candidate_scores, dtype=torch.bool) + limited_clip_mask[indices] = True + limited_clip_mask = limited_clip_mask.view_as(xD) + + final_keep = keep_bool.clone() + final_keep[neg_to_clip] = True + final_keep[limited_clip_mask] = False + keep = final_keep.to(M.dtype) + + M_clipped = M * keep + exps.batch["response_mask"] = M_clipped + + total_tokens = M.to(torch.float32).sum().clamp_min(1.0) + frac_clipped = 1.0 - (M_clipped.to(torch.float32).sum() / total_tokens).item() + + pos_mask = (A > 0).to(M.dtype) + neg_mask = (A < 0).to(M.dtype) + total_pos = (M * pos_mask).to(torch.float32).sum().clamp_min(1.0) + total_neg = (M * neg_mask).to(torch.float32).sum().clamp_min(1.0) + frac_clipped_pos = 1.0 - ((M_clipped * pos_mask).to(torch.float32).sum() / total_pos).item() + frac_clipped_neg = 1.0 - ((M_clipped * neg_mask).to(torch.float32).sum() / total_neg).item() + + metrics = { + "frac_clipped": frac_clipped, + "frac_clipped_pos": frac_clipped_pos, + "frac_clipped_neg": frac_clipped_neg, + "stdD": std_d.item(), + } + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "epsilon": 1e-6, + "mu": 8.5, + "max_frac": 1e-4, + } From 8e01fabf0ccacb3146622a59a7d5161766f44b5c Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Tue, 24 Feb 2026 12:29:35 +0800 Subject: [PATCH 3/7] fix typo --- examples/entropy/clipb.yaml | 2 +- examples/entropy/clipv.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/entropy/clipb.yaml b/examples/entropy/clipb.yaml index d78edda47e..3e0fdebfcf 100644 --- a/examples/entropy/clipb.yaml +++ b/examples/entropy/clipb.yaml @@ -61,7 +61,7 @@ buffer: response_key: 'answer' rollout_args: temperature: 0.7 - - name : aime25 + - name: aime25 storage_type: file path: math-ai/aime25 # Path to the AIME2025 dataset split: 'test' diff --git a/examples/entropy/clipv.yaml b/examples/entropy/clipv.yaml index 72deb72848..423e048a5d 100644 --- a/examples/entropy/clipv.yaml +++ b/examples/entropy/clipv.yaml @@ -61,7 +61,7 @@ buffer: response_key: 'answer' rollout_args: temperature: 0.7 - - name : aime25 + - name: aime25 storage_type: file path: math-ai/aime25 # Path to the AIME2025 dataset split: 'test' From 1bff739d5c5deb3e663a0db92e467e08572ebec7 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Tue, 24 Feb 2026 15:22:14 +0800 Subject: [PATCH 4/7] remove some comments --- examples/entropy/clipv_dp_actor.py | 32 ---------------------------- examples/entropy/clipv_trainer.patch | 8 ------- 2 files changed, 40 deletions(-) diff --git a/examples/entropy/clipv_dp_actor.py b/examples/entropy/clipv_dp_actor.py index cf48ee63db..ede043607f 100644 --- a/examples/entropy/clipv_dp_actor.py +++ b/examples/entropy/clipv_dp_actor.py @@ -106,22 +106,17 @@ def compute_log_prob( log_probs_lst = [] entropy_lst = [] - # !!! Patch starts !!! nec_lst = [] - # !!! Patch ends !!! for micro_batch in micro_batches: micro_batch = micro_batch.to(get_device_id()) model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} with torch.no_grad(): - # !!! Patch starts !!! outputs = self._forward_micro_batch( model_inputs, temperature=temperature, calculate_entropy=calculate_entropy, calculate_nec=calculate_nec, ) - # !!! Patch ends !!! - # !!! Patch starts !!! if calculate_nec: entropy, log_probs, nec = cast( tuple[torch.Tensor, torch.Tensor, torch.Tensor], outputs @@ -129,7 +124,6 @@ def compute_log_prob( nec_lst.append(nec) else: entropy, log_probs = cast(tuple[torch.Tensor, torch.Tensor], outputs) - # !!! Patch ends !!! log_probs_lst.append(log_probs) if calculate_entropy: entropy_lst.append(entropy) @@ -138,25 +132,19 @@ def compute_log_prob( entropys = None if calculate_entropy: entropys = torch.concat(entropy_lst, dim=0) - # !!! Patch starts !!! necs = None if calculate_nec: necs = torch.concat(nec_lst, dim=0) - # !!! Patch ends !!! if use_dynamic_bsz: log_probs = restore_dynamic_batch(log_probs, batch_idx_list) if calculate_entropy: entropys = restore_dynamic_batch(entropys, batch_idx_list) - # !!! Patch starts !!! if calculate_nec: necs = restore_dynamic_batch(necs, batch_idx_list) - # !!! Patch ends !!! - # !!! Patch starts !!! if calculate_nec: return log_probs, entropys, necs - # !!! Patch ends !!! return log_probs, entropys def _forward_micro_batch( # type: ignore # noqa: C901 @@ -180,9 +168,7 @@ def _forward_micro_batch( # type: ignore # noqa: C901 attention_mask = micro_batch["attention_mask"] position_ids = micro_batch["position_ids"] entropy = None - # !!! Patch starts !!! nec = None - # !!! Patch ends !!! if position_ids.dim() == 3: # qwen2vl mrope position_ids = position_ids.transpose(0, 1) # (bsz, 4, seqlen) -> (4, bsz, seqlen) @@ -291,12 +277,10 @@ def _forward_micro_batch( # type: ignore # noqa: C901 if self.use_fused_kernels: log_probs = output.log_probs.squeeze(0) # (total_nnz,) entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,) - # !!! Patch starts !!! if calculate_nec: raise RuntimeError( "calculate_nec=True is not supported with fused kernels in _forward_micro_batch" ) - # !!! Patch ends !!! else: logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) @@ -322,11 +306,9 @@ def _forward_micro_batch( # type: ignore # noqa: C901 entropy_rmpad = torch.utils.checkpoint.checkpoint( self.compute_entropy_from_logits, logits_rmpad ) - # !!! Patch starts !!! if calculate_nec: H_for_N = entropy_rmpad.to(torch.float32) if calculate_entropy else None N_rmpad = compute_N_from_logits(logits_rmpad, entropy=H_for_N) - # !!! Patch ends !!! # gather log_prob if sp > 1 if self.use_ulysses_sp: @@ -344,7 +326,6 @@ def _forward_micro_batch( # type: ignore # noqa: C901 unpad_dim=0, padding_size=pad_size, ) - # !!! Patch starts !!! if calculate_nec: N_rmpad = gather_outputs_and_unpad( N_rmpad, @@ -352,16 +333,13 @@ def _forward_micro_batch( # type: ignore # noqa: C901 unpad_dim=0, padding_size=pad_size, ) - # !!! Patch ends !!! if is_mask_all_zero: log_probs = log_probs[:0] if calculate_entropy: entropy_rmpad = entropy_rmpad[:0] - # !!! Patch starts !!! if calculate_nec: N_rmpad = N_rmpad[:0] - # !!! Patch ends !!! # pad back to (bsz, seqlen) if calculate_entropy: @@ -371,7 +349,6 @@ def _forward_micro_batch( # type: ignore # noqa: C901 batch=batch_size, seqlen=seqlen, ) - # !!! Patch starts !!! if calculate_nec: full_N = pad_input( hidden_states=N_rmpad.unsqueeze(-1), @@ -379,7 +356,6 @@ def _forward_micro_batch( # type: ignore # noqa: C901 batch=batch_size, seqlen=seqlen, ) - # !!! Patch ends !!! full_log_probs = pad_input( hidden_states=log_probs.unsqueeze(-1), indices=indices, @@ -392,10 +368,8 @@ def _forward_micro_batch( # type: ignore # noqa: C901 entropy = full_entropy.squeeze(-1)[ :, -response_length - 1 : -1 ] # (bsz, response_length) - # !!! Patch starts !!! if calculate_nec: nec = full_N.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) - # !!! Patch ends !!! log_probs = full_log_probs.squeeze(-1)[ :, -response_length - 1 : -1 ] # (bsz, response_length) @@ -418,12 +392,10 @@ def _forward_micro_batch( # type: ignore # noqa: C901 if self.use_fused_kernels: log_probs = output.log_probs[:, -response_length - 1 : -1] entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length) - # !!! Patch starts !!! if calculate_nec: raise RuntimeError( "calculate_nec=True is not supported with fused kernels in _forward_micro_batch" ) - # !!! Patch ends !!! else: logits = output.logits @@ -440,14 +412,10 @@ def _forward_micro_batch( # type: ignore # noqa: C901 entropy = torch.utils.checkpoint.checkpoint( verl_F.entropy_from_logits, logits ) - # !!! Patch starts !!! if calculate_nec: H_for_N = entropy.to(torch.float32) if calculate_entropy else None nec = compute_N_from_logits(logits, entropy=H_for_N) - # !!! Patch ends !!! - # !!! Patch starts !!! if calculate_nec: return entropy, log_probs, nec - # !!! Patch ends !!! return entropy, log_probs diff --git a/examples/entropy/clipv_trainer.patch b/examples/entropy/clipv_trainer.patch index 160172b414..5bc619c4e4 100644 --- a/examples/entropy/clipv_trainer.patch +++ b/examples/entropy/clipv_trainer.patch @@ -12,16 +12,12 @@ with adapter_ctx: - output, entropys = self.actor.compute_log_prob( - data=data, calculate_entropy=not is_lora -+ # !!! Patch starts !!! + output, entropys, necs = self.actor.compute_log_prob( + data=data, calculate_entropy=not is_lora, calculate_nec=True ) -+ # !!! Patch ends !!! tensors = {"ref_log_prob": output} if is_lora else {"old_log_probs": output} -+ # !!! Patch starts !!! + if necs is not None: + tensors["necs"] = necs -+ # !!! Patch ends !!! if not is_lora: tensors["entropys"] = entropys output = DataProto.from_dict( @@ -84,10 +80,8 @@ with marked_timer("old_log_prob", timing_raw, color="blue"): old_log_prob, old_log_prob_mfu = self._compute_old_log_prob(batch) entropys = old_log_prob.batch["entropys"] -+ # !!! Patch starts !!! + new_log_probs = old_log_prob.batch["old_log_probs"] + necs = old_log_prob.batch["necs"] -+ # !!! Patch ends !!! response_masks = batch.batch["response_mask"] actor_config = self.config.actor_rollout_ref.actor entropy_agg = agg_loss( @@ -104,11 +98,9 @@ metrics.update(calculate_debug_metrics(batch)) + -+ # !!! Patch starts !!! + batch.batch["new_log_probs"] = new_log_probs + batch.batch["new_entropys"] = entropys + batch.batch["necs"] = necs -+ # !!! Patch ends !!! if self.algorithm.use_reference: # ref_logprob may not be used # compute reference log_prob From 92a23e52ff373aa4971f7738704d97b5fd4df0d8 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Tue, 24 Feb 2026 16:31:51 +0800 Subject: [PATCH 5/7] add registry for compute_log_prob --- examples/entropy/clipv.yaml | 3 ++ examples/entropy/clipv_trainer.patch | 49 +++++++--------------------- trinity/common/verl_config.py | 3 ++ trinity/trainer/verl/fsdp_workers.py | 23 +++++++++---- 4 files changed, 34 insertions(+), 44 deletions(-) diff --git a/examples/entropy/clipv.yaml b/examples/entropy/clipv.yaml index 423e048a5d..280990bf08 100644 --- a/examples/entropy/clipv.yaml +++ b/examples/entropy/clipv.yaml @@ -91,6 +91,9 @@ trainer: trainer_type: 'verl' save_interval: 100 trainer_config: + actor_rollout_ref: + actor: + log_prob_fn: clipv_entropy_nec algorithm: rollout_correction: bypass_mode: false diff --git a/examples/entropy/clipv_trainer.patch b/examples/entropy/clipv_trainer.patch index 5bc619c4e4..b80d957125 100644 --- a/examples/entropy/clipv_trainer.patch +++ b/examples/entropy/clipv_trainer.patch @@ -1,27 +1,4 @@ ---- a/trinity/trainer/verl/fsdp_workers.py -+++ b/trinity/trainer/verl/fsdp_workers.py -@@ -610,7 +610,8 @@ - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): -- from trinity.trainer.verl.dp_actor import DataParallelPPOActor -+ from examples.entropy.clipv_dp_actor import DataParallelPPOActor -@@ -903,10 +904,16 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension): - # perform recompute log_prob - with self.ulysses_sharding_manager: - with adapter_ctx: -- output, entropys = self.actor.compute_log_prob( -- data=data, calculate_entropy=not is_lora -+ output, entropys, necs = self.actor.compute_log_prob( -+ data=data, calculate_entropy=not is_lora, calculate_nec=True - ) - tensors = {"ref_log_prob": output} if is_lora else {"old_log_probs": output} -+ if necs is not None: -+ tensors["necs"] = necs - if not is_lora: - tensors["entropys"] = entropys - output = DataProto.from_dict( - +diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -24,13 +24,15 @@ from verl.trainer.ppo.ray_trainer import ( @@ -41,11 +18,14 @@ from trinity.algorithm import ADVANTAGE_FN, ALGORITHM_TYPE, KL_FN from trinity.algorithm.utils import prefix_metrics from trinity.common.config import Config -@@ -433,6 +435,34 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): +@@ -433,6 +435,37 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): self.config.actor_rollout_ref.actor.optim.total_training_steps = self.total_training_steps self.config.critic.optim.total_training_steps = self.total_training_steps + def _compute_old_log_prob(self, batch: DataProto): ++ """ ++ We add nec to the batch to make the advantage function (e.g. Clip_V) use it. ++ """ + if self.use_legacy_worker_impl == "disable": + # TODO: remove step 1, 2, 4 after we make the whole training tensordict and padding free + # step 1: convert dataproto to tensordict. @@ -76,16 +56,7 @@ async def save_state_dict(self): # checkpoint sync actor_local_path = os.path.join( self.config.trainer.default_local_dir, f"global_step_{self.global_steps}", "actor" -@@ -488,6 +518,10 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): - with marked_timer("old_log_prob", timing_raw, color="blue"): - old_log_prob, old_log_prob_mfu = self._compute_old_log_prob(batch) - entropys = old_log_prob.batch["entropys"] -+ new_log_probs = old_log_prob.batch["old_log_probs"] -+ necs = old_log_prob.batch["necs"] - response_masks = batch.batch["response_mask"] - actor_config = self.config.actor_rollout_ref.actor - entropy_agg = agg_loss( -@@ -501,13 +535,19 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): +@@ -501,13 +534,19 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): "perf/mfu/actor_infer": old_log_prob_mfu, } metrics.update(old_log_prob_metrics) @@ -98,13 +69,15 @@ metrics.update(calculate_debug_metrics(batch)) + -+ batch.batch["new_log_probs"] = new_log_probs ++ # !!! Patch starts !!! ++ batch.batch["new_log_probs"] = old_log_prob.batch["old_log_probs"] + batch.batch["new_entropys"] = entropys -+ batch.batch["necs"] = necs ++ batch.batch["necs"] = old_log_prob.batch["necs"] ++ # !!! Patch ends !!! if self.algorithm.use_reference: # ref_logprob may not be used # compute reference log_prob -@@ -526,7 +566,8 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): +@@ -526,7 +565,8 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): batch, kl_metrics = self.kl_fn.apply_kl_penalty_to_reward(batch) metrics.update(prefix_metrics(kl_metrics, prefix="critic")) # compute advantages, executed on the driver process diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 689b3231b4..92972406ea 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -182,6 +182,9 @@ class Actor: entropy_coeff: float = 0.001 use_kl_loss: bool = False + # custom log_prob_fn + log_prob_fn: Optional[str] = None + @dataclass class Ref: diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 1b893bdfef..fd54acebb2 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -901,14 +901,20 @@ def compute_log_prob(self, data: DataProto): data.meta_info["use_dynamic_bsz"] = config_source.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature # perform recompute log_prob + calculate_entropy = not is_lora with self.ulysses_sharding_manager: with adapter_ctx: - output, entropys = self.actor.compute_log_prob( - data=data, calculate_entropy=not is_lora + outputs = self.actor.compute_log_prob( + data=data, calculate_entropy=calculate_entropy ) - tensors = {"ref_log_prob": output} if is_lora else {"old_log_probs": output} if not is_lora: - tensors["entropys"] = entropys + tensors = {"old_log_probs": outputs["log_probs"]} + else: + tensors = {"ref_log_prob": outputs["log_probs"]} + if calculate_entropy: + tensors["entropys"] = outputs["entropys"] + if "necs" in outputs: + tensors["necs"] = outputs["necs"] output = DataProto.from_dict( tensors=tensors, meta_info={"temperature": self.config.rollout.temperature}, @@ -947,8 +953,13 @@ def compute_ref_log_prob(self, data: DataProto): data = data.to( "cpu" ) # data will to device with each micro batch on ref.compute_log_prob - output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) - output = DataProto.from_dict(tensors={"ref_log_prob": output}) + outputs = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) + if isinstance(outputs, dict): + ref_log_prob = outputs["log_probs"] + else: + # Backward compatibility with old tuple return style. + ref_log_prob, _ = outputs + output = DataProto.from_dict(tensors={"ref_log_prob": ref_log_prob}) output = output.to("cpu") From 9df93555ea198510ffff3c0c5344310c75067cd1 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Tue, 24 Feb 2026 16:44:55 +0800 Subject: [PATCH 6/7] update dp_actor --- examples/entropy/clipv_dp_actor.py | 660 ++++++++++++++--------------- trinity/trainer/verl/dp_actor.py | 42 ++ 2 files changed, 371 insertions(+), 331 deletions(-) diff --git a/examples/entropy/clipv_dp_actor.py b/examples/entropy/clipv_dp_actor.py index ede043607f..4c40960676 100644 --- a/examples/entropy/clipv_dp_actor.py +++ b/examples/entropy/clipv_dp_actor.py @@ -21,7 +21,9 @@ from verl.utils.debug import GPUMemoryLogger from verl.utils.device import get_device_id from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch -from verl.utils.torch_functional import logprobs_from_logits +from verl.utils.torch_functional import ( + logprobs_from_logits as verl_logprobs_from_logits, +) from verl.utils.ulysses import ( gather_outputs_and_unpad, ulysses_pad, @@ -30,12 +32,20 @@ from trinity.trainer.verl.dp_actor import DataParallelPPOActor as OriginalDPActor -__all__ = ["DataParallelPPOActor"] +__all__ = ["clipv_compute_log_prob_with_nec"] logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) +def clipv_logprobs_from_logits( + logits: torch.Tensor, labels: torch.Tensor, inplace_backward: bool = True +) -> torch.Tensor: + return verl_logprobs_from_logits( + logits=logits, labels=labels, inplace_backward=inplace_backward + ) + + def compute_N_from_logits( logits: torch.Tensor, entropy: torch.Tensor | None = None ) -> torch.Tensor: @@ -58,364 +68,352 @@ def compute_N_from_logits( return N -class DataParallelPPOActor(OriginalDPActor): - @GPUMemoryLogger(role="dp actor", logger=logger) - def compute_log_prob( - self, data: DataProto, calculate_entropy: bool = False, calculate_nec: bool = False - ) -> ( - tuple[torch.Tensor, torch.Tensor | None] - | tuple[torch.Tensor, torch.Tensor | None, torch.Tensor] - ): - """Compute the log probability of the responses given input_ids, attention_mask and position_ids - - Args: - data (DataProto): a DataProto containing keys - - ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the - concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. - - ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. - - ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. - - ``responses``: tensor of shape [batch_size, response_length]. torch.int64. - - Returns: - - when calculate_nec is False: (log_probs, entropys) - - when calculate_nec is True: (log_probs, entropys, necs) - """ - # set to eval - self.actor_module.eval() - - micro_batch_size = data.meta_info["micro_batch_size"] - temperature = data.meta_info[ - "temperature" - ] # temperature must be in the data.meta_info to avoid silent error - use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() - select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] - non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] - - data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) - - if use_dynamic_bsz: - max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) - else: - micro_batches = data.split(micro_batch_size) - - log_probs_lst = [] - entropy_lst = [] - nec_lst = [] - for micro_batch in micro_batches: - micro_batch = micro_batch.to(get_device_id()) - model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} - with torch.no_grad(): - outputs = self._forward_micro_batch( - model_inputs, - temperature=temperature, - calculate_entropy=calculate_entropy, - calculate_nec=calculate_nec, +def _forward_micro_batch_with_nec( # noqa: C901 + actor: OriginalDPActor, + micro_batch, + temperature, + calculate_entropy: bool = True, + calculate_nec: bool = True, +) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Returns: + entropy: # (bs, response_len) + log_probs: # (bs, response_len) + nec: # (bs, response_len) + """ + response_length = micro_batch["responses"].size(-1) + multi_modal_inputs = {} + if "multi_modal_inputs" in micro_batch.keys(): + from verl.utils.model import extract_multi_modal_inputs + + multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"]) + + with torch.autocast(device_type=actor.device_name, dtype=actor.param_dtype): + input_ids = micro_batch["input_ids"] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] + entropy = None + nec = None + if position_ids.dim() == 3: # qwen2vl mrope + position_ids = position_ids.transpose(0, 1) # (bsz, 4, seqlen) -> (4, bsz, seqlen) + + if actor.use_remove_padding: + input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + if position_ids.dim() == 3: + position_ids_rmpad = ( + index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) + .transpose(0, 1) + .unsqueeze(1) + ) # (4, bsz, seqlen) -> (4, 1, bsz * seqlen) + else: + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) + + is_mask_all_zero = attention_mask.sum() == 0 + if is_mask_all_zero: + input_ids_rmpad = torch.zeros( + (1, actor.ulysses_sequence_parallel_size), + device=input_ids.device, + dtype=input_ids.dtype, ) - if calculate_nec: - entropy, log_probs, nec = cast( - tuple[torch.Tensor, torch.Tensor, torch.Tensor], outputs + if position_ids.dim() == 3: + position_ids_rmpad = torch.zeros( + (position_ids.shape[0], 1, actor.ulysses_sequence_parallel_size), + device=position_ids.device, + dtype=position_ids.dtype, + ) + else: + position_ids_rmpad = torch.zeros( + (1, actor.ulysses_sequence_parallel_size), + device=position_ids.device, + dtype=position_ids.dtype, + ) + + if "image_bound" in multi_modal_inputs: + from verl.utils.dataset.vision_utils import ( + process_multi_modal_inputs_for_minicpmo, ) - nec_lst.append(nec) - else: - entropy, log_probs = cast(tuple[torch.Tensor, torch.Tensor], outputs) - log_probs_lst.append(log_probs) - if calculate_entropy: - entropy_lst.append(entropy) - log_probs = torch.concat(log_probs_lst, dim=0) - entropys = None - if calculate_entropy: - entropys = torch.concat(entropy_lst, dim=0) - necs = None - if calculate_nec: - necs = torch.concat(nec_lst, dim=0) + multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( + input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs + ) - if use_dynamic_bsz: - log_probs = restore_dynamic_batch(log_probs, batch_idx_list) - if calculate_entropy: - entropys = restore_dynamic_batch(entropys, batch_idx_list) - if calculate_nec: - necs = restore_dynamic_batch(necs, batch_idx_list) + # for compute the log_prob + input_ids_rmpad_rolled = torch.roll( + input_ids_rmpad, shifts=-1, dims=1 + ) # (1, total_nnz) - if calculate_nec: - return log_probs, entropys, necs - return log_probs, entropys - - def _forward_micro_batch( # type: ignore # noqa: C901 - self, micro_batch, temperature, calculate_entropy: bool = False, calculate_nec: bool = False - ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Returns: - entropy: # (bs, response_len) - log_probs: # (bs, response_len) - """ - response_length = micro_batch["responses"].size(-1) - multi_modal_inputs = {} - if "multi_modal_inputs" in micro_batch.keys(): - from verl.utils.model import extract_multi_modal_inputs - - multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"]) - - with torch.autocast(device_type=self.device_name, dtype=self.param_dtype): - input_ids = micro_batch["input_ids"] - batch_size, seqlen = input_ids.shape - attention_mask = micro_batch["attention_mask"] - position_ids = micro_batch["position_ids"] - entropy = None - nec = None - if position_ids.dim() == 3: # qwen2vl mrope - position_ids = position_ids.transpose(0, 1) # (bsz, 4, seqlen) -> (4, bsz, seqlen) - - if self.use_remove_padding: - input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # unpad the position_ids to align the rotary - if position_ids.dim() == 3: - position_ids_rmpad = ( - index_first_axis( - rearrange(position_ids, "c b s ... -> (b s) c ..."), indices - ) - .transpose(0, 1) - .unsqueeze(1) - ) # (4, bsz, seqlen) -> (4, 1, bsz * seqlen) + # pad and slice the inputs if sp > 1 + if actor.use_ulysses_sp: + is_vlm_model = hasattr( + getattr(actor.actor_module, "module", actor.actor_module).config, + "vision_config", + ) + if is_vlm_model: + # vlm model's inputs will be sliced after embedding + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=actor.ulysses_sequence_parallel_size, + ) else: - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - is_mask_all_zero = attention_mask.sum() == 0 - if is_mask_all_zero: - input_ids_rmpad = torch.zeros( - (1, self.ulysses_sequence_parallel_size), - device=input_ids.device, - dtype=input_ids.dtype, + ( + input_ids_rmpad, + position_ids_rmpad, + pad_size, + ) = ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=actor.ulysses_sequence_parallel_size, ) - if position_ids.dim() == 3: - position_ids_rmpad = torch.zeros( - (position_ids.shape[0], 1, self.ulysses_sequence_parallel_size), - device=position_ids.device, - dtype=position_ids.dtype, - ) - else: - position_ids_rmpad = torch.zeros( - (1, self.ulysses_sequence_parallel_size), - device=position_ids.device, - dtype=position_ids.dtype, - ) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, + position_ids_rmpad=None, + sp_size=actor.ulysses_sequence_parallel_size, + ) - if "image_bound" in multi_modal_inputs: - from verl.utils.dataset.vision_utils import ( - process_multi_modal_inputs_for_minicpmo, + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) + + # only pass input_ids and position_ids to enable flash_attn_varlen + extra_args = {} + if actor.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = actor.actor_module( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if actor.use_fused_kernels: + log_probs = output.log_probs.squeeze(0) # (total_nnz,) + entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,) + if calculate_nec: + raise RuntimeError( + "calculate_nec=True is not supported with fused kernels in _forward_micro_batch" ) - multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( - input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs - ) + else: + logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) + logits_rmpad.div_(temperature) - # for compute the log_prob - input_ids_rmpad_rolled = torch.roll( - input_ids_rmpad, shifts=-1, dims=1 - ) # (1, total_nnz) + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + inplace_backward = True + if calculate_entropy: + inplace_backward = False + log_probs = actor.compute_log_probs_from_logits( + logits=logits_rmpad, + labels=input_ids_rmpad_rolled, + inplace_backward=inplace_backward, + ) - # pad and slice the inputs if sp > 1 - if self.use_ulysses_sp: - is_vlm_model = hasattr( - getattr(self.actor_module, "module", self.actor_module).config, - "vision_config", - ) - if is_vlm_model: - # vlm model's inputs will be sliced after embedding - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( - input_ids_rmpad, - position_ids_rmpad=position_ids_rmpad, - sp_size=self.ulysses_sequence_parallel_size, - ) + # compute entropy + if calculate_entropy: + if not actor.config.entropy_checkpointing: + entropy_rmpad = actor.compute_entropy_from_logits( + logits_rmpad + ) # ((total_nnz / sp) + pad) else: - ( - input_ids_rmpad, - position_ids_rmpad, - pad_size, - ) = ulysses_pad_and_slice_inputs( - input_ids_rmpad, - position_ids_rmpad=position_ids_rmpad, - sp_size=self.ulysses_sequence_parallel_size, - ) - input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( - input_ids_rmpad_rolled, - position_ids_rmpad=None, - sp_size=self.ulysses_sequence_parallel_size, - ) - - input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze( - 0 - ) # ((total_nnz / sp) + pad) - - # only pass input_ids and position_ids to enable flash_attn_varlen - extra_args = {} - if self.use_fused_kernels: - extra_args["temperature"] = temperature - extra_args["return_dict"] = True - - output = self.actor_module( - input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - **multi_modal_inputs, - use_cache=False, - **extra_args, - ) # prevent model thinks we are generating - - if self.use_fused_kernels: - log_probs = output.log_probs.squeeze(0) # (total_nnz,) - entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,) - if calculate_nec: - raise RuntimeError( - "calculate_nec=True is not supported with fused kernels in _forward_micro_batch" + entropy_rmpad = torch.utils.checkpoint.checkpoint( + actor.compute_entropy_from_logits, logits_rmpad ) - - else: - logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) - logits_rmpad.div_(temperature) - - # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) - inplace_backward = True - if calculate_entropy: - inplace_backward = False - log_probs = logprobs_from_logits( - logits=logits_rmpad, - labels=input_ids_rmpad_rolled, - inplace_backward=inplace_backward, + if calculate_nec: + H_for_N = entropy_rmpad.to(torch.float32) if calculate_entropy else None + N_rmpad = compute_N_from_logits(logits_rmpad, entropy=H_for_N) + + # gather log_prob if sp > 1 + if actor.use_ulysses_sp: + # gather and unpad for the ulysses sp + log_probs = gather_outputs_and_unpad( + log_probs, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + if calculate_entropy: + entropy_rmpad = gather_outputs_and_unpad( + entropy_rmpad, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, ) - - # compute entropy - if calculate_entropy: - if not self.config.entropy_checkpointing: - entropy_rmpad = self.compute_entropy_from_logits( - logits_rmpad - ) # ((total_nnz / sp) + pad) - else: - entropy_rmpad = torch.utils.checkpoint.checkpoint( - self.compute_entropy_from_logits, logits_rmpad - ) - if calculate_nec: - H_for_N = entropy_rmpad.to(torch.float32) if calculate_entropy else None - N_rmpad = compute_N_from_logits(logits_rmpad, entropy=H_for_N) - - # gather log_prob if sp > 1 - if self.use_ulysses_sp: - # gather and unpad for the ulysses sp - log_probs = gather_outputs_and_unpad( - log_probs, + if calculate_nec: + N_rmpad = gather_outputs_and_unpad( + N_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size, ) - if calculate_entropy: - entropy_rmpad = gather_outputs_and_unpad( - entropy_rmpad, - gather_dim=0, - unpad_dim=0, - padding_size=pad_size, - ) - if calculate_nec: - N_rmpad = gather_outputs_and_unpad( - N_rmpad, - gather_dim=0, - unpad_dim=0, - padding_size=pad_size, - ) - - if is_mask_all_zero: - log_probs = log_probs[:0] - if calculate_entropy: - entropy_rmpad = entropy_rmpad[:0] - if calculate_nec: - N_rmpad = N_rmpad[:0] - # pad back to (bsz, seqlen) + if is_mask_all_zero: + log_probs = log_probs[:0] if calculate_entropy: - full_entropy = pad_input( - hidden_states=entropy_rmpad.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen, - ) + entropy_rmpad = entropy_rmpad[:0] if calculate_nec: - full_N = pad_input( - hidden_states=N_rmpad.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen, - ) - full_log_probs = pad_input( - hidden_states=log_probs.unsqueeze(-1), + N_rmpad = N_rmpad[:0] + + # pad back to (bsz, seqlen) + if calculate_entropy: + full_entropy = pad_input( + hidden_states=entropy_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen, ) - - # only return response part: - if calculate_entropy: - entropy = full_entropy.squeeze(-1)[ - :, -response_length - 1 : -1 - ] # (bsz, response_length) - if calculate_nec: - nec = full_N.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) - log_probs = full_log_probs.squeeze(-1)[ + if calculate_nec: + full_N = pad_input( + hidden_states=N_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + full_log_probs = pad_input( + hidden_states=log_probs.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + + # only return response part: + if calculate_entropy: + entropy = full_entropy.squeeze(-1)[ :, -response_length - 1 : -1 ] # (bsz, response_length) + if calculate_nec: + nec = full_N.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) + log_probs = full_log_probs.squeeze(-1)[ + :, -response_length - 1 : -1 + ] # (bsz, response_length) + + else: # not using rmpad and no ulysses sp + extra_args = {} + if actor.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = actor.actor_module( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if actor.use_fused_kernels: + log_probs = output.log_probs[:, -response_length - 1 : -1] + entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length) + if calculate_nec: + raise RuntimeError( + "calculate_nec=True is not supported with fused kernels in _forward_micro_batch" + ) + + else: + logits = output.logits - else: # not using rmpad and no ulysses sp - extra_args = {} - if self.use_fused_kernels: - extra_args["temperature"] = temperature - extra_args["return_dict"] = True - - output = self.actor_module( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - **multi_modal_inputs, - use_cache=False, - **extra_args, - ) # prevent model thinks we are generating - - if self.use_fused_kernels: - log_probs = output.log_probs[:, -response_length - 1 : -1] - entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length) - if calculate_nec: - raise RuntimeError( - "calculate_nec=True is not supported with fused kernels in _forward_micro_batch" + logits.div_(temperature) + logits = logits[ + :, -response_length - 1 : -1, : + ] # (bsz, response_length, vocab_size) + log_probs = actor.compute_log_probs_from_logits(logits, micro_batch["responses"]) + if calculate_entropy: + if not actor.config.entropy_checkpointing: + entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) + else: + entropy = torch.utils.checkpoint.checkpoint( + verl_F.entropy_from_logits, logits ) + if calculate_nec: + H_for_N = entropy.to(torch.float32) if calculate_entropy else None + nec = compute_N_from_logits(logits, entropy=H_for_N) - else: - logits = output.logits - - logits.div_(temperature) - logits = logits[ - :, -response_length - 1 : -1, : - ] # (bsz, response_length, vocab_size) - log_probs = logprobs_from_logits(logits, micro_batch["responses"]) - if calculate_entropy: - if not self.config.entropy_checkpointing: - entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) - else: - entropy = torch.utils.checkpoint.checkpoint( - verl_F.entropy_from_logits, logits - ) - if calculate_nec: - H_for_N = entropy.to(torch.float32) if calculate_entropy else None - nec = compute_N_from_logits(logits, entropy=H_for_N) + if calculate_nec: + return entropy, log_probs, nec + return entropy, log_probs - if calculate_nec: - return entropy, log_probs, nec - return entropy, log_probs + +@GPUMemoryLogger(role="dp actor", logger=logger) +def clipv_compute_log_prob_with_nec( + actor: OriginalDPActor, + data: DataProto, + calculate_entropy: bool = True, + calculate_nec: bool = True, +) -> dict[str, torch.Tensor]: + """ + Returns: dict[str, torch.Tensor] + "log_probs": (bs, response_len) + "entropys": (bs, response_len) + "necs": (bs, response_len) + """ + # set to eval + actor.actor_module.eval() + + micro_batch_size = data.meta_info["micro_batch_size"] + temperature = data.meta_info[ + "temperature" + ] # temperature must be in the data.meta_info to avoid silent error + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] + non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] + + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + if use_dynamic_bsz: + max_token_len = data.meta_info["max_token_len"] * actor.ulysses_sequence_parallel_size + micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) + else: + micro_batches = data.split(micro_batch_size) + + log_probs_lst = [] + entropy_lst = [] + nec_lst = [] + for micro_batch in micro_batches: + micro_batch = micro_batch.to(get_device_id()) + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} + with torch.no_grad(): + outputs = _forward_micro_batch_with_nec( + actor, + model_inputs, + temperature=temperature, + calculate_entropy=calculate_entropy, + calculate_nec=calculate_nec, + ) + if calculate_nec: + entropy, log_probs, nec = cast(tuple[torch.Tensor, torch.Tensor, torch.Tensor], outputs) + nec_lst.append(nec) + else: + entropy, log_probs = cast(tuple[torch.Tensor, torch.Tensor], outputs) + log_probs_lst.append(log_probs) + if calculate_entropy: + entropy_lst.append(entropy) + + log_probs = torch.concat(log_probs_lst, dim=0) + entropys = None + if calculate_entropy: + entropys = torch.concat(entropy_lst, dim=0) + necs = None + if calculate_nec: + necs = torch.concat(nec_lst, dim=0) + + if use_dynamic_bsz: + log_probs = restore_dynamic_batch(log_probs, batch_idx_list) + if calculate_entropy: + entropys = restore_dynamic_batch(entropys, batch_idx_list) + if calculate_nec: + necs = restore_dynamic_batch(necs, batch_idx_list) + + if calculate_nec: + return {"log_probs": log_probs, "entropys": entropys, "necs": necs} + return {"log_probs": log_probs, "entropys": entropys} diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 80bc4a16c8..6f66d1bf13 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -28,6 +28,9 @@ from verl.utils.device import get_device_id from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import prepare_dynamic_batch +from verl.utils.torch_functional import ( + logprobs_from_logits as verl_logprobs_from_logits, +) from verl.workers.actor.dp_actor import DataParallelPPOActor as DPActor from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN @@ -35,12 +38,20 @@ from trinity.algorithm.kl_fn.kl_fn import DummyKLFn from trinity.algorithm.utils import prefix_metrics from trinity.common.config import AlgorithmConfig +from trinity.utils.registry import Registry __all__ = ["DataParallelPPOActor"] logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) +LOG_PROB_FN: Registry = Registry( + "log_prob_fn", + default_mapping={ + "clipv_entropy_nec": "examples.entropy.clipv_dp_actor.clipv_compute_log_prob_with_nec", + }, +) + class DataParallelPPOActor(DPActor): def __init__( @@ -51,6 +62,22 @@ def __init__( self.policy_loss_fn = None self.kl_loss_fn = None self.entropy_loss_fn = None + log_prob_fn_key = config.get("log_prob_fn", "default") + if not log_prob_fn_key or log_prob_fn_key == "default": + self._compute_log_prob_fn_name = "default" + self._compute_log_prob_fn = None # use default log_prob_fn + else: + self._compute_log_prob_fn_name = str(log_prob_fn_key) + self._compute_log_prob_fn = LOG_PROB_FN.get( + self._compute_log_prob_fn_name + ) # use custom log_prob_fn + + def compute_log_probs_from_logits( + self, logits: torch.Tensor, labels: torch.Tensor, inplace_backward: bool = True + ) -> torch.Tensor: + return verl_logprobs_from_logits( + logits=logits, labels=labels, inplace_backward=inplace_backward + ) def set_algorithm(self, algorithm_config: AlgorithmConfig): self.loss_agg_mode = algorithm_config.loss_agg_mode @@ -221,3 +248,18 @@ def update_policy(self, data: DataProto): # noqa: C901 append_to_dict(metrics, mini_batch_metrics) self.actor_optimizer.zero_grad() return metrics + + @GPUMemoryLogger(role="dp actor", logger=logger) + def compute_log_prob( + self, data: DataProto, calculate_entropy: bool = False, calculate_nec: bool = False + ): + if self._compute_log_prob_fn_name == "default" or self._compute_log_prob_fn is None: + log_probs, entropys = super().compute_log_prob( + data=data, calculate_entropy=calculate_entropy + ) + return {"entropys": entropys, "log_probs": log_probs} + else: + outputs = self._compute_log_prob_fn( + actor=self, data=data, calculate_entropy=calculate_entropy + ) + return outputs From 373e6060312b47914c8a9eb96cc28d3948de9d97 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Wed, 25 Feb 2026 16:10:35 +0800 Subject: [PATCH 7/7] update to patch version --- examples/entropy/README.md | 40 +- examples/entropy/clipv.yaml | 3 - examples/entropy/clipv_dp_actor.py | 669 ++++++++++++++------------- examples/entropy/clipv_trainer.patch | 29 +- trinity/trainer/verl/dp_actor.py | 109 +++-- trinity/trainer/verl/fsdp_workers.py | 13 +- 6 files changed, 472 insertions(+), 391 deletions(-) diff --git a/examples/entropy/README.md b/examples/entropy/README.md index 941cd3657d..58463f892d 100644 --- a/examples/entropy/README.md +++ b/examples/entropy/README.md @@ -2,7 +2,7 @@ This example shows the two algorithms **Clip_B** and **Clip_V** from the work [On the Entropy Dynamics in Reinforcement Fine-Tuning of Large Language Models](https://arxiv.org/pdf/2602.03392). -NOTE: This example is only tested on verl==0.7.0. +NOTE: This example is only tested on trinity==0.5.1 and verl==0.7.0. The following experiments require `synchronizer.sync_interval=1` and `trainer.trainer_config.algorithm.rollout_correction.bypass_mode=false` to be set. ## Data Preparation @@ -16,9 +16,11 @@ The training set is filtered out samples from the training set with excessively ```bash cd /path/to/Trinity-RFT git apply examples/entropy/clipb_trainer.patch +# if not successful, try: +# git apply --3way --ignore-whitespace examples/entropy/clipb_trainer.patch ``` -2. Update the dataset paths and other configurations in the file [`clipb.yaml`](clipb.yaml) to point to your local data. +2. Update the dataset paths and other configurations in the file [`clipb.yaml`](./clipb.yaml) to point to your local data. 3. Run the experiment: @@ -33,12 +35,44 @@ trinity run examples/entropy/clipb.yaml ```bash cd /path/to/Trinity-RFT git apply examples/entropy/clipv_trainer.patch +# if not successful, try: +# git apply --3way --ignore-whitespace examples/entropy/clipv_trainer.patch ``` -2. Update the dataset paths and other configurations in the file [`clipv.yaml`](clipv.yaml) to point to your local data. +2. Update the dataset paths and other configurations in the file [`clipv.yaml`](./clipv.yaml) to point to your local data. 3. Run the experiment: ```bash trinity run examples/entropy/clipv.yaml ``` + +## Clip_V Code Logic + +As shown in the following flowchart, the forward pass of [examples/entropy/clipv_dp_actor.py](./clipv_dp_actor.py) outputs `log_probs`, `entropy`, and `nec`. +These signals are then used by [Clip_V advantage function](../../trinity/algorithm/advantage_fn/clipv_advantage.py) to compute `xD` and clip only negative-advantage tokens. This process returns the revised `advantages`. + +```mermaid +flowchart TD + A["data"] + B["forward pass"] + C1["log_probs"] + C2["entropy (additional)"] + C3["nec (additional)"] + subgraph D["advantage computation"] + direction TB + F["xD = nec - exp(log_probs) * (entropy + log_probs)"] + G["only clip negative-advantage tokens"] + F --> G + end + E["advantages"] + + A --> B + B --> C1 + B --> C2 + B --> C3 + C1 --> D + C2 --> D + C3 --> D + D --> E +``` diff --git a/examples/entropy/clipv.yaml b/examples/entropy/clipv.yaml index 280990bf08..423e048a5d 100644 --- a/examples/entropy/clipv.yaml +++ b/examples/entropy/clipv.yaml @@ -91,9 +91,6 @@ trainer: trainer_type: 'verl' save_interval: 100 trainer_config: - actor_rollout_ref: - actor: - log_prob_fn: clipv_entropy_nec algorithm: rollout_correction: bypass_mode: false diff --git a/examples/entropy/clipv_dp_actor.py b/examples/entropy/clipv_dp_actor.py index 4c40960676..6cfa6c6866 100644 --- a/examples/entropy/clipv_dp_actor.py +++ b/examples/entropy/clipv_dp_actor.py @@ -6,7 +6,6 @@ import logging import os -from typing import cast import torch import torch.nn.functional as F @@ -21,9 +20,7 @@ from verl.utils.debug import GPUMemoryLogger from verl.utils.device import get_device_id from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch -from verl.utils.torch_functional import ( - logprobs_from_logits as verl_logprobs_from_logits, -) +from verl.utils.torch_functional import logprobs_from_logits from verl.utils.ulysses import ( gather_outputs_and_unpad, ulysses_pad, @@ -32,20 +29,12 @@ from trinity.trainer.verl.dp_actor import DataParallelPPOActor as OriginalDPActor -__all__ = ["clipv_compute_log_prob_with_nec"] +__all__ = ["DataParallelPPOActor"] logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) -def clipv_logprobs_from_logits( - logits: torch.Tensor, labels: torch.Tensor, inplace_backward: bool = True -) -> torch.Tensor: - return verl_logprobs_from_logits( - logits=logits, labels=labels, inplace_backward=inplace_backward - ) - - def compute_N_from_logits( logits: torch.Tensor, entropy: torch.Tensor | None = None ) -> torch.Tensor: @@ -68,352 +57,370 @@ def compute_N_from_logits( return N -def _forward_micro_batch_with_nec( # noqa: C901 - actor: OriginalDPActor, - micro_batch, - temperature, - calculate_entropy: bool = True, - calculate_nec: bool = True, -) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Returns: - entropy: # (bs, response_len) - log_probs: # (bs, response_len) - nec: # (bs, response_len) - """ - response_length = micro_batch["responses"].size(-1) - multi_modal_inputs = {} - if "multi_modal_inputs" in micro_batch.keys(): - from verl.utils.model import extract_multi_modal_inputs - - multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"]) - - with torch.autocast(device_type=actor.device_name, dtype=actor.param_dtype): - input_ids = micro_batch["input_ids"] - batch_size, seqlen = input_ids.shape - attention_mask = micro_batch["attention_mask"] - position_ids = micro_batch["position_ids"] - entropy = None - nec = None - if position_ids.dim() == 3: # qwen2vl mrope - position_ids = position_ids.transpose(0, 1) # (bsz, 4, seqlen) -> (4, bsz, seqlen) - - if actor.use_remove_padding: - input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # unpad the position_ids to align the rotary - if position_ids.dim() == 3: - position_ids_rmpad = ( - index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) - .transpose(0, 1) - .unsqueeze(1) - ) # (4, bsz, seqlen) -> (4, 1, bsz * seqlen) - else: - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - is_mask_all_zero = attention_mask.sum() == 0 - if is_mask_all_zero: - input_ids_rmpad = torch.zeros( - (1, actor.ulysses_sequence_parallel_size), - device=input_ids.device, - dtype=input_ids.dtype, - ) - if position_ids.dim() == 3: - position_ids_rmpad = torch.zeros( - (position_ids.shape[0], 1, actor.ulysses_sequence_parallel_size), - device=position_ids.device, - dtype=position_ids.dtype, - ) - else: - position_ids_rmpad = torch.zeros( - (1, actor.ulysses_sequence_parallel_size), - device=position_ids.device, - dtype=position_ids.dtype, - ) +class DataParallelPPOActor(OriginalDPActor): + @GPUMemoryLogger(role="dp actor", logger=logger) + def compute_log_prob( + self, data: DataProto, calculate_entropy: bool = False, calculate_nec: bool = True + ) -> dict[str, torch.Tensor | None]: + """Compute the log probability of the responses given input_ids, attention_mask and position_ids - if "image_bound" in multi_modal_inputs: - from verl.utils.dataset.vision_utils import ( - process_multi_modal_inputs_for_minicpmo, - ) + Args: + data (DataProto): a DataProto containing keys - multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( - input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs - ) + ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the + concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. + + ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``responses``: tensor of shape [batch_size, response_length]. torch.int64. + + Returns: + dict containing: + - ``log_probs``: tensor, shape (bsz, response_len) + - ``entropys``: tensor | None, shape (bsz, response_len) + - ``necs``: tensor | None, shape (bsz, response_len) + """ + # set to eval + self.actor_module.eval() + + micro_batch_size = data.meta_info["micro_batch_size"] + temperature = data.meta_info[ + "temperature" + ] # temperature must be in the data.meta_info to avoid silent error + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] + non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] - # for compute the log_prob - input_ids_rmpad_rolled = torch.roll( - input_ids_rmpad, shifts=-1, dims=1 - ) # (1, total_nnz) + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) - # pad and slice the inputs if sp > 1 - if actor.use_ulysses_sp: - is_vlm_model = hasattr( - getattr(actor.actor_module, "module", actor.actor_module).config, - "vision_config", + if use_dynamic_bsz: + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size + micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) + else: + micro_batches = data.split(micro_batch_size) + + log_probs_lst = [] + entropy_lst = [] + nec_lst = [] + for micro_batch in micro_batches: + micro_batch = micro_batch.to(get_device_id()) + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} + with torch.no_grad(): + outputs = self._forward_micro_batch( + model_inputs, + temperature=temperature, + calculate_entropy=calculate_entropy, + calculate_nec=calculate_nec, ) - if is_vlm_model: - # vlm model's inputs will be sliced after embedding - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( - input_ids_rmpad, - position_ids_rmpad=position_ids_rmpad, - sp_size=actor.ulysses_sequence_parallel_size, - ) + log_probs = outputs["log_probs"] + entropy = outputs["entropy"] + nec = outputs["nec"] + + log_probs_lst.append(log_probs) + if calculate_entropy: + entropy_lst.append(entropy) + if calculate_nec: + nec_lst.append(nec) + + log_probs = torch.concat(log_probs_lst, dim=0) + entropys = None + if calculate_entropy: + entropys = torch.concat(entropy_lst, dim=0) + necs = None + if calculate_nec: + necs = torch.concat(nec_lst, dim=0) + + if use_dynamic_bsz: + log_probs = restore_dynamic_batch(log_probs, batch_idx_list) + if calculate_entropy: + entropys = restore_dynamic_batch(entropys, batch_idx_list) + if calculate_nec: + necs = restore_dynamic_batch(necs, batch_idx_list) + + outputs = {"log_probs": log_probs} + if calculate_entropy: + outputs["entropys"] = entropys + if calculate_nec: + outputs["necs"] = necs + return outputs + + def _forward_micro_batch( # type: ignore # noqa: C901 + self, micro_batch, temperature, calculate_entropy: bool = False, calculate_nec: bool = True + ) -> dict[str, torch.Tensor | None]: + """ + Returns: + dict containing: + - ``log_probs``: tensor, shape (bs, response_len) + - ``entropy``: tensor | None, shape (bs, response_len) + - ``nec``: tensor | None, shape (bs, response_len) + """ + response_length = micro_batch["responses"].size(-1) + multi_modal_inputs = {} + if "multi_modal_inputs" in micro_batch.keys(): + from verl.utils.model import extract_multi_modal_inputs + + multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"]) + + with torch.autocast(device_type=self.device_name, dtype=self.param_dtype): + input_ids = micro_batch["input_ids"] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] + entropy = None + nec = None + if position_ids.dim() == 3: # qwen2vl mrope + position_ids = position_ids.transpose(0, 1) # (bsz, 4, seqlen) -> (4, bsz, seqlen) + + if self.use_remove_padding: + input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + if position_ids.dim() == 3: + position_ids_rmpad = ( + index_first_axis( + rearrange(position_ids, "c b s ... -> (b s) c ..."), indices + ) + .transpose(0, 1) + .unsqueeze(1) + ) # (4, bsz, seqlen) -> (4, 1, bsz * seqlen) else: - ( - input_ids_rmpad, - position_ids_rmpad, - pad_size, - ) = ulysses_pad_and_slice_inputs( - input_ids_rmpad, - position_ids_rmpad=position_ids_rmpad, - sp_size=actor.ulysses_sequence_parallel_size, + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) + + is_mask_all_zero = attention_mask.sum() == 0 + if is_mask_all_zero: + input_ids_rmpad = torch.zeros( + (1, self.ulysses_sequence_parallel_size), + device=input_ids.device, + dtype=input_ids.dtype, ) - input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( - input_ids_rmpad_rolled, - position_ids_rmpad=None, - sp_size=actor.ulysses_sequence_parallel_size, - ) + if position_ids.dim() == 3: + position_ids_rmpad = torch.zeros( + (position_ids.shape[0], 1, self.ulysses_sequence_parallel_size), + device=position_ids.device, + dtype=position_ids.dtype, + ) + else: + position_ids_rmpad = torch.zeros( + (1, self.ulysses_sequence_parallel_size), + device=position_ids.device, + dtype=position_ids.dtype, + ) - input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) - - # only pass input_ids and position_ids to enable flash_attn_varlen - extra_args = {} - if actor.use_fused_kernels: - extra_args["temperature"] = temperature - extra_args["return_dict"] = True - - output = actor.actor_module( - input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - **multi_modal_inputs, - use_cache=False, - **extra_args, - ) # prevent model thinks we are generating - - if actor.use_fused_kernels: - log_probs = output.log_probs.squeeze(0) # (total_nnz,) - entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,) - if calculate_nec: - raise RuntimeError( - "calculate_nec=True is not supported with fused kernels in _forward_micro_batch" + if "image_bound" in multi_modal_inputs: + from verl.utils.dataset.vision_utils import ( + process_multi_modal_inputs_for_minicpmo, ) - else: - logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) - logits_rmpad.div_(temperature) + multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( + input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs + ) - # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) - inplace_backward = True - if calculate_entropy: - inplace_backward = False - log_probs = actor.compute_log_probs_from_logits( - logits=logits_rmpad, - labels=input_ids_rmpad_rolled, - inplace_backward=inplace_backward, - ) + # for compute the log_prob + input_ids_rmpad_rolled = torch.roll( + input_ids_rmpad, shifts=-1, dims=1 + ) # (1, total_nnz) - # compute entropy - if calculate_entropy: - if not actor.config.entropy_checkpointing: - entropy_rmpad = actor.compute_entropy_from_logits( - logits_rmpad - ) # ((total_nnz / sp) + pad) + # pad and slice the inputs if sp > 1 + if self.use_ulysses_sp: + is_vlm_model = hasattr( + getattr(self.actor_module, "module", self.actor_module).config, + "vision_config", + ) + if is_vlm_model: + # vlm model's inputs will be sliced after embedding + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) else: - entropy_rmpad = torch.utils.checkpoint.checkpoint( - actor.compute_entropy_from_logits, logits_rmpad + ( + input_ids_rmpad, + position_ids_rmpad, + pad_size, + ) = ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, ) - if calculate_nec: - H_for_N = entropy_rmpad.to(torch.float32) if calculate_entropy else None - N_rmpad = compute_N_from_logits(logits_rmpad, entropy=H_for_N) - - # gather log_prob if sp > 1 - if actor.use_ulysses_sp: - # gather and unpad for the ulysses sp - log_probs = gather_outputs_and_unpad( - log_probs, - gather_dim=0, - unpad_dim=0, - padding_size=pad_size, - ) - if calculate_entropy: - entropy_rmpad = gather_outputs_and_unpad( - entropy_rmpad, - gather_dim=0, - unpad_dim=0, - padding_size=pad_size, + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, + position_ids_rmpad=None, + sp_size=self.ulysses_sequence_parallel_size, ) - if calculate_nec: - N_rmpad = gather_outputs_and_unpad( - N_rmpad, + + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze( + 0 + ) # ((total_nnz / sp) + pad) + + # only pass input_ids and position_ids to enable flash_attn_varlen + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = self.actor_module( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if self.use_fused_kernels: + log_probs = output.log_probs.squeeze(0) # (total_nnz,) + entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,) + if calculate_nec: + raise RuntimeError( + "calculate_nec=True is not supported with fused kernels in _forward_micro_batch" + ) + + else: + logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) + logits_rmpad.div_(temperature) + + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + inplace_backward = True + if calculate_entropy: + inplace_backward = False + log_probs = logprobs_from_logits( + logits=logits_rmpad, + labels=input_ids_rmpad_rolled, + inplace_backward=inplace_backward, + ) + + # compute entropy + if calculate_entropy: + if not self.config.entropy_checkpointing: + entropy_rmpad = self.compute_entropy_from_logits( + logits_rmpad + ) # ((total_nnz / sp) + pad) + else: + entropy_rmpad = torch.utils.checkpoint.checkpoint( + self.compute_entropy_from_logits, logits_rmpad + ) + if calculate_nec: + H_for_N = entropy_rmpad.to(torch.float32) if calculate_entropy else None + N_rmpad = compute_N_from_logits(logits_rmpad, entropy=H_for_N) + + # gather log_prob if sp > 1 + if self.use_ulysses_sp: + # gather and unpad for the ulysses sp + log_probs = gather_outputs_and_unpad( + log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size, ) + if calculate_entropy: + entropy_rmpad = gather_outputs_and_unpad( + entropy_rmpad, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + if calculate_nec: + N_rmpad = gather_outputs_and_unpad( + N_rmpad, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) - if is_mask_all_zero: - log_probs = log_probs[:0] + if is_mask_all_zero: + log_probs = log_probs[:0] + if calculate_entropy: + entropy_rmpad = entropy_rmpad[:0] + if calculate_nec: + N_rmpad = N_rmpad[:0] + + # pad back to (bsz, seqlen) if calculate_entropy: - entropy_rmpad = entropy_rmpad[:0] + full_entropy = pad_input( + hidden_states=entropy_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) if calculate_nec: - N_rmpad = N_rmpad[:0] - - # pad back to (bsz, seqlen) - if calculate_entropy: - full_entropy = pad_input( - hidden_states=entropy_rmpad.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen, - ) - if calculate_nec: - full_N = pad_input( - hidden_states=N_rmpad.unsqueeze(-1), + full_N = pad_input( + hidden_states=N_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + full_log_probs = pad_input( + hidden_states=log_probs.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen, ) - full_log_probs = pad_input( - hidden_states=log_probs.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen, - ) - - # only return response part: - if calculate_entropy: - entropy = full_entropy.squeeze(-1)[ - :, -response_length - 1 : -1 - ] # (bsz, response_length) - if calculate_nec: - nec = full_N.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) - log_probs = full_log_probs.squeeze(-1)[ - :, -response_length - 1 : -1 - ] # (bsz, response_length) - - else: # not using rmpad and no ulysses sp - extra_args = {} - if actor.use_fused_kernels: - extra_args["temperature"] = temperature - extra_args["return_dict"] = True - - output = actor.actor_module( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - **multi_modal_inputs, - use_cache=False, - **extra_args, - ) # prevent model thinks we are generating - - if actor.use_fused_kernels: - log_probs = output.log_probs[:, -response_length - 1 : -1] - entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length) - if calculate_nec: - raise RuntimeError( - "calculate_nec=True is not supported with fused kernels in _forward_micro_batch" - ) - - else: - logits = output.logits - logits.div_(temperature) - logits = logits[ - :, -response_length - 1 : -1, : - ] # (bsz, response_length, vocab_size) - log_probs = actor.compute_log_probs_from_logits(logits, micro_batch["responses"]) + # only return response part: if calculate_entropy: - if not actor.config.entropy_checkpointing: - entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) - else: - entropy = torch.utils.checkpoint.checkpoint( - verl_F.entropy_from_logits, logits - ) + entropy = full_entropy.squeeze(-1)[ + :, -response_length - 1 : -1 + ] # (bsz, response_length) if calculate_nec: - H_for_N = entropy.to(torch.float32) if calculate_entropy else None - nec = compute_N_from_logits(logits, entropy=H_for_N) - - if calculate_nec: - return entropy, log_probs, nec - return entropy, log_probs - + nec = full_N.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) + log_probs = full_log_probs.squeeze(-1)[ + :, -response_length - 1 : -1 + ] # (bsz, response_length) -@GPUMemoryLogger(role="dp actor", logger=logger) -def clipv_compute_log_prob_with_nec( - actor: OriginalDPActor, - data: DataProto, - calculate_entropy: bool = True, - calculate_nec: bool = True, -) -> dict[str, torch.Tensor]: - """ - Returns: dict[str, torch.Tensor] - "log_probs": (bs, response_len) - "entropys": (bs, response_len) - "necs": (bs, response_len) - """ - # set to eval - actor.actor_module.eval() - - micro_batch_size = data.meta_info["micro_batch_size"] - temperature = data.meta_info[ - "temperature" - ] # temperature must be in the data.meta_info to avoid silent error - use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() - select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] - non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] - - data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) - - if use_dynamic_bsz: - max_token_len = data.meta_info["max_token_len"] * actor.ulysses_sequence_parallel_size - micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) - else: - micro_batches = data.split(micro_batch_size) - - log_probs_lst = [] - entropy_lst = [] - nec_lst = [] - for micro_batch in micro_batches: - micro_batch = micro_batch.to(get_device_id()) - model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} - with torch.no_grad(): - outputs = _forward_micro_batch_with_nec( - actor, - model_inputs, - temperature=temperature, - calculate_entropy=calculate_entropy, - calculate_nec=calculate_nec, - ) - if calculate_nec: - entropy, log_probs, nec = cast(tuple[torch.Tensor, torch.Tensor, torch.Tensor], outputs) - nec_lst.append(nec) - else: - entropy, log_probs = cast(tuple[torch.Tensor, torch.Tensor], outputs) - log_probs_lst.append(log_probs) - if calculate_entropy: - entropy_lst.append(entropy) - - log_probs = torch.concat(log_probs_lst, dim=0) - entropys = None - if calculate_entropy: - entropys = torch.concat(entropy_lst, dim=0) - necs = None - if calculate_nec: - necs = torch.concat(nec_lst, dim=0) - - if use_dynamic_bsz: - log_probs = restore_dynamic_batch(log_probs, batch_idx_list) - if calculate_entropy: - entropys = restore_dynamic_batch(entropys, batch_idx_list) - if calculate_nec: - necs = restore_dynamic_batch(necs, batch_idx_list) + else: # not using rmpad and no ulysses sp + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = self.actor_module( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if self.use_fused_kernels: + log_probs = output.log_probs[:, -response_length - 1 : -1] + entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length) + if calculate_nec: + raise RuntimeError( + "calculate_nec=True is not supported with fused kernels in _forward_micro_batch" + ) - if calculate_nec: - return {"log_probs": log_probs, "entropys": entropys, "necs": necs} - return {"log_probs": log_probs, "entropys": entropys} + else: + logits = output.logits + + logits.div_(temperature) + logits = logits[ + :, -response_length - 1 : -1, : + ] # (bsz, response_length, vocab_size) + log_probs = logprobs_from_logits(logits, micro_batch["responses"]) + if calculate_entropy: + if not self.config.entropy_checkpointing: + entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) + else: + entropy = torch.utils.checkpoint.checkpoint( + verl_F.entropy_from_logits, logits + ) + if calculate_nec: + H_for_N = entropy.to(torch.float32) if calculate_entropy else None + nec = compute_N_from_logits(logits, entropy=H_for_N) + + outputs = {"log_probs": log_probs} + if calculate_entropy: + outputs["entropys"] = entropy + if calculate_nec: + outputs["necs"] = nec + return outputs diff --git a/examples/entropy/clipv_trainer.patch b/examples/entropy/clipv_trainer.patch index b80d957125..cb972124c9 100644 --- a/examples/entropy/clipv_trainer.patch +++ b/examples/entropy/clipv_trainer.patch @@ -1,4 +1,27 @@ +diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py +index b82fd264..0fff4233 100644 +--- a/trinity/trainer/verl/fsdp_workers.py ++++ b/trinity/trainer/verl/fsdp_workers.py +@@ -610,7 +610,7 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension): + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): +- from trinity.trainer.verl.dp_actor import DataParallelPPOActor ++ from examples.entropy.clipv_dp_actor import DataParallelPPOActor + + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get("external_lib", None)) +@@ -911,6 +911,8 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension): + tensors = {"ref_log_prob": outputs["log_probs"]} + if calculate_entropy: + tensors["entropys"] = outputs["entropys"] ++ if "necs" in outputs: ++ tensors["necs"] = outputs["necs"] + output = DataProto.from_dict( + tensors=tensors, + meta_info={"temperature": self.config.rollout.temperature}, diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py +index 849e176b..734b8b26 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -24,13 +24,15 @@ from verl.trainer.ppo.ray_trainer import ( @@ -56,7 +79,7 @@ diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py async def save_state_dict(self): # checkpoint sync actor_local_path = os.path.join( self.config.trainer.default_local_dir, f"global_step_{self.global_steps}", "actor" -@@ -501,13 +534,19 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): +@@ -501,13 +534,17 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): "perf/mfu/actor_infer": old_log_prob_mfu, } metrics.update(old_log_prob_metrics) @@ -69,15 +92,13 @@ diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py metrics.update(calculate_debug_metrics(batch)) + -+ # !!! Patch starts !!! + batch.batch["new_log_probs"] = old_log_prob.batch["old_log_probs"] + batch.batch["new_entropys"] = entropys + batch.batch["necs"] = old_log_prob.batch["necs"] -+ # !!! Patch ends !!! if self.algorithm.use_reference: # ref_logprob may not be used # compute reference log_prob -@@ -526,7 +565,8 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): +@@ -526,7 +563,8 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): batch, kl_metrics = self.kl_fn.apply_kl_penalty_to_reward(batch) metrics.update(prefix_metrics(kl_metrics, prefix="critic")) # compute advantages, executed on the driver process diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 6f66d1bf13..ab2def977a 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -27,10 +27,7 @@ from verl.utils.debug import GPUMemoryLogger from verl.utils.device import get_device_id from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import prepare_dynamic_batch -from verl.utils.torch_functional import ( - logprobs_from_logits as verl_logprobs_from_logits, -) +from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch from verl.workers.actor.dp_actor import DataParallelPPOActor as DPActor from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN @@ -38,20 +35,12 @@ from trinity.algorithm.kl_fn.kl_fn import DummyKLFn from trinity.algorithm.utils import prefix_metrics from trinity.common.config import AlgorithmConfig -from trinity.utils.registry import Registry __all__ = ["DataParallelPPOActor"] logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) -LOG_PROB_FN: Registry = Registry( - "log_prob_fn", - default_mapping={ - "clipv_entropy_nec": "examples.entropy.clipv_dp_actor.clipv_compute_log_prob_with_nec", - }, -) - class DataParallelPPOActor(DPActor): def __init__( @@ -62,22 +51,6 @@ def __init__( self.policy_loss_fn = None self.kl_loss_fn = None self.entropy_loss_fn = None - log_prob_fn_key = config.get("log_prob_fn", "default") - if not log_prob_fn_key or log_prob_fn_key == "default": - self._compute_log_prob_fn_name = "default" - self._compute_log_prob_fn = None # use default log_prob_fn - else: - self._compute_log_prob_fn_name = str(log_prob_fn_key) - self._compute_log_prob_fn = LOG_PROB_FN.get( - self._compute_log_prob_fn_name - ) # use custom log_prob_fn - - def compute_log_probs_from_logits( - self, logits: torch.Tensor, labels: torch.Tensor, inplace_backward: bool = True - ) -> torch.Tensor: - return verl_logprobs_from_logits( - logits=logits, labels=labels, inplace_backward=inplace_backward - ) def set_algorithm(self, algorithm_config: AlgorithmConfig): self.loss_agg_mode = algorithm_config.loss_agg_mode @@ -157,11 +130,13 @@ def update_policy(self, data: DataProto): # noqa: C901 # all return: (bsz, response_length) calculate_entropy = self.entropy_loss_fn != DummyEntropyLossFn - entropy, log_prob = self._forward_micro_batch( + outputs = self._forward_micro_batch( micro_batch=model_inputs, temperature=temperature, calculate_entropy=calculate_entropy, ) + log_prob = outputs["log_probs"] + entropy = outputs["entropys"] if calculate_entropy else None pg_loss, pg_loss_metrics = self.policy_loss_fn( # type: ignore logprob=log_prob, **model_inputs @@ -249,17 +224,71 @@ def update_policy(self, data: DataProto): # noqa: C901 self.actor_optimizer.zero_grad() return metrics + # TODO: remove this method after upgrading verl @GPUMemoryLogger(role="dp actor", logger=logger) - def compute_log_prob( - self, data: DataProto, calculate_entropy: bool = False, calculate_nec: bool = False - ): - if self._compute_log_prob_fn_name == "default" or self._compute_log_prob_fn is None: - log_probs, entropys = super().compute_log_prob( - data=data, calculate_entropy=calculate_entropy - ) - return {"entropys": entropys, "log_probs": log_probs} + def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> dict[str, torch.Tensor]: + """Compute the log probability of the responses given input_ids, attention_mask and position_ids + + Args: + data (DataProto): a DataProto containing keys + + ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the + concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. + + ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``responses``: tensor of shape [batch_size, response_length]. torch.int64. + + Returns: + dict[str, torch.Tensor]: a dict containing keys + - ``log_probs``: tensor of shape [batch_size, response_length]. torch.float32. + - ``entropys``: tensor of shape [batch_size, response_length]. torch.float32. + """ + # set to eval + self.actor_module.eval() + + micro_batch_size = data.meta_info["micro_batch_size"] + temperature = data.meta_info[ + "temperature" + ] # temperature must be in the data.meta_info to avoid silent error + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] + non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] + + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + if use_dynamic_bsz: + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size + micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) else: - outputs = self._compute_log_prob_fn( - actor=self, data=data, calculate_entropy=calculate_entropy - ) + micro_batches = data.split(micro_batch_size) + + log_probs_lst = [] + entropy_lst = [] + for micro_batch in micro_batches: + micro_batch = micro_batch.to(get_device_id()) + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} + with torch.no_grad(): + outputs = self._forward_micro_batch( + model_inputs, temperature=temperature, calculate_entropy=calculate_entropy + ) + log_probs_lst.append(outputs["log_probs"]) + if calculate_entropy: + entropy_lst.append(outputs["entropys"]) + + log_probs = torch.concat(log_probs_lst, dim=0) + if calculate_entropy: + entropys = torch.concat(entropy_lst, dim=0) + + if use_dynamic_bsz: + log_probs = restore_dynamic_batch(log_probs, batch_idx_list) + if calculate_entropy: + entropys = restore_dynamic_batch(entropys, batch_idx_list) + + outputs = {"log_probs": log_probs} + if calculate_entropy: + outputs["entropys"] = entropys return outputs diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index fd54acebb2..0fff42338e 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -610,7 +610,7 @@ def _build_model_optimizer( # noqa: C901 @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): - from trinity.trainer.verl.dp_actor import DataParallelPPOActor + from examples.entropy.clipv_dp_actor import DataParallelPPOActor # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) @@ -904,9 +904,7 @@ def compute_log_prob(self, data: DataProto): calculate_entropy = not is_lora with self.ulysses_sharding_manager: with adapter_ctx: - outputs = self.actor.compute_log_prob( - data=data, calculate_entropy=calculate_entropy - ) + outputs = self.actor.compute_log_prob(data=data, calculate_entropy=not is_lora) if not is_lora: tensors = {"old_log_probs": outputs["log_probs"]} else: @@ -954,12 +952,7 @@ def compute_ref_log_prob(self, data: DataProto): "cpu" ) # data will to device with each micro batch on ref.compute_log_prob outputs = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) - if isinstance(outputs, dict): - ref_log_prob = outputs["log_probs"] - else: - # Backward compatibility with old tuple return style. - ref_log_prob, _ = outputs - output = DataProto.from_dict(tensors={"ref_log_prob": ref_log_prob}) + output = DataProto.from_dict(tensors={"ref_log_prob": outputs["log_probs"]}) output = output.to("cpu")