From 563fbddf979faf915f10ed2a4b49cd8aaff0be3b Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 26 Mar 2026 21:05:38 +0800 Subject: [PATCH 01/14] support dpo --- cookbook/rl/dpo.py | 267 +++++++++++ cookbook/rl/dpo.sh | 84 ++++ src/twinkle/loss/__init__.py | 6 + src/twinkle/loss/dpo.py | 655 +++++++++++++++++++++++++++ src/twinkle/preprocessor/__init__.py | 1 + src/twinkle/preprocessor/dpo.py | 387 ++++++++++++++++ 6 files changed, 1400 insertions(+) create mode 100644 cookbook/rl/dpo.py create mode 100644 cookbook/rl/dpo.sh create mode 100644 src/twinkle/loss/dpo.py create mode 100644 src/twinkle/preprocessor/dpo.py diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py new file mode 100644 index 00000000..1e03f84b --- /dev/null +++ b/cookbook/rl/dpo.py @@ -0,0 +1,267 @@ +"""DPO (Direct Preference Optimization) Training via Ray. + +Off-policy preference alignment: trains the model to prefer chosen responses +over rejected responses using preference data, without explicit reward modeling. + +Pipeline: + 1. Load preference dataset with chosen/rejected pairs. + 2. Compute reference model log probabilities (frozen). + 3. Train policy model using DPO loss. + +Architecture (Ray): + ┌─────────────────────────────────────────────────────────────────┐ + │ Driver (CPU) │ + │ dataloader ──► batched preference pairs │ + │ ref_model.forward_only() ──► reference log probs │ + │ policy_model.forward_backward() ──► DPO loss + gradient │ + └─────────────────────────────────────────────────────────────────┘ + │ │ │ + DataLoader RefModel (frozen) PolicyModel (trainable) + (ref GPUs) (policy GPUs) + +For SimPO/ORPO variants that don't require a reference model, +set USE_REFERENCE_MODEL=0 to skip reference model computation. + +Environment variables (all optional): + MODEL_ID – (default: ms://Qwen/Qwen3.5-4B) + DATASET_ID – (default: ms://argilla/ultrafeedback-binarized-preferences-cleaned) + MODEL_GPUS – GPUs for policy model (default: 4) + REF_MODEL_GPUS – GPUs for reference model (default: 4, 0 to disable) + USE_REFERENCE_MODEL – Whether to use reference model (default: 1) + BATCH_SIZE – global batch size (pairs) (default: 8) + MICRO_BATCH_SIZE – per-device micro batch size (default: 2) + MAX_STEPS – total optimization steps (default: 1000) + LR – learning rate (default: 5e-6) + DPO_BETA – DPO temperature parameter (default: 0.1) + LOSS_TYPE – DPO variant (sigmoid/hinge/ipo/simpo/orpo/cpo) (default: sigmoid) + SAVE_STEPS – checkpoint save interval (default: 100) + MAX_LENGTH – max sequence length (default: 2048) +""" + +import os +from typing import Any, Dict, List, Optional + +import torch +from peft import LoraConfig + +import twinkle +from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss +from twinkle.model import TransformersModel +from twinkle.preprocessor import DPOProcessor +from twinkle.processor import InputProcessor +from twinkle.template import Template + +logger = get_logger() + +# ── Configuration ───────────────────────────────────────────────────────────── +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') +DATASET_ID = os.environ.get('DATASET_ID', 'ms://argilla/ultrafeedback-binarized-preferences-cleaned') + +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4)) +USE_REFERENCE_MODEL = bool(int(os.environ.get('USE_REFERENCE_MODEL', 1))) + +# Adjust total GPUs based on whether reference model is used +if USE_REFERENCE_MODEL and REF_MODEL_GPUS > 0: + NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS +else: + NUM_GPUS = MODEL_GPUS + USE_REFERENCE_MODEL = False + +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) +MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) +LEARNING_RATE = float(os.environ.get('LR', 5e-6)) +DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) +LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo, simpo, orpo, cpo +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100)) +MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) +ADAPTER_NAME = 'default' + + +# ── Dataset ─────────────────────────────────────────────────────────────────── + +def create_dpo_dataset(): + """Create preference dataset for DPO training. + + The dataset should contain 'chosen' and 'rejected' columns after preprocessing. + Each sample will be duplicated: first the chosen, then the rejected version. + """ + dataset = Dataset(DatasetMeta(DATASET_ID, split='train')) + dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) + + # Use DPOProcessor to convert dataset to standard format + # Adjust processor based on your dataset format + dataset.map(DPOProcessor( + system='You are a helpful, harmless, and honest assistant.', + chosen_key='chosen', + rejected_key='rejected', + prompt_key='prompt', + )) + + # Encode both chosen and rejected trajectories + dataset.encode() + return dataset + + +def collate_preference_batch(batch: List[Dict[str, Any]]) -> Dict[str, List]: + """Collate preference pairs into DPO batch format. + + DPO loss expects: [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + """ + chosen_samples = [] + rejected_samples = [] + + for item in batch: + if 'chosen' in item and 'rejected' in item: + chosen_samples.append(item['chosen']) + rejected_samples.append(item['rejected']) + else: + # Assume alternating format if not explicitly separated + chosen_samples.append(item) + + # Concatenate: all chosen first, then all rejected + return chosen_samples + rejected_samples + + +# ── Loss Factory ────────────────────────────────────────────────────────────── + +def create_loss(loss_type: str, beta: float, reference_free: bool = False): + """Create the appropriate loss function based on configuration.""" + if loss_type == 'simpo': + return SimPOLoss(beta=beta, gamma=0.5) + elif loss_type == 'orpo': + return ORPOLoss(lambda_orpo=beta) + elif loss_type == 'cpo': + return CPOLoss(beta=beta, bc_coef=1.0) + else: + # Standard DPO variants: sigmoid, hinge, ipo + return DPOLoss( + beta=beta, + loss_type=loss_type, + reference_free=reference_free, + ) + + +# ── Main Training Loop ──────────────────────────────────────────────────────── + +def main(): + # Set up device groups + if USE_REFERENCE_MODEL: + device_groups = [ + DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + DeviceGroup(name='reference', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'), + ] + else: + device_groups = [ + DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + ] + + policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) + + # ── Policy Model Setup ──────────────────────────────────────────────────── + lora_config = LoraConfig( + target_modules=[ + 'q_proj', 'k_proj', 'v_proj', 'o_proj', + 'gate_proj', 'up_proj', 'down_proj', + ], + r=16, + lora_alpha=32, + lora_dropout=0.05, + ) + + policy_model = TransformersModel( + model_id=MODEL_ID, + device_mesh=policy_mesh, + remote_group='policy', + ) + policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01) + policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1) + + # Determine if we need reference model based on loss type + reference_free = LOSS_TYPE in ['simpo', 'orpo', 'cpo'] + + # Set up loss function + loss_fn = create_loss(LOSS_TYPE, DPO_BETA, reference_free=not USE_REFERENCE_MODEL) + policy_model.set_loss(loss_fn) + policy_model.set_processor(InputProcessor) + policy_model.set_template('Template', model_id=MODEL_ID) + + # ── Reference Model Setup (if needed) ───────────────────────────────────── + ref_model = None + if USE_REFERENCE_MODEL and not reference_free: + ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, dp_size=REF_MODEL_GPUS) + ref_model = TransformersModel( + model_id=MODEL_ID, + device_mesh=ref_mesh, + remote_group='reference', + ) + ref_model.set_processor(InputProcessor) + ref_model.set_template('Template', model_id=MODEL_ID) + logger.info('Reference model initialized for DPO training') + else: + logger.info(f'Training without reference model (loss_type={LOSS_TYPE})') + + # ── DataLoader Setup ────────────────────────────────────────────────────── + GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS + dataloader = DataLoader( + dataset=create_dpo_dataset, + batch_size=GLOBAL_BATCH_SIZE, + min_batch_size=GLOBAL_BATCH_SIZE, + device_mesh=policy_mesh, + remote_group='policy', + ) + + optim_step = 0 + logger.info(get_device_placement()) + logger.info(f'Starting DPO training: loss_type={LOSS_TYPE}, beta={DPO_BETA}, ' + f'use_ref_model={USE_REFERENCE_MODEL}') + + # ── Training Loop ───────────────────────────────────────────────────────── + for batch in dataloader: + if optim_step >= MAX_STEPS: + break + + # Collate preference pairs: [chosen..., rejected...] + preference_batch = collate_preference_batch(batch if isinstance(batch, list) else [batch]) + + # Compute reference log probabilities if using reference model + ref_logps = None + if ref_model is not None: + with torch.no_grad(): + ref_outputs = ref_model.forward_only(inputs=preference_batch) + ref_logps = ref_outputs.get('logps') + + # Forward-backward pass with DPO loss + policy_model.forward_backward( + inputs=preference_batch, + ref_logps=ref_logps, + micro_batch_size=MICRO_BATCH_SIZE, + ) + + # Gradient clipping and optimizer step + policy_model.clip_grad_and_step() + optim_step += 1 + + # Logging + if optim_step % 10 == 0: + metrics = policy_model.calculate_metric(is_training=True) + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metrics}') + + # Checkpointing + if optim_step % SAVE_STEPS == 0: + policy_model.save(f'dpo-checkpoint-{optim_step}') + + # ── Save Final Checkpoint ───────────────────────────────────────────────── + logger.info(f'Training completed. Total steps: {optim_step}') + policy_model.save('dpo-final-checkpoint') + + +if __name__ == '__main__': + main() diff --git a/cookbook/rl/dpo.sh b/cookbook/rl/dpo.sh new file mode 100644 index 00000000..65206839 --- /dev/null +++ b/cookbook/rl/dpo.sh @@ -0,0 +1,84 @@ +#!/bin/bash +# DPO Training Script for Ray Mode +# +# This script launches DPO (Direct Preference Optimization) training using Ray +# for distributed training across multiple GPUs. +# +# Usage: +# ./dpo.sh # Default settings (8 GPUs: 4 policy + 4 ref) +# ./dpo.sh simpo # Use SimPO (no reference model needed) +# ./dpo.sh orpo # Use ORPO (no reference model needed) +# +# Environment variables can be set to customize training: +# MODEL_ID - Model to train (default: ms://Qwen/Qwen3.5-4B) +# DATASET_ID - Preference dataset (default: UltraFeedback) +# MODEL_GPUS - GPUs for policy model (default: 4) +# REF_MODEL_GPUS - GPUs for reference model (default: 4) +# USE_REFERENCE_MODEL - Use reference model (default: 1) +# BATCH_SIZE - Global batch size (default: 8) +# MAX_STEPS - Training steps (default: 1000) +# LR - Learning rate (default: 5e-6) +# DPO_BETA - DPO beta parameter (default: 0.1) +# LOSS_TYPE - Loss variant: sigmoid/hinge/ipo/simpo/orpo/cpo (default: sigmoid) + +set -e + +# Parse command line argument for loss type +LOSS_TYPE_ARG=${1:-sigmoid} + +# Set default environment variables if not already set +export MODEL_ID=${MODEL_ID:-"ms://Qwen/Qwen3.5-4B"} +export DATASET_ID=${DATASET_ID:-"ms://argilla/ultrafeedback-binarized-preferences-cleaned"} +export MODEL_GPUS=${MODEL_GPUS:-4} +export BATCH_SIZE=${BATCH_SIZE:-8} +export MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-2} +export MAX_STEPS=${MAX_STEPS:-1000} +export LR=${LR:-5e-6} +export DPO_BETA=${DPO_BETA:-0.1} +export SAVE_STEPS=${SAVE_STEPS:-100} +export MAX_LENGTH=${MAX_LENGTH:-2048} + +# Set loss type from argument or environment +export LOSS_TYPE=${LOSS_TYPE:-$LOSS_TYPE_ARG} + +# Reference-free losses don't need reference model +if [[ "$LOSS_TYPE" == "simpo" || "$LOSS_TYPE" == "orpo" || "$LOSS_TYPE" == "cpo" ]]; then + export USE_REFERENCE_MODEL=${USE_REFERENCE_MODEL:-0} + export REF_MODEL_GPUS=${REF_MODEL_GPUS:-0} + echo "Using $LOSS_TYPE loss (reference-free)" +else + export USE_REFERENCE_MODEL=${USE_REFERENCE_MODEL:-1} + export REF_MODEL_GPUS=${REF_MODEL_GPUS:-4} + echo "Using $LOSS_TYPE loss with reference model" +fi + +# Calculate total GPUs +if [[ "$USE_REFERENCE_MODEL" == "1" && "$REF_MODEL_GPUS" -gt 0 ]]; then + TOTAL_GPUS=$((MODEL_GPUS + REF_MODEL_GPUS)) +else + TOTAL_GPUS=$MODEL_GPUS +fi + +echo "==========================================" +echo "DPO Training Configuration" +echo "==========================================" +echo "Model: $MODEL_ID" +echo "Dataset: $DATASET_ID" +echo "Loss Type: $LOSS_TYPE" +echo "DPO Beta: $DPO_BETA" +echo "Policy GPUs: $MODEL_GPUS" +echo "Reference GPUs: $REF_MODEL_GPUS" +echo "Total GPUs: $TOTAL_GPUS" +echo "Batch Size: $BATCH_SIZE" +echo "Micro Batch Size: $MICRO_BATCH_SIZE" +echo "Max Steps: $MAX_STEPS" +echo "Learning Rate: $LR" +echo "Max Length: $MAX_LENGTH" +echo "Save Steps: $SAVE_STEPS" +echo "==========================================" + +# Get script directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Run training +python "$SCRIPT_DIR/dpo.py" diff --git a/src/twinkle/loss/__init__.py b/src/twinkle/loss/__init__.py index 7870f5a4..4e4d0e82 100644 --- a/src/twinkle/loss/__init__.py +++ b/src/twinkle/loss/__init__.py @@ -2,6 +2,7 @@ from .base import Loss from .chunked_cross_entropy import ChunkedCrossEntropyLoss from .cross_entropy import CrossEntropyLoss +from .dpo import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss from .gkd import GKDLoss from .grpo import BNPOLoss, CISPOLoss, DRGRPOLoss, GRPOLoss, GSPOLoss, SAPOLoss from .mse import MSELoss @@ -19,4 +20,9 @@ 'cispo': CISPOLoss, 'bnpo': BNPOLoss, 'dr_grpo': DRGRPOLoss, + # DPO family losses + 'dpo': DPOLoss, + 'simpo': SimPOLoss, + 'cpo': CPOLoss, + 'orpo': ORPOLoss, } diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py new file mode 100644 index 00000000..0f3e364a --- /dev/null +++ b/src/twinkle/loss/dpo.py @@ -0,0 +1,655 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +DPO (Direct Preference Optimization) Loss Implementation. + +Reference: + "Direct Preference Optimization: Your Language Model is Secretly a Reward Model" + (https://arxiv.org/abs/2305.18290) +""" +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +import numpy as np + +from twinkle.data_format import LossOutput +from twinkle.kernel import selective_log_softmax +from twinkle.loss.base import Loss + +if TYPE_CHECKING: + import torch + + +class DPOLoss(Loss): + """Direct Preference Optimization (DPO) Loss. + + DPO directly optimizes the policy using preference data without explicit reward modeling. + The loss function is derived from the Bradley-Terry preference model: + + L_DPO = -log(σ(β * (log π(y_w|x)/π_ref(y_w|x) - log π(y_l|x)/π_ref(y_l|x)))) + + where: + - y_w is the preferred (chosen) response + - y_l is the dispreferred (rejected) response + - β is the temperature parameter controlling deviation from reference + - π is the current policy + - π_ref is the reference policy (frozen) + + Args: + beta: Temperature parameter controlling how much to deviate from ref policy (default: 0.1). + label_smoothing: Label smoothing parameter for soft labels (default: 0.0). + ignore_index: Index to ignore in labels (default: -100). + loss_type: Type of DPO loss variant ('sigmoid', 'hinge', 'ipo', 'kto_pair') (default: 'sigmoid'). + reference_free: Whether to use reference-free DPO (default: False). + """ + + def __init__( + self, + beta: float = 0.1, + label_smoothing: float = 0.0, + ignore_index: int = -100, + loss_type: str = 'sigmoid', + reference_free: bool = False, + **kwargs, + ): + self.beta = beta + self.label_smoothing = label_smoothing + self.ignore_index = ignore_index + self.loss_type = loss_type + self.reference_free = reference_free + + def _compute_logps_from_logits( + self, + logits: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute per-token log probabilities from logits. + + Args: + logits: [batch, seq_len, vocab_size] model logits + labels: [batch, seq_len] target token ids + + Returns: + logps: [batch, seq_len] per-token log probabilities + """ + loss_mask = (labels != self.ignore_index).bool() + masked_labels = labels.clone() + masked_labels[~loss_mask] = 0 + logps = selective_log_softmax(logits, masked_labels) + return logps + + def _compute_sequence_logps( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute sequence-level log probabilities by summing valid token logps. + + Args: + per_token_logps: [batch, seq_len] per-token log probabilities + labels: [batch, seq_len] labels for computing mask + + Returns: + seq_logps: [batch] sequence-level log probabilities + """ + loss_mask = (labels != self.ignore_index).float() + return (per_token_logps * loss_mask).sum(dim=-1) + + def _pad_and_align_logps( + self, + logps: Union['torch.Tensor', List[List[float]]], + target_shape: tuple, + loss_mask: 'torch.Tensor', + device: 'torch.device', + dtype: 'torch.dtype', + ) -> 'torch.Tensor': + """Pad and align log probabilities to target shape. + + Args: + logps: Input log probabilities (tensor or ragged list) + target_shape: Target (batch, seq_len) shape + loss_mask: Boolean mask for valid positions + device: Target device + dtype: Target dtype + + Returns: + Aligned tensor of shape target_shape + """ + import torch + + if torch.is_tensor(logps): + if logps.shape == target_shape: + return logps.to(device=device, dtype=dtype) + elif logps.dim() == 1: + logps = logps.unsqueeze(0) + if logps.shape == target_shape: + return logps.to(device=device, dtype=dtype) + + # Handle ragged list input + if isinstance(logps, (list, tuple)): + batch_size, seq_len = target_shape + padded = torch.zeros(target_shape, device=device, dtype=dtype) + for i, row in enumerate(logps): + if row is None: + continue + row_t = torch.as_tensor(row, device=device, dtype=dtype) + valid_positions = loss_mask[i].nonzero(as_tuple=True)[0] + length = min(len(row_t), len(valid_positions)) + if length > 0: + padded[i, valid_positions[:length]] = row_t[:length] + return padded + + return logps.to(device=device, dtype=dtype) + + def _compute_dpo_loss( + self, + policy_chosen_logps: 'torch.Tensor', + policy_rejected_logps: 'torch.Tensor', + reference_chosen_logps: 'torch.Tensor', + reference_rejected_logps: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute the DPO loss. + + Args: + policy_chosen_logps: [batch/2] log probs of chosen under current policy + policy_rejected_logps: [batch/2] log probs of rejected under current policy + reference_chosen_logps: [batch/2] log probs of chosen under reference policy + reference_rejected_logps: [batch/2] log probs of rejected under reference policy + + Returns: + loss: Scalar DPO loss + """ + import torch + import torch.nn.functional as F + + # Compute log ratios + if self.reference_free: + # Reference-free: only use policy log probs + chosen_logratios = policy_chosen_logps + rejected_logratios = policy_rejected_logps + else: + chosen_logratios = policy_chosen_logps - reference_chosen_logps + rejected_logratios = policy_rejected_logps - reference_rejected_logps + + # Compute preference margin + logits = self.beta * (chosen_logratios - rejected_logratios) + + if self.loss_type == 'sigmoid': + # Standard DPO loss: -log(sigmoid(beta * margin)) + losses = -F.logsigmoid(logits) + elif self.loss_type == 'hinge': + # Hinge loss variant + losses = torch.relu(1 - logits) + elif self.loss_type == 'ipo': + # IPO (Identity Preference Optimization) loss + # Reference: "A General Theoretical Paradigm to Understand Learning from Human Feedback" + losses = (logits - 1 / (2 * self.beta)) ** 2 + elif self.loss_type == 'kto_pair': + # KTO pair loss (simplified version) + chosen_logratios_scaled = self.beta * chosen_logratios + rejected_logratios_scaled = self.beta * rejected_logratios + chosen_losses = 1 - F.sigmoid(chosen_logratios_scaled) + rejected_losses = F.sigmoid(rejected_logratios_scaled) + losses = chosen_losses + rejected_losses + else: + raise ValueError(f"Unknown loss_type: {self.loss_type}") + + # Apply label smoothing if specified + if self.label_smoothing > 0: + # Soft labels: (1 - eps) * loss_chosen + eps * loss_rejected + smooth_losses = -F.logsigmoid(-logits) # Loss for flipped preference + losses = (1 - self.label_smoothing) * losses + self.label_smoothing * smooth_losses + + return losses.mean() + + def __call__( + self, + inputs: Dict, + outputs: Dict, + *, + ref_logps: Optional[Union['torch.Tensor', List[List[float]]]] = None, + ref_chosen_logps: Optional['torch.Tensor'] = None, + ref_rejected_logps: Optional['torch.Tensor'] = None, + **kwargs, + ) -> LossOutput: + """Compute DPO loss. + + The inputs should contain concatenated chosen and rejected examples: + - First half of batch: chosen responses + - Second half of batch: rejected responses + + Args: + inputs: Dict containing 'input_ids' and 'labels' [batch, seq_len]. + Batch should be organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + outputs: Dict containing either: + - 'logps': [batch, seq_len] pre-computed log probs, OR + - 'logits': [batch, seq_len, vocab] from which logps will be computed + ref_logps: [batch, seq_len] or List[List[float]] reference model log probs. + Can also be provided as separate ref_chosen_logps and ref_rejected_logps. + ref_chosen_logps: [batch/2] pre-computed reference log probs for chosen. + ref_rejected_logps: [batch/2] pre-computed reference log probs for rejected. + **kwargs: Additional arguments. + + Returns: + LossOutput with DPO loss and metrics. + """ + import torch + + labels = inputs.get('labels') + assert labels is not None, "inputs must contain 'labels'" + if not torch.is_tensor(labels): + labels = torch.as_tensor(labels) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + batch_size = labels.shape[0] + assert batch_size % 2 == 0, "Batch size must be even (chosen + rejected pairs)" + half_batch = batch_size // 2 + + # Get log probabilities from outputs + logps = outputs.get('logps') + if logps is None: + logits = outputs.get('logits') + assert logits is not None, "outputs must contain 'logps' or 'logits'" + if logits.shape[1] != labels.shape[1]: + logits = logits[:, -labels.shape[1]:] + logps = self._compute_logps_from_logits(logits, labels) + + device = logps.device + dtype = logps.dtype + + # Split into chosen and rejected + chosen_labels = labels[:half_batch] + rejected_labels = labels[half_batch:] + chosen_logps = logps[:half_batch] + rejected_logps = logps[half_batch:] + + # Compute sequence-level log probs for policy + policy_chosen_logps = self._compute_sequence_logps(chosen_logps, chosen_labels) + policy_rejected_logps = self._compute_sequence_logps(rejected_logps, rejected_labels) + + # Handle reference log probs + if ref_chosen_logps is not None and ref_rejected_logps is not None: + # Pre-computed sequence-level reference log probs provided + reference_chosen_logps = ref_chosen_logps.to(device=device, dtype=dtype) + reference_rejected_logps = ref_rejected_logps.to(device=device, dtype=dtype) + elif ref_logps is not None: + # Per-token reference log probs provided, need to align and sum + loss_mask = (labels != self.ignore_index).bool() + ref_logps_aligned = self._pad_and_align_logps( + ref_logps, labels.shape, loss_mask, device, dtype + ) + ref_chosen = ref_logps_aligned[:half_batch] + ref_rejected = ref_logps_aligned[half_batch:] + reference_chosen_logps = self._compute_sequence_logps(ref_chosen, chosen_labels) + reference_rejected_logps = self._compute_sequence_logps(ref_rejected, rejected_labels) + elif self.reference_free: + # Reference-free mode: no reference model needed + reference_chosen_logps = torch.zeros_like(policy_chosen_logps) + reference_rejected_logps = torch.zeros_like(policy_rejected_logps) + else: + raise ValueError( + "ref_logps or (ref_chosen_logps, ref_rejected_logps) must be provided " + "unless reference_free=True" + ) + + # Compute DPO loss + loss = self._compute_dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + + return LossOutput(loss=loss, num_tokens=0) + + +class SimPOLoss(Loss): + """SimPO (Simple Preference Optimization) Loss. + + SimPO is a simpler variant of DPO that doesn't require a reference model. + It uses length-normalized log probabilities. + + Reference: + "SimPO: Simple Preference Optimization with a Reference-Free Reward" + (https://arxiv.org/abs/2405.14734) + + Args: + beta: Temperature parameter (default: 2.5). + gamma: Target reward margin (default: 0.5). + ignore_index: Index to ignore in labels (default: -100). + """ + + def __init__( + self, + beta: float = 2.5, + gamma: float = 0.5, + ignore_index: int = -100, + **kwargs, + ): + self.beta = beta + self.gamma = gamma + self.ignore_index = ignore_index + + def _compute_logps_from_logits( + self, + logits: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute per-token log probabilities from logits.""" + loss_mask = (labels != self.ignore_index).bool() + masked_labels = labels.clone() + masked_labels[~loss_mask] = 0 + logps = selective_log_softmax(logits, masked_labels) + return logps + + def _compute_length_normalized_logps( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute length-normalized sequence log probabilities. + + Args: + per_token_logps: [batch, seq_len] per-token log probabilities + labels: [batch, seq_len] labels for computing mask + + Returns: + normalized_logps: [batch] length-normalized log probabilities + """ + loss_mask = (labels != self.ignore_index).float() + seq_lengths = loss_mask.sum(dim=-1).clamp(min=1) + seq_logps = (per_token_logps * loss_mask).sum(dim=-1) + return seq_logps / seq_lengths + + def __call__( + self, + inputs: Dict, + outputs: Dict, + **kwargs, + ) -> LossOutput: + """Compute SimPO loss. + + Args: + inputs: Dict containing 'input_ids' and 'labels' [batch, seq_len]. + Batch: [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + outputs: Dict containing 'logps' or 'logits'. + + Returns: + LossOutput with SimPO loss. + """ + import torch + import torch.nn.functional as F + + labels = inputs.get('labels') + assert labels is not None, "inputs must contain 'labels'" + if not torch.is_tensor(labels): + labels = torch.as_tensor(labels) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + batch_size = labels.shape[0] + assert batch_size % 2 == 0, "Batch size must be even (chosen + rejected pairs)" + half_batch = batch_size // 2 + + # Get log probabilities + logps = outputs.get('logps') + if logps is None: + logits = outputs.get('logits') + assert logits is not None, "outputs must contain 'logps' or 'logits'" + if logits.shape[1] != labels.shape[1]: + logits = logits[:, -labels.shape[1]:] + logps = self._compute_logps_from_logits(logits, labels) + + # Split into chosen and rejected + chosen_labels = labels[:half_batch] + rejected_labels = labels[half_batch:] + chosen_logps = logps[:half_batch] + rejected_logps = logps[half_batch:] + + # Compute length-normalized log probs + chosen_rewards = self._compute_length_normalized_logps(chosen_logps, chosen_labels) + rejected_rewards = self._compute_length_normalized_logps(rejected_logps, rejected_labels) + + # SimPO loss: -log(sigmoid(beta * (r_w - r_l) - gamma)) + logits = self.beta * (chosen_rewards - rejected_rewards) - self.gamma + loss = -F.logsigmoid(logits).mean() + + return LossOutput(loss=loss, num_tokens=0) + + +class CPOLoss(Loss): + """CPO (Contrastive Preference Optimization) Loss. + + CPO adds a behavior cloning term to preference optimization. + + Reference: + "Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation" + (https://arxiv.org/abs/2401.08417) + + Args: + beta: Temperature parameter for preference (default: 0.1). + bc_coef: Behavior cloning coefficient (default: 1.0). + ignore_index: Index to ignore in labels (default: -100). + """ + + def __init__( + self, + beta: float = 0.1, + bc_coef: float = 1.0, + ignore_index: int = -100, + **kwargs, + ): + self.beta = beta + self.bc_coef = bc_coef + self.ignore_index = ignore_index + + def _compute_logps_from_logits( + self, + logits: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute per-token log probabilities from logits.""" + loss_mask = (labels != self.ignore_index).bool() + masked_labels = labels.clone() + masked_labels[~loss_mask] = 0 + logps = selective_log_softmax(logits, masked_labels) + return logps + + def _compute_sequence_logps( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute sequence-level log probabilities.""" + loss_mask = (labels != self.ignore_index).float() + return (per_token_logps * loss_mask).sum(dim=-1) + + def _compute_nll_loss( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute negative log likelihood loss for chosen responses.""" + loss_mask = (labels != self.ignore_index).float() + nll = -(per_token_logps * loss_mask).sum() / loss_mask.sum().clamp(min=1) + return nll + + def __call__( + self, + inputs: Dict, + outputs: Dict, + **kwargs, + ) -> LossOutput: + """Compute CPO loss. + + Args: + inputs: Dict containing 'labels' [batch, seq_len]. + outputs: Dict containing 'logps' or 'logits'. + + Returns: + LossOutput with CPO loss. + """ + import torch + import torch.nn.functional as F + + labels = inputs.get('labels') + assert labels is not None, "inputs must contain 'labels'" + if not torch.is_tensor(labels): + labels = torch.as_tensor(labels) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + batch_size = labels.shape[0] + assert batch_size % 2 == 0, "Batch size must be even" + half_batch = batch_size // 2 + + # Get log probabilities + logps = outputs.get('logps') + if logps is None: + logits = outputs.get('logits') + assert logits is not None, "outputs must contain 'logps' or 'logits'" + if logits.shape[1] != labels.shape[1]: + logits = logits[:, -labels.shape[1]:] + logps = self._compute_logps_from_logits(logits, labels) + + # Split into chosen and rejected + chosen_labels = labels[:half_batch] + rejected_labels = labels[half_batch:] + chosen_logps = logps[:half_batch] + rejected_logps = logps[half_batch:] + + # Compute sequence-level log probs + chosen_seq_logps = self._compute_sequence_logps(chosen_logps, chosen_labels) + rejected_seq_logps = self._compute_sequence_logps(rejected_logps, rejected_labels) + + # Preference loss (reference-free DPO) + logits = self.beta * (chosen_seq_logps - rejected_seq_logps) + preference_loss = -F.logsigmoid(logits).mean() + + # Behavior cloning loss on chosen + bc_loss = self._compute_nll_loss(chosen_logps, chosen_labels) + + # Combined loss + loss = preference_loss + self.bc_coef * bc_loss + + return LossOutput(loss=loss, num_tokens=0) + + +class ORPOLoss(Loss): + """ORPO (Odds Ratio Preference Optimization) Loss. + + ORPO combines SFT and preference alignment in a single objective using odds ratios. + + Reference: + "ORPO: Monolithic Preference Optimization without Reference Model" + (https://arxiv.org/abs/2403.07691) + + Args: + lambda_orpo: Weight for the odds ratio term (default: 0.1). + ignore_index: Index to ignore in labels (default: -100). + """ + + def __init__( + self, + lambda_orpo: float = 0.1, + ignore_index: int = -100, + **kwargs, + ): + self.lambda_orpo = lambda_orpo + self.ignore_index = ignore_index + + def _compute_logps_from_logits( + self, + logits: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute per-token log probabilities from logits.""" + loss_mask = (labels != self.ignore_index).bool() + masked_labels = labels.clone() + masked_labels[~loss_mask] = 0 + logps = selective_log_softmax(logits, masked_labels) + return logps + + def _compute_avg_logps( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute average log probabilities over valid tokens.""" + loss_mask = (labels != self.ignore_index).float() + seq_lengths = loss_mask.sum(dim=-1).clamp(min=1) + return (per_token_logps * loss_mask).sum(dim=-1) / seq_lengths + + def _compute_nll_loss( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute negative log likelihood loss.""" + loss_mask = (labels != self.ignore_index).float() + return -(per_token_logps * loss_mask).sum() / loss_mask.sum().clamp(min=1) + + def __call__( + self, + inputs: Dict, + outputs: Dict, + **kwargs, + ) -> LossOutput: + """Compute ORPO loss. + + Args: + inputs: Dict containing 'labels' [batch, seq_len]. + outputs: Dict containing 'logps' or 'logits'. + + Returns: + LossOutput with ORPO loss. + """ + import torch + import torch.nn.functional as F + + labels = inputs.get('labels') + assert labels is not None, "inputs must contain 'labels'" + if not torch.is_tensor(labels): + labels = torch.as_tensor(labels) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + batch_size = labels.shape[0] + assert batch_size % 2 == 0, "Batch size must be even" + half_batch = batch_size // 2 + + # Get log probabilities + logps = outputs.get('logps') + if logps is None: + logits = outputs.get('logits') + assert logits is not None, "outputs must contain 'logps' or 'logits'" + if logits.shape[1] != labels.shape[1]: + logits = logits[:, -labels.shape[1]:] + logps = self._compute_logps_from_logits(logits, labels) + + # Split into chosen and rejected + chosen_labels = labels[:half_batch] + rejected_labels = labels[half_batch:] + chosen_logps = logps[:half_batch] + rejected_logps = logps[half_batch:] + + # SFT loss on chosen + sft_loss = self._compute_nll_loss(chosen_logps, chosen_labels) + + # Compute average log probs for odds ratio + chosen_avg_logps = self._compute_avg_logps(chosen_logps, chosen_labels) + rejected_avg_logps = self._compute_avg_logps(rejected_logps, rejected_labels) + + # Odds ratio: log(odds_chosen / odds_rejected) + # log(p/(1-p)) ≈ log(p) - log(1-p) ≈ log(p) + p for small p + # Simplified: log_odds = avg_logp (since p is small) + log_odds_chosen = chosen_avg_logps - torch.log1p(-torch.exp(chosen_avg_logps).clamp(max=1-1e-7)) + log_odds_rejected = rejected_avg_logps - torch.log1p(-torch.exp(rejected_avg_logps).clamp(max=1-1e-7)) + + # ORPO odds ratio loss + odds_ratio = log_odds_chosen - log_odds_rejected + orpo_loss = -F.logsigmoid(odds_ratio).mean() + + # Combined loss + loss = sft_loss + self.lambda_orpo * orpo_loss + + return LossOutput(loss=loss, num_tokens=0) diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py index 13b52d99..18dae667 100644 --- a/src/twinkle/preprocessor/__init__.py +++ b/src/twinkle/preprocessor/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import DataFilter, Preprocessor +from .dpo import DPOProcessor, HHRLHFProcessor, IntelOrcaDPOProcessor, ShareGPTDPOProcessor, UltraFeedbackProcessor from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor, GSM8KProcessor, SelfCognitionProcessor) diff --git a/src/twinkle/preprocessor/dpo.py b/src/twinkle/preprocessor/dpo.py new file mode 100644 index 00000000..c951caab --- /dev/null +++ b/src/twinkle/preprocessor/dpo.py @@ -0,0 +1,387 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +DPO (Direct Preference Optimization) Data Preprocessors. + +These preprocessors convert various preference dataset formats into the standard +Trajectory format required by Twinkle for DPO training. +""" +from typing import Any, Dict, List, Optional, Union + +from twinkle.data_format import Message, Trajectory +from .base import Preprocessor + + +class DPOProcessor(Preprocessor): + """Generic DPO preference data preprocessor. + + Converts preference data with chosen/rejected pairs into Trajectory format. + Supports multiple common dataset formats. + + Expected input format (one of): + 1. {'prompt': str, 'chosen': str, 'rejected': str} + 2. {'prompt': str, 'chosen': List[Message], 'rejected': List[Message]} + 3. {'messages': List[Message], 'chosen': str, 'rejected': str} + 4. {'chosen': List[Message], 'rejected': List[Message]} (full conversations) + + Output: Each sample generates TWO Trajectories: + - First: chosen response trajectory + - Second: rejected response trajectory + The DPO loss expects batch to be [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + + Args: + system: Optional system prompt to prepend. + chosen_key: Key for chosen response (default: 'chosen'). + rejected_key: Key for rejected response (default: 'rejected'). + prompt_key: Key for prompt/question (default: 'prompt'). + messages_key: Key for conversation messages (default: 'messages'). + """ + + def __init__( + self, + system: Optional[str] = None, + chosen_key: str = 'chosen', + rejected_key: str = 'rejected', + prompt_key: str = 'prompt', + messages_key: str = 'messages', + ): + self.system = system + self.chosen_key = chosen_key + self.rejected_key = rejected_key + self.prompt_key = prompt_key + self.messages_key = messages_key + + def _parse_response(self, response: Union[str, List[Dict], List[Message]]) -> List[Message]: + """Parse response into list of Messages.""" + if isinstance(response, str): + return [Message(role='assistant', content=response)] + elif isinstance(response, list): + messages = [] + for msg in response: + if isinstance(msg, Message): + messages.append(msg) + elif isinstance(msg, dict): + messages.append(Message(role=msg.get('role', 'assistant'), content=msg.get('content', ''))) + return messages + return [Message(role='assistant', content=str(response))] + + def _build_prompt_messages(self, row: Dict[str, Any]) -> List[Message]: + """Build prompt messages from row data.""" + messages = [] + + # Add system message if provided + if self.system: + messages.append(Message(role='system', content=self.system)) + + # Check for messages field (conversation format) + if self.messages_key in row and row[self.messages_key]: + raw_messages = row[self.messages_key] + for msg in raw_messages: + if isinstance(msg, Message): + messages.append(msg) + elif isinstance(msg, dict): + messages.append(Message(role=msg.get('role'), content=msg.get('content', ''))) + return messages + + # Check for prompt field + if self.prompt_key in row and row[self.prompt_key]: + prompt = row[self.prompt_key] + if isinstance(prompt, str): + messages.append(Message(role='user', content=prompt)) + elif isinstance(prompt, list): + for msg in prompt: + if isinstance(msg, Message): + messages.append(msg) + elif isinstance(msg, dict): + messages.append(Message(role=msg.get('role'), content=msg.get('content', ''))) + + return messages + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + """Process a single row into chosen and rejected trajectories. + + Returns: + Dict with 'chosen' and 'rejected' Trajectory objects. + """ + # Build prompt messages + prompt_messages = self._build_prompt_messages(row) + + # Get chosen response + chosen_raw = row.get(self.chosen_key, '') + chosen_messages = self._parse_response(chosen_raw) + + # Get rejected response + rejected_raw = row.get(self.rejected_key, '') + rejected_messages = self._parse_response(rejected_raw) + + # Build full trajectories + chosen_trajectory = Trajectory(messages=prompt_messages + chosen_messages) + rejected_trajectory = Trajectory(messages=prompt_messages + rejected_messages) + + return { + 'chosen': chosen_trajectory, + 'rejected': rejected_trajectory, + } + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + """Process batched data into paired trajectories. + + Note: Output maintains separate 'chosen' and 'rejected' columns. + The DataLoader/collator should handle pairing them appropriately + for the DPO loss (concatenating chosen batch + rejected batch). + """ + rows = self.map_col_to_row(rows) + processed = [self.preprocess(row) for row in rows] + return self.map_row_to_col(processed) + + +class HHRLHFProcessor(Preprocessor): + """Preprocessor for Anthropic HH-RLHF dataset format. + + HH-RLHF format: + {'chosen': "Human: ... Assistant: ...", 'rejected': "Human: ... Assistant: ..."} + + The conversations use "Human:" and "Assistant:" prefixes. + """ + + def __init__(self, system: Optional[str] = None): + self.system = system + + def _parse_hh_conversation(self, text: str) -> List[Message]: + """Parse HH-RLHF style conversation text into Messages.""" + messages = [] + + if self.system: + messages.append(Message(role='system', content=self.system)) + + # Split by Human/Assistant markers + parts = text.split('\n\nHuman: ') + for i, part in enumerate(parts): + if i == 0 and not part.startswith('Human: '): + if part.strip(): + # Initial text before first Human marker + if part.startswith('Human: '): + part = part[7:] # Remove "Human: " prefix + messages.append(Message(role='user', content=part.strip())) + continue + + # Split Human and Assistant parts + if '\n\nAssistant: ' in part: + human_part, assistant_part = part.split('\n\nAssistant: ', 1) + messages.append(Message(role='user', content=human_part.strip())) + messages.append(Message(role='assistant', content=assistant_part.strip())) + else: + messages.append(Message(role='user', content=part.strip())) + + return messages + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + """Process HH-RLHF format row.""" + chosen_text = row.get('chosen', '') + rejected_text = row.get('rejected', '') + + chosen_messages = self._parse_hh_conversation(chosen_text) + rejected_messages = self._parse_hh_conversation(rejected_text) + + return { + 'chosen': Trajectory(messages=chosen_messages), + 'rejected': Trajectory(messages=rejected_messages), + } + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + processed = [self.preprocess(row) for row in rows] + return self.map_row_to_col(processed) + + +class UltraFeedbackProcessor(Preprocessor): + """Preprocessor for UltraFeedback dataset format. + + UltraFeedback format: + { + 'instruction': str, + 'completions': [ + {'response': str, 'overall_score': float, ...}, + ... + ] + } + + Selects highest and lowest scored completions as chosen/rejected. + """ + + def __init__( + self, + system: Optional[str] = None, + instruction_key: str = 'instruction', + completions_key: str = 'completions', + response_key: str = 'response', + score_key: str = 'overall_score', + ): + self.system = system + self.instruction_key = instruction_key + self.completions_key = completions_key + self.response_key = response_key + self.score_key = score_key + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Trajectory]]: + """Process UltraFeedback format row.""" + instruction = row.get(self.instruction_key, '') + completions = row.get(self.completions_key, []) + + if len(completions) < 2: + return None # Need at least 2 completions for preference + + # Sort by score + scored_completions = [ + (c.get(self.score_key, 0), c.get(self.response_key, '')) + for c in completions + if c.get(self.response_key) + ] + + if len(scored_completions) < 2: + return None + + scored_completions.sort(key=lambda x: x[0], reverse=True) + chosen_response = scored_completions[0][1] + rejected_response = scored_completions[-1][1] + + # Build messages + messages = [] + if self.system: + messages.append(Message(role='system', content=self.system)) + messages.append(Message(role='user', content=instruction)) + + chosen_trajectory = Trajectory( + messages=messages + [Message(role='assistant', content=chosen_response)] + ) + rejected_trajectory = Trajectory( + messages=messages + [Message(role='assistant', content=rejected_response)] + ) + + return { + 'chosen': chosen_trajectory, + 'rejected': rejected_trajectory, + } + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + processed = [self.preprocess(row) for row in rows if self.preprocess(row) is not None] + if not processed: + return {'chosen': [], 'rejected': []} + return self.map_row_to_col(processed) + + +class ShareGPTDPOProcessor(Preprocessor): + """Preprocessor for ShareGPT-style DPO datasets. + + Expected format: + { + 'conversations': [ + {'from': 'human', 'value': '...'}, + {'from': 'gpt', 'value': '...'}, + ... + ], + 'chosen': {'from': 'gpt', 'value': '...'}, + 'rejected': {'from': 'gpt', 'value': '...'} + } + """ + + ROLE_MAPPING = { + 'human': 'user', + 'gpt': 'assistant', + 'system': 'system', + 'user': 'user', + 'assistant': 'assistant', + } + + def __init__(self, system: Optional[str] = None): + self.system = system + + def _parse_sharegpt_message(self, msg: Dict) -> Message: + """Parse ShareGPT format message.""" + role = self.ROLE_MAPPING.get(msg.get('from', ''), 'user') + content = msg.get('value', '') or msg.get('content', '') + return Message(role=role, content=content) + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + """Process ShareGPT DPO format row.""" + conversations = row.get('conversations', []) + + # Build prompt messages (excluding last assistant turn if present) + messages = [] + if self.system: + messages.append(Message(role='system', content=self.system)) + + for msg in conversations: + messages.append(self._parse_sharegpt_message(msg)) + + # Remove last message if it's assistant (will be replaced by chosen/rejected) + if messages and messages[-1].role == 'assistant': + messages = messages[:-1] + + # Get chosen and rejected + chosen_msg = row.get('chosen', {}) + rejected_msg = row.get('rejected', {}) + + if isinstance(chosen_msg, dict): + chosen_response = Message( + role='assistant', + content=chosen_msg.get('value', '') or chosen_msg.get('content', '') + ) + else: + chosen_response = Message(role='assistant', content=str(chosen_msg)) + + if isinstance(rejected_msg, dict): + rejected_response = Message( + role='assistant', + content=rejected_msg.get('value', '') or rejected_msg.get('content', '') + ) + else: + rejected_response = Message(role='assistant', content=str(rejected_msg)) + + return { + 'chosen': Trajectory(messages=messages + [chosen_response]), + 'rejected': Trajectory(messages=messages + [rejected_response]), + } + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + processed = [self.preprocess(row) for row in rows] + return self.map_row_to_col(processed) + + +class IntelOrcaDPOProcessor(Preprocessor): + """Preprocessor for Intel ORCA DPO dataset format. + + Expected format: + { + 'system': str, + 'question': str, + 'chosen': str, + 'rejected': str + } + """ + + def __init__(self, default_system: Optional[str] = None): + self.default_system = default_system + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + """Process Intel ORCA DPO format row.""" + system = row.get('system', self.default_system) + question = row.get('question', '') + chosen = row.get('chosen', '') + rejected = row.get('rejected', '') + + messages = [] + if system: + messages.append(Message(role='system', content=system)) + messages.append(Message(role='user', content=question)) + + return { + 'chosen': Trajectory(messages=messages + [Message(role='assistant', content=chosen)]), + 'rejected': Trajectory(messages=messages + [Message(role='assistant', content=rejected)]), + } + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + processed = [self.preprocess(row) for row in rows] + return self.map_row_to_col(processed) From 52978e9312cc7091aa89265c59ac43e9404814f7 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 26 Mar 2026 23:03:16 +0800 Subject: [PATCH 02/14] fix --- cookbook/rl/dpo.py | 48 ++++++++++++++++++++------------- src/twinkle/loss/dpo.py | 12 ++++----- src/twinkle/preprocessor/dpo.py | 38 ++++++++++++++++++++------ 3 files changed, 66 insertions(+), 32 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index 1e03f84b..0a776f29 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -71,8 +71,8 @@ NUM_GPUS = MODEL_GPUS USE_REFERENCE_MODEL = False -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) # Must be even (chosen + rejected) GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) LEARNING_RATE = float(os.environ.get('LR', 5e-6)) @@ -88,41 +88,49 @@ def create_dpo_dataset(): """Create preference dataset for DPO training. - The dataset should contain 'chosen' and 'rejected' columns after preprocessing. - Each sample will be duplicated: first the chosen, then the rejected version. + The dataset will contain interleaved chosen/rejected pairs after preprocessing: + [chosen_1, rejected_1, chosen_2, rejected_2, ...] + + The collate function will reorder to: [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] """ dataset = Dataset(DatasetMeta(DATASET_ID, split='train')) dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) - # Use DPOProcessor to convert dataset to standard format - # Adjust processor based on your dataset format + # Use DPOProcessor with interleaved output format + # This creates alternating chosen/rejected pairs that can be properly encoded dataset.map(DPOProcessor( system='You are a helpful, harmless, and honest assistant.', chosen_key='chosen', rejected_key='rejected', prompt_key='prompt', + output_format='interleaved', # Output: [chosen_1, rejected_1, chosen_2, ...] )) - # Encode both chosen and rejected trajectories + # Encode the interleaved trajectories dataset.encode() return dataset -def collate_preference_batch(batch: List[Dict[str, Any]]) -> Dict[str, List]: - """Collate preference pairs into DPO batch format. +def collate_preference_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Collate interleaved preference pairs into DPO batch format. + + Input: [chosen_1, rejected_1, chosen_2, rejected_2, ...] (interleaved) + Output: [chosen_1, chosen_2, ..., rejected_1, rejected_2, ...] (grouped) - DPO loss expects: [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + DPO loss expects: first half chosen, second half rejected. """ + if not batch: + return batch + + # Extract alternating chosen/rejected chosen_samples = [] rejected_samples = [] - for item in batch: - if 'chosen' in item and 'rejected' in item: - chosen_samples.append(item['chosen']) - rejected_samples.append(item['rejected']) - else: - # Assume alternating format if not explicitly separated + for i, item in enumerate(batch): + if i % 2 == 0: # Even indices are chosen chosen_samples.append(item) + else: # Odd indices are rejected + rejected_samples.append(item) # Concatenate: all chosen first, then all rejected return chosen_samples + rejected_samples @@ -209,7 +217,9 @@ def main(): logger.info(f'Training without reference model (loss_type={LOSS_TYPE})') # ── DataLoader Setup ────────────────────────────────────────────────────── - GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS + # Since dataset is interleaved (chosen, rejected, chosen, rejected, ...), + # we need batch_size * 2 samples to get BATCH_SIZE preference pairs + GLOBAL_BATCH_SIZE = BATCH_SIZE * 2 * GRADIENT_ACCUMULATION_STEPS dataloader = DataLoader( dataset=create_dpo_dataset, batch_size=GLOBAL_BATCH_SIZE, @@ -239,10 +249,12 @@ def main(): ref_logps = ref_outputs.get('logps') # Forward-backward pass with DPO loss + # micro_batch_size must be even to maintain chosen/rejected pairing + actual_micro_batch = MICRO_BATCH_SIZE * 2 # Convert pairs to samples policy_model.forward_backward( inputs=preference_batch, ref_logps=ref_logps, - micro_batch_size=MICRO_BATCH_SIZE, + micro_batch_size=actual_micro_batch, ) # Gradient clipping and optimizer step diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index 0f3e364a..d8f5b207 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -8,8 +8,6 @@ """ from typing import TYPE_CHECKING, Dict, List, Optional, Union -import numpy as np - from twinkle.data_format import LossOutput from twinkle.kernel import selective_log_softmax from twinkle.loss.base import Loss @@ -640,10 +638,12 @@ def __call__( rejected_avg_logps = self._compute_avg_logps(rejected_logps, rejected_labels) # Odds ratio: log(odds_chosen / odds_rejected) - # log(p/(1-p)) ≈ log(p) - log(1-p) ≈ log(p) + p for small p - # Simplified: log_odds = avg_logp (since p is small) - log_odds_chosen = chosen_avg_logps - torch.log1p(-torch.exp(chosen_avg_logps).clamp(max=1-1e-7)) - log_odds_rejected = rejected_avg_logps - torch.log1p(-torch.exp(rejected_avg_logps).clamp(max=1-1e-7)) + # log_odds = log(p/(1-p)) = log(p) - log(1-p) + # Use numerically stable computation + prob_chosen = torch.exp(chosen_avg_logps).clamp(min=1e-7, max=1-1e-7) + prob_rejected = torch.exp(rejected_avg_logps).clamp(min=1e-7, max=1-1e-7) + log_odds_chosen = torch.log(prob_chosen) - torch.log(1 - prob_chosen) + log_odds_rejected = torch.log(prob_rejected) - torch.log(1 - prob_rejected) # ORPO odds ratio loss odds_ratio = log_odds_chosen - log_odds_rejected diff --git a/src/twinkle/preprocessor/dpo.py b/src/twinkle/preprocessor/dpo.py index c951caab..0c2ab769 100644 --- a/src/twinkle/preprocessor/dpo.py +++ b/src/twinkle/preprocessor/dpo.py @@ -23,10 +23,12 @@ class DPOProcessor(Preprocessor): 3. {'messages': List[Message], 'chosen': str, 'rejected': str} 4. {'chosen': List[Message], 'rejected': List[Message]} (full conversations) - Output: Each sample generates TWO Trajectories: - - First: chosen response trajectory - - Second: rejected response trajectory - The DPO loss expects batch to be [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + Output: Each sample is expanded into TWO rows in the dataset: + - Row 2i: chosen response trajectory + - Row 2i+1: rejected response trajectory + + The DPO loss expects batch to be [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n], + which should be handled by a custom collate function or DataLoader. Args: system: Optional system prompt to prepend. @@ -34,6 +36,9 @@ class DPOProcessor(Preprocessor): rejected_key: Key for rejected response (default: 'rejected'). prompt_key: Key for prompt/question (default: 'prompt'). messages_key: Key for conversation messages (default: 'messages'). + output_format: How to structure the output: + - 'interleaved': [chosen_1, rejected_1, chosen_2, rejected_2, ...] + - 'paired': {'chosen': [Traj1, ...], 'rejected': [Traj1, ...]} """ def __init__( @@ -43,12 +48,14 @@ def __init__( rejected_key: str = 'rejected', prompt_key: str = 'prompt', messages_key: str = 'messages', + output_format: str = 'interleaved', ): self.system = system self.chosen_key = chosen_key self.rejected_key = rejected_key self.prompt_key = prompt_key self.messages_key = messages_key + self.output_format = output_format def _parse_response(self, response: Union[str, List[Dict], List[Message]]) -> List[Message]: """Parse response into list of Messages.""" @@ -125,13 +132,28 @@ def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Process batched data into paired trajectories. - Note: Output maintains separate 'chosen' and 'rejected' columns. - The DataLoader/collator should handle pairing them appropriately - for the DPO loss (concatenating chosen batch + rejected batch). + Output format depends on self.output_format: + - 'interleaved': Returns standard Trajectory column with alternating + chosen/rejected pairs for proper encoding. The DataLoader should use + a custom collate function to reorder into [chosen..., rejected...]. + - 'paired': Returns separate 'chosen' and 'rejected' columns (requires + special handling in DataLoader). """ rows = self.map_col_to_row(rows) processed = [self.preprocess(row) for row in rows] - return self.map_row_to_col(processed) + + if self.output_format == 'interleaved': + # Flatten to interleaved format: [chosen_1, rejected_1, chosen_2, rejected_2, ...] + # This allows standard Dataset.encode() to work + trajectories = [] + for pair in processed: + trajectories.append(pair['chosen']) + trajectories.append(pair['rejected']) + # Return as standard trajectory column + return {'messages': trajectories} + else: + # Paired format: separate columns + return self.map_row_to_col(processed) class HHRLHFProcessor(Preprocessor): From f5d5961503284f85c7531dabf899a939dac28beb Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 26 Mar 2026 23:09:43 +0800 Subject: [PATCH 03/14] fix --- src/twinkle/loss/dpo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index d8f5b207..3cd3d631 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union from twinkle.data_format import LossOutput -from twinkle.kernel import selective_log_softmax +from twinkle.utils.torch_utils import selective_log_softmax from twinkle.loss.base import Loss if TYPE_CHECKING: From 3cf03cd3d0f35bd8a7ec9a084b8921ba3e282965 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 26 Mar 2026 23:31:33 +0800 Subject: [PATCH 04/14] fix --- cookbook/rl/dpo.py | 120 ++++++---- src/twinkle/loss/dpo.py | 331 +++++++++------------------ src/twinkle/preprocessor/__init__.py | 3 +- src/twinkle/preprocessor/dpo.py | 291 ++++++++++++++--------- 4 files changed, 368 insertions(+), 377 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index 0a776f29..6edb917d 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -5,8 +5,9 @@ Pipeline: 1. Load preference dataset with chosen/rejected pairs. - 2. Compute reference model log probabilities (frozen). - 3. Train policy model using DPO loss. + 2. Encode chosen and rejected separately. + 3. Compute reference model log probabilities (frozen). + 4. Train policy model using DPO loss. Architecture (Ray): ┌─────────────────────────────────────────────────────────────────┐ @@ -19,16 +20,20 @@ DataLoader RefModel (frozen) PolicyModel (trainable) (ref GPUs) (policy GPUs) +DPO Trajectory format: + - messages: List[Message] - chosen response + - extend_message: [('rejected_messages', List[Message])] - rejected response + For SimPO/ORPO variants that don't require a reference model, set USE_REFERENCE_MODEL=0 to skip reference model computation. Environment variables (all optional): MODEL_ID – (default: ms://Qwen/Qwen3.5-4B) - DATASET_ID – (default: ms://argilla/ultrafeedback-binarized-preferences-cleaned) + DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji) MODEL_GPUS – GPUs for policy model (default: 4) REF_MODEL_GPUS – GPUs for reference model (default: 4, 0 to disable) USE_REFERENCE_MODEL – Whether to use reference model (default: 1) - BATCH_SIZE – global batch size (pairs) (default: 8) + BATCH_SIZE – global batch size (preference pairs) (default: 8) MICRO_BATCH_SIZE – per-device micro batch size (default: 2) MAX_STEPS – total optimization steps (default: 1000) LR – learning rate (default: 5e-6) @@ -46,11 +51,12 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger +from twinkle.data_format import Trajectory from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss from twinkle.model import TransformersModel -from twinkle.preprocessor import DPOProcessor +from twinkle.preprocessor import EmojiDPOProcessor from twinkle.processor import InputProcessor from twinkle.template import Template @@ -58,7 +64,7 @@ # ── Configuration ───────────────────────────────────────────────────────────── MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') -DATASET_ID = os.environ.get('DATASET_ID', 'ms://argilla/ultrafeedback-binarized-preferences-cleaned') +DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4)) @@ -72,7 +78,7 @@ USE_REFERENCE_MODEL = False BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) # Must be even (chosen + rejected) +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) LEARNING_RATE = float(os.environ.get('LR', 5e-6)) @@ -88,51 +94,75 @@ def create_dpo_dataset(): """Create preference dataset for DPO training. - The dataset will contain interleaved chosen/rejected pairs after preprocessing: - [chosen_1, rejected_1, chosen_2, rejected_2, ...] + Uses shareAI/DPO-zh-en-emoji dataset: + - answer_zh: chosen response (Chinese) + - answer_en: rejected response (English) + + Output Trajectory format: + - messages: chosen response + - extend_message: [('rejected_messages', rejected response)] - The collate function will reorder to: [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + Note: We do NOT call dataset.encode() here. Encoding is done in + prepare_dpo_batch() to properly handle both chosen and rejected. """ dataset = Dataset(DatasetMeta(DATASET_ID, split='train')) dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) - # Use DPOProcessor with interleaved output format - # This creates alternating chosen/rejected pairs that can be properly encoded - dataset.map(DPOProcessor( - system='You are a helpful, harmless, and honest assistant.', - chosen_key='chosen', - rejected_key='rejected', + # Use EmojiDPOProcessor for shareAI/DPO-zh-en-emoji dataset + # answer_zh -> chosen (messages), answer_en -> rejected (extend_message) + dataset.map(EmojiDPOProcessor( + system='You are a helpful assistant.', + chosen_key='answer_zh', + rejected_key='answer_en', prompt_key='prompt', - output_format='interleaved', # Output: [chosen_1, rejected_1, chosen_2, ...] )) - # Encode the interleaved trajectories - dataset.encode() + # Do NOT encode here - encoding is done in prepare_dpo_batch + # to preserve extend_message for rejected encoding return dataset -def collate_preference_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Collate interleaved preference pairs into DPO batch format. +def prepare_dpo_batch( + batch: List[Dict[str, Any]], + template: Template, +) -> List[Dict[str, Any]]: + """Prepare DPO batch: encode both chosen and rejected. - Input: [chosen_1, rejected_1, chosen_2, rejected_2, ...] (interleaved) - Output: [chosen_1, chosen_2, ..., rejected_1, rejected_2, ...] (grouped) + Args: + batch: List of raw Trajectory dicts with messages (chosen) and + extend_message containing ('rejected_messages', rejected) - DPO loss expects: first half chosen, second half rejected. + Returns: + List organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + where each item is an encoded InputFeature dict. """ - if not batch: - return batch - - # Extract alternating chosen/rejected chosen_samples = [] rejected_samples = [] - for i, item in enumerate(batch): - if i % 2 == 0: # Even indices are chosen - chosen_samples.append(item) - else: # Odd indices are rejected - rejected_samples.append(item) - - # Concatenate: all chosen first, then all rejected + for item in batch: + # Get messages (chosen) and encode + messages = item.get('messages', []) + chosen_trajectory = Trajectory(messages=messages) + chosen_encoded = template.encode(chosen_trajectory) + chosen_samples.append(dict(chosen_encoded)) + + # Get rejected from extend_message and encode + extend_message = item.get('extend_message', []) + rejected_messages = None + for key, msgs in extend_message: + if key == 'rejected_messages': + rejected_messages = msgs + break + + if rejected_messages: + rejected_trajectory = Trajectory(messages=rejected_messages) + rejected_encoded = template.encode(rejected_trajectory) + rejected_samples.append(dict(rejected_encoded)) + else: + # Fallback: use chosen (should not happen with proper preprocessing) + rejected_samples.append(dict(chosen_encoded)) + + # Return [chosen..., rejected...] return chosen_samples + rejected_samples @@ -201,6 +231,9 @@ def main(): policy_model.set_processor(InputProcessor) policy_model.set_template('Template', model_id=MODEL_ID) + # Get template for encoding rejected messages + template = Template(model_id=MODEL_ID, max_length=MAX_LENGTH) + # ── Reference Model Setup (if needed) ───────────────────────────────────── ref_model = None if USE_REFERENCE_MODEL and not reference_free: @@ -217,9 +250,7 @@ def main(): logger.info(f'Training without reference model (loss_type={LOSS_TYPE})') # ── DataLoader Setup ────────────────────────────────────────────────────── - # Since dataset is interleaved (chosen, rejected, chosen, rejected, ...), - # we need batch_size * 2 samples to get BATCH_SIZE preference pairs - GLOBAL_BATCH_SIZE = BATCH_SIZE * 2 * GRADIENT_ACCUMULATION_STEPS + GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS dataloader = DataLoader( dataset=create_dpo_dataset, batch_size=GLOBAL_BATCH_SIZE, @@ -238,21 +269,22 @@ def main(): if optim_step >= MAX_STEPS: break - # Collate preference pairs: [chosen..., rejected...] - preference_batch = collate_preference_batch(batch if isinstance(batch, list) else [batch]) + # Prepare DPO batch: [chosen..., rejected...] + batch_list = batch if isinstance(batch, list) else [batch] + dpo_batch = prepare_dpo_batch(batch_list, template) # Compute reference log probabilities if using reference model ref_logps = None if ref_model is not None: with torch.no_grad(): - ref_outputs = ref_model.forward_only(inputs=preference_batch) + ref_outputs = ref_model.forward_only(inputs=dpo_batch) ref_logps = ref_outputs.get('logps') # Forward-backward pass with DPO loss - # micro_batch_size must be even to maintain chosen/rejected pairing - actual_micro_batch = MICRO_BATCH_SIZE * 2 # Convert pairs to samples + # micro_batch_size should be even to maintain chosen/rejected pairing + actual_micro_batch = MICRO_BATCH_SIZE * 2 policy_model.forward_backward( - inputs=preference_batch, + inputs=dpo_batch, ref_logps=ref_logps, micro_batch_size=actual_micro_batch, ) diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index 3cd3d631..e727193a 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -16,7 +16,84 @@ import torch -class DPOLoss(Loss): +class PreferenceLossBase(Loss): + """Base class for preference optimization losses with shared utilities.""" + + def __init__(self, ignore_index: int = -100): + self.ignore_index = ignore_index + + def _compute_logps_from_logits( + self, + logits: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute per-token log probabilities from logits. + + Args: + logits: [batch, seq_len, vocab_size] model logits + labels: [batch, seq_len] target token ids + + Returns: + logps: [batch, seq_len] per-token log probabilities + """ + loss_mask = (labels != self.ignore_index).bool() + masked_labels = labels.clone() + masked_labels[~loss_mask] = 0 + return selective_log_softmax(logits, masked_labels) + + def _compute_sequence_logps( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute sequence-level log probabilities by summing valid token logps.""" + loss_mask = (labels != self.ignore_index).float() + return (per_token_logps * loss_mask).sum(dim=-1) + + def _compute_avg_logps( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute length-normalized (average) log probabilities.""" + loss_mask = (labels != self.ignore_index).float() + seq_lengths = loss_mask.sum(dim=-1).clamp(min=1) + return (per_token_logps * loss_mask).sum(dim=-1) / seq_lengths + + def _compute_nll_loss( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute negative log likelihood loss.""" + loss_mask = (labels != self.ignore_index).float() + return -(per_token_logps * loss_mask).sum() / loss_mask.sum().clamp(min=1) + + def _get_logps_from_outputs( + self, + outputs: Dict, + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Extract or compute log probabilities from model outputs.""" + logps = outputs.get('logps') + if logps is None: + logits = outputs.get('logits') + assert logits is not None, "outputs must contain 'logps' or 'logits'" + if logits.shape[1] != labels.shape[1]: + logits = logits[:, -labels.shape[1]:] + logps = self._compute_logps_from_logits(logits, labels) + return logps + + def _split_chosen_rejected( + self, + tensor: 'torch.Tensor', + ) -> tuple: + """Split tensor into chosen (first half) and rejected (second half).""" + half = tensor.shape[0] // 2 + return tensor[:half], tensor[half:] + + +class DPOLoss(PreferenceLossBase): """Direct Preference Optimization (DPO) Loss. DPO directly optimizes the policy using preference data without explicit reward modeling. @@ -48,49 +125,12 @@ def __init__( reference_free: bool = False, **kwargs, ): + super().__init__(ignore_index=ignore_index) self.beta = beta self.label_smoothing = label_smoothing - self.ignore_index = ignore_index self.loss_type = loss_type self.reference_free = reference_free - def _compute_logps_from_logits( - self, - logits: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute per-token log probabilities from logits. - - Args: - logits: [batch, seq_len, vocab_size] model logits - labels: [batch, seq_len] target token ids - - Returns: - logps: [batch, seq_len] per-token log probabilities - """ - loss_mask = (labels != self.ignore_index).bool() - masked_labels = labels.clone() - masked_labels[~loss_mask] = 0 - logps = selective_log_softmax(logits, masked_labels) - return logps - - def _compute_sequence_logps( - self, - per_token_logps: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute sequence-level log probabilities by summing valid token logps. - - Args: - per_token_logps: [batch, seq_len] per-token log probabilities - labels: [batch, seq_len] labels for computing mask - - Returns: - seq_logps: [batch] sequence-level log probabilities - """ - loss_mask = (labels != self.ignore_index).float() - return (per_token_logps * loss_mask).sum(dim=-1) - def _pad_and_align_logps( self, logps: Union['torch.Tensor', List[List[float]]], @@ -240,25 +280,15 @@ def __call__( batch_size = labels.shape[0] assert batch_size % 2 == 0, "Batch size must be even (chosen + rejected pairs)" - half_batch = batch_size // 2 # Get log probabilities from outputs - logps = outputs.get('logps') - if logps is None: - logits = outputs.get('logits') - assert logits is not None, "outputs must contain 'logps' or 'logits'" - if logits.shape[1] != labels.shape[1]: - logits = logits[:, -labels.shape[1]:] - logps = self._compute_logps_from_logits(logits, labels) - + logps = self._get_logps_from_outputs(outputs, labels) device = logps.device dtype = logps.dtype # Split into chosen and rejected - chosen_labels = labels[:half_batch] - rejected_labels = labels[half_batch:] - chosen_logps = logps[:half_batch] - rejected_logps = logps[half_batch:] + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + chosen_logps, rejected_logps = self._split_chosen_rejected(logps) # Compute sequence-level log probs for policy policy_chosen_logps = self._compute_sequence_logps(chosen_logps, chosen_labels) @@ -275,8 +305,7 @@ def __call__( ref_logps_aligned = self._pad_and_align_logps( ref_logps, labels.shape, loss_mask, device, dtype ) - ref_chosen = ref_logps_aligned[:half_batch] - ref_rejected = ref_logps_aligned[half_batch:] + ref_chosen, ref_rejected = self._split_chosen_rejected(ref_logps_aligned) reference_chosen_logps = self._compute_sequence_logps(ref_chosen, chosen_labels) reference_rejected_logps = self._compute_sequence_logps(ref_rejected, rejected_labels) elif self.reference_free: @@ -300,7 +329,7 @@ def __call__( return LossOutput(loss=loss, num_tokens=0) -class SimPOLoss(Loss): +class SimPOLoss(PreferenceLossBase): """SimPO (Simple Preference Optimization) Loss. SimPO is a simpler variant of DPO that doesn't require a reference model. @@ -323,40 +352,9 @@ def __init__( ignore_index: int = -100, **kwargs, ): + super().__init__(ignore_index=ignore_index) self.beta = beta self.gamma = gamma - self.ignore_index = ignore_index - - def _compute_logps_from_logits( - self, - logits: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute per-token log probabilities from logits.""" - loss_mask = (labels != self.ignore_index).bool() - masked_labels = labels.clone() - masked_labels[~loss_mask] = 0 - logps = selective_log_softmax(logits, masked_labels) - return logps - - def _compute_length_normalized_logps( - self, - per_token_logps: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute length-normalized sequence log probabilities. - - Args: - per_token_logps: [batch, seq_len] per-token log probabilities - labels: [batch, seq_len] labels for computing mask - - Returns: - normalized_logps: [batch] length-normalized log probabilities - """ - loss_mask = (labels != self.ignore_index).float() - seq_lengths = loss_mask.sum(dim=-1).clamp(min=1) - seq_logps = (per_token_logps * loss_mask).sum(dim=-1) - return seq_logps / seq_lengths def __call__( self, @@ -364,16 +362,7 @@ def __call__( outputs: Dict, **kwargs, ) -> LossOutput: - """Compute SimPO loss. - - Args: - inputs: Dict containing 'input_ids' and 'labels' [batch, seq_len]. - Batch: [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] - outputs: Dict containing 'logps' or 'logits'. - - Returns: - LossOutput with SimPO loss. - """ + """Compute SimPO loss.""" import torch import torch.nn.functional as F @@ -384,28 +373,18 @@ def __call__( if labels.dim() == 1: labels = labels.unsqueeze(0) - batch_size = labels.shape[0] - assert batch_size % 2 == 0, "Batch size must be even (chosen + rejected pairs)" - half_batch = batch_size // 2 + assert labels.shape[0] % 2 == 0, "Batch size must be even (chosen + rejected pairs)" # Get log probabilities - logps = outputs.get('logps') - if logps is None: - logits = outputs.get('logits') - assert logits is not None, "outputs must contain 'logps' or 'logits'" - if logits.shape[1] != labels.shape[1]: - logits = logits[:, -labels.shape[1]:] - logps = self._compute_logps_from_logits(logits, labels) + logps = self._get_logps_from_outputs(outputs, labels) # Split into chosen and rejected - chosen_labels = labels[:half_batch] - rejected_labels = labels[half_batch:] - chosen_logps = logps[:half_batch] - rejected_logps = logps[half_batch:] + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + chosen_logps, rejected_logps = self._split_chosen_rejected(logps) # Compute length-normalized log probs - chosen_rewards = self._compute_length_normalized_logps(chosen_logps, chosen_labels) - rejected_rewards = self._compute_length_normalized_logps(rejected_logps, rejected_labels) + chosen_rewards = self._compute_avg_logps(chosen_logps, chosen_labels) + rejected_rewards = self._compute_avg_logps(rejected_logps, rejected_labels) # SimPO loss: -log(sigmoid(beta * (r_w - r_l) - gamma)) logits = self.beta * (chosen_rewards - rejected_rewards) - self.gamma @@ -414,7 +393,7 @@ def __call__( return LossOutput(loss=loss, num_tokens=0) -class CPOLoss(Loss): +class CPOLoss(PreferenceLossBase): """CPO (Contrastive Preference Optimization) Loss. CPO adds a behavior cloning term to preference optimization. @@ -436,40 +415,9 @@ def __init__( ignore_index: int = -100, **kwargs, ): + super().__init__(ignore_index=ignore_index) self.beta = beta self.bc_coef = bc_coef - self.ignore_index = ignore_index - - def _compute_logps_from_logits( - self, - logits: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute per-token log probabilities from logits.""" - loss_mask = (labels != self.ignore_index).bool() - masked_labels = labels.clone() - masked_labels[~loss_mask] = 0 - logps = selective_log_softmax(logits, masked_labels) - return logps - - def _compute_sequence_logps( - self, - per_token_logps: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute sequence-level log probabilities.""" - loss_mask = (labels != self.ignore_index).float() - return (per_token_logps * loss_mask).sum(dim=-1) - - def _compute_nll_loss( - self, - per_token_logps: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute negative log likelihood loss for chosen responses.""" - loss_mask = (labels != self.ignore_index).float() - nll = -(per_token_logps * loss_mask).sum() / loss_mask.sum().clamp(min=1) - return nll def __call__( self, @@ -477,15 +425,7 @@ def __call__( outputs: Dict, **kwargs, ) -> LossOutput: - """Compute CPO loss. - - Args: - inputs: Dict containing 'labels' [batch, seq_len]. - outputs: Dict containing 'logps' or 'logits'. - - Returns: - LossOutput with CPO loss. - """ + """Compute CPO loss.""" import torch import torch.nn.functional as F @@ -496,24 +436,14 @@ def __call__( if labels.dim() == 1: labels = labels.unsqueeze(0) - batch_size = labels.shape[0] - assert batch_size % 2 == 0, "Batch size must be even" - half_batch = batch_size // 2 + assert labels.shape[0] % 2 == 0, "Batch size must be even" # Get log probabilities - logps = outputs.get('logps') - if logps is None: - logits = outputs.get('logits') - assert logits is not None, "outputs must contain 'logps' or 'logits'" - if logits.shape[1] != labels.shape[1]: - logits = logits[:, -labels.shape[1]:] - logps = self._compute_logps_from_logits(logits, labels) + logps = self._get_logps_from_outputs(outputs, labels) # Split into chosen and rejected - chosen_labels = labels[:half_batch] - rejected_labels = labels[half_batch:] - chosen_logps = logps[:half_batch] - rejected_logps = logps[half_batch:] + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + chosen_logps, rejected_logps = self._split_chosen_rejected(logps) # Compute sequence-level log probs chosen_seq_logps = self._compute_sequence_logps(chosen_logps, chosen_labels) @@ -532,7 +462,7 @@ def __call__( return LossOutput(loss=loss, num_tokens=0) -class ORPOLoss(Loss): +class ORPOLoss(PreferenceLossBase): """ORPO (Odds Ratio Preference Optimization) Loss. ORPO combines SFT and preference alignment in a single objective using odds ratios. @@ -552,39 +482,8 @@ def __init__( ignore_index: int = -100, **kwargs, ): + super().__init__(ignore_index=ignore_index) self.lambda_orpo = lambda_orpo - self.ignore_index = ignore_index - - def _compute_logps_from_logits( - self, - logits: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute per-token log probabilities from logits.""" - loss_mask = (labels != self.ignore_index).bool() - masked_labels = labels.clone() - masked_labels[~loss_mask] = 0 - logps = selective_log_softmax(logits, masked_labels) - return logps - - def _compute_avg_logps( - self, - per_token_logps: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute average log probabilities over valid tokens.""" - loss_mask = (labels != self.ignore_index).float() - seq_lengths = loss_mask.sum(dim=-1).clamp(min=1) - return (per_token_logps * loss_mask).sum(dim=-1) / seq_lengths - - def _compute_nll_loss( - self, - per_token_logps: 'torch.Tensor', - labels: 'torch.Tensor', - ) -> 'torch.Tensor': - """Compute negative log likelihood loss.""" - loss_mask = (labels != self.ignore_index).float() - return -(per_token_logps * loss_mask).sum() / loss_mask.sum().clamp(min=1) def __call__( self, @@ -592,15 +491,7 @@ def __call__( outputs: Dict, **kwargs, ) -> LossOutput: - """Compute ORPO loss. - - Args: - inputs: Dict containing 'labels' [batch, seq_len]. - outputs: Dict containing 'logps' or 'logits'. - - Returns: - LossOutput with ORPO loss. - """ + """Compute ORPO loss.""" import torch import torch.nn.functional as F @@ -611,24 +502,14 @@ def __call__( if labels.dim() == 1: labels = labels.unsqueeze(0) - batch_size = labels.shape[0] - assert batch_size % 2 == 0, "Batch size must be even" - half_batch = batch_size // 2 + assert labels.shape[0] % 2 == 0, "Batch size must be even" # Get log probabilities - logps = outputs.get('logps') - if logps is None: - logits = outputs.get('logits') - assert logits is not None, "outputs must contain 'logps' or 'logits'" - if logits.shape[1] != labels.shape[1]: - logits = logits[:, -labels.shape[1]:] - logps = self._compute_logps_from_logits(logits, labels) + logps = self._get_logps_from_outputs(outputs, labels) # Split into chosen and rejected - chosen_labels = labels[:half_batch] - rejected_labels = labels[half_batch:] - chosen_logps = logps[:half_batch] - rejected_logps = logps[half_batch:] + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + chosen_logps, rejected_logps = self._split_chosen_rejected(logps) # SFT loss on chosen sft_loss = self._compute_nll_loss(chosen_logps, chosen_labels) diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py index 18dae667..6d9f6dd7 100644 --- a/src/twinkle/preprocessor/__init__.py +++ b/src/twinkle/preprocessor/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import DataFilter, Preprocessor -from .dpo import DPOProcessor, HHRLHFProcessor, IntelOrcaDPOProcessor, ShareGPTDPOProcessor, UltraFeedbackProcessor +from .dpo import (DPOProcessor, EmojiDPOProcessor, HHRLHFProcessor, IntelOrcaDPOProcessor, ShareGPTDPOProcessor, + UltraFeedbackKTOProcessor, UltraFeedbackProcessor) from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor, GSM8KProcessor, SelfCognitionProcessor) diff --git a/src/twinkle/preprocessor/dpo.py b/src/twinkle/preprocessor/dpo.py index 0c2ab769..a81b5f57 100644 --- a/src/twinkle/preprocessor/dpo.py +++ b/src/twinkle/preprocessor/dpo.py @@ -4,6 +4,10 @@ These preprocessors convert various preference dataset formats into the standard Trajectory format required by Twinkle for DPO training. + +DPO Trajectory format: + - messages: List[Message] - chosen response messages + - extend_message: [('rejected_messages', List[Message])] - rejected response messages """ from typing import Any, Dict, List, Optional, Union @@ -23,12 +27,9 @@ class DPOProcessor(Preprocessor): 3. {'messages': List[Message], 'chosen': str, 'rejected': str} 4. {'chosen': List[Message], 'rejected': List[Message]} (full conversations) - Output: Each sample is expanded into TWO rows in the dataset: - - Row 2i: chosen response trajectory - - Row 2i+1: rejected response trajectory - - The DPO loss expects batch to be [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n], - which should be handled by a custom collate function or DataLoader. + Output Trajectory format: + - messages: chosen response (prompt + chosen assistant message) + - extend_message: [('rejected_messages', rejected_messages)] Args: system: Optional system prompt to prepend. @@ -36,9 +37,6 @@ class DPOProcessor(Preprocessor): rejected_key: Key for rejected response (default: 'rejected'). prompt_key: Key for prompt/question (default: 'prompt'). messages_key: Key for conversation messages (default: 'messages'). - output_format: How to structure the output: - - 'interleaved': [chosen_1, rejected_1, chosen_2, rejected_2, ...] - - 'paired': {'chosen': [Traj1, ...], 'rejected': [Traj1, ...]} """ def __init__( @@ -48,14 +46,12 @@ def __init__( rejected_key: str = 'rejected', prompt_key: str = 'prompt', messages_key: str = 'messages', - output_format: str = 'interleaved', ): self.system = system self.chosen_key = chosen_key self.rejected_key = rejected_key self.prompt_key = prompt_key self.messages_key = messages_key - self.output_format = output_format def _parse_response(self, response: Union[str, List[Dict], List[Message]]) -> List[Message]: """Parse response into list of Messages.""" @@ -103,57 +99,38 @@ def _build_prompt_messages(self, row: Dict[str, Any]) -> List[Message]: return messages - def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: - """Process a single row into chosen and rejected trajectories. + def preprocess(self, row: Dict[str, Any]) -> Trajectory: + """Process a single row into a DPO Trajectory. Returns: - Dict with 'chosen' and 'rejected' Trajectory objects. + Trajectory with chosen in messages and rejected in extend_message. """ # Build prompt messages prompt_messages = self._build_prompt_messages(row) # Get chosen response chosen_raw = row.get(self.chosen_key, '') - chosen_messages = self._parse_response(chosen_raw) + chosen_response = self._parse_response(chosen_raw) # Get rejected response rejected_raw = row.get(self.rejected_key, '') - rejected_messages = self._parse_response(rejected_raw) + rejected_response = self._parse_response(rejected_raw) - # Build full trajectories - chosen_trajectory = Trajectory(messages=prompt_messages + chosen_messages) - rejected_trajectory = Trajectory(messages=prompt_messages + rejected_messages) + # Build full message lists + chosen_messages = prompt_messages + chosen_response + rejected_messages = prompt_messages + rejected_response - return { - 'chosen': chosen_trajectory, - 'rejected': rejected_trajectory, - } + # Return Trajectory with rejected in extend_message + return Trajectory( + messages=chosen_messages, + extend_message=[('rejected_messages', rejected_messages)] + ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - """Process batched data into paired trajectories. - - Output format depends on self.output_format: - - 'interleaved': Returns standard Trajectory column with alternating - chosen/rejected pairs for proper encoding. The DataLoader should use - a custom collate function to reorder into [chosen..., rejected...]. - - 'paired': Returns separate 'chosen' and 'rejected' columns (requires - special handling in DataLoader). - """ + """Process batched data into DPO trajectories.""" rows = self.map_col_to_row(rows) - processed = [self.preprocess(row) for row in rows] - - if self.output_format == 'interleaved': - # Flatten to interleaved format: [chosen_1, rejected_1, chosen_2, rejected_2, ...] - # This allows standard Dataset.encode() to work - trajectories = [] - for pair in processed: - trajectories.append(pair['chosen']) - trajectories.append(pair['rejected']) - # Return as standard trajectory column - return {'messages': trajectories} - else: - # Paired format: separate columns - return self.map_row_to_col(processed) + trajectories = [self.preprocess(row) for row in rows] + return {'messages': trajectories} class HHRLHFProcessor(Preprocessor): @@ -180,9 +157,8 @@ def _parse_hh_conversation(self, text: str) -> List[Message]: for i, part in enumerate(parts): if i == 0 and not part.startswith('Human: '): if part.strip(): - # Initial text before first Human marker if part.startswith('Human: '): - part = part[7:] # Remove "Human: " prefix + part = part[7:] messages.append(Message(role='user', content=part.strip())) continue @@ -196,7 +172,7 @@ def _parse_hh_conversation(self, text: str) -> List[Message]: return messages - def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + def preprocess(self, row: Dict[str, Any]) -> Trajectory: """Process HH-RLHF format row.""" chosen_text = row.get('chosen', '') rejected_text = row.get('rejected', '') @@ -204,15 +180,15 @@ def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: chosen_messages = self._parse_hh_conversation(chosen_text) rejected_messages = self._parse_hh_conversation(rejected_text) - return { - 'chosen': Trajectory(messages=chosen_messages), - 'rejected': Trajectory(messages=rejected_messages), - } + return Trajectory( + messages=chosen_messages, + extend_message=[('rejected_messages', rejected_messages)] + ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - processed = [self.preprocess(row) for row in rows] - return self.map_row_to_col(processed) + trajectories = [self.preprocess(row) for row in rows] + return {'messages': trajectories} class UltraFeedbackProcessor(Preprocessor): @@ -244,13 +220,13 @@ def __init__( self.response_key = response_key self.score_key = score_key - def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Trajectory]]: + def preprocess(self, row: Dict[str, Any]) -> Optional[Trajectory]: """Process UltraFeedback format row.""" instruction = row.get(self.instruction_key, '') completions = row.get(self.completions_key, []) if len(completions) < 2: - return None # Need at least 2 completions for preference + return None # Sort by score scored_completions = [ @@ -267,29 +243,29 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Trajectory]]: rejected_response = scored_completions[-1][1] # Build messages - messages = [] + prompt_messages = [] if self.system: - messages.append(Message(role='system', content=self.system)) - messages.append(Message(role='user', content=instruction)) + prompt_messages.append(Message(role='system', content=self.system)) + prompt_messages.append(Message(role='user', content=instruction)) - chosen_trajectory = Trajectory( - messages=messages + [Message(role='assistant', content=chosen_response)] - ) - rejected_trajectory = Trajectory( - messages=messages + [Message(role='assistant', content=rejected_response)] - ) + chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_response)] + rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_response)] - return { - 'chosen': chosen_trajectory, - 'rejected': rejected_trajectory, - } + return Trajectory( + messages=chosen_messages, + extend_message=[('rejected_messages', rejected_messages)] + ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - processed = [self.preprocess(row) for row in rows if self.preprocess(row) is not None] - if not processed: - return {'chosen': [], 'rejected': []} - return self.map_row_to_col(processed) + trajectories = [] + for row in rows: + result = self.preprocess(row) + if result is not None: + trajectories.append(result) + if not trajectories: + return {'messages': []} + return {'messages': trajectories} class ShareGPTDPOProcessor(Preprocessor): @@ -324,51 +300,48 @@ def _parse_sharegpt_message(self, msg: Dict) -> Message: content = msg.get('value', '') or msg.get('content', '') return Message(role=role, content=content) - def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + def preprocess(self, row: Dict[str, Any]) -> Trajectory: """Process ShareGPT DPO format row.""" conversations = row.get('conversations', []) - # Build prompt messages (excluding last assistant turn if present) - messages = [] + # Build prompt messages + prompt_messages = [] if self.system: - messages.append(Message(role='system', content=self.system)) + prompt_messages.append(Message(role='system', content=self.system)) for msg in conversations: - messages.append(self._parse_sharegpt_message(msg)) + prompt_messages.append(self._parse_sharegpt_message(msg)) - # Remove last message if it's assistant (will be replaced by chosen/rejected) - if messages and messages[-1].role == 'assistant': - messages = messages[:-1] + # Remove last message if it's assistant (will be replaced) + if prompt_messages and prompt_messages[-1]['role'] == 'assistant': + prompt_messages = prompt_messages[:-1] # Get chosen and rejected chosen_msg = row.get('chosen', {}) rejected_msg = row.get('rejected', {}) if isinstance(chosen_msg, dict): - chosen_response = Message( - role='assistant', - content=chosen_msg.get('value', '') or chosen_msg.get('content', '') - ) + chosen_content = chosen_msg.get('value', '') or chosen_msg.get('content', '') else: - chosen_response = Message(role='assistant', content=str(chosen_msg)) + chosen_content = str(chosen_msg) if isinstance(rejected_msg, dict): - rejected_response = Message( - role='assistant', - content=rejected_msg.get('value', '') or rejected_msg.get('content', '') - ) + rejected_content = rejected_msg.get('value', '') or rejected_msg.get('content', '') else: - rejected_response = Message(role='assistant', content=str(rejected_msg)) + rejected_content = str(rejected_msg) - return { - 'chosen': Trajectory(messages=messages + [chosen_response]), - 'rejected': Trajectory(messages=messages + [rejected_response]), - } + chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_content)] + rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_content)] + + return Trajectory( + messages=chosen_messages, + extend_message=[('rejected_messages', rejected_messages)] + ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - processed = [self.preprocess(row) for row in rows] - return self.map_row_to_col(processed) + trajectories = [self.preprocess(row) for row in rows] + return {'messages': trajectories} class IntelOrcaDPOProcessor(Preprocessor): @@ -386,24 +359,128 @@ class IntelOrcaDPOProcessor(Preprocessor): def __init__(self, default_system: Optional[str] = None): self.default_system = default_system - def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + def preprocess(self, row: Dict[str, Any]) -> Trajectory: """Process Intel ORCA DPO format row.""" system = row.get('system', self.default_system) question = row.get('question', '') chosen = row.get('chosen', '') rejected = row.get('rejected', '') - messages = [] + prompt_messages = [] if system: - messages.append(Message(role='system', content=system)) - messages.append(Message(role='user', content=question)) + prompt_messages.append(Message(role='system', content=system)) + prompt_messages.append(Message(role='user', content=question)) + + chosen_messages = prompt_messages + [Message(role='assistant', content=chosen)] + rejected_messages = prompt_messages + [Message(role='assistant', content=rejected)] + + return Trajectory( + messages=chosen_messages, + extend_message=[('rejected_messages', rejected_messages)] + ) + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + trajectories = [self.preprocess(row) for row in rows] + return {'messages': trajectories} - return { - 'chosen': Trajectory(messages=messages + [Message(role='assistant', content=chosen)]), - 'rejected': Trajectory(messages=messages + [Message(role='assistant', content=rejected)]), + +class EmojiDPOProcessor(Preprocessor): + """Preprocessor for shareAI/DPO-zh-en-emoji dataset format. + + Dataset format: + { + 'prompt': str, + 'answer_zh': str, # chosen response (Chinese) + 'answer_en': str, # rejected response (English) } + Output Trajectory format: + - messages: prompt + chosen (answer_zh) + - extend_message: [('rejected_messages', prompt + rejected (answer_en))] + + Args: + system: Optional system prompt. + chosen_key: Key for chosen response (default: 'answer_zh'). + rejected_key: Key for rejected response (default: 'answer_en'). + prompt_key: Key for prompt (default: 'prompt'). + """ + + def __init__( + self, + system: Optional[str] = None, + chosen_key: str = 'answer_zh', + rejected_key: str = 'answer_en', + prompt_key: str = 'prompt', + ): + self.system = system + self.chosen_key = chosen_key + self.rejected_key = rejected_key + self.prompt_key = prompt_key + + def preprocess(self, row: Dict[str, Any]) -> Trajectory: + """Process a single row.""" + prompt = row.get(self.prompt_key, '') + chosen = row.get(self.chosen_key, '') + rejected = row.get(self.rejected_key, '') + + prompt_messages = [] + if self.system: + prompt_messages.append(Message(role='system', content=self.system)) + prompt_messages.append(Message(role='user', content=prompt)) + + chosen_messages = prompt_messages + [Message(role='assistant', content=chosen)] + rejected_messages = prompt_messages + [Message(role='assistant', content=rejected)] + + return Trajectory( + messages=chosen_messages, + extend_message=[('rejected_messages', rejected_messages)] + ) + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + trajectories = [self.preprocess(row) for row in rows] + return {'messages': trajectories} + + +class UltraFeedbackKTOProcessor(Preprocessor): + """Preprocessor for ultrafeedback-binarized-preferences-cleaned-kto dataset. + + Dataset format: + { + 'prompt': str, + 'completion': str, + 'label': bool, # True for chosen, False for rejected + } + + For KTO training, we need (prompt, completion, label) format. + The label is stored in user_data. + + Args: + system: Optional system prompt. + """ + + def __init__(self, system: Optional[str] = None): + self.system = system + + def preprocess(self, row: Dict[str, Any]) -> Trajectory: + """Process a single row for KTO.""" + prompt = row.get('prompt', '') + completion = row.get('completion', '') + label = row.get('label', True) + + messages = [] + if self.system: + messages.append(Message(role='system', content=self.system)) + messages.append(Message(role='user', content=prompt)) + messages.append(Message(role='assistant', content=completion)) + + return Trajectory( + messages=messages, + user_data=[('kto_label', label)] + ) + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - processed = [self.preprocess(row) for row in rows] - return self.map_row_to_col(processed) + trajectories = [self.preprocess(row) for row in rows] + return {'messages': trajectories} From bcdad646f7e6fdcb6a8be8d5fe46b00a8504b93b Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 26 Mar 2026 23:38:27 +0800 Subject: [PATCH 05/14] fix --- cookbook/rl/dpo.py | 79 ++++++++++++++++++++++------------------------ 1 file changed, 38 insertions(+), 41 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index 6edb917d..1e3d1cd2 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -41,6 +41,12 @@ LOSS_TYPE – DPO variant (sigmoid/hinge/ipo/simpo/orpo/cpo) (default: sigmoid) SAVE_STEPS – checkpoint save interval (default: 100) MAX_LENGTH – max sequence length (default: 2048) + + Dataset field mapping (for custom datasets): + PROMPT_KEY – key for prompt field (default: 'prompt') + CHOSEN_KEY – key for chosen response (default: 'answer_zh') + REJECTED_KEY – key for rejected response (default: 'answer_en') + SYSTEM_PROMPT – system prompt to prepend (default: 'You are a helpful assistant.') """ import os @@ -51,12 +57,11 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger -from twinkle.data_format import Trajectory +from twinkle.data_format import Message, Trajectory from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss from twinkle.model import TransformersModel -from twinkle.preprocessor import EmojiDPOProcessor from twinkle.processor import InputProcessor from twinkle.template import Template @@ -91,6 +96,13 @@ # ── Dataset ─────────────────────────────────────────────────────────────────── +# Dataset field configuration for shareAI/DPO-zh-en-emoji +PROMPT_KEY = os.environ.get('PROMPT_KEY', 'prompt') +CHOSEN_KEY = os.environ.get('CHOSEN_KEY', 'answer_zh') +REJECTED_KEY = os.environ.get('REJECTED_KEY', 'answer_en') +SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.') + + def create_dpo_dataset(): """Create preference dataset for DPO training. @@ -98,27 +110,12 @@ def create_dpo_dataset(): - answer_zh: chosen response (Chinese) - answer_en: rejected response (English) - Output Trajectory format: - - messages: chosen response - - extend_message: [('rejected_messages', rejected response)] - - Note: We do NOT call dataset.encode() here. Encoding is done in - prepare_dpo_batch() to properly handle both chosen and rejected. + Returns raw dataset without preprocessing - preprocessing is done + in prepare_dpo_batch() to avoid PyArrow serialization issues. """ dataset = Dataset(DatasetMeta(DATASET_ID, split='train')) dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) - - # Use EmojiDPOProcessor for shareAI/DPO-zh-en-emoji dataset - # answer_zh -> chosen (messages), answer_en -> rejected (extend_message) - dataset.map(EmojiDPOProcessor( - system='You are a helpful assistant.', - chosen_key='answer_zh', - rejected_key='answer_en', - prompt_key='prompt', - )) - - # Do NOT encode here - encoding is done in prepare_dpo_batch - # to preserve extend_message for rejected encoding + # Do NOT apply preprocessor here - raw data will be processed in prepare_dpo_batch return dataset @@ -126,11 +123,10 @@ def prepare_dpo_batch( batch: List[Dict[str, Any]], template: Template, ) -> List[Dict[str, Any]]: - """Prepare DPO batch: encode both chosen and rejected. + """Prepare DPO batch: build trajectories and encode both chosen and rejected. Args: - batch: List of raw Trajectory dicts with messages (chosen) and - extend_message containing ('rejected_messages', rejected) + batch: List of raw data dicts from dataset (e.g., {prompt, answer_zh, answer_en}) Returns: List organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] @@ -140,27 +136,28 @@ def prepare_dpo_batch( rejected_samples = [] for item in batch: - # Get messages (chosen) and encode - messages = item.get('messages', []) - chosen_trajectory = Trajectory(messages=messages) + # Build messages from raw data + prompt = item.get(PROMPT_KEY, '') + chosen_response = item.get(CHOSEN_KEY, '') + rejected_response = item.get(REJECTED_KEY, '') + + # Build prompt messages + prompt_messages = [] + if SYSTEM_PROMPT: + prompt_messages.append(Message(role='system', content=SYSTEM_PROMPT)) + prompt_messages.append(Message(role='user', content=prompt)) + + # Build chosen trajectory and encode + chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_response)] + chosen_trajectory = Trajectory(messages=chosen_messages) chosen_encoded = template.encode(chosen_trajectory) chosen_samples.append(dict(chosen_encoded)) - # Get rejected from extend_message and encode - extend_message = item.get('extend_message', []) - rejected_messages = None - for key, msgs in extend_message: - if key == 'rejected_messages': - rejected_messages = msgs - break - - if rejected_messages: - rejected_trajectory = Trajectory(messages=rejected_messages) - rejected_encoded = template.encode(rejected_trajectory) - rejected_samples.append(dict(rejected_encoded)) - else: - # Fallback: use chosen (should not happen with proper preprocessing) - rejected_samples.append(dict(chosen_encoded)) + # Build rejected trajectory and encode + rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_response)] + rejected_trajectory = Trajectory(messages=rejected_messages) + rejected_encoded = template.encode(rejected_trajectory) + rejected_samples.append(dict(rejected_encoded)) # Return [chosen..., rejected...] return chosen_samples + rejected_samples From ee3602cde993cfbfd4758a9530e7619782ef4caa Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 26 Mar 2026 23:44:37 +0800 Subject: [PATCH 06/14] fix --- cookbook/rl/dpo.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index 1e3d1cd2..6f6db1ac 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -132,8 +132,8 @@ def prepare_dpo_batch( List organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] where each item is an encoded InputFeature dict. """ - chosen_samples = [] - rejected_samples = [] + chosen_trajectories = [] + rejected_trajectories = [] for item in batch: # Build messages from raw data @@ -147,17 +147,20 @@ def prepare_dpo_batch( prompt_messages.append(Message(role='system', content=SYSTEM_PROMPT)) prompt_messages.append(Message(role='user', content=prompt)) - # Build chosen trajectory and encode + # Build chosen and rejected trajectories chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_response)] - chosen_trajectory = Trajectory(messages=chosen_messages) - chosen_encoded = template.encode(chosen_trajectory) - chosen_samples.append(dict(chosen_encoded)) - - # Build rejected trajectory and encode rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_response)] - rejected_trajectory = Trajectory(messages=rejected_messages) - rejected_encoded = template.encode(rejected_trajectory) - rejected_samples.append(dict(rejected_encoded)) + + chosen_trajectories.append(Trajectory(messages=chosen_messages)) + rejected_trajectories.append(Trajectory(messages=rejected_messages)) + + # Batch encode all trajectories (properly handles multimodal preprocessing) + chosen_encoded = template.batch_encode(chosen_trajectories) + rejected_encoded = template.batch_encode(rejected_trajectories) + + # Convert to list of dicts + chosen_samples = [dict(enc) for enc in chosen_encoded] + rejected_samples = [dict(enc) for enc in rejected_encoded] # Return [chosen..., rejected...] return chosen_samples + rejected_samples From bebbe780c5407d7383fdf53b13952582e6dedc7b Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 27 Mar 2026 10:53:02 +0800 Subject: [PATCH 07/14] wip --- cookbook/rl/dpo.py | 101 ++++++++++++++++++++-------------------- src/twinkle/loss/dpo.py | 17 +++++-- 2 files changed, 65 insertions(+), 53 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index 6f6db1ac..5850f8c7 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -25,14 +25,13 @@ - extend_message: [('rejected_messages', List[Message])] - rejected response For SimPO/ORPO variants that don't require a reference model, -set USE_REFERENCE_MODEL=0 to skip reference model computation. +set REF_MODEL_GPUS=0 to skip reference model computation. Environment variables (all optional): MODEL_ID – (default: ms://Qwen/Qwen3.5-4B) DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji) MODEL_GPUS – GPUs for policy model (default: 4) REF_MODEL_GPUS – GPUs for reference model (default: 4, 0 to disable) - USE_REFERENCE_MODEL – Whether to use reference model (default: 1) BATCH_SIZE – global batch size (preference pairs) (default: 8) MICRO_BATCH_SIZE – per-device micro batch size (default: 2) MAX_STEPS – total optimization steps (default: 1000) @@ -62,6 +61,7 @@ from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss from twinkle.model import TransformersModel +from twinkle.preprocessor import EmojiDPOProcessor from twinkle.processor import InputProcessor from twinkle.template import Template @@ -73,14 +73,7 @@ MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4)) -USE_REFERENCE_MODEL = bool(int(os.environ.get('USE_REFERENCE_MODEL', 1))) - -# Adjust total GPUs based on whether reference model is used -if USE_REFERENCE_MODEL and REF_MODEL_GPUS > 0: - NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS -else: - NUM_GPUS = MODEL_GPUS - USE_REFERENCE_MODEL = False +NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) @@ -92,30 +85,18 @@ SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100)) MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) ADAPTER_NAME = 'default' - - -# ── Dataset ─────────────────────────────────────────────────────────────────── - -# Dataset field configuration for shareAI/DPO-zh-en-emoji -PROMPT_KEY = os.environ.get('PROMPT_KEY', 'prompt') -CHOSEN_KEY = os.environ.get('CHOSEN_KEY', 'answer_zh') -REJECTED_KEY = os.environ.get('REJECTED_KEY', 'answer_en') SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.') def create_dpo_dataset(): - """Create preference dataset for DPO training. - - Uses shareAI/DPO-zh-en-emoji dataset: - - answer_zh: chosen response (Chinese) - - answer_en: rejected response (English) - - Returns raw dataset without preprocessing - preprocessing is done - in prepare_dpo_batch() to avoid PyArrow serialization issues. - """ - dataset = Dataset(DatasetMeta(DATASET_ID, split='train')) + dataset = Dataset(DatasetMeta(DATASET_ID)) dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) - # Do NOT apply preprocessor here - raw data will be processed in prepare_dpo_batch + dataset.map( + EmojiDPOProcessor, + init_args={ + 'system': SYSTEM_PROMPT, + } + ) return dataset @@ -130,7 +111,6 @@ def prepare_dpo_batch( Returns: List organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] - where each item is an encoded InputFeature dict. """ chosen_trajectories = [] rejected_trajectories = [] @@ -189,15 +169,10 @@ def create_loss(loss_type: str, beta: float, reference_free: bool = False): def main(): # Set up device groups - if USE_REFERENCE_MODEL: - device_groups = [ - DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), - DeviceGroup(name='reference', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'), - ] - else: - device_groups = [ - DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), - ] + device_groups = [ + DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + DeviceGroup(name='reference', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'), + ] policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) @@ -226,7 +201,7 @@ def main(): reference_free = LOSS_TYPE in ['simpo', 'orpo', 'cpo'] # Set up loss function - loss_fn = create_loss(LOSS_TYPE, DPO_BETA, reference_free=not USE_REFERENCE_MODEL) + loss_fn = create_loss(LOSS_TYPE, DPO_BETA, reference_free=False) policy_model.set_loss(loss_fn) policy_model.set_processor(InputProcessor) policy_model.set_template('Template', model_id=MODEL_ID) @@ -234,9 +209,9 @@ def main(): # Get template for encoding rejected messages template = Template(model_id=MODEL_ID, max_length=MAX_LENGTH) - # ── Reference Model Setup (if needed) ───────────────────────────────────── + # ── Reference Model Setup ───────────────────────────────────────────────── ref_model = None - if USE_REFERENCE_MODEL and not reference_free: + if not reference_free: ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, dp_size=REF_MODEL_GPUS) ref_model = TransformersModel( model_id=MODEL_ID, @@ -261,8 +236,7 @@ def main(): optim_step = 0 logger.info(get_device_placement()) - logger.info(f'Starting DPO training: loss_type={LOSS_TYPE}, beta={DPO_BETA}, ' - f'use_ref_model={USE_REFERENCE_MODEL}') + logger.info(f'Starting DPO training: loss_type={LOSS_TYPE}, beta={DPO_BETA}') # ── Training Loop ───────────────────────────────────────────────────────── for batch in dataloader: @@ -274,19 +248,46 @@ def main(): dpo_batch = prepare_dpo_batch(batch_list, template) # Compute reference log probabilities if using reference model - ref_logps = None + # We compute sequence-level logps here to avoid alignment issues with micro-batching + ref_chosen_logps = None + ref_rejected_logps = None if ref_model is not None: with torch.no_grad(): ref_outputs = ref_model.forward_only(inputs=dpo_batch) - ref_logps = ref_outputs.get('logps') + ref_logps = ref_outputs.get('logps') # [batch, seq_len] + if ref_logps is not None: + # Get labels and pad to same length for stacking + label_tensors = [torch.as_tensor(s['labels']) for s in dpo_batch] + max_len = max(t.shape[0] for t in label_tensors) + # Pad labels with -100 (ignore_index) to max length + padded_labels = [] + for t in label_tensors: + if t.shape[0] < max_len: + pad_size = max_len - t.shape[0] + t = torch.cat([torch.full((pad_size,), -100, dtype=t.dtype), t]) + padded_labels.append(t) + ref_labels = torch.stack(padded_labels) + if ref_labels.device != ref_logps.device: + ref_labels = ref_labels.to(ref_logps.device) + # Align sequence lengths if needed + if ref_logps.shape[1] != ref_labels.shape[1]: + min_len = min(ref_logps.shape[1], ref_labels.shape[1]) + ref_logps = ref_logps[:, -min_len:] + ref_labels = ref_labels[:, -min_len:] + # Compute sequence-level logps (sum of valid token logps) + loss_mask = (ref_labels != -100).float() + seq_logps = (ref_logps * loss_mask).sum(dim=-1) # [batch] + + # Split into chosen and rejected + half = seq_logps.shape[0] // 2 + ref_chosen_logps = seq_logps[:half] + ref_rejected_logps = seq_logps[half:] # Forward-backward pass with DPO loss - # micro_batch_size should be even to maintain chosen/rejected pairing - actual_micro_batch = MICRO_BATCH_SIZE * 2 policy_model.forward_backward( inputs=dpo_batch, - ref_logps=ref_logps, - micro_batch_size=actual_micro_batch, + ref_chosen_logps=ref_chosen_logps, + ref_rejected_logps=ref_rejected_logps, ) # Gradient clipping and optimizer step diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index e727193a..7f89d446 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -154,12 +154,23 @@ def _pad_and_align_logps( import torch if torch.is_tensor(logps): - if logps.shape == target_shape: - return logps.to(device=device, dtype=dtype) - elif logps.dim() == 1: + if logps.dim() == 1: logps = logps.unsqueeze(0) if logps.shape == target_shape: return logps.to(device=device, dtype=dtype) + # Handle tensor with different sequence length - align to target shape + if logps.dim() == 2 and logps.shape[0] == target_shape[0]: + batch_size, target_seq_len = target_shape + src_seq_len = logps.shape[1] + logps = logps.to(device=device, dtype=dtype) + if src_seq_len > target_seq_len: + # Truncate: take the last target_seq_len tokens (response part) + return logps[:, -target_seq_len:] + else: + # Pad: add zeros at the beginning + padded = torch.zeros(target_shape, device=device, dtype=dtype) + padded[:, -src_seq_len:] = logps + return padded # Handle ragged list input if isinstance(logps, (list, tuple)): From 3f8d1a3c4393aa0a5b666dbd6dea8dd6da3c0548 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 27 Mar 2026 13:44:22 +0800 Subject: [PATCH 08/14] wip --- cookbook/rl/dpo.py | 53 +++++------- src/twinkle/data_format/trajectory.py | 2 +- src/twinkle/preprocessor/dpo.py | 47 +++++----- src/twinkle/template/base.py | 120 ++++++++++++++++++++------ 4 files changed, 140 insertions(+), 82 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index 5850f8c7..e7720714 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -56,7 +56,7 @@ import twinkle from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger -from twinkle.data_format import Message, Trajectory +from twinkle.data_format import Trajectory from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss @@ -101,13 +101,15 @@ def create_dpo_dataset(): def prepare_dpo_batch( - batch: List[Dict[str, Any]], + batch: List[Trajectory], template: Template, ) -> List[Dict[str, Any]]: - """Prepare DPO batch: build trajectories and encode both chosen and rejected. + """Prepare DPO batch: encode both chosen and rejected from preprocessed Trajectories. Args: - batch: List of raw data dicts from dataset (e.g., {prompt, answer_zh, answer_en}) + batch: List of Trajectory objects with: + - messages: chosen response messages + - extend_message: [('rejected_messages', rejected_messages)] Returns: List organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] @@ -115,26 +117,12 @@ def prepare_dpo_batch( chosen_trajectories = [] rejected_trajectories = [] - for item in batch: - # Build messages from raw data - prompt = item.get(PROMPT_KEY, '') - chosen_response = item.get(CHOSEN_KEY, '') - rejected_response = item.get(REJECTED_KEY, '') - - # Build prompt messages - prompt_messages = [] - if SYSTEM_PROMPT: - prompt_messages.append(Message(role='system', content=SYSTEM_PROMPT)) - prompt_messages.append(Message(role='user', content=prompt)) - - # Build chosen and rejected trajectories - chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_response)] - rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_response)] - - chosen_trajectories.append(Trajectory(messages=chosen_messages)) + for traj in batch: + chosen_trajectories.append(Trajectory(messages=traj.messages)) + rejected_messages = [m[1] for m in traj['extend_messages'] if m[0] == 'rejected_messages'][0] rejected_trajectories.append(Trajectory(messages=rejected_messages)) - # Batch encode all trajectories (properly handles multimodal preprocessing) + # Batch encode all trajectories chosen_encoded = template.batch_encode(chosen_trajectories) rejected_encoded = template.batch_encode(rejected_trajectories) @@ -175,7 +163,16 @@ def main(): ] policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) - twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups) + + # ── DataLoader Setup ────────────────────────────────────────────────────── + dataloader = DataLoader( + dataset=create_dpo_dataset, + batch_size=BATCH_SIZE, + min_batch_size=BATCH_SIZE, + device_mesh=policy_mesh, + ) + length = len(dataloader) # ── Policy Model Setup ──────────────────────────────────────────────────── lora_config = LoraConfig( @@ -224,16 +221,6 @@ def main(): else: logger.info(f'Training without reference model (loss_type={LOSS_TYPE})') - # ── DataLoader Setup ────────────────────────────────────────────────────── - GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS - dataloader = DataLoader( - dataset=create_dpo_dataset, - batch_size=GLOBAL_BATCH_SIZE, - min_batch_size=GLOBAL_BATCH_SIZE, - device_mesh=policy_mesh, - remote_group='policy', - ) - optim_step = 0 logger.info(get_device_placement()) logger.info(f'Starting DPO training: loss_type={LOSS_TYPE}, beta={DPO_BETA}') diff --git a/src/twinkle/data_format/trajectory.py b/src/twinkle/data_format/trajectory.py index c7742d75..d6910eaf 100644 --- a/src/twinkle/data_format/trajectory.py +++ b/src/twinkle/data_format/trajectory.py @@ -13,6 +13,6 @@ class Trajectory(TypedDict, total=False): messages: List[Message] - extend_message: List[Tuple[str, List[Message]]] + extend_message: List[List[Message]] tools: List[Tool] user_data: List[Tuple[str, Any]] diff --git a/src/twinkle/preprocessor/dpo.py b/src/twinkle/preprocessor/dpo.py index a81b5f57..7c35b676 100644 --- a/src/twinkle/preprocessor/dpo.py +++ b/src/twinkle/preprocessor/dpo.py @@ -123,14 +123,15 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: # Return Trajectory with rejected in extend_message return Trajectory( messages=chosen_messages, - extend_message=[('rejected_messages', rejected_messages)] + extend_message=[rejected_messages] ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Process batched data into DPO trajectories.""" rows = self.map_col_to_row(rows) - trajectories = [self.preprocess(row) for row in rows] - return {'messages': trajectories} + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows class HHRLHFProcessor(Preprocessor): @@ -182,13 +183,14 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: return Trajectory( messages=chosen_messages, - extend_message=[('rejected_messages', rejected_messages)] + extend_message=[rejected_messages] ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - trajectories = [self.preprocess(row) for row in rows] - return {'messages': trajectories} + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows class UltraFeedbackProcessor(Preprocessor): @@ -253,7 +255,7 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Trajectory]: return Trajectory( messages=chosen_messages, - extend_message=[('rejected_messages', rejected_messages)] + extend_message=[rejected_messages] ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: @@ -264,8 +266,9 @@ def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: if result is not None: trajectories.append(result) if not trajectories: - return {'messages': []} - return {'messages': trajectories} + return {} + rows = self.map_row_to_col(trajectories) + return rows class ShareGPTDPOProcessor(Preprocessor): @@ -335,13 +338,14 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: return Trajectory( messages=chosen_messages, - extend_message=[('rejected_messages', rejected_messages)] + extend_message=[rejected_messages] ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - trajectories = [self.preprocess(row) for row in rows] - return {'messages': trajectories} + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows class IntelOrcaDPOProcessor(Preprocessor): @@ -376,13 +380,14 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: return Trajectory( messages=chosen_messages, - extend_message=[('rejected_messages', rejected_messages)] + extend_message=[rejected_messages] ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - trajectories = [self.preprocess(row) for row in rows] - return {'messages': trajectories} + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows class EmojiDPOProcessor(Preprocessor): @@ -434,13 +439,14 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: return Trajectory( messages=chosen_messages, - extend_message=[('rejected_messages', rejected_messages)] + extend_message=[rejected_messages] ) def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - trajectories = [self.preprocess(row) for row in rows] - return {'messages': trajectories} + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows class UltraFeedbackKTOProcessor(Preprocessor): @@ -482,5 +488,6 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - trajectories = [self.preprocess(row) for row in rows] - return {'messages': trajectories} + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 167a459d..cd783a9d 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -179,7 +179,7 @@ def _add_default_system(self, trajectory: Trajectory) -> List[Trajectory]: if self.use_chat_template and self.default_system: if trajectory['messages'][0]['role'] == 'user': trajectory['messages'].insert(0, Message(role='system', content=self.default_system)) - for (_, messages) in trajectory.get('extend_message', []): + for messages in trajectory.get('extend_message', []): if messages and messages[0]['role'] == 'user': messages.insert(0, Message(role='system', content=self.default_system)) return [trajectory] @@ -209,43 +209,79 @@ def _extract_reasoning_content(messages: list[Message]) -> List[Message]: extra_messages = trajectory.get('extend_message', []) if extra_messages: result = [] - for key, extra_message in trajectory.get('extend_message', []): - result.append((key, _extract_reasoning_content(extra_message))) + for extra_message in trajectory.get('extend_message', []): + result.append(_extract_reasoning_content(extra_message)) trajectory['extend_message'] = result return [trajectory] + def _truncate_feature(self, feature: InputFeature, strategy: str) -> InputFeature: + """Truncate input_ids and labels in a single InputFeature.""" + length = len(feature['input_ids']) + if length <= self.max_length: + return feature + if strategy == 'raise': + raise ValueError(f'Input length {length} exceeds max_length {self.max_length}') + result = dict(feature) + if strategy == 'left': + result['input_ids'] = result['input_ids'][-self.max_length:] + if 'labels' in result: + result['labels'] = result['labels'][-self.max_length:] + elif strategy == 'right': + result['input_ids'] = result['input_ids'][:self.max_length] + if 'labels' in result: + result['labels'] = result['labels'][:self.max_length] + return InputFeature(**result) + def _check_max_length(self, input_feature: InputFeature) -> List[InputFeature]: - if self.max_length and len(input_feature['input_ids']) > self.max_length: - if self.truncation_strategy == 'raise': - raise ValueError(f'An input message(length: {len(input_feature["input_ids"])} ' - f'exceeds the maximum length({self.max_length})') - elif self.truncation_strategy == 'left': - return [InputFeature(**{key: value[-self.max_length:] for key, value in input_feature.items()})] - elif self.truncation_strategy == 'right': - return [InputFeature(**{key: value[:self.max_length] for key, value in input_feature.items()})] - else: # split - result = [] - total_length = len(input_feature['input_ids']) - for start in range(0, total_length, self.max_length): - end = min(start + self.max_length, total_length) - result.append(InputFeature(**{key: value[start:end] for key, value in input_feature.items()})) - return result - else: + if not self.max_length: return [input_feature] + strategy = self.truncation_strategy + + # Split strategy: doesn't support extend_message + if strategy == 'split': + if input_feature.get('extend_message'): + raise ValueError('Split strategy does not support extend_message.') + results = [] + for start in range(0, len(input_feature['input_ids']), self.max_length): + end = min(start + self.max_length, len(input_feature['input_ids'])) + feat = dict(input_feature) + feat['input_ids'] = feat['input_ids'][start:end] + if 'labels' in feat: + feat['labels'] = feat['labels'][start:end] + results.append(InputFeature(**feat)) + return results + + # left/right/raise: apply to main and extend_message + result = self._truncate_feature(input_feature, strategy) + if input_feature.get('extend_message'): + result['extend_message'] = [self._truncate_feature(f, strategy) for f in input_feature['extend_message']] + return [result] + + def _add_attention_to_feature(self, feature: InputFeature) -> InputFeature: + """Add attention fields to a single InputFeature.""" + input_ids = feature['input_ids'] + feature['attention_mask'] = np.ones_like(input_ids) + feature['position_ids'] = np.arange(len(input_ids)) + feature['length'] = len(input_ids) + return feature + def _add_attention_fields(self, input_feature: InputFeature) -> List[InputFeature]: - input_ids = input_feature['input_ids'] - input_feature['attention_mask'] = np.ones_like(input_ids) - input_feature['position_ids'] = np.arange(len(input_ids)) - input_feature['length'] = len(input_ids) + self._add_attention_to_feature(input_feature) + if input_feature.get('extend_message'): + for f in input_feature['extend_message']: + self._add_attention_to_feature(f) return [input_feature] def _roll_labels(self, input_feature: InputFeature) -> List[InputFeature]: input_feature['labels'] = np.roll(input_feature['labels'], -1, axis=-1) + if input_feature.get('extend_message'): + for f in input_feature['extend_message']: + f['labels'] = np.roll(f['labels'], -1, axis=-1) return [input_feature] - def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: - messages = trajectory['messages'] + def _process_mm_messages(self, messages: List) -> List: + """Process multimodal content in a list of messages.""" new_messages = [] for message in messages: message = copy(message) @@ -265,8 +301,16 @@ def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: new_messages.append( transfer_to_standard_message(message, self.image_placeholder, self.video_placeholder, self.audio_placeholder, self.is_mm)) + return new_messages + + def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: + trajectory['messages'] = self._process_mm_messages(trajectory['messages']) + if trajectory.get('extend_message'): + new_extend_message = [] + for msgs in trajectory['extend_message']: + new_extend_message.append(self._process_mm_messages(msgs)) + trajectory['extend_message'] = new_extend_message - trajectory['messages'] = new_messages return [trajectory] def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bool = False, **kwargs): @@ -283,7 +327,8 @@ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bo **kwargs) return inputs - def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> InputFeature: + def _encode_messages(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> InputFeature: + """Encode a single trajectory's messages into InputFeature.""" if self.use_chat_template: if add_generation_prompt: # For inference: just get input_ids with generation prompt, no labels needed @@ -306,11 +351,30 @@ def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> input_ids = self.tokenizer.encode(text) encoded = {} labels = deepcopy(input_ids) - return InputFeature( + + input_feature = InputFeature( input_ids=np.array(input_ids), labels=np.array(labels), **encoded, ) + trajectory.update(input_feature) + return trajectory + + def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> InputFeature: + # Encode main messages + result = self._encode_messages(trajectory, add_generation_prompt) + + # Encode extend_message (e.g., rejected messages in DPO) + if trajectory.get('extend_message'): + encoded_extend = [] + for msgs in trajectory['extend_message']: + # Create a temporary trajectory with the extended messages + ext_trajectory = Trajectory(messages=msgs) + ext_feature = self._encode_messages(ext_trajectory, add_generation_prompt) + encoded_extend.append(ext_feature) + result['extend_message'] = encoded_extend + + return result @staticmethod def map_col_to_row(trajectories: Dict[str, Any]): From d1f223fbc9db55362e3e12363a07139957293f1d Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 27 Mar 2026 15:01:31 +0800 Subject: [PATCH 09/14] wip --- cookbook/rl/dpo.py | 48 +++--- .../Components/Data Format/Trajectory.md | 6 +- .../Trajectory.md" | 6 +- src/twinkle/data_format/trajectory.py | 1 - src/twinkle/dataset/base.py | 2 +- src/twinkle/preprocessor/dpo.py | 147 ++++++++-------- src/twinkle/template/base.py | 157 ++++++++++-------- 7 files changed, 199 insertions(+), 168 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index e7720714..a9937aa2 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -5,7 +5,7 @@ Pipeline: 1. Load preference dataset with chosen/rejected pairs. - 2. Encode chosen and rejected separately. + 2. Encode positive and negative separately. 3. Compute reference model log probabilities (frozen). 4. Train policy model using DPO loss. @@ -20,9 +20,9 @@ DataLoader RefModel (frozen) PolicyModel (trainable) (ref GPUs) (policy GPUs) -DPO Trajectory format: - - messages: List[Message] - chosen response - - extend_message: [('rejected_messages', List[Message])] - rejected response +DPO data format (after preprocessing): + - positive: List[Trajectory] - chosen responses + - negative: List[Trajectory] - rejected responses For SimPO/ORPO variants that don't require a reference model, set REF_MODEL_GPUS=0 to skip reference model computation. @@ -89,6 +89,7 @@ def create_dpo_dataset(): + """Create DPO dataset with positive/negative format.""" dataset = Dataset(DatasetMeta(DATASET_ID)) dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) dataset.map( @@ -97,41 +98,33 @@ def create_dpo_dataset(): 'system': SYSTEM_PROMPT, } ) + # DPO preprocessor returns {'positive': [...], 'negative': [...]} + # batch_encode handles this format automatically + dataset.encode() return dataset def prepare_dpo_batch( - batch: List[Trajectory], + batch: Dict[str, List[Any]], template: Template, ) -> List[Dict[str, Any]]: - """Prepare DPO batch: encode both chosen and rejected from preprocessed Trajectories. + """Prepare DPO batch: convert encoded batch to list format for training. Args: - batch: List of Trajectory objects with: - - messages: chosen response messages - - extend_message: [('rejected_messages', rejected_messages)] + batch: Dict with 'positive' and 'negative' keys, each containing List[InputFeature] Returns: - List organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + List organized as [positive_1, ..., positive_n, negative_1, ..., negative_n] """ - chosen_trajectories = [] - rejected_trajectories = [] - - for traj in batch: - chosen_trajectories.append(Trajectory(messages=traj.messages)) - rejected_messages = [m[1] for m in traj['extend_messages'] if m[0] == 'rejected_messages'][0] - rejected_trajectories.append(Trajectory(messages=rejected_messages)) - - # Batch encode all trajectories - chosen_encoded = template.batch_encode(chosen_trajectories) - rejected_encoded = template.batch_encode(rejected_trajectories) + positive_features = batch.get('positive', []) + negative_features = batch.get('negative', []) # Convert to list of dicts - chosen_samples = [dict(enc) for enc in chosen_encoded] - rejected_samples = [dict(enc) for enc in rejected_encoded] + positive_samples = [dict(f) for f in positive_features] + negative_samples = [dict(f) for f in negative_features] - # Return [chosen..., rejected...] - return chosen_samples + rejected_samples + # Return [positive..., negative...] + return positive_samples + negative_samples # ── Loss Factory ────────────────────────────────────────────────────────────── @@ -230,9 +223,8 @@ def main(): if optim_step >= MAX_STEPS: break - # Prepare DPO batch: [chosen..., rejected...] - batch_list = batch if isinstance(batch, list) else [batch] - dpo_batch = prepare_dpo_batch(batch_list, template) + # batch is Dict[str, List[Trajectory]] with 'positive' and 'negative' keys + dpo_batch = prepare_dpo_batch(batch, template) # Compute reference log probabilities if using reference model # We compute sequence-level logps here to avoid alignment issues with micro-batching diff --git a/docs/source_en/Components/Data Format/Trajectory.md b/docs/source_en/Components/Data Format/Trajectory.md index d0c14aec..0efc6b6e 100644 --- a/docs/source_en/Components/Data Format/Trajectory.md +++ b/docs/source_en/Components/Data Format/Trajectory.md @@ -5,12 +5,14 @@ The raw data structure input to Template after dataset ETL is `Trajectory` (traj ```python class Trajectory(TypedDict, total=False): messages: List[Message] - extend_message: List[Tuple[str, List[Message]]] tools: List[Tool] + user_data: List[Tuple[str, Any]] ``` - messages: A list of Message messages, representing the multi-turn conversations actually conducted by the model, usually alternating between `user` and `assistant`. -- extend_message: In training such as DPO and PPO, unusable trajectories or low-score trajectories are usually needed, which will be placed in extend_message - tools: A list of all available tools for the model in this call +- user_data: User-defined data, such as labels in KTO training + +For preference alignment training like DPO, preprocessors return `{'positive': List[Trajectory], 'negative': List[Trajectory]}` format. Trajectory is the standard interface for all dataset preprocessing outputs and template inputs in Twinkle. The format conversion goes from the original dataset to Trajectory, and then to InputFeature. diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md" index f7ed4f12..5281999c 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md" @@ -5,12 +5,14 @@ ```python class Trajectory(TypedDict, total=False): messages: List[Message] - extend_message: List[Tuple[str, List[Message]]] tools: List[Tool] + user_data: List[Tuple[str, Any]] ``` - messages: Message消息的列表,代表模型实际进行的多轮对话,通常是`user`和`assistant`交替出现。 -- extend_message: 在DPO、PPO等训练中通常需要不可用轨迹,或低分轨迹,该轨迹会放在extend_message中 - tools: 模型在本次调用中的所有可用工具列表 +- user_data: 用户自定义数据,如KTO训练中的label + +对于DPO等偏好对齐训练,预处理器返回`{'positive': List[Trajectory], 'negative': List[Trajectory]}`格式。 Trajectory是twinkle中所有数据集预处理输出,模板输入的标准接口。格式转换为由原始数据集转换为Trajectory,再到InputFeature。 diff --git a/src/twinkle/data_format/trajectory.py b/src/twinkle/data_format/trajectory.py index d6910eaf..51a21fc5 100644 --- a/src/twinkle/data_format/trajectory.py +++ b/src/twinkle/data_format/trajectory.py @@ -13,6 +13,5 @@ class Trajectory(TypedDict, total=False): messages: List[Message] - extend_message: List[List[Message]] tools: List[Tool] user_data: List[Tuple[str, Any]] diff --git a/src/twinkle/dataset/base.py b/src/twinkle/dataset/base.py index 00adc984..fc14a030 100644 --- a/src/twinkle/dataset/base.py +++ b/src/twinkle/dataset/base.py @@ -87,7 +87,7 @@ def encode(self, add_generation_prompt: bool = False, **kwargs): with processing_lock('dataset'): # use a default lock because encode is to all datasets self.dataset = self.dataset.map(encode_fn, - **kwargs).filter(lambda batch: [len(x) > 0 for x in batch['input_ids']], + **kwargs).filter(lambda batch: [True] * len(next(iter(batch.values()))) if 'input_ids' not in batch else [len(x) > 0 for x in batch['input_ids']], **kwargs) @remote_function() diff --git a/src/twinkle/preprocessor/dpo.py b/src/twinkle/preprocessor/dpo.py index 7c35b676..0a03c4ad 100644 --- a/src/twinkle/preprocessor/dpo.py +++ b/src/twinkle/preprocessor/dpo.py @@ -3,11 +3,11 @@ DPO (Direct Preference Optimization) Data Preprocessors. These preprocessors convert various preference dataset formats into the standard -Trajectory format required by Twinkle for DPO training. +format required by Twinkle for DPO training. -DPO Trajectory format: - - messages: List[Message] - chosen response messages - - extend_message: [('rejected_messages', List[Message])] - rejected response messages +DPO output format: + - positive: Trajectory - chosen response trajectory + - negative: Trajectory - rejected response trajectory """ from typing import Any, Dict, List, Optional, Union @@ -18,7 +18,7 @@ class DPOProcessor(Preprocessor): """Generic DPO preference data preprocessor. - Converts preference data with chosen/rejected pairs into Trajectory format. + Converts preference data with chosen/rejected pairs into positive/negative Trajectories. Supports multiple common dataset formats. Expected input format (one of): @@ -27,9 +27,9 @@ class DPOProcessor(Preprocessor): 3. {'messages': List[Message], 'chosen': str, 'rejected': str} 4. {'chosen': List[Message], 'rejected': List[Message]} (full conversations) - Output Trajectory format: - - messages: chosen response (prompt + chosen assistant message) - - extend_message: [('rejected_messages', rejected_messages)] + Output format: + - positive: Trajectory with chosen response + - negative: Trajectory with rejected response Args: system: Optional system prompt to prepend. @@ -99,11 +99,11 @@ def _build_prompt_messages(self, row: Dict[str, Any]) -> List[Message]: return messages - def preprocess(self, row: Dict[str, Any]) -> Trajectory: - """Process a single row into a DPO Trajectory. + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + """Process a single row into positive/negative Trajectories. Returns: - Trajectory with chosen in messages and rejected in extend_message. + Dict with 'positive' and 'negative' Trajectory. """ # Build prompt messages prompt_messages = self._build_prompt_messages(row) @@ -120,18 +120,22 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: chosen_messages = prompt_messages + chosen_response rejected_messages = prompt_messages + rejected_response - # Return Trajectory with rejected in extend_message - return Trajectory( - messages=chosen_messages, - extend_message=[rejected_messages] - ) + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - """Process batched data into DPO trajectories.""" + """Process batched data into DPO format.""" rows = self.map_col_to_row(rows) - rows = [self.preprocess(row) for row in rows] - rows = self.map_row_to_col(rows) - return rows + results = [self.preprocess(row) for row in rows] + # Collect all positive and negative trajectories + positive_list = [r['positive'] for r in results] + negative_list = [r['negative'] for r in results] + return { + 'positive': positive_list, + 'negative': negative_list, + } class HHRLHFProcessor(Preprocessor): @@ -173,7 +177,7 @@ def _parse_hh_conversation(self, text: str) -> List[Message]: return messages - def preprocess(self, row: Dict[str, Any]) -> Trajectory: + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: """Process HH-RLHF format row.""" chosen_text = row.get('chosen', '') rejected_text = row.get('rejected', '') @@ -181,16 +185,18 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: chosen_messages = self._parse_hh_conversation(chosen_text) rejected_messages = self._parse_hh_conversation(rejected_text) - return Trajectory( - messages=chosen_messages, - extend_message=[rejected_messages] - ) + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - rows = [self.preprocess(row) for row in rows] - rows = self.map_row_to_col(rows) - return rows + results = [self.preprocess(row) for row in rows] + return { + 'positive': [r['positive'] for r in results], + 'negative': [r['negative'] for r in results], + } class UltraFeedbackProcessor(Preprocessor): @@ -222,7 +228,7 @@ def __init__( self.response_key = response_key self.score_key = score_key - def preprocess(self, row: Dict[str, Any]) -> Optional[Trajectory]: + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Trajectory]]: """Process UltraFeedback format row.""" instruction = row.get(self.instruction_key, '') completions = row.get(self.completions_key, []) @@ -253,22 +259,21 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Trajectory]: chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_response)] rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_response)] - return Trajectory( - messages=chosen_messages, - extend_message=[rejected_messages] - ) + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - trajectories = [] - for row in rows: - result = self.preprocess(row) - if result is not None: - trajectories.append(result) - if not trajectories: + results = [self.preprocess(row) for row in rows] + results = [r for r in results if r is not None] + if not results: return {} - rows = self.map_row_to_col(trajectories) - return rows + return { + 'positive': [r['positive'] for r in results], + 'negative': [r['negative'] for r in results], + } class ShareGPTDPOProcessor(Preprocessor): @@ -303,7 +308,7 @@ def _parse_sharegpt_message(self, msg: Dict) -> Message: content = msg.get('value', '') or msg.get('content', '') return Message(role=role, content=content) - def preprocess(self, row: Dict[str, Any]) -> Trajectory: + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: """Process ShareGPT DPO format row.""" conversations = row.get('conversations', []) @@ -336,16 +341,18 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_content)] rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_content)] - return Trajectory( - messages=chosen_messages, - extend_message=[rejected_messages] - ) + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - rows = [self.preprocess(row) for row in rows] - rows = self.map_row_to_col(rows) - return rows + results = [self.preprocess(row) for row in rows] + return { + 'positive': [r['positive'] for r in results], + 'negative': [r['negative'] for r in results], + } class IntelOrcaDPOProcessor(Preprocessor): @@ -363,7 +370,7 @@ class IntelOrcaDPOProcessor(Preprocessor): def __init__(self, default_system: Optional[str] = None): self.default_system = default_system - def preprocess(self, row: Dict[str, Any]) -> Trajectory: + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: """Process Intel ORCA DPO format row.""" system = row.get('system', self.default_system) question = row.get('question', '') @@ -378,16 +385,18 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: chosen_messages = prompt_messages + [Message(role='assistant', content=chosen)] rejected_messages = prompt_messages + [Message(role='assistant', content=rejected)] - return Trajectory( - messages=chosen_messages, - extend_message=[rejected_messages] - ) + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - rows = [self.preprocess(row) for row in rows] - rows = self.map_row_to_col(rows) - return rows + results = [self.preprocess(row) for row in rows] + return { + 'positive': [r['positive'] for r in results], + 'negative': [r['negative'] for r in results], + } class EmojiDPOProcessor(Preprocessor): @@ -400,9 +409,9 @@ class EmojiDPOProcessor(Preprocessor): 'answer_en': str, # rejected response (English) } - Output Trajectory format: - - messages: prompt + chosen (answer_zh) - - extend_message: [('rejected_messages', prompt + rejected (answer_en))] + Output format: + - positive: Trajectory with chosen (answer_zh) + - negative: Trajectory with rejected (answer_en) Args: system: Optional system prompt. @@ -423,7 +432,7 @@ def __init__( self.rejected_key = rejected_key self.prompt_key = prompt_key - def preprocess(self, row: Dict[str, Any]) -> Trajectory: + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: """Process a single row.""" prompt = row.get(self.prompt_key, '') chosen = row.get(self.chosen_key, '') @@ -437,16 +446,18 @@ def preprocess(self, row: Dict[str, Any]) -> Trajectory: chosen_messages = prompt_messages + [Message(role='assistant', content=chosen)] rejected_messages = prompt_messages + [Message(role='assistant', content=rejected)] - return Trajectory( - messages=chosen_messages, - extend_message=[rejected_messages] - ) + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: rows = self.map_col_to_row(rows) - rows = [self.preprocess(row) for row in rows] - rows = self.map_row_to_col(rows) - return rows + results = [self.preprocess(row) for row in rows] + return { + 'positive': [r['positive'] for r in results], + 'negative': [r['negative'] for r in results], + } class UltraFeedbackKTOProcessor(Preprocessor): diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index cd783a9d..31c04b6a 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -179,9 +179,6 @@ def _add_default_system(self, trajectory: Trajectory) -> List[Trajectory]: if self.use_chat_template and self.default_system: if trajectory['messages'][0]['role'] == 'user': trajectory['messages'].insert(0, Message(role='system', content=self.default_system)) - for messages in trajectory.get('extend_message', []): - if messages and messages[0]['role'] == 'user': - messages.insert(0, Message(role='system', content=self.default_system)) return [trajectory] def _to_standard_reasoning_content(self, trajectory: Trajectory) -> List[Trajectory]: @@ -206,12 +203,6 @@ def _extract_reasoning_content(messages: list[Message]) -> List[Message]: return result trajectory['messages'] = _extract_reasoning_content(trajectory['messages']) - extra_messages = trajectory.get('extend_message', []) - if extra_messages: - result = [] - for extra_message in trajectory.get('extend_message', []): - result.append(_extract_reasoning_content(extra_message)) - trajectory['extend_message'] = result return [trajectory] def _truncate_feature(self, feature: InputFeature, strategy: str) -> InputFeature: @@ -238,10 +229,8 @@ def _check_max_length(self, input_feature: InputFeature) -> List[InputFeature]: strategy = self.truncation_strategy - # Split strategy: doesn't support extend_message + # Split strategy if strategy == 'split': - if input_feature.get('extend_message'): - raise ValueError('Split strategy does not support extend_message.') results = [] for start in range(0, len(input_feature['input_ids']), self.max_length): end = min(start + self.max_length, len(input_feature['input_ids'])) @@ -252,32 +241,18 @@ def _check_max_length(self, input_feature: InputFeature) -> List[InputFeature]: results.append(InputFeature(**feat)) return results - # left/right/raise: apply to main and extend_message - result = self._truncate_feature(input_feature, strategy) - if input_feature.get('extend_message'): - result['extend_message'] = [self._truncate_feature(f, strategy) for f in input_feature['extend_message']] - return [result] - - def _add_attention_to_feature(self, feature: InputFeature) -> InputFeature: - """Add attention fields to a single InputFeature.""" - input_ids = feature['input_ids'] - feature['attention_mask'] = np.ones_like(input_ids) - feature['position_ids'] = np.arange(len(input_ids)) - feature['length'] = len(input_ids) - return feature + # left/right/raise + return [self._truncate_feature(input_feature, strategy)] def _add_attention_fields(self, input_feature: InputFeature) -> List[InputFeature]: - self._add_attention_to_feature(input_feature) - if input_feature.get('extend_message'): - for f in input_feature['extend_message']: - self._add_attention_to_feature(f) + input_ids = input_feature['input_ids'] + input_feature['attention_mask'] = np.ones_like(input_ids) + input_feature['position_ids'] = np.arange(len(input_ids)) + input_feature['length'] = len(input_ids) return [input_feature] def _roll_labels(self, input_feature: InputFeature) -> List[InputFeature]: input_feature['labels'] = np.roll(input_feature['labels'], -1, axis=-1) - if input_feature.get('extend_message'): - for f in input_feature['extend_message']: - f['labels'] = np.roll(f['labels'], -1, axis=-1) return [input_feature] def _process_mm_messages(self, messages: List) -> List: @@ -305,12 +280,6 @@ def _process_mm_messages(self, messages: List) -> List: def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: trajectory['messages'] = self._process_mm_messages(trajectory['messages']) - if trajectory.get('extend_message'): - new_extend_message = [] - for msgs in trajectory['extend_message']: - new_extend_message.append(self._process_mm_messages(msgs)) - trajectory['extend_message'] = new_extend_message - return [trajectory] def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bool = False, **kwargs): @@ -361,20 +330,7 @@ def _encode_messages(self, trajectory: Trajectory, add_generation_prompt: bool = return trajectory def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> InputFeature: - # Encode main messages - result = self._encode_messages(trajectory, add_generation_prompt) - - # Encode extend_message (e.g., rejected messages in DPO) - if trajectory.get('extend_message'): - encoded_extend = [] - for msgs in trajectory['extend_message']: - # Create a temporary trajectory with the extended messages - ext_trajectory = Trajectory(messages=msgs) - ext_feature = self._encode_messages(ext_trajectory, add_generation_prompt) - encoded_extend.append(ext_feature) - result['extend_message'] = encoded_extend - - return result + return self._encode_messages(trajectory, add_generation_prompt) @staticmethod def map_col_to_row(trajectories: Dict[str, Any]): @@ -402,21 +358,90 @@ def map_row_to_col(rows: List[Union[Dict[str, Any], InputFeature]]) -> Dict[str, return columns - def batch_encode(self, - trajectories: Union[Dict[str, Any], List[Trajectory]], - add_generation_prompt: bool = False) -> List[InputFeature]: - output = [] - _transfer = False + def _is_trajectory(self, obj: Any) -> bool: + """Check if an object is a Trajectory (has 'messages' key).""" + return isinstance(obj, Mapping) and 'messages' in obj + + def _is_trajectory_dict(self, obj: Any) -> bool: + """Check if obj is Dict[str, Trajectory] - all values are Trajectories.""" + if not isinstance(obj, Mapping) or not obj: + return False + return all(self._is_trajectory(v) for v in obj.values()) + + def _get_trajectory_keys(self, obj: Mapping) -> List[str]: + """Get keys in a dict whose values are Trajectories.""" + return [k for k, v in obj.items() if self._is_trajectory(v)] + + def _is_columnar_format(self, obj: Any) -> bool: + """Check if obj is columnar format: Dict[str, List[Any]] but NOT a Trajectory.""" + if not isinstance(obj, Mapping) or not obj: + return False + # Trajectory has 'messages' key with list of Message dicts - not columnar + if self._is_trajectory(obj): + return False + # Dict[str, Trajectory] - not columnar + if self._is_trajectory_dict(obj): + return False + # Check if all values are non-empty lists of same length + first_val = next(iter(obj.values())) + if not isinstance(first_val, list) or len(first_val) == 0: + return False + length = len(first_val) + return all(isinstance(v, list) and len(v) == length for v in obj.values()) + + def batch_encode( + self, + trajectories: Union[Dict[str, Any], List[Trajectory], Trajectory], + add_generation_prompt: bool = False, + ) -> Union[Dict[str, Any], List[InputFeature], InputFeature]: + """Encode trajectories into InputFeatures. + + Supports three input formats: + 1. Trajectory -> InputFeature + 2. List[Trajectory] -> List[InputFeature] + 3. Dict containing Trajectories -> Dict with Trajectories encoded + + Also handles columnar format (Dict[str, List]) by converting to rows first. + """ + # Handle columnar format: convert to rows first + if self._is_columnar_format(trajectories): + rows = self.map_col_to_row(trajectories) + encoded = self.batch_encode(rows, add_generation_prompt=add_generation_prompt) + if isinstance(encoded, list) and encoded: + return self.map_row_to_col(encoded) + return encoded + + # Case 1: Single Trajectory + if self._is_trajectory(trajectories) and not self._is_trajectory_dict(trajectories): + processed = self._invoke_pre_pipeline([trajectories]) + output = [self.encode(t, add_generation_prompt=add_generation_prompt) for t in processed] + output = self._invoke_post_pipeline(output) + return output[0] if len(output) == 1 else output + + # Case 2: List (Trajectory or Dict containing Trajectories) + if isinstance(trajectories, list): + if not trajectories: + return [] + first = trajectories[0] + if isinstance(first, Mapping) and self._get_trajectory_keys(first): + # List of dicts containing Trajectories + return [self.batch_encode(row, add_generation_prompt=add_generation_prompt) for row in trajectories] + else: + # List[Trajectory] + processed = self._invoke_pre_pipeline(trajectories) + output = [self.encode(t, add_generation_prompt=add_generation_prompt) for t in processed] + return self._invoke_post_pipeline(output) + + # Case 3: Dict containing Trajectories (encode only Trajectory values) if isinstance(trajectories, Mapping): - _transfer = True - trajectories = self.map_col_to_row(trajectories) - trajectories = self._invoke_pre_pipeline(trajectories) - for trajectory in trajectories: - output.append(self.encode(trajectory, add_generation_prompt=add_generation_prompt)) - output = self._invoke_post_pipeline(output) - if _transfer: - output = self.map_row_to_col(output) - return output + traj_keys = self._get_trajectory_keys(trajectories) + if traj_keys: + result = dict(trajectories) # Copy non-trajectory keys + for key in traj_keys: + result[key] = self.batch_encode(trajectories[key], add_generation_prompt=add_generation_prompt) + return result + + raise ValueError(f'Unsupported input type: {type(trajectories)}') def check(self, trajectory: Trajectory) -> Optional[Trajectory]: encoded = None From 3a25caa7ad20d55066f0e77a098cd3704c0beabc Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 27 Mar 2026 17:01:22 +0800 Subject: [PATCH 10/14] wip --- cookbook/rl/dpo.py | 91 ++++++++++------------------ src/twinkle/loss/dpo.py | 91 ++++++++++++++-------------- src/twinkle/metric/loss.py | 1 - src/twinkle/template/base.py | 114 +++++++++++++++-------------------- 4 files changed, 124 insertions(+), 173 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index a9937aa2..00311c63 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -51,7 +51,6 @@ import os from typing import Any, Dict, List, Optional -import torch from peft import LoraConfig import twinkle @@ -63,7 +62,6 @@ from twinkle.model import TransformersModel from twinkle.preprocessor import EmojiDPOProcessor from twinkle.processor import InputProcessor -from twinkle.template import Template logger = get_logger() @@ -75,8 +73,8 @@ REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4)) NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs -MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) # Number of preference pairs +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 4)) GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) LEARNING_RATE = float(os.environ.get('LR', 5e-6)) @@ -100,31 +98,38 @@ def create_dpo_dataset(): ) # DPO preprocessor returns {'positive': [...], 'negative': [...]} # batch_encode handles this format automatically - dataset.encode() + dataset.encode(load_from_cache_file=True) return dataset -def prepare_dpo_batch( - batch: Dict[str, List[Any]], - template: Template, -) -> List[Dict[str, Any]]: - """Prepare DPO batch: convert encoded batch to list format for training. +def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Prepare DPO batch: reorganize batch for training with DP-safe interleaving. Args: - batch: Dict with 'positive' and 'negative' keys, each containing List[InputFeature] + batch: List of rows, each with 'positive' and 'negative' InputFeatures + and other fields (question, etc.) Returns: - List organized as [positive_1, ..., positive_n, negative_1, ..., negative_n] + List interleaved as [pos_1, neg_1, pos_2, neg_2, ...] to ensure each DP + worker gets complete positive/negative pairs after slicing. + Each item contains all original fields plus the InputFeature fields. """ - positive_features = batch.get('positive', []) - negative_features = batch.get('negative', []) + result = [] - # Convert to list of dicts - positive_samples = [dict(f) for f in positive_features] - negative_samples = [dict(f) for f in negative_features] + for row in batch: + # Get base fields (excluding positive/negative) + base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')} - # Return [positive..., negative...] - return positive_samples + negative_samples + # Positive sample: merge base fields with positive InputFeature + pos_sample = {**base_fields, **row['positive']} + # Negative sample: merge base fields with negative InputFeature + neg_sample = {**base_fields, **row['negative']} + + # Interleave: [pos, neg] per pair for DP-safe slicing + result.append(pos_sample) + result.append(neg_sample) + + return result # ── Loss Factory ────────────────────────────────────────────────────────────── @@ -196,9 +201,6 @@ def main(): policy_model.set_processor(InputProcessor) policy_model.set_template('Template', model_id=MODEL_ID) - # Get template for encoding rejected messages - template = Template(model_id=MODEL_ID, max_length=MAX_LENGTH) - # ── Reference Model Setup ───────────────────────────────────────────────── ref_model = None if not reference_free: @@ -223,50 +225,19 @@ def main(): if optim_step >= MAX_STEPS: break - # batch is Dict[str, List[Trajectory]] with 'positive' and 'negative' keys - dpo_batch = prepare_dpo_batch(batch, template) + # batch is List[Dict] with 'positive' and 'negative' keys + dpo_batch = prepare_dpo_batch(batch) - # Compute reference log probabilities if using reference model - # We compute sequence-level logps here to avoid alignment issues with micro-batching - ref_chosen_logps = None - ref_rejected_logps = None + # Get reference outputs (lazy - not collected to driver) + ref_outputs = None if ref_model is not None: - with torch.no_grad(): - ref_outputs = ref_model.forward_only(inputs=dpo_batch) - ref_logps = ref_outputs.get('logps') # [batch, seq_len] - if ref_logps is not None: - # Get labels and pad to same length for stacking - label_tensors = [torch.as_tensor(s['labels']) for s in dpo_batch] - max_len = max(t.shape[0] for t in label_tensors) - # Pad labels with -100 (ignore_index) to max length - padded_labels = [] - for t in label_tensors: - if t.shape[0] < max_len: - pad_size = max_len - t.shape[0] - t = torch.cat([torch.full((pad_size,), -100, dtype=t.dtype), t]) - padded_labels.append(t) - ref_labels = torch.stack(padded_labels) - if ref_labels.device != ref_logps.device: - ref_labels = ref_labels.to(ref_logps.device) - # Align sequence lengths if needed - if ref_logps.shape[1] != ref_labels.shape[1]: - min_len = min(ref_logps.shape[1], ref_labels.shape[1]) - ref_logps = ref_logps[:, -min_len:] - ref_labels = ref_labels[:, -min_len:] - # Compute sequence-level logps (sum of valid token logps) - loss_mask = (ref_labels != -100).float() - seq_logps = (ref_logps * loss_mask).sum(dim=-1) # [batch] - - # Split into chosen and rejected - half = seq_logps.shape[0] // 2 - ref_chosen_logps = seq_logps[:half] - ref_rejected_logps = seq_logps[half:] + ref_outputs = ref_model.forward_only(inputs=dpo_batch) # Forward-backward pass with DPO loss + # ref_outputs is passed to loss which extracts logps internally policy_model.forward_backward( inputs=dpo_batch, - ref_chosen_logps=ref_chosen_logps, - ref_rejected_logps=ref_rejected_logps, + ref_outputs=ref_outputs, ) # Gradient clipping and optimizer step diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index 7f89d446..52fc28b0 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -88,9 +88,13 @@ def _split_chosen_rejected( self, tensor: 'torch.Tensor', ) -> tuple: - """Split tensor into chosen (first half) and rejected (second half).""" - half = tensor.shape[0] // 2 - return tensor[:half], tensor[half:] + """Split interleaved tensor into chosen and rejected. + + Input format: [pos_1, neg_1, pos_2, neg_2, ...] (interleaved for DP-safe slicing) + Output: (chosen [pos_1, pos_2, ...], rejected [neg_1, neg_2, ...]) + """ + # Even indices = chosen (positive), odd indices = rejected (negative) + return tensor[0::2], tensor[1::2] class DPOLoss(PreferenceLossBase): @@ -131,20 +135,18 @@ def __init__( self.loss_type = loss_type self.reference_free = reference_free - def _pad_and_align_logps( + def _align_logps( self, - logps: Union['torch.Tensor', List[List[float]]], + logps: 'torch.Tensor', target_shape: tuple, - loss_mask: 'torch.Tensor', device: 'torch.device', dtype: 'torch.dtype', ) -> 'torch.Tensor': - """Pad and align log probabilities to target shape. + """Align log probabilities to target shape. Args: - logps: Input log probabilities (tensor or ragged list) + logps: Input log probabilities tensor target_shape: Target (batch, seq_len) shape - loss_mask: Boolean mask for valid positions device: Target device dtype: Target dtype @@ -153,40 +155,32 @@ def _pad_and_align_logps( """ import torch - if torch.is_tensor(logps): - if logps.dim() == 1: - logps = logps.unsqueeze(0) - if logps.shape == target_shape: - return logps.to(device=device, dtype=dtype) - # Handle tensor with different sequence length - align to target shape - if logps.dim() == 2 and logps.shape[0] == target_shape[0]: - batch_size, target_seq_len = target_shape - src_seq_len = logps.shape[1] - logps = logps.to(device=device, dtype=dtype) - if src_seq_len > target_seq_len: - # Truncate: take the last target_seq_len tokens (response part) - return logps[:, -target_seq_len:] - else: - # Pad: add zeros at the beginning - padded = torch.zeros(target_shape, device=device, dtype=dtype) - padded[:, -src_seq_len:] = logps - return padded - - # Handle ragged list input - if isinstance(logps, (list, tuple)): - batch_size, seq_len = target_shape - padded = torch.zeros(target_shape, device=device, dtype=dtype) - for i, row in enumerate(logps): - if row is None: - continue - row_t = torch.as_tensor(row, device=device, dtype=dtype) - valid_positions = loss_mask[i].nonzero(as_tuple=True)[0] - length = min(len(row_t), len(valid_positions)) - if length > 0: - padded[i, valid_positions[:length]] = row_t[:length] - return padded - - return logps.to(device=device, dtype=dtype) + if not torch.is_tensor(logps): + raise TypeError(f'Expected torch.Tensor, got {type(logps)}') + + if logps.dim() == 1: + logps = logps.unsqueeze(0) + + if logps.shape == target_shape: + return logps.to(device=device, dtype=dtype) + + # Handle tensor with different sequence length + if logps.dim() == 2 and logps.shape[0] == target_shape[0]: + batch_size, target_seq_len = target_shape + src_seq_len = logps.shape[1] + logps = logps.to(device=device, dtype=dtype) + if src_seq_len > target_seq_len: + # Truncate right (keep left part) - may happen in Ray result merging + return logps[:, :target_seq_len] + else: + raise ValueError( + f'ref_logps seq_len ({src_seq_len}) < target seq_len ({target_seq_len}). ' + f'This should not happen when both models process the same batch.' + ) + + raise ValueError( + f'Cannot align ref_logps shape {logps.shape} to target shape {target_shape}' + ) def _compute_dpo_loss( self, @@ -254,6 +248,7 @@ def __call__( inputs: Dict, outputs: Dict, *, + ref_outputs: Optional[Dict] = None, ref_logps: Optional[Union['torch.Tensor', List[List[float]]]] = None, ref_chosen_logps: Optional['torch.Tensor'] = None, ref_rejected_logps: Optional['torch.Tensor'] = None, @@ -271,6 +266,7 @@ def __call__( outputs: Dict containing either: - 'logps': [batch, seq_len] pre-computed log probs, OR - 'logits': [batch, seq_len, vocab] from which logps will be computed + ref_outputs: Dict from reference model forward, containing 'logps'. ref_logps: [batch, seq_len] or List[List[float]] reference model log probs. Can also be provided as separate ref_chosen_logps and ref_rejected_logps. ref_chosen_logps: [batch/2] pre-computed reference log probs for chosen. @@ -282,6 +278,10 @@ def __call__( """ import torch + # Extract ref_logps from ref_outputs if provided + if ref_outputs is not None and ref_logps is None: + ref_logps = ref_outputs.get('logps') + labels = inputs.get('labels') assert labels is not None, "inputs must contain 'labels'" if not torch.is_tensor(labels): @@ -312,9 +312,8 @@ def __call__( reference_rejected_logps = ref_rejected_logps.to(device=device, dtype=dtype) elif ref_logps is not None: # Per-token reference log probs provided, need to align and sum - loss_mask = (labels != self.ignore_index).bool() - ref_logps_aligned = self._pad_and_align_logps( - ref_logps, labels.shape, loss_mask, device, dtype + ref_logps_aligned = self._align_logps( + ref_logps, labels.shape, device, dtype ) ref_chosen, ref_rejected = self._split_chosen_rejected(ref_logps_aligned) reference_chosen_logps = self._compute_sequence_logps(ref_chosen, chosen_labels) diff --git a/src/twinkle/metric/loss.py b/src/twinkle/metric/loss.py index 52f50fdd..7c4a8b93 100644 --- a/src/twinkle/metric/loss.py +++ b/src/twinkle/metric/loss.py @@ -52,7 +52,6 @@ def calculate(self): 'grad_norm': self.grad_norm, 'num_tokens': self.num_tokens }] - all_results = self.gather_results(local_results) total_loss = sum(r['loss'] for r in all_results) diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 31c04b6a..a3e53056 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -362,86 +362,68 @@ def _is_trajectory(self, obj: Any) -> bool: """Check if an object is a Trajectory (has 'messages' key).""" return isinstance(obj, Mapping) and 'messages' in obj - def _is_trajectory_dict(self, obj: Any) -> bool: - """Check if obj is Dict[str, Trajectory] - all values are Trajectories.""" - if not isinstance(obj, Mapping) or not obj: - return False - return all(self._is_trajectory(v) for v in obj.values()) - - def _get_trajectory_keys(self, obj: Mapping) -> List[str]: - """Get keys in a dict whose values are Trajectories.""" - return [k for k, v in obj.items() if self._is_trajectory(v)] - - def _is_columnar_format(self, obj: Any) -> bool: - """Check if obj is columnar format: Dict[str, List[Any]] but NOT a Trajectory.""" - if not isinstance(obj, Mapping) or not obj: - return False - # Trajectory has 'messages' key with list of Message dicts - not columnar - if self._is_trajectory(obj): - return False - # Dict[str, Trajectory] - not columnar - if self._is_trajectory_dict(obj): - return False - # Check if all values are non-empty lists of same length - first_val = next(iter(obj.values())) - if not isinstance(first_val, list) or len(first_val) == 0: - return False - length = len(first_val) - return all(isinstance(v, list) and len(v) == length for v in obj.values()) + def _get_trajectory_keys(self, columnar: Mapping) -> List[str]: + """Get keys whose values are lists of Trajectories in columnar format.""" + keys = [] + for k, v in columnar.items(): + if isinstance(v, list) and v and self._is_trajectory(v[0]): + keys.append(k) + return keys def batch_encode( self, - trajectories: Union[Dict[str, Any], List[Trajectory], Trajectory], + trajectories: Union[Dict[str, Any], List[Trajectory]], add_generation_prompt: bool = False, - ) -> Union[Dict[str, Any], List[InputFeature], InputFeature]: + ) -> Union[Dict[str, Any], List[InputFeature]]: """Encode trajectories into InputFeatures. - Supports three input formats: - 1. Trajectory -> InputFeature - 2. List[Trajectory] -> List[InputFeature] - 3. Dict containing Trajectories -> Dict with Trajectories encoded + Args: + trajectories: Either List[Trajectory] or columnar Dict[str, List]. + For DPO, columnar format with 'positive'/'negative' keys containing + List[Trajectory] is supported. + add_generation_prompt: Whether to add generation prompt. - Also handles columnar format (Dict[str, List]) by converting to rows first. + Returns: + List[InputFeature] or columnar Dict[str, List[InputFeature]]. """ - # Handle columnar format: convert to rows first - if self._is_columnar_format(trajectories): - rows = self.map_col_to_row(trajectories) - encoded = self.batch_encode(rows, add_generation_prompt=add_generation_prompt) - if isinstance(encoded, list) and encoded: - return self.map_row_to_col(encoded) - return encoded - - # Case 1: Single Trajectory - if self._is_trajectory(trajectories) and not self._is_trajectory_dict(trajectories): - processed = self._invoke_pre_pipeline([trajectories]) - output = [self.encode(t, add_generation_prompt=add_generation_prompt) for t in processed] - output = self._invoke_post_pipeline(output) - return output[0] if len(output) == 1 else output - - # Case 2: List (Trajectory or Dict containing Trajectories) - if isinstance(trajectories, list): - if not trajectories: - return [] - first = trajectories[0] - if isinstance(first, Mapping) and self._get_trajectory_keys(first): - # List of dicts containing Trajectories - return [self.batch_encode(row, add_generation_prompt=add_generation_prompt) for row in trajectories] - else: - # List[Trajectory] - processed = self._invoke_pre_pipeline(trajectories) - output = [self.encode(t, add_generation_prompt=add_generation_prompt) for t in processed] - return self._invoke_post_pipeline(output) + _transfer = False - # Case 3: Dict containing Trajectories (encode only Trajectory values) if isinstance(trajectories, Mapping): + _transfer = True + # Check if it has trajectory list columns (DPO format) traj_keys = self._get_trajectory_keys(trajectories) if traj_keys: - result = dict(trajectories) # Copy non-trajectory keys - for key in traj_keys: - result[key] = self.batch_encode(trajectories[key], add_generation_prompt=add_generation_prompt) + # DPO format: encode each trajectory list separately, keep other columns + result = {} + for key in trajectories: + if key in traj_keys: + # Encode this trajectory list + result[key] = self.batch_encode( + trajectories[key], add_generation_prompt=add_generation_prompt + ) + else: + # Keep non-trajectory columns as-is + result[key] = trajectories[key] return result + else: + # Standard columnar format + trajectories = self.map_col_to_row(trajectories) + + # Process List[Trajectory] + trajectories = self._invoke_pre_pipeline(trajectories) - raise ValueError(f'Unsupported input type: {type(trajectories)}') + # Use thread pool for parallel encoding + from concurrent.futures import ThreadPoolExecutor + from functools import partial + encode_fn = partial(self.encode, add_generation_prompt=add_generation_prompt) + with ThreadPoolExecutor() as executor: + output = list(executor.map(encode_fn, trajectories)) + + output = self._invoke_post_pipeline(output) + + if _transfer: + output = self.map_row_to_col(output) + return output def check(self, trajectory: Trajectory) -> Optional[Trajectory]: encoded = None From 0cf1ac38d5aebba51d7a0f4753e814037432f048 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 27 Mar 2026 20:08:33 +0800 Subject: [PATCH 11/14] wip --- cookbook/rl/dpo.py | 25 +-- src/twinkle/loss/dpo.py | 19 +- src/twinkle/metric/__init__.py | 1 + src/twinkle/metric/dpo.py | 177 ++++++++++++++++++ src/twinkle/metric/loss.py | 4 +- .../model/transformers/transformers.py | 7 +- 6 files changed, 217 insertions(+), 16 deletions(-) create mode 100644 src/twinkle/metric/dpo.py diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index 00311c63..7c5ffe8a 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -59,6 +59,7 @@ from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss +from twinkle.metric import DPOMetric from twinkle.model import TransformersModel from twinkle.preprocessor import EmojiDPOProcessor from twinkle.processor import InputProcessor @@ -66,7 +67,7 @@ logger = get_logger() # ── Configuration ───────────────────────────────────────────────────────────── -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct') DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) @@ -75,12 +76,13 @@ BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) # Number of preference pairs MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 4)) -GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 8)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) -LEARNING_RATE = float(os.environ.get('LR', 5e-6)) +LEARNING_RATE = float(os.environ.get('LR', 5e-5)) DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) +SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 0.1)) # SFT loss weight for regularization LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo, simpo, orpo, cpo -SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100)) +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 200)) MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) ADAPTER_NAME = 'default' SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.') @@ -88,7 +90,7 @@ def create_dpo_dataset(): """Create DPO dataset with positive/negative format.""" - dataset = Dataset(DatasetMeta(DATASET_ID)) + dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(15000))) dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) dataset.map( EmojiDPOProcessor, @@ -134,7 +136,7 @@ def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # ── Loss Factory ────────────────────────────────────────────────────────────── -def create_loss(loss_type: str, beta: float, reference_free: bool = False): +def create_loss(loss_type: str, beta: float, sft_weight: float = 0.0, reference_free: bool = False): """Create the appropriate loss function based on configuration.""" if loss_type == 'simpo': return SimPOLoss(beta=beta, gamma=0.5) @@ -148,6 +150,7 @@ def create_loss(loss_type: str, beta: float, reference_free: bool = False): beta=beta, loss_type=loss_type, reference_free=reference_free, + sft_weight=sft_weight, ) @@ -174,10 +177,7 @@ def main(): # ── Policy Model Setup ──────────────────────────────────────────────────── lora_config = LoraConfig( - target_modules=[ - 'q_proj', 'k_proj', 'v_proj', 'o_proj', - 'gate_proj', 'up_proj', 'down_proj', - ], + target_modules='all-linear', r=16, lora_alpha=32, lora_dropout=0.05, @@ -195,9 +195,10 @@ def main(): # Determine if we need reference model based on loss type reference_free = LOSS_TYPE in ['simpo', 'orpo', 'cpo'] - # Set up loss function - loss_fn = create_loss(LOSS_TYPE, DPO_BETA, reference_free=False) + # Set up loss function and metrics + loss_fn = create_loss(LOSS_TYPE, DPO_BETA, sft_weight=SFT_WEIGHT, reference_free=False) policy_model.set_loss(loss_fn) + policy_model.add_metric(DPOMetric, beta=DPO_BETA) policy_model.set_processor(InputProcessor) policy_model.set_template('Template', model_id=MODEL_ID) diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index 52fc28b0..0ad69ed9 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -118,6 +118,7 @@ class DPOLoss(PreferenceLossBase): ignore_index: Index to ignore in labels (default: -100). loss_type: Type of DPO loss variant ('sigmoid', 'hinge', 'ipo', 'kto_pair') (default: 'sigmoid'). reference_free: Whether to use reference-free DPO (default: False). + sft_weight: Weight for SFT loss on chosen responses to prevent likelihood displacement (default: 0.0). """ def __init__( @@ -127,6 +128,7 @@ def __init__( ignore_index: int = -100, loss_type: str = 'sigmoid', reference_free: bool = False, + sft_weight: float = 0.0, **kwargs, ): super().__init__(ignore_index=ignore_index) @@ -134,6 +136,7 @@ def __init__( self.label_smoothing = label_smoothing self.loss_type = loss_type self.reference_free = reference_free + self.sft_weight = sft_weight def _align_logps( self, @@ -329,14 +332,26 @@ def __call__( ) # Compute DPO loss - loss = self._compute_dpo_loss( + dpo_loss = self._compute_dpo_loss( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, ) - return LossOutput(loss=loss, num_tokens=0) + # Add SFT loss on chosen responses to prevent likelihood displacement + if self.sft_weight > 0: + sft_loss = self._compute_nll_loss(chosen_logps, chosen_labels) + loss = dpo_loss + self.sft_weight * sft_loss + else: + loss = dpo_loss + + # Return sample count for gradient normalization (not token count) + # DPO loss is already per-sample mean, so we just count samples for accumulation + import torch + num_samples = torch.tensor(chosen_labels.shape[0], device=loss.device) + + return LossOutput(loss=loss, num_tokens=num_samples) class SimPOLoss(PreferenceLossBase): diff --git a/src/twinkle/metric/__init__.py b/src/twinkle/metric/__init__.py index 739c7a0d..59d5bbeb 100644 --- a/src/twinkle/metric/__init__.py +++ b/src/twinkle/metric/__init__.py @@ -2,5 +2,6 @@ from .accuracy import Accuracy from .base import Metric from .completion_and_reward import CompletionRewardMetric +from .dpo import DPOMetric from .loss import LossMetric from .train_metric import TrainMetric diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py new file mode 100644 index 00000000..c3f3e6cf --- /dev/null +++ b/src/twinkle/metric/dpo.py @@ -0,0 +1,177 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""DPO-specific metrics for preference optimization training.""" +from typing import List, Union + +from twinkle.data_format import InputFeature, ModelOutput +from .base import Metric + + +class DPOMetric(Metric): + """Metrics for DPO (Direct Preference Optimization) training. + + Computes TRL-style metrics: + - logps/chosen: Average sequence-level log prob of chosen responses + - logps/rejected: Average sequence-level log prob of rejected responses + - rewards/chosen: β * (policy_chosen - ref_chosen) + - rewards/rejected: β * (policy_rejected - ref_rejected) + - rewards/margins: chosen_reward - rejected_reward + - rewards/accuracies: Percentage where chosen_reward > rejected_reward + + Args: + device_mesh: The device mesh + process_group: The process group to collect data from + ignore_index: Label index to ignore (default: -100) + beta: DPO beta parameter for reward scaling (default: 0.1) + """ + + def __init__(self, device_mesh, process_group, ignore_index: int = -100, beta: float = 0.1, **kwargs): + super().__init__(device_mesh, process_group, **kwargs) + self.ignore_index = ignore_index + self.beta = beta + self.reset() + + def _compute_sequence_logps(self, per_token_logps, labels): + """Compute sequence-level log probs by summing valid token logps.""" + import torch + loss_mask = (labels != self.ignore_index).float() + return (per_token_logps * loss_mask).sum(dim=-1) + + def _split_chosen_rejected(self, tensor): + """Split interleaved tensor into chosen and rejected. + + Input format: [pos_1, neg_1, pos_2, neg_2, ...] (interleaved for DP-safe slicing) + Output: (chosen [pos_1, pos_2, ...], rejected [neg_1, neg_2, ...]) + """ + return tensor[0::2], tensor[1::2] + + def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs): + """Accumulate DPO metrics from model outputs. + + Expects: + - outputs['logps']: [batch, seq_len] per-token log probabilities + - inputs['labels']: [batch, seq_len] labels with ignore_index for non-target tokens + - kwargs['ref_outputs']: Optional reference model outputs with 'logps' + """ + import torch + + logps = outputs.get('logps') + if logps is None: + return + + # Get labels from inputs + if isinstance(inputs, list): + # Stack labels from list of inputs + labels_list = [torch.as_tensor(inp['labels']) for inp in inputs] + max_len = max(l.shape[0] for l in labels_list) + padded = [] + for l in labels_list: + if l.shape[0] < max_len: + pad = torch.full((max_len - l.shape[0],), self.ignore_index, dtype=l.dtype) + l = torch.cat([pad, l]) + padded.append(l) + labels = torch.stack(padded) + else: + labels = torch.as_tensor(inputs['labels']) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + # Ensure logps and labels have same device + if logps.device != labels.device: + labels = labels.to(logps.device) + + # Align sequence lengths if needed (truncate right) + if logps.shape[1] != labels.shape[1]: + min_len = min(logps.shape[1], labels.shape[1]) + logps = logps[:, :min_len] + labels = labels[:, :min_len] + + # Compute sequence-level logps + seq_logps = self._compute_sequence_logps(logps, labels) + + # Split into chosen and rejected (interleaved format) + chosen_logps, rejected_logps = self._split_chosen_rejected(seq_logps) + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + + # Accumulate policy logps + self.total_chosen_logps += chosen_logps.sum().item() + self.total_rejected_logps += rejected_logps.sum().item() + + # Compute rewards if ref_outputs available + ref_outputs = kwargs.get('ref_outputs') + if ref_outputs is not None: + ref_logps = ref_outputs.get('logps') + if ref_logps is not None: + # Align ref_logps + if ref_logps.device != labels.device: + ref_logps = ref_logps.to(labels.device) + if ref_logps.shape[1] != labels.shape[1]: + min_len = min(ref_logps.shape[1], labels.shape[1]) + ref_logps = ref_logps[:, :min_len] + + ref_seq_logps = self._compute_sequence_logps(ref_logps, labels) + ref_chosen_logps, ref_rejected_logps = self._split_chosen_rejected(ref_seq_logps) + + # Compute rewards: β * (policy - ref) + chosen_rewards = self.beta * (chosen_logps - ref_chosen_logps) + rejected_rewards = self.beta * (rejected_logps - ref_rejected_logps) + + self.total_chosen_rewards += chosen_rewards.sum().item() + self.total_rejected_rewards += rejected_rewards.sum().item() + margins = chosen_rewards - rejected_rewards + self.total_reward_margin += margins.sum().item() + self.total_reward_correct += (margins > 0).sum().item() + self.has_rewards = True + + self.total_count += chosen_logps.shape[0] + + def reset(self): + """Reset all accumulated values.""" + self.total_chosen_logps = 0.0 + self.total_rejected_logps = 0.0 + self.total_chosen_rewards = 0.0 + self.total_rejected_rewards = 0.0 + self.total_reward_margin = 0.0 + self.total_reward_correct = 0 + self.total_count = 0 + self.has_rewards = False + + def calculate(self): + """Calculate and return aggregated metrics.""" + local_results = [{ + 'chosen_logps': self.total_chosen_logps, + 'rejected_logps': self.total_rejected_logps, + 'chosen_rewards': self.total_chosen_rewards, + 'rejected_rewards': self.total_rejected_rewards, + 'reward_margin': self.total_reward_margin, + 'reward_correct': self.total_reward_correct, + 'count': self.total_count, + 'has_rewards': self.has_rewards, + }] + all_results = self.gather_results(local_results) + + total_chosen_logps = sum(r['chosen_logps'] for r in all_results) + total_rejected_logps = sum(r['rejected_logps'] for r in all_results) + total_chosen_rewards = sum(r['chosen_rewards'] for r in all_results) + total_rejected_rewards = sum(r['rejected_rewards'] for r in all_results) + total_reward_margin = sum(r['reward_margin'] for r in all_results) + total_reward_correct = sum(r['reward_correct'] for r in all_results) + total_count = sum(r['count'] for r in all_results) + has_rewards = any(r['has_rewards'] for r in all_results) + + self.reset() + + if total_count == 0: + return {} + + results = { + 'logps/chosen': f'{total_chosen_logps / total_count:.2f}', + 'logps/rejected': f'{total_rejected_logps / total_count:.2f}', + } + + if has_rewards: + results['rewards/chosen'] = f'{total_chosen_rewards / total_count:.4f}' + results['rewards/rejected'] = f'{total_rejected_rewards / total_count:.4f}' + results['rewards/margins'] = f'{total_reward_margin / total_count:.4f}' + results['rewards/accuracies'] = f'{total_reward_correct / total_count * 100:.1f}%' + + return results diff --git a/src/twinkle/metric/loss.py b/src/twinkle/metric/loss.py index 7c4a8b93..8f4ad0c9 100644 --- a/src/twinkle/metric/loss.py +++ b/src/twinkle/metric/loss.py @@ -60,8 +60,10 @@ def calculate(self): num_tokens = sum(r['num_tokens'] for r in all_results) if num_tokens > 0: avg_loss = total_loss / num_tokens - else: + elif total_count > 0: avg_loss = total_loss / total_count + else: + avg_loss = 0.0 self.reset() results = {} if avg_loss is not None: diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 520aaf9f..cdafbf59 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -121,6 +121,8 @@ def accumulate_metrics(self, is_training): metrics = self.train_metrics else: metrics = self.eval_metrics + # Get stored forward_kwargs from previous forward + forward_kwargs = getattr(self, 'forward_kwargs', None) or {} if len(metrics) > 0 and self.inputs is not None and self.outputs is not None: for metric in metrics: metric.accumulate( @@ -130,7 +132,8 @@ def accumulate_metrics(self, is_training): step=self.cur_step - 1, gradient_accumulation_steps=self.gradient_accumulation_steps, grad_norm=self._last_grad_norm, - loss_reduction=getattr(self.loss_instance, 'reduction', 'mean')) + loss_reduction=getattr(self.loss_instance, 'reduction', 'mean'), + **forward_kwargs) def calculate_metrics(self, is_training): self.accumulate_metrics(is_training) @@ -405,6 +408,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec inputs['labels'] = labels optimizer_config.inputs = inputs optimizer_config.outputs = outputs + optimizer_config.forward_kwargs = kwargs # Store for next metric accumulation optimizer_config.loss_value = outputs.get('aux_loss', 0) if labels is not None: loss_mask = (labels != -100).bool() @@ -1086,6 +1090,7 @@ def set_grad_scaler(self, **kwargs): grad_scaler_config.update(kwargs) optimizer_config.scaler = GradScaler(**grad_scaler_config) + @remote_function() def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs): """Add an eval metric From 8c662f047519342f0c8eafed205557b97bd0c69a Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 27 Mar 2026 22:03:07 +0800 Subject: [PATCH 12/14] wip --- cookbook/rl/dpo.py | 10 +++++----- src/twinkle/loss/dpo.py | 10 ++++------ 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py index 7c5ffe8a..f3454c2f 100644 --- a/cookbook/rl/dpo.py +++ b/cookbook/rl/dpo.py @@ -76,11 +76,10 @@ BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) # Number of preference pairs MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 4)) -GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 8)) -MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) -LEARNING_RATE = float(os.environ.get('LR', 5e-5)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 4)) +LEARNING_RATE = float(os.environ.get('LR', 5e-6)) # TRL default for DPO is 5e-7 to 5e-6 DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) -SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 0.1)) # SFT loss weight for regularization +SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo, simpo, orpo, cpo SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 200)) MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) @@ -90,7 +89,7 @@ def create_dpo_dataset(): """Create DPO dataset with positive/negative format.""" - dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(15000))) + dataset = Dataset(DatasetMeta(DATASET_ID)) dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) dataset.map( EmojiDPOProcessor, @@ -188,6 +187,7 @@ def main(): device_mesh=policy_mesh, remote_group='policy', ) + MAX_STEPS = len(dataloader) policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01) policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1) diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py index 0ad69ed9..ce533053 100644 --- a/src/twinkle/loss/dpo.py +++ b/src/twinkle/loss/dpo.py @@ -346,12 +346,10 @@ def __call__( else: loss = dpo_loss - # Return sample count for gradient normalization (not token count) - # DPO loss is already per-sample mean, so we just count samples for accumulation - import torch - num_samples = torch.tensor(chosen_labels.shape[0], device=loss.device) - - return LossOutput(loss=loss, num_tokens=num_samples) + # Return 0 to skip gradient normalization by num_tokens + # DPO loss is already per-sample mean, unlike SFT which sums per-token loss + # When num_tokens=0, normalize_and_clip_grad_norm defaults to 1 (no division) + return LossOutput(loss=loss, num_tokens=0) class SimPOLoss(PreferenceLossBase): From aa860993c493b80e49d3339d32057b91151d8a3f Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 27 Mar 2026 23:38:48 +0800 Subject: [PATCH 13/14] wip --- cookbook/transformers/fsdp2.py | 2 +- src/twinkle/model/megatron/megatron.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 10d75df6..a8c8bb87 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -10,7 +10,7 @@ from twinkle.preprocessor import SelfCognitionProcessor # Construct a device_mesh, dp=2 -device_mesh = DeviceMesh.from_sizes(dp_size=2) +device_mesh = DeviceMesh.from_sizes(dp_size=8) # use torchrun mode twinkle.initialize(mode='local', global_device_mesh=device_mesh) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 4b37973c..ae2ae59f 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -245,7 +245,7 @@ def __init__( def _construct_default_optimizer_group(self): return MegatronOptimizerGroup( - loss_instance=CrossEntropyLoss(), + loss_instance=CrossEntropyLoss(reduction='sum'), template=Template(self.tokenizer_id), processor=InputProcessor(self.device_mesh, framework='megatron'), _device_mesh=self.device_mesh, From 6bdaaca88c924b564ce12317e136207ecea6e12c Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 28 Mar 2026 00:11:18 +0800 Subject: [PATCH 14/14] wip --- src/twinkle/model/megatron/megatron.py | 6 ++++++ src/twinkle/model/transformers/transformers.py | 13 ++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index ae2ae59f..7c131210 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -479,6 +479,12 @@ def post_loss_function(output_tensor, inputs, logps): losses = result['loss'] counts = result['num_tokens'] if not counts: + # Later will gather this value, so it becomes: + # 1. SUM loss: gather_sum(local_num_tokens) = global_num_tokens + # 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps ) = gradient_accumulation_steps * world_size + # Then, grad will divided by this value: + # 1. SUM loss: (global_sum_grad) / (global_num_tokens) = global_sum_grad/global_num_tokens + # 2. PER TOKEN MEAN loss: (gather_sum(per_token_grad * gradient_accumulation_steps)) / (gradient_accumulation_steps * world_size ) = avg_per_token_grad counts = torch.tensor(1, device=losses.device) return self.strategy.reduce_loss(losses, counts, output_tensor, logps) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index cdafbf59..a6be74e8 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -500,7 +500,18 @@ def calculate_loss(self, **kwargs): loss_value = result['loss'] counts = result['num_tokens'] if not counts: - counts = torch.tensor(0, device=loss_value.device) + counts = torch.tensor(1, device=loss_value.device) + # Later will gather this value, so it becomes: + # 1. SUM loss: gather_sum(local_num_tokens / dp_world_size) = global_num_tokens / dp_world_size + # 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps / dp_world_size ) = gradient_accumulation_steps + # Then, grad will divided by this value: + # 1. SUM loss: gather_mean(local_sum_grad) / (global_num_tokens / dp_world_size) + # = (global_sum_grad / dp_world_size) / (global_num_tokens / dp_world_size) + # = global_sum_grad/global_num_tokens + # 2. PER TOKEN MEAN loss: gather_mean(per_token_grad * gradient_accumulation_steps) / gradient_accumulation_steps + # = (global_per_token_grad * gradient_accumulation_steps / dp_world_size ) / gradient_accumulation_steps + # = global_per_token_grad / dp_world_size = avg_per_token_grad + counts = counts / self.device_mesh.data_world_size optimizer_config = self.optimizer_group[adapter_name] optimizer_config.num_tokens += counts.item() if self.sp_strategy is not None and 'labels' in inputs: