diff --git a/examples/entropy/README.md b/examples/entropy/README.md new file mode 100644 index 0000000000..58463f892d --- /dev/null +++ b/examples/entropy/README.md @@ -0,0 +1,78 @@ +# Entropy dynamics of RL training + +This example shows the two algorithms **Clip_B** and **Clip_V** from the work [On the Entropy Dynamics in Reinforcement Fine-Tuning of Large Language Models](https://arxiv.org/pdf/2602.03392). + +NOTE: This example is only tested on trinity==0.5.1 and verl==0.7.0. The following experiments require `synchronizer.sync_interval=1` and `trainer.trainer_config.algorithm.rollout_correction.bypass_mode=false` to be set. + +## Data Preparation + +We utilize the [DAPO-Math-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset as our training set. We exclude 500 questions from the training set to form the validation set (denoted by dapo-validation-500). +The training set is filtered out samples from the training set with excessively high (≥ 15/16) or low (≤ 1/16) pass rates, as evaluated by Qwen2.5-7B-Instruct. + +## Clip_B Experiment + +1. Apply the patch to keep entropy information in the trainer batch: + +```bash +cd /path/to/Trinity-RFT +git apply examples/entropy/clipb_trainer.patch +# if not successful, try: +# git apply --3way --ignore-whitespace examples/entropy/clipb_trainer.patch +``` + +2. Update the dataset paths and other configurations in the file [`clipb.yaml`](./clipb.yaml) to point to your local data. + +3. Run the experiment: + +```bash +trinity run examples/entropy/clipb.yaml +``` + +## Clip_V Implementation + +1. Apply the patch to keep entropy information in the trainer batch: + +```bash +cd /path/to/Trinity-RFT +git apply examples/entropy/clipv_trainer.patch +# if not successful, try: +# git apply --3way --ignore-whitespace examples/entropy/clipv_trainer.patch +``` + +2. Update the dataset paths and other configurations in the file [`clipv.yaml`](./clipv.yaml) to point to your local data. + +3. Run the experiment: + +```bash +trinity run examples/entropy/clipv.yaml +``` + +## Clip_V Code Logic + +As shown in the following flowchart, the forward pass of [examples/entropy/clipv_dp_actor.py](./clipv_dp_actor.py) outputs `log_probs`, `entropy`, and `nec`. +These signals are then used by [Clip_V advantage function](../../trinity/algorithm/advantage_fn/clipv_advantage.py) to compute `xD` and clip only negative-advantage tokens. This process returns the revised `advantages`. + +```mermaid +flowchart TD + A["data"] + B["forward pass"] + C1["log_probs"] + C2["entropy (additional)"] + C3["nec (additional)"] + subgraph D["advantage computation"] + direction TB + F["xD = nec - exp(log_probs) * (entropy + log_probs)"] + G["only clip negative-advantage tokens"] + F --> G + end + E["advantages"] + + A --> B + B --> C1 + B --> C2 + B --> C3 + C1 --> D + C2 --> D + C3 --> D + D --> E +``` diff --git a/examples/entropy/clipb.yaml b/examples/entropy/clipb.yaml new file mode 100644 index 0000000000..3e0fdebfcf --- /dev/null +++ b/examples/entropy/clipb.yaml @@ -0,0 +1,100 @@ +project: math_dapo +name: clipb_example +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} + max_prompt_tokens: 1024 + max_response_tokens: 7168 +algorithm: + algorithm_type: grpo_verl + advantage_fn: clipb + advantage_fn_args: + mu: 2.5 + repeat_times: 16 + kl_loss_fn_args: + kl_coef: 0.0 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 20 + batch_size: 64 + explorer_input: + taskset: + name: dapo_235 + storage_type: file + path: ${oc.env:TRINITY_TASKSET_PATH} # processed DAPO-Math-17k + format: + prompt_key: 'question' + response_key: 'ground_truth' + rollout_args: + temperature: 1.0 + logprobs: 20 + eval_tasksets: + - name: dapo-validation-500 + storage_type: file + path: '/path/to/dapo-validation' # validation samples from DAPO-Math-17k + split: 'test' + repeat_times: 32 + format: + prompt_key: 'question' + response_key: 'ground_truth' + rollout_args: + temperature: 0.7 + - name: amc23 + storage_type: file + path: math-ai/amc23 # Path to the AMC23 dataset + split: 'test' + repeat_times: 32 + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 0.7 + - name: aime24 + storage_type: file + path: HuggingFaceH4/aime_2024 # Path to the AIME2024 dataset + split: 'train' + repeat_times: 32 + format: + prompt_key: 'problem' + response_key: 'answer' + rollout_args: + temperature: 0.7 + - name: aime25 + storage_type: file + path: math-ai/aime25 # Path to the AIME2025 dataset + split: 'test' + repeat_times: 32 + format: + prompt_key: 'problem' + response_key: 'answer' + rollout_args: + temperature: 0.7 + default_workflow_type: 'async_math_workflow' + default_reward_fn_type: 'math_boxed_reward' + trainer_input: + experience_buffer: + name: math_buffer + storage_type: queue + max_read_timeout: 7200 +explorer: + eval_interval: 20 + eval_on_startup: true + runner_per_model: 8 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + seed: 42 +trainer: + trainer_type: 'verl' + save_interval: 200 + trainer_config: + algorithm: + rollout_correction: + bypass_mode: false +synchronizer: + sync_method: 'nccl' + sync_interval: 1 + sync_timeout: 3200 diff --git a/examples/entropy/clipb_trainer.patch b/examples/entropy/clipb_trainer.patch new file mode 100644 index 0000000000..03ca08b29c --- /dev/null +++ b/examples/entropy/clipb_trainer.patch @@ -0,0 +1,11 @@ +--- a/trinity/trainer/verl_trainer.py ++++ b/trinity/trainer/verl_trainer.py +@@ -501,7 +501,8 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): + } + metrics.update(old_log_prob_metrics) +- old_log_prob.batch.pop("entropys") ++ # Keep entropys in batch so advantage_fn (e.g. Clip_B) can use it ++ # old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. diff --git a/examples/entropy/clipv.yaml b/examples/entropy/clipv.yaml new file mode 100644 index 0000000000..423e048a5d --- /dev/null +++ b/examples/entropy/clipv.yaml @@ -0,0 +1,100 @@ +project: math_dapo +name: clipv_example +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} + max_prompt_tokens: 1024 + max_response_tokens: 7168 +algorithm: + algorithm_type: grpo_verl + advantage_fn: clipv + advantage_fn_args: + mu: 8.5 + repeat_times: 8 + kl_loss_fn_args: + kl_coef: 0.0 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 20 + batch_size: 64 + explorer_input: + taskset: + name: dapo_235 + storage_type: file + path: ${oc.env:TRINITY_TASKSET_PATH} # processed DAPO-Math-17k + format: + prompt_key: 'question' + response_key: 'ground_truth' + rollout_args: + temperature: 1.0 + logprobs: 20 + eval_tasksets: + - name: dapo-validation-500 + storage_type: file + path: '/path/to/dapo-validation' # validation samples from DAPO-Math-17k + split: 'test' + repeat_times: 32 + format: + prompt_key: 'question' + response_key: 'ground_truth' + rollout_args: + temperature: 0.7 + - name: amc23 + storage_type: file + path: math-ai/amc23 # Path to the AMC23 dataset + split: 'test' + repeat_times: 32 + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 0.7 + - name: aime24 + storage_type: file + path: HuggingFaceH4/aime_2024 # Path to the AIME2024 dataset + split: 'train' + repeat_times: 32 + format: + prompt_key: 'problem' + response_key: 'answer' + rollout_args: + temperature: 0.7 + - name: aime25 + storage_type: file + path: math-ai/aime25 # Path to the AIME2025 dataset + split: 'test' + repeat_times: 32 + format: + prompt_key: 'problem' + response_key: 'answer' + rollout_args: + temperature: 0.7 + default_workflow_type: 'async_math_workflow' + default_reward_fn_type: 'math_boxed_reward' + trainer_input: + experience_buffer: + name: math_buffer + storage_type: queue + max_read_timeout: 7200 +explorer: + eval_interval: 20 + eval_on_startup: true + runner_per_model: 8 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + seed: 42 +trainer: + trainer_type: 'verl' + save_interval: 100 + trainer_config: + algorithm: + rollout_correction: + bypass_mode: false +synchronizer: + sync_method: 'nccl' + sync_interval: 1 + sync_timeout: 3600 diff --git a/examples/entropy/clipv_dp_actor.py b/examples/entropy/clipv_dp_actor.py new file mode 100644 index 0000000000..6cfa6c6866 --- /dev/null +++ b/examples/entropy/clipv_dp_actor.py @@ -0,0 +1,426 @@ +""" +Single Process Actor. +Modified from https://github.com/volcengine/verl/blob/v0.7.0/verl/workers/actor/dp_actor.py +Note: This patch only works for verl==0.7.0 +""" + +import logging +import os + +import torch +import torch.nn.functional as F +import verl.utils.torch_functional as verl_F +from verl import DataProto +from verl.utils.attention_utils import ( + index_first_axis, + pad_input, + rearrange, + unpad_input, +) +from verl.utils.debug import GPUMemoryLogger +from verl.utils.device import get_device_id +from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch +from verl.utils.torch_functional import logprobs_from_logits +from verl.utils.ulysses import ( + gather_outputs_and_unpad, + ulysses_pad, + ulysses_pad_and_slice_inputs, +) + +from trinity.trainer.verl.dp_actor import DataParallelPPOActor as OriginalDPActor + +__all__ = ["DataParallelPPOActor"] + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def compute_N_from_logits( + logits: torch.Tensor, entropy: torch.Tensor | None = None +) -> torch.Tensor: + """ + logits: [..., V], better to use float32 for stability + entropy: [...], if None, compute H = -sum p log p from logits + return: [...], the same shape as logits without the last dimension + """ + z = logits # use float32 for stability + logp = F.log_softmax(z, dim=-1) # [..., V] + p = logp.exp() # [..., V] + + p2 = p * p # [..., V] + sum_p2 = p2.sum(dim=-1) # [...] + term2 = (p2 * logp).sum(dim=-1) # sum p_i^2 log p_i + + if entropy is None: + entropy = -(p * logp).sum(dim=-1) # H = -sum p log p + N = entropy * sum_p2 + term2 # [... ] + return N + + +class DataParallelPPOActor(OriginalDPActor): + @GPUMemoryLogger(role="dp actor", logger=logger) + def compute_log_prob( + self, data: DataProto, calculate_entropy: bool = False, calculate_nec: bool = True + ) -> dict[str, torch.Tensor | None]: + """Compute the log probability of the responses given input_ids, attention_mask and position_ids + + Args: + data (DataProto): a DataProto containing keys + + ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the + concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. + + ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``responses``: tensor of shape [batch_size, response_length]. torch.int64. + + Returns: + dict containing: + - ``log_probs``: tensor, shape (bsz, response_len) + - ``entropys``: tensor | None, shape (bsz, response_len) + - ``necs``: tensor | None, shape (bsz, response_len) + """ + # set to eval + self.actor_module.eval() + + micro_batch_size = data.meta_info["micro_batch_size"] + temperature = data.meta_info[ + "temperature" + ] # temperature must be in the data.meta_info to avoid silent error + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] + non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] + + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + if use_dynamic_bsz: + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size + micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) + else: + micro_batches = data.split(micro_batch_size) + + log_probs_lst = [] + entropy_lst = [] + nec_lst = [] + for micro_batch in micro_batches: + micro_batch = micro_batch.to(get_device_id()) + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} + with torch.no_grad(): + outputs = self._forward_micro_batch( + model_inputs, + temperature=temperature, + calculate_entropy=calculate_entropy, + calculate_nec=calculate_nec, + ) + log_probs = outputs["log_probs"] + entropy = outputs["entropy"] + nec = outputs["nec"] + + log_probs_lst.append(log_probs) + if calculate_entropy: + entropy_lst.append(entropy) + if calculate_nec: + nec_lst.append(nec) + + log_probs = torch.concat(log_probs_lst, dim=0) + entropys = None + if calculate_entropy: + entropys = torch.concat(entropy_lst, dim=0) + necs = None + if calculate_nec: + necs = torch.concat(nec_lst, dim=0) + + if use_dynamic_bsz: + log_probs = restore_dynamic_batch(log_probs, batch_idx_list) + if calculate_entropy: + entropys = restore_dynamic_batch(entropys, batch_idx_list) + if calculate_nec: + necs = restore_dynamic_batch(necs, batch_idx_list) + + outputs = {"log_probs": log_probs} + if calculate_entropy: + outputs["entropys"] = entropys + if calculate_nec: + outputs["necs"] = necs + return outputs + + def _forward_micro_batch( # type: ignore # noqa: C901 + self, micro_batch, temperature, calculate_entropy: bool = False, calculate_nec: bool = True + ) -> dict[str, torch.Tensor | None]: + """ + Returns: + dict containing: + - ``log_probs``: tensor, shape (bs, response_len) + - ``entropy``: tensor | None, shape (bs, response_len) + - ``nec``: tensor | None, shape (bs, response_len) + """ + response_length = micro_batch["responses"].size(-1) + multi_modal_inputs = {} + if "multi_modal_inputs" in micro_batch.keys(): + from verl.utils.model import extract_multi_modal_inputs + + multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"]) + + with torch.autocast(device_type=self.device_name, dtype=self.param_dtype): + input_ids = micro_batch["input_ids"] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] + entropy = None + nec = None + if position_ids.dim() == 3: # qwen2vl mrope + position_ids = position_ids.transpose(0, 1) # (bsz, 4, seqlen) -> (4, bsz, seqlen) + + if self.use_remove_padding: + input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + if position_ids.dim() == 3: + position_ids_rmpad = ( + index_first_axis( + rearrange(position_ids, "c b s ... -> (b s) c ..."), indices + ) + .transpose(0, 1) + .unsqueeze(1) + ) # (4, bsz, seqlen) -> (4, 1, bsz * seqlen) + else: + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) + + is_mask_all_zero = attention_mask.sum() == 0 + if is_mask_all_zero: + input_ids_rmpad = torch.zeros( + (1, self.ulysses_sequence_parallel_size), + device=input_ids.device, + dtype=input_ids.dtype, + ) + if position_ids.dim() == 3: + position_ids_rmpad = torch.zeros( + (position_ids.shape[0], 1, self.ulysses_sequence_parallel_size), + device=position_ids.device, + dtype=position_ids.dtype, + ) + else: + position_ids_rmpad = torch.zeros( + (1, self.ulysses_sequence_parallel_size), + device=position_ids.device, + dtype=position_ids.dtype, + ) + + if "image_bound" in multi_modal_inputs: + from verl.utils.dataset.vision_utils import ( + process_multi_modal_inputs_for_minicpmo, + ) + + multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( + input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs + ) + + # for compute the log_prob + input_ids_rmpad_rolled = torch.roll( + input_ids_rmpad, shifts=-1, dims=1 + ) # (1, total_nnz) + + # pad and slice the inputs if sp > 1 + if self.use_ulysses_sp: + is_vlm_model = hasattr( + getattr(self.actor_module, "module", self.actor_module).config, + "vision_config", + ) + if is_vlm_model: + # vlm model's inputs will be sliced after embedding + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) + else: + ( + input_ids_rmpad, + position_ids_rmpad, + pad_size, + ) = ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, + position_ids_rmpad=None, + sp_size=self.ulysses_sequence_parallel_size, + ) + + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze( + 0 + ) # ((total_nnz / sp) + pad) + + # only pass input_ids and position_ids to enable flash_attn_varlen + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = self.actor_module( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if self.use_fused_kernels: + log_probs = output.log_probs.squeeze(0) # (total_nnz,) + entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,) + if calculate_nec: + raise RuntimeError( + "calculate_nec=True is not supported with fused kernels in _forward_micro_batch" + ) + + else: + logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) + logits_rmpad.div_(temperature) + + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + inplace_backward = True + if calculate_entropy: + inplace_backward = False + log_probs = logprobs_from_logits( + logits=logits_rmpad, + labels=input_ids_rmpad_rolled, + inplace_backward=inplace_backward, + ) + + # compute entropy + if calculate_entropy: + if not self.config.entropy_checkpointing: + entropy_rmpad = self.compute_entropy_from_logits( + logits_rmpad + ) # ((total_nnz / sp) + pad) + else: + entropy_rmpad = torch.utils.checkpoint.checkpoint( + self.compute_entropy_from_logits, logits_rmpad + ) + if calculate_nec: + H_for_N = entropy_rmpad.to(torch.float32) if calculate_entropy else None + N_rmpad = compute_N_from_logits(logits_rmpad, entropy=H_for_N) + + # gather log_prob if sp > 1 + if self.use_ulysses_sp: + # gather and unpad for the ulysses sp + log_probs = gather_outputs_and_unpad( + log_probs, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + if calculate_entropy: + entropy_rmpad = gather_outputs_and_unpad( + entropy_rmpad, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + if calculate_nec: + N_rmpad = gather_outputs_and_unpad( + N_rmpad, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + + if is_mask_all_zero: + log_probs = log_probs[:0] + if calculate_entropy: + entropy_rmpad = entropy_rmpad[:0] + if calculate_nec: + N_rmpad = N_rmpad[:0] + + # pad back to (bsz, seqlen) + if calculate_entropy: + full_entropy = pad_input( + hidden_states=entropy_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + if calculate_nec: + full_N = pad_input( + hidden_states=N_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + full_log_probs = pad_input( + hidden_states=log_probs.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + + # only return response part: + if calculate_entropy: + entropy = full_entropy.squeeze(-1)[ + :, -response_length - 1 : -1 + ] # (bsz, response_length) + if calculate_nec: + nec = full_N.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) + log_probs = full_log_probs.squeeze(-1)[ + :, -response_length - 1 : -1 + ] # (bsz, response_length) + + else: # not using rmpad and no ulysses sp + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = self.actor_module( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if self.use_fused_kernels: + log_probs = output.log_probs[:, -response_length - 1 : -1] + entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length) + if calculate_nec: + raise RuntimeError( + "calculate_nec=True is not supported with fused kernels in _forward_micro_batch" + ) + + else: + logits = output.logits + + logits.div_(temperature) + logits = logits[ + :, -response_length - 1 : -1, : + ] # (bsz, response_length, vocab_size) + log_probs = logprobs_from_logits(logits, micro_batch["responses"]) + if calculate_entropy: + if not self.config.entropy_checkpointing: + entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) + else: + entropy = torch.utils.checkpoint.checkpoint( + verl_F.entropy_from_logits, logits + ) + if calculate_nec: + H_for_N = entropy.to(torch.float32) if calculate_entropy else None + nec = compute_N_from_logits(logits, entropy=H_for_N) + + outputs = {"log_probs": log_probs} + if calculate_entropy: + outputs["entropys"] = entropy + if calculate_nec: + outputs["necs"] = nec + return outputs diff --git a/examples/entropy/clipv_trainer.patch b/examples/entropy/clipv_trainer.patch new file mode 100644 index 0000000000..cb972124c9 --- /dev/null +++ b/examples/entropy/clipv_trainer.patch @@ -0,0 +1,110 @@ +diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py +index b82fd264..0fff4233 100644 +--- a/trinity/trainer/verl/fsdp_workers.py ++++ b/trinity/trainer/verl/fsdp_workers.py +@@ -610,7 +610,7 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension): + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): +- from trinity.trainer.verl.dp_actor import DataParallelPPOActor ++ from examples.entropy.clipv_dp_actor import DataParallelPPOActor + + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get("external_lib", None)) +@@ -911,6 +911,8 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension): + tensors = {"ref_log_prob": outputs["log_probs"]} + if calculate_entropy: + tensors["entropys"] = outputs["entropys"] ++ if "necs" in outputs: ++ tensors["necs"] = outputs["necs"] + output = DataProto.from_dict( + tensors=tensors, + meta_info={"temperature": self.config.rollout.temperature}, +diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py +index 849e176b..734b8b26 100644 +--- a/trinity/trainer/verl_trainer.py ++++ b/trinity/trainer/verl_trainer.py +@@ -24,13 +24,15 @@ from verl.trainer.ppo.ray_trainer import ( + Role, + create_colocated_worker_cls, + ) ++from verl import DataProto ++from verl.utils import tensordict_utils as tu ++from verl.workers.utils.padding import left_right_2_no_padding, no_padding_2_padding + from verl.utils import hf_processor, hf_tokenizer + from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path + from verl.utils.debug import marked_timer + from verl.utils.fs import copy_local_path_from_hdfs + from verl.utils.metric import reduce_metrics + from verl.workers.config import FSDPEngineConfig +- + from trinity.algorithm import ADVANTAGE_FN, ALGORITHM_TYPE, KL_FN + from trinity.algorithm.utils import prefix_metrics + from trinity.common.config import Config +@@ -433,6 +435,37 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): + self.config.actor_rollout_ref.actor.optim.total_training_steps = self.total_training_steps + self.config.critic.optim.total_training_steps = self.total_training_steps + ++ def _compute_old_log_prob(self, batch: DataProto): ++ """ ++ We add nec to the batch to make the advantage function (e.g. Clip_V) use it. ++ """ ++ if self.use_legacy_worker_impl == "disable": ++ # TODO: remove step 1, 2, 4 after we make the whole training tensordict and padding free ++ # step 1: convert dataproto to tensordict. ++ batch_td = batch.to_tensordict() ++ # step 2: convert from padding to nopadding ++ batch_td = left_right_2_no_padding(batch_td) ++ # step 3: add meta info ++ tu.assign_non_tensor(batch_td, calculate_entropy=True, compute_loss=False) ++ output = self.actor_rollout_wg.compute_log_prob(batch_td) ++ # gather output ++ entropy = tu.get(output, "entropy") ++ log_probs = tu.get(output, "log_probs") ++ old_log_prob_mfu = tu.get(output, "metrics")["mfu"] ++ necs = tu.get(output, "necs") ++ # step 4. No padding to padding ++ entropy = no_padding_2_padding(entropy, batch_td) ++ log_probs = no_padding_2_padding(log_probs, batch_td) ++ necs = no_padding_2_padding(necs, batch_td) ++ # step 5: rebuild a tensordict and convert to dataproto ++ old_log_prob = tu.get_tensordict({"old_log_probs": log_probs.float(), "entropys": entropy.float(), "necs": necs.float()}) ++ old_log_prob = DataProto.from_tensordict(old_log_prob) ++ else: ++ old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) ++ old_log_prob_mfu = 0 ++ return old_log_prob, old_log_prob_mfu ++ ++ + async def save_state_dict(self): # checkpoint sync + actor_local_path = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.global_steps}", "actor" +@@ -501,13 +534,17 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): + "perf/mfu/actor_infer": old_log_prob_mfu, + } + metrics.update(old_log_prob_metrics) +- old_log_prob.batch.pop("entropys") ++ # old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + from verl.utils.debug.metrics import calculate_debug_metrics + + metrics.update(calculate_debug_metrics(batch)) ++ ++ batch.batch["new_log_probs"] = old_log_prob.batch["old_log_probs"] ++ batch.batch["new_entropys"] = entropys ++ batch.batch["necs"] = old_log_prob.batch["necs"] + + if self.algorithm.use_reference: # ref_logprob may not be used + # compute reference log_prob +@@ -526,7 +563,8 @@ class VerlPPOTrainerWrapper(RayPPOTrainer, TrainEngineWrapper): + batch, kl_metrics = self.kl_fn.apply_kl_penalty_to_reward(batch) + metrics.update(prefix_metrics(kl_metrics, prefix="critic")) + # compute advantages, executed on the driver process +- batch, _ = self.advantage_fn(batch) ++ batch, adv_metrics = self.advantage_fn(batch) ++ metrics.update(prefix_metrics(adv_metrics, prefix="clipv")) + else: + # skip token_level_scores for sft/dpo + if "token_level_scores" in batch.batch.keys(): diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py index 52bb605bcd..e693684cb9 100644 --- a/trinity/algorithm/__init__.py +++ b/trinity/algorithm/__init__.py @@ -29,6 +29,7 @@ "multi_step_grpo": "trinity.algorithm.algorithm.MultiStepGRPOAlgorithm", "on_policy_distill": "trinity.algorithm.algorithm.OnPolicyDistillAlgorithm", "jsd": "trinity.algorithm.algorithm.JSDAlgorithm", + "grpo_verl": "trinity.algorithm.algorithm.GRPOverlAlgorithm", }, ) diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index 239862ba58..fc23a412df 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -19,6 +19,8 @@ "rec": "trinity.algorithm.advantage_fn.rec_advantage.RECGroupedAdvantage", "on_policy_distill": "trinity.algorithm.advantage_fn.on_policy_distill_advantage.OnPolicyDistillAdvantage", "jsd": "trinity.algorithm.advantage_fn.jsd_advantage.JSDAdvantage", + "clipb": "trinity.algorithm.advantage_fn.clipb_advantage.ClipBAdvantageFn", + "clipv": "trinity.algorithm.advantage_fn.clipv_advantage.ClipVAdvantageFn", }, ) diff --git a/trinity/algorithm/advantage_fn/clipb_advantage.py b/trinity/algorithm/advantage_fn/clipb_advantage.py new file mode 100644 index 0000000000..62898ada1a --- /dev/null +++ b/trinity/algorithm/advantage_fn/clipb_advantage.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +"""Advantage computation for Clip_B +Ref: https://arxiv.org/pdf/2602.03392""" + +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, Tuple + +import torch + +if TYPE_CHECKING: + from verl import DataProto + +from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn + + +class ClipBAdvantageFn(AdvantageFn): + """Clip_B advantage: keep all positive-advantage tokens, + one-side clip negative-advantage tokens by entropy signal.""" + + def __init__( + self, + epsilon: float = 1e-6, + mu: float = 2.5, + ) -> None: + self.epsilon = epsilon + self.mu = mu + + def __call__( + self, + exps: "DataProto", + **kwargs, + ) -> Tuple["DataProto", Dict]: + """ + Compute advantage for Clip_B. + exps should contain the following fields: + - token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + - response_mask: `(torch.Tensor)` + shape: (bs, response_length) + - uid: `(torch.Tensor)` + shape: (bs,) + - rollout_log_probs: `(torch.Tensor)` + shape: (bs, response_length) + - entropys: `(torch.Tensor)` + shape: (bs, response_length) + Returns: + exps: DataProto with advantages and returns added + metrics: Dict with clipping metrics + """ + token_level_rewards = exps.batch["token_level_rewards"] + response_mask = exps.batch["response_mask"] + index = exps.non_tensor_batch["uid"] + + response_length = token_level_rewards.shape[-1] + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0, dtype=scores.dtype, device=scores.device) + id2std[idx] = torch.tensor(1.0, dtype=scores.dtype, device=scores.device) + elif len(id2score[idx]) > 1: + group_scores = torch.stack(id2score[idx]).to( + dtype=scores.dtype, device=scores.device + ) + id2mean[idx] = torch.mean(group_scores) + id2std[idx] = torch.std(group_scores) + else: + raise ValueError(f"no score in prompt index: {idx}") + + for i in range(bsz): + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + self.epsilon) + scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask + + exps.batch["advantages"] = scores + exps.batch["returns"] = scores.clone() + + # --- BEGIN: token filtering logic --- + # Use recomputed logprobs & entropy from current model (not rollout) + LP = exps.batch["rollout_log_probs"] # [B, T], recomputed logprobs + H = exps.batch["entropys"] # [B, T], recomputed entropy + M = response_mask # [B, T], mask of valid tokens + p = LP.exp() # [B, T], probability of valid tokens + S = p * (H + LP) # [B, T], indicator + + # Detach for constructing clip mask (no gradient needed) + xS = S.detach().to(torch.float32) # [B, T] + m = M.to(torch.float32) # [B, T] + + # Masked global mean & variance (population variance, denominator = n) + n = m.sum().clamp_min(1.0) + ES = (xS * m).sum() / n # scalar + varS = ((xS - ES) ** 2 * m).sum() / n # scalar + stdS = varS.sqrt() # scalar + + # Centered signal + z = xS - ES # [B, T] + + # if stdS is too small, keep all tokens; otherwise + # keep all positive-advantage tokens; one-side clip negative-advantage tokens + if stdS.item() < 1e-12: + keep = torch.ones_like(M, dtype=M.dtype) # all kept + else: + A = exps.batch["advantages"].detach().to(torch.float32) # [B, T] + pos_mask = A > 0 + neg_mask = A < 0 + + keep_pos = torch.ones_like(pos_mask, dtype=torch.bool) # positive: all kept + keep_neg = z >= -(self.mu * stdS) # negative: lower-side clip + keep_zero = torch.ones_like(pos_mask, dtype=torch.bool) # zero: all kept + + keep_bool = torch.where(pos_mask, keep_pos, torch.where(neg_mask, keep_neg, keep_zero)) + keep = keep_bool.to(M.dtype) + + M_clipped = M * keep + exps.batch["response_mask"] = M_clipped + # --- END: token filtering logic --- + + # Monitoring metrics + total_tokens = m.sum().clamp_min(1.0) + frac_clipped = 1.0 - (M_clipped.to(torch.float32).sum() / total_tokens).item() + + A = exps.batch["advantages"].detach().to(torch.float32) + pos_mask = (A > 0).to(M.dtype) + neg_mask = (A < 0).to(M.dtype) + total_pos = (M * pos_mask).to(torch.float32).sum().clamp_min(1.0) + total_neg = (M * neg_mask).to(torch.float32).sum().clamp_min(1.0) + frac_clipped_pos = 1.0 - ((M_clipped * pos_mask).to(torch.float32).sum() / total_pos).item() + frac_clipped_neg = 1.0 - ((M_clipped * neg_mask).to(torch.float32).sum() / total_neg).item() + + metrics = { + "frac_clipped": frac_clipped, + "frac_clipped_pos": frac_clipped_pos, + "frac_clipped_neg": frac_clipped_neg, + "ES": ES.item(), + "varS": varS.item(), + } + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "epsilon": 1e-6, + "mu": 2.5, + } diff --git a/trinity/algorithm/advantage_fn/clipv_advantage.py b/trinity/algorithm/advantage_fn/clipv_advantage.py new file mode 100644 index 0000000000..0a126c839e --- /dev/null +++ b/trinity/algorithm/advantage_fn/clipv_advantage.py @@ -0,0 +1,153 @@ +"""GRPO advantage computation with Clip_V token filtering. +""" + +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, Tuple + +import torch + +if TYPE_CHECKING: + from verl import DataProto + +from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn + + +class ClipVAdvantageFn(AdvantageFn): + """Clip_V advantage: one-side clip only negative-advantage tokens, + and cap the global clipped-token ratio.""" + + def __init__( + self, + epsilon: float = 1e-6, + mu: float = 2.0, + max_frac: float = 1e-4, + ) -> None: + self.epsilon = epsilon + self.mu = mu + self.max_frac = max_frac + + def __call__( + self, + exps: "DataProto", + **kwargs, + ) -> Tuple["DataProto", Dict]: + token_level_rewards = exps.batch["token_level_rewards"] + response_mask = exps.batch["response_mask"] + index = exps.non_tensor_batch["uid"] + + new_log_probs = exps.batch["old_log_probs"] + new_entropys = exps.batch["entropys"] + necs = exps.batch["necs"] + + response_length = token_level_rewards.shape[-1] + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + + for idx, grouped_scores in id2score.items(): + if len(grouped_scores) == 1: + id2mean[idx] = torch.tensor(0.0, dtype=scores.dtype, device=scores.device) + id2std[idx] = torch.tensor(1.0, dtype=scores.dtype, device=scores.device) + elif len(grouped_scores) > 1: + group_scores = torch.stack(grouped_scores).to( + dtype=scores.dtype, device=scores.device + ) + id2mean[idx] = torch.mean(group_scores) + id2std[idx] = torch.std(group_scores) + else: + raise ValueError(f"no score in prompt index: {idx}") + + for i in range(bsz): + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + self.epsilon) + scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask + + exps.batch["advantages"] = scores + exps.batch["returns"] = scores.clone() + + LP = new_log_probs + H = new_entropys + N = necs + M = response_mask + p = LP.exp() + S = p * (H + LP) + + xD = (N - S).detach().to(torch.float32) + A = exps.batch["advantages"].detach().to(torch.float32) + m = M.to(torch.float32) + + n = m.sum().clamp_min(1.0) + mean_d = (xD * m).sum() / n + var_d = ((xD - mean_d) ** 2 * m).sum() / n + std_d = var_d.sqrt() + + if std_d.item() < 1e-12: + keep = torch.ones_like(M, dtype=M.dtype) + else: + pos_mask = A > 0 + neg_mask = A < 0 + + keep_neg = xD <= (self.mu * std_d) + keep_bool = torch.where(pos_mask, torch.ones_like(pos_mask), keep_neg) + + total_tokens = m.sum().clamp_min(1.0) + clipped_mask = (M > 0) & (~keep_bool) + frac_clipped = (clipped_mask.to(torch.float32).sum() / total_tokens).item() + + if frac_clipped <= self.max_frac: + keep = keep_bool.to(M.dtype) + else: + max_clipped_tokens = max(int(self.max_frac * total_tokens.item()), 1) + neg_to_clip = neg_mask & (M > 0) & (~keep_neg) + neg_to_clip_count = int(neg_to_clip.to(torch.int32).sum().item()) + + if neg_to_clip_count <= max_clipped_tokens: + keep = keep_bool.to(M.dtype) + else: + # Keep only top-K most violating negative-advantage tokens clipped. + candidate_scores = xD.masked_fill(~neg_to_clip, float("-inf")).view(-1) + k = min(max_clipped_tokens, neg_to_clip_count) + _, indices = torch.topk(candidate_scores, k, largest=True) + + limited_clip_mask = torch.zeros_like(candidate_scores, dtype=torch.bool) + limited_clip_mask[indices] = True + limited_clip_mask = limited_clip_mask.view_as(xD) + + final_keep = keep_bool.clone() + final_keep[neg_to_clip] = True + final_keep[limited_clip_mask] = False + keep = final_keep.to(M.dtype) + + M_clipped = M * keep + exps.batch["response_mask"] = M_clipped + + total_tokens = M.to(torch.float32).sum().clamp_min(1.0) + frac_clipped = 1.0 - (M_clipped.to(torch.float32).sum() / total_tokens).item() + + pos_mask = (A > 0).to(M.dtype) + neg_mask = (A < 0).to(M.dtype) + total_pos = (M * pos_mask).to(torch.float32).sum().clamp_min(1.0) + total_neg = (M * neg_mask).to(torch.float32).sum().clamp_min(1.0) + frac_clipped_pos = 1.0 - ((M_clipped * pos_mask).to(torch.float32).sum() / total_pos).item() + frac_clipped_neg = 1.0 - ((M_clipped * neg_mask).to(torch.float32).sum() / total_neg).item() + + metrics = { + "frac_clipped": frac_clipped, + "frac_clipped_pos": frac_clipped_pos, + "frac_clipped_neg": frac_clipped_neg, + "stdD": std_d.item(), + } + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "epsilon": 1e-6, + "mu": 8.5, + "max_frac": 1e-4, + } diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 5352bba449..35b7453222 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -540,3 +540,25 @@ def default_config(cls) -> Dict: "kl_loss_fn": "none", "entropy_loss_fn": "none", } + + +class GRPOverlAlgorithm(AlgorithmType): + """GRPO algorithm, but advantage computation is done in trainer.""" + + use_critic: bool = False + use_reference: bool = True + compute_advantage_in_trainer: bool = True + can_balance_batch: bool = True + schema: str = "experience" + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 2, + "advantage_fn": "grpo", + "sample_strategy": "default", + "policy_loss_fn": "ppo", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", + } diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 1aaa4a3c4a..92972406ea 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -175,12 +175,16 @@ class Actor: router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig) # do not set loss_agg_mode: str = "token-mean" + loss_scale_factor: Optional[float] = None clip_ratio: float = 0.2 clip_ratio_low: Optional[float] = None clip_ratio_high: Optional[float] = None entropy_coeff: float = 0.001 use_kl_loss: bool = False + # custom log_prob_fn + log_prob_fn: Optional[str] = None + @dataclass class Ref: diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 80bc4a16c8..ab2def977a 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -27,7 +27,7 @@ from verl.utils.debug import GPUMemoryLogger from verl.utils.device import get_device_id from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import prepare_dynamic_batch +from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch from verl.workers.actor.dp_actor import DataParallelPPOActor as DPActor from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN @@ -130,11 +130,13 @@ def update_policy(self, data: DataProto): # noqa: C901 # all return: (bsz, response_length) calculate_entropy = self.entropy_loss_fn != DummyEntropyLossFn - entropy, log_prob = self._forward_micro_batch( + outputs = self._forward_micro_batch( micro_batch=model_inputs, temperature=temperature, calculate_entropy=calculate_entropy, ) + log_prob = outputs["log_probs"] + entropy = outputs["entropys"] if calculate_entropy else None pg_loss, pg_loss_metrics = self.policy_loss_fn( # type: ignore logprob=log_prob, **model_inputs @@ -221,3 +223,72 @@ def update_policy(self, data: DataProto): # noqa: C901 append_to_dict(metrics, mini_batch_metrics) self.actor_optimizer.zero_grad() return metrics + + # TODO: remove this method after upgrading verl + @GPUMemoryLogger(role="dp actor", logger=logger) + def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> dict[str, torch.Tensor]: + """Compute the log probability of the responses given input_ids, attention_mask and position_ids + + Args: + data (DataProto): a DataProto containing keys + + ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the + concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. + + ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``responses``: tensor of shape [batch_size, response_length]. torch.int64. + + Returns: + dict[str, torch.Tensor]: a dict containing keys + - ``log_probs``: tensor of shape [batch_size, response_length]. torch.float32. + - ``entropys``: tensor of shape [batch_size, response_length]. torch.float32. + """ + # set to eval + self.actor_module.eval() + + micro_batch_size = data.meta_info["micro_batch_size"] + temperature = data.meta_info[ + "temperature" + ] # temperature must be in the data.meta_info to avoid silent error + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] + non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] + + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + if use_dynamic_bsz: + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size + micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) + else: + micro_batches = data.split(micro_batch_size) + + log_probs_lst = [] + entropy_lst = [] + for micro_batch in micro_batches: + micro_batch = micro_batch.to(get_device_id()) + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} + with torch.no_grad(): + outputs = self._forward_micro_batch( + model_inputs, temperature=temperature, calculate_entropy=calculate_entropy + ) + log_probs_lst.append(outputs["log_probs"]) + if calculate_entropy: + entropy_lst.append(outputs["entropys"]) + + log_probs = torch.concat(log_probs_lst, dim=0) + if calculate_entropy: + entropys = torch.concat(entropy_lst, dim=0) + + if use_dynamic_bsz: + log_probs = restore_dynamic_batch(log_probs, batch_idx_list) + if calculate_entropy: + entropys = restore_dynamic_batch(entropys, batch_idx_list) + + outputs = {"log_probs": log_probs} + if calculate_entropy: + outputs["entropys"] = entropys + return outputs diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 1b893bdfef..0fff42338e 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -610,7 +610,7 @@ def _build_model_optimizer( # noqa: C901 @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): - from trinity.trainer.verl.dp_actor import DataParallelPPOActor + from examples.entropy.clipv_dp_actor import DataParallelPPOActor # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) @@ -901,14 +901,18 @@ def compute_log_prob(self, data: DataProto): data.meta_info["use_dynamic_bsz"] = config_source.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature # perform recompute log_prob + calculate_entropy = not is_lora with self.ulysses_sharding_manager: with adapter_ctx: - output, entropys = self.actor.compute_log_prob( - data=data, calculate_entropy=not is_lora - ) - tensors = {"ref_log_prob": output} if is_lora else {"old_log_probs": output} + outputs = self.actor.compute_log_prob(data=data, calculate_entropy=not is_lora) if not is_lora: - tensors["entropys"] = entropys + tensors = {"old_log_probs": outputs["log_probs"]} + else: + tensors = {"ref_log_prob": outputs["log_probs"]} + if calculate_entropy: + tensors["entropys"] = outputs["entropys"] + if "necs" in outputs: + tensors["necs"] = outputs["necs"] output = DataProto.from_dict( tensors=tensors, meta_info={"temperature": self.config.rollout.temperature}, @@ -947,8 +951,8 @@ def compute_ref_log_prob(self, data: DataProto): data = data.to( "cpu" ) # data will to device with each micro batch on ref.compute_log_prob - output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) - output = DataProto.from_dict(tensors={"ref_log_prob": output}) + outputs = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) + output = DataProto.from_dict(tensors={"ref_log_prob": outputs["log_probs"]}) output = output.to("cpu")