diff --git a/recipe/dapo/config/dapo_fsdp_config_with_resampling.yaml b/recipe/dapo/config/dapo_fsdp_config_with_resampling.yaml new file mode 100644 index 00000000..044255e1 --- /dev/null +++ b/recipe/dapo/config/dapo_fsdp_config_with_resampling.yaml @@ -0,0 +1,33 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +# parameters added to enable PassRateWeightedSampler with DAPO; override parameters in verl/trainer/config/data/legacy_data.yaml +data: + gen_batch_size: ${data.train_batch_size} + dataloader_num_workers: 0 # Recommended to set to 0 when using curriculum learning samplers (e.g., PassRateWeightedSampler) to prevent data caching before batches are reordered. + sampler: + pass_rate_temperature: 1.0 # temperature parameter for PassRateWeightedSampler, controls sharpness of weighting distribution + use_ema: False # whether to use EMA smoothed pass rates for weighting + ema_alpha: 0.1 # alpha parameter for EMA smoothing of pass rates + +reward_model: + reward_manager: dapo + overlong_buffer: + enable: False # We try to avoid forgetting to set enable + len: 0 + penalty_factor: 0.0 + log: False + +algorithm: + filter_groups: + _target_: verl.trainer.config.FilterGroupsConfig + enable: False # We try to avoid forgetting to set enable + metric: null # acc / score / seq_reward / seq_final_reward / ... + max_num_gen_batches: 0 # Non-positive values mean no upper limit + + diff --git a/recipe/dapo/dapo_ray_trainer.py b/recipe/dapo/dapo_ray_trainer.py index 2e144130..7ccec043 100644 --- a/recipe/dapo/dapo_ray_trainer.py +++ b/recipe/dapo/dapo_ray_trainer.py @@ -43,7 +43,7 @@ ) from verl.utils.profiler import marked_timer from verl.utils.rollout_skip import RolloutSkip - +from verl.utils.pass_rate_weighted_sampler import PassRateWeightedSampler class RayDAPOTrainer(RayPPOTrainer): """ @@ -68,12 +68,19 @@ def fit(self): config=OmegaConf.to_container(self.config, resolve=True), ) - self.global_steps = 0 + self.global_steps = 0 self.gen_steps = 0 - # load checkpoint before doing anything self._load_checkpoint() + # Extract pass rate tracker from sampler if using curriculum learning + # The PassRateWeightedSampler owns the tracker internally but we need to manually update it during training + # Currently, we only support PassRateWeightedSampler for curriculum learning + self.pass_rate_tracker = None + self.data_sampler = self.train_dataloader.sampler # train_dataloader is created in `RayPPOTrainer._create_dataloader()` and always has a sampler + if isinstance(self.data_sampler, PassRateWeightedSampler): + self.pass_rate_tracker = self.data_sampler.pass_rate_tracker + # perform validation before training # currently, we only support validation using the reward_function. if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): @@ -135,7 +142,6 @@ def fit(self): non_tensor_batch_keys=["raw_prompt_ids"], ) gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) - is_last_step = self.global_steps >= self.total_training_steps with marked_timer("step", timing_raw): @@ -189,7 +195,6 @@ def fit(self): reward_extra_infos_dict = {} new_batch.batch["token_level_scores"] = reward_tensor - if reward_extra_infos_dict: new_batch.non_tensor_batch.update( {k: np.array(v) for k, v in reward_extra_infos_dict.items()} @@ -206,6 +211,47 @@ def fit(self): else: new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] + # === Curriculum Learning: Update pass rate tracker for weighted resampling === + # When using PassRateWeightedSampler, track per-sample success rates to enable dynamic curriculum learning. + # The sampler uses these pass rates to adjust sampling probabilities in the next epoch. + + # Note: make updating the pass rate tracker as a utility function later + # 1. if sampler is an instance of PassRateWeightedSampler, self.pass_rate_tracker is not None + # 2. `dataset_index` field is added to the RL datatset to identify samples + if "dataset_index" in new_batch.non_tensor_batch and self.pass_rate_tracker is not None: + dataset_indices = new_batch.non_tensor_batch["dataset_index"] + # Sum token-level rewards to get sequence-level reward + seq_rewards = new_batch.batch["token_level_rewards"].sum(dim=-1).cpu().numpy() + # Success is 1 if sequence reward > 0, else 0 + successes = (seq_rewards > 0).astype(float) + + # Deduplicate: batch was repeated n times (interleaved), so we need to aggregate + unique_indices, inverse_indices = np.unique(dataset_indices, return_inverse=True) + + assert len(unique_indices) > 0, "No unique samples found in batch. Check data pipeline configuration." + # Aggregate successes: take mean across rollouts for each sample + aggregated_successes = np.zeros(len(unique_indices), dtype=float) + for i, _ in enumerate(unique_indices): + mask = inverse_indices == i # boolean array to indicate positions of unique index i + aggregated_successes[i] = np.mean(successes[mask]) # take average success across rollouts for sample i + + pass_rates = self.pass_rate_tracker.get_pass_rates() + + # Log curriculum metrics BEFORE updating tracker + # Track improvement of hardest samples (across all samples, not just attempted) + metrics['curriculum/hardest_10pct_pass_rate'] = float(np.percentile(pass_rates, 10)) + metrics['curriculum/hardest_25pct_pass_rate'] = float(np.percentile(pass_rates, 25)) + metrics['curriculum/hardest_50pct_pass_rate'] = float(np.percentile(pass_rates, 50)) + metrics['curriculum/hardest_75pct_pass_rate'] = float(np.percentile(pass_rates, 75)) + + # Batch-level statistics + metrics['curriculum/min_batch_pass_rate'] = float(np.min(aggregated_successes)) + metrics['curriculum/mean_batch_pass_rate'] = float(np.mean(aggregated_successes)) + metrics['curriculum/effective_batch_size'] = np.sum(aggregated_successes > 0)/len(unique_indices) + + # Update tracker with current batch results + self.pass_rate_tracker.update(sample_indices=unique_indices.astype(int), batch_pass_rate=aggregated_successes) + if not self.config.algorithm.filter_groups.enable: batch = new_batch else: # NOTE: When prompts after filtering is less than train batch size, @@ -280,7 +326,6 @@ def fit(self): # === Updating === batch.batch["response_mask"] = compute_response_mask(batch) - # Balance the number of valid tokens across DP ranks. # NOTE: This usually changes the order of data in the `batch`, # which won't affect the advantage calculation (since it's based on uid), @@ -342,6 +387,7 @@ def fit(self): actor_output = self.actor_rollout_wg.update_actor(batch) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) + print("in critic warmup loop") # Log rollout generations if enabled rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) @@ -430,6 +476,31 @@ def _to_sequence(value): num_total_prompts = 0 num_gen_batches = 0 + # Add curriculum learning metrics to W&B + if isinstance(self.data_sampler, PassRateWeightedSampler): + # Add 3D plot data for weight and count distributions (percentile-based) + try: + import wandb + import pandas as pd + + weight_3d_data = self.data_sampler.get_wandb_3d_plot_data(metric_type='weight') + count_3d_data = self.data_sampler.get_wandb_3d_plot_data(metric_type='count') + + # Add step to each data point for 3D visualization + for point in weight_3d_data: + point['step'] = self.global_steps + for point in count_3d_data: + point['step'] = self.global_steps + + metrics['curriculum/weight_distribution_3d'] = wandb.Table( + dataframe=pd.DataFrame(weight_3d_data) + ) if weight_3d_data else None + metrics['curriculum/count_distribution_3d'] = wandb.Table( + dataframe=pd.DataFrame(count_3d_data) + ) if count_3d_data else None + except ImportError: + pass # wandb or pandas not available + # TODO: make a canonical logger that supports various backend logger.log(data=metrics, step=self.global_steps) diff --git a/scripts/train/pass_rate_weighted_sampler_multinode_rl_qwen2.5_32b_base_fsdp.sh b/scripts/train/pass_rate_weighted_sampler_multinode_rl_qwen2.5_32b_base_fsdp.sh new file mode 100644 index 00000000..3d93ce3c --- /dev/null +++ b/scripts/train/pass_rate_weighted_sampler_multinode_rl_qwen2.5_32b_base_fsdp.sh @@ -0,0 +1,295 @@ +#!/bin/bash +#SBATCH --job-name=example-multinode-rl-qwen2.5-32b-base-fsdp +#SBATCH --nodes=2 +#SBATCH --ntasks=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=main +#SBATCH --account=iq + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +export CONDA_BIN_PATH=/mnt/weka/home/jalaj.bhandari/miniconda3/envs/jalaj_sync_rl/bin/ +export NCCL_TIMEOUT_SECONDS=4800 +export TORCH_NCCL_ENABLE_MONITORING=0 +export NCCL_DEBUG=warn +export NCCL_NET=IB +export NCCL_IB_HCA="mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7" +export NCCL_CROSS_NIC=1 +export NCCL_IB_TC=136 +export NCCL_SOCKET_IFNAME="^lo,docker,virbr" +export CUDA_DEVICE_MAX_CONNECTIONS=8 +export NCCL_NVLS_ENABLE=1 + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +echo "Number of nodes (workers): $worker_num" + +# =================== Data Mixture =================== +# SHARED_DATA_PATH=/mnt/weka/shrd/k2tls/k2v2rl-data/ +# TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train/ +# TEST_DATA_DIR=${SHARED_DATA_PATH}/online_eval/ + +# TRAIN_DATA_DIR=/mnt/weka/shrd/k2tls/k2v2rl-data/data_mix_1/main_questions +TRAIN_DATA_DIR=/mnt/weka/shrd/k2tls/jalaj/data-mixtures_round2/medium/data_mix_math_9000/main_questions +TEST_DATA_DIR=/mnt/weka/shrd/k2tls/rl-test-data-12k + +# Math (train) +math_train_path=/mnt/weka/shrd/k2tls/k2v2rl-data/data-mixtures/pass_rate_sampling/medium/data_mix_math_20000/main_questions/math__combined_118.2k.part2_scored.parquet +# math_train_path=${TRAIN_DATA_DIR}/math__combined_118.2k.part1_scored.parquet +# math_train_path2=${TEST_DATA_DIR}/math__combined_118.2k.part2_scored.parquet + +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__leetcode2k_1.3k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__livecodebench_440.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__primeintellect_7.5k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__taco_8.8k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_200.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_1.2k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_1.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_1.3k.parquet +# Logic (test) +ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_100.parquet +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +arcagi_test_path=${TEST_DATA_DIR}/logic__arcagi1_200.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_3.7k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_200.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_200.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet +# Stem (test) +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet + +train_files="['${math_train_path}']" +# train_files="['${math_train_path}', '${math_train_path2}']" # Use math as example, add to more tasks as needed +test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}']" # Use math as example, add to more tasks as needed + + +# =================== Model =================== +BASE_MODEL=Qwen/Qwen2.5-32B + +# =================== Logging =================== +WANDB_PROJECT=Reasoning360 +WANDB_EXPERIMENT_NAME=Curriculum-${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=32 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=8 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=1 +gen_tp=2 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +# These arguments enable pass-rate based weighted sampling +# - data.sampler.class_path='pkg://verl.utils.pass_rate_weighted_sampler' \ +# - data.sampler.class_name='PassRateWeightedSampler' \ +# - data.sampler.pass_rate_temperature=0.5 \ + +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config_with_resampling.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.sampler.class_path='pkg://verl.utils.pass_rate_weighted_sampler' \ + data.sampler.class_name='PassRateWeightedSampler' \ + data.sampler.pass_rate_temperature=0.5 \ + data.sampler.use_ema=False \ + data.sampler.ema_alpha=0.1 \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=2 \ No newline at end of file diff --git a/verl/trainer/ppo/metric_utils.py b/verl/trainer/ppo/metric_utils.py index 2a62216f..a17b5ee4 100644 --- a/verl/trainer/ppo/metric_utils.py +++ b/verl/trainer/ppo/metric_utils.py @@ -133,6 +133,7 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, score_min = torch.min(non_aborted_sequence_score).detach().item() reward_mean = torch.mean(non_aborted_sequence_reward).detach().item() + reward_std = torch.std(non_aborted_sequence_reward).detach().item() reward_max = torch.max(non_aborted_sequence_reward).detach().item() reward_min = torch.min(non_aborted_sequence_reward).detach().item() @@ -175,6 +176,7 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, "critic/score/min": score_min, # reward "critic/rewards/mean": reward_mean, + "critic/rewards/std": reward_std, "critic/rewards/max": reward_max, "critic/rewards/min": reward_min, # adv diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index c99530e8..ec07668c 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -375,7 +375,11 @@ def __getitem__(self, item): logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"]) row_dict["index"] = index row_dict["tools_kwargs"] = tools_kwargs - row_dict["interaction_kwargs"] = interaction_kwargs + row_dict["interaction_kwargs"] = interaction_kwargs + # add unique dataset index for pass rate tracking + # Note: item is the row number in the dataframe that dataloader is using --- we store is as dataset index for tracking + row_dict["dataset_index"] = item + return row_dict def __getstate__(self): diff --git a/verl/utils/pass_rate_tracker.py b/verl/utils/pass_rate_tracker.py new file mode 100644 index 00000000..ecdf4b08 --- /dev/null +++ b/verl/utils/pass_rate_tracker.py @@ -0,0 +1,118 @@ +""" +Standalone tracker for historical pass rates which can be used with multiple samplers. + + - Tracks success and attempt counts for each sample in the dataset + - Can be used with different sampling strategies (e.g., weighted sampling based on pass rates), + see `PassRateWeightedSampler` for an example +""" + +import numpy as np + +class PassRateTracker: + """ + Tracks pass rates for all samples in the dataset. + Uses dataset indices (0, 1, 2, ..., N-1) as persistent sample IDs. + + This class only tracks pass rates; weighting and sampling strategies are implemented + separately (see `PassRateWeightedSampler` for an example of how pass rates can be + converted into sampling weights). + """ + + def __init__(self, dataset_size: int, use_ema: bool = False, ema_alpha: float = 0.1): + """ + Args: + dataset_size: Total number of samples in the dataset + use_ema: If True, use exponential moving average for pass rates + ema_alpha: EMA smoothing factor (0 to 1). Higher = more weight to recent updates + """ + self.dataset_size = dataset_size + self.use_ema = use_ema + self.ema_alpha = ema_alpha + + # Track stats for each sample index + self.attempt_counts = np.zeros(dataset_size, dtype=np.int32) + + # Initialize pass_rate to -5.0 for all untried samples which forces the model to sample each prompt at least once + # before curriculum based sampling starts + # TODO (Jalaj): Consider changing this strategy to read and write from a file to persist across training runs rather than keeping pass rates in memory + + # Keep both pass_rate and ema_pass_rate to enable future analysis and comparison: + # TODO (Jalaj): Detecting high-variance samples or sudden performance changes and do adaptive weighting strategies based on learning dynamics + self.pass_rate = -5 * np.ones(dataset_size, dtype=np.float16) + self.ema_pass_rate = -5 * np.ones(dataset_size, dtype=np.float16) + + def update(self, sample_indices: np.ndarray, batch_pass_rate: np.ndarray): + """ + Update pass rate statistics for a batch of samples. + + Args: + sample_indices: Array of dataset indices, shape (batch_size,) + batch_pass_rate: Array indicating average pass rate for each sample, shape (batch_size,) + """ + assert len(sample_indices) == len(batch_pass_rate), \ + f"Mismatch: {len(sample_indices)} indices vs {len(batch_pass_rate)} pass rates" + + # Increment attempt count, this can be used for sampling with bandit style algorithms + self.attempt_counts[sample_indices] += 1 + + # Update latest pass rate with this batch's result + self.pass_rate[sample_indices] = batch_pass_rate + + # Update EMA pass rate if enabled + if self.use_ema: + old_ema = self.ema_pass_rate[sample_indices] # Get current EMA values for the batch + first_time_mask = old_ema < 0 # Create mask for first-time updates (negative values) + + # Compute new EMA values: + # For first time: use batch_pass_rate directly + # For subsequent: apply EMA formula + new_ema = np.where( + first_time_mask, + batch_pass_rate, # First time: just use current value + self.ema_alpha * batch_pass_rate + (1 - self.ema_alpha) * old_ema # EMA update + ) + + # Update all at once + self.ema_pass_rate[sample_indices] = new_ema + + def get_pass_rates(self) -> np.ndarray: + """ + Compute pass rates for all samples. For now, just use the historical pass rates (and optionally an + exponential moving average); can be extended in many different ways + + Returns: + Array of shape (dataset_size,) with pass rates in [0, 1] or [-5, 0] for untried. + Uses EMA if self.use_ema=True, otherwise returns latest pass rates. + """ + if self.use_ema: + return self.ema_pass_rate + else: + return self.pass_rate + + def get_count_distribution_statistics(self) -> dict: + """ + Return metrics for monitoring how count distribution changes over training steps + + Returns: + Dict with percentiles of `attempt_counts` distribution + """ + return { + f'count_p{p}': float(np.percentile(self.attempt_counts, p)) + for p in range(5, 100, 5) + } + + # For saving pass_rate_tracker state when saving training checkpoint + def state_dict(self) -> dict: + """Return state for checkpointing.""" + return { + 'attempt_counts': self.attempt_counts.copy(), + 'pass_rate': self.pass_rate.copy(), + 'ema_pass_rate': self.ema_pass_rate.copy(), + } + + # For loading pass_rate_tracker state when resuming training from checkpoint + def load_state_dict(self, state_dict: dict): + """Load state from checkpoint.""" + self.attempt_counts = state_dict['attempt_counts'].copy() + self.pass_rate = state_dict['pass_rate'].copy() + self.ema_pass_rate = state_dict['ema_pass_rate'].copy() \ No newline at end of file diff --git a/verl/utils/pass_rate_weighted_sampler.py b/verl/utils/pass_rate_weighted_sampler.py new file mode 100644 index 00000000..5c0136ef --- /dev/null +++ b/verl/utils/pass_rate_weighted_sampler.py @@ -0,0 +1,156 @@ +""" +Weighted Sampler for curriculum learning. + +Uses PassRateTracker to compute dynamic sampling weights based on pass rates. +""" + +import numpy as np +from omegaconf import DictConfig + +from verl.experimental.dataset.sampler import AbstractSampler +from verl.utils.pass_rate_tracker import PassRateTracker + + +class PassRateWeightedSampler(AbstractSampler): + """ + Weighted sampler that uses pass rates to adjust sampling probabilities. + + Implements curriculum learning by dynamically adjusting sampling weights + based on per-sample success rates. Samples with lower pass rates are sampled more frequently. + """ + + def __init__(self, data_source, data_config: DictConfig): + """ + Args: + data_source: The dataset object (Sized) + data_config: Configuration dictionary containing the entire data config + """ + self.data_source = data_source + self.data_config = data_config + self.dataset_size = len(data_source) + + # Temperature parameter for controlling the sharpness of the weighting distribution (from sampler config) + # - temperature < 1.0: Sharp (hard samples dominate) + # - temperature = 1.0: Balanced + # - temperature > 1.0: Soft (nearly uniform) + self.temperature = data_config.sampler.get("pass_rate_temperature", 1.0) + use_ema = data_config.sampler.get("use_ema", False) + ema_alpha = data_config.sampler.get("ema_alpha", 0.1) + + # Create tracker for this dataset: set `use_ema=True` for exponential moving average pass rates + self.pass_rate_tracker = PassRateTracker(dataset_size=self.dataset_size, use_ema=use_ema, ema_alpha=ema_alpha) + self._cached_weights = None # Cache for weights to avoid recomputation each iteration + self._last_pass_rate = None # Track pass rate in the previous training step + self._update_tolerance = 0.01 # Tolerance for detecting significant pass rate changes which require recomputing the weight vector + + def __len__(self): + return self.dataset_size + + def get_weights(self) -> np.ndarray: + """ + We can add different weighting strategies here. Todo (Jalaj): add an additional argument to select strategy + + Current strategy: compute sampling weights inversely proportional to pass rates. + - Untried samples (pass_rate=-5.0): weight = exp(5.0/temperature) -- baseline + - Tried but failing (pass_rate≈0): weight = exp(0/temperature) = 1.0 -- highest priority after trying atleast once + - Tried and succeeding (pass_rate>0): weight = exp(-pass_rate/temperature) -- lower priority + + Returns: + Array of shape (dataset_size,) with unnormalized sampling weights + """ + pass_rates_current_train_step = self.pass_rate_tracker.get_pass_rates() + # Check if we need to recompute (vectorized comparison) + if self._cached_weights is None or self._last_pass_rate is None: + needs_update = True + else: + max_change = np.abs(pass_rates_current_train_step - self._last_pass_rate).max() + needs_update = max_change > self._update_tolerance + + if needs_update: + # ------ Weight inversely proportional to pass rate --------- + + ## Option 1: weights = np.power(1.0 - pass_rates, 1.0 / max(temperature, 0.01)) + # weights = np.exp(1 - pass_rates / max(self.temperature, 0.01)) + ## Stable implementation using log-exp trick + # log_weights = (1.0/max(temperature, 0.01)) * np.log(1.0 - pass_rates + 1e-10) # log(1 - p) + # weights = np.exp(log_weights - log_weights.max()) # stable softmax exp(-(y - max_y)/temperature) + + # Option 2: negative exponential scaling + x = -pass_rates_current_train_step / max(self.temperature, 0.01) + # weights = np.exp(x) + self._cached_weights = np.exp(x - x.max()) # stable softmax exp(-(y - max_y)/temperature) + self._last_pass_rate = pass_rates_current_train_step.copy() + + return self._cached_weights + + def __iter__(self): + """ + Generate indices for one epoch using current pass rate weights. + """ + # Get current weights from tracker + weights = self.get_weights() + + # Sample with replacement using weights + # TODO: change this to make it scalable for large datasets + indices = np.random.choice( + self.dataset_size, + size=self.dataset_size, + replace=True, + p=weights / weights.sum() # Normalize to probability + ) + + return iter(indices) + + def get_weight_distribution_statistics(self) -> dict: + """ + Return metrics for monitoring how weight distribution change over training steps + + Returns: + Dict with percentiles of weight distributions + """ + weights = self.get_weights() + return { + f'weight_p{p}': float(np.percentile(weights, p)) + for p in range(5, 100, 5) + } + + def get_wandb_3d_plot_data(self, metric_type: str = 'weight') -> list: + """ + Prepare data for W&B 3D plot: percentiles (x), values (y), step (z). + + Args: + metric_type: 'weight' or 'count' + + Returns: + List of dicts with percentile, value for 3D plotting + """ + if metric_type == 'weight': + stats = self.get_weight_distribution_statistics() + elif metric_type == 'count': + stats = self.pass_rate_tracker.get_count_distribution_statistics() # get count stats from tracker + else: + raise ValueError(f"metric_type must be 'weight' or 'count', got {metric_type}") + + # Build data list directly from percentiles + return [ + { + 'percentile': p, + 'percentile_name': f'p{p}', + 'value': stats[f'{metric_type}_p{p}'], + } + for p in range(5, 100, 5) + ] + + def state_dict(self) -> dict: + """ + Return state for checkpointing. + Includes the pass rate tracker state so it can be restored on resume. + """ + return self.pass_rate_tracker.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + """ + Load state from checkpoint. + Restores the pass rate tracker state. + """ + self.pass_rate_tracker.load_state_dict(state_dict)