diff --git a/cookbook/rl/dpo.py b/cookbook/rl/dpo.py new file mode 100644 index 00000000..f3454c2f --- /dev/null +++ b/cookbook/rl/dpo.py @@ -0,0 +1,263 @@ +"""DPO (Direct Preference Optimization) Training via Ray. + +Off-policy preference alignment: trains the model to prefer chosen responses +over rejected responses using preference data, without explicit reward modeling. + +Pipeline: + 1. Load preference dataset with chosen/rejected pairs. + 2. Encode positive and negative separately. + 3. Compute reference model log probabilities (frozen). + 4. Train policy model using DPO loss. + +Architecture (Ray): + ┌─────────────────────────────────────────────────────────────────┐ + │ Driver (CPU) │ + │ dataloader ──► batched preference pairs │ + │ ref_model.forward_only() ──► reference log probs │ + │ policy_model.forward_backward() ──► DPO loss + gradient │ + └─────────────────────────────────────────────────────────────────┘ + │ │ │ + DataLoader RefModel (frozen) PolicyModel (trainable) + (ref GPUs) (policy GPUs) + +DPO data format (after preprocessing): + - positive: List[Trajectory] - chosen responses + - negative: List[Trajectory] - rejected responses + +For SimPO/ORPO variants that don't require a reference model, +set REF_MODEL_GPUS=0 to skip reference model computation. + +Environment variables (all optional): + MODEL_ID – (default: ms://Qwen/Qwen3.5-4B) + DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji) + MODEL_GPUS – GPUs for policy model (default: 4) + REF_MODEL_GPUS – GPUs for reference model (default: 4, 0 to disable) + BATCH_SIZE – global batch size (preference pairs) (default: 8) + MICRO_BATCH_SIZE – per-device micro batch size (default: 2) + MAX_STEPS – total optimization steps (default: 1000) + LR – learning rate (default: 5e-6) + DPO_BETA – DPO temperature parameter (default: 0.1) + LOSS_TYPE – DPO variant (sigmoid/hinge/ipo/simpo/orpo/cpo) (default: sigmoid) + SAVE_STEPS – checkpoint save interval (default: 100) + MAX_LENGTH – max sequence length (default: 2048) + + Dataset field mapping (for custom datasets): + PROMPT_KEY – key for prompt field (default: 'prompt') + CHOSEN_KEY – key for chosen response (default: 'answer_zh') + REJECTED_KEY – key for rejected response (default: 'answer_en') + SYSTEM_PROMPT – system prompt to prepend (default: 'You are a helpful assistant.') +""" + +import os +from typing import Any, Dict, List, Optional + +from peft import LoraConfig + +import twinkle +from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger +from twinkle.data_format import Trajectory +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss +from twinkle.metric import DPOMetric +from twinkle.model import TransformersModel +from twinkle.preprocessor import EmojiDPOProcessor +from twinkle.processor import InputProcessor + +logger = get_logger() + +# ── Configuration ───────────────────────────────────────────────────────────── +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct') +DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji') + +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4)) +NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS + +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) # Number of preference pairs +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 4)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 4)) +LEARNING_RATE = float(os.environ.get('LR', 5e-6)) # TRL default for DPO is 5e-7 to 5e-6 +DPO_BETA = float(os.environ.get('DPO_BETA', 0.1)) +SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization +LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo, simpo, orpo, cpo +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 200)) +MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048)) +ADAPTER_NAME = 'default' +SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.') + + +def create_dpo_dataset(): + """Create DPO dataset with positive/negative format.""" + dataset = Dataset(DatasetMeta(DATASET_ID)) + dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH) + dataset.map( + EmojiDPOProcessor, + init_args={ + 'system': SYSTEM_PROMPT, + } + ) + # DPO preprocessor returns {'positive': [...], 'negative': [...]} + # batch_encode handles this format automatically + dataset.encode(load_from_cache_file=True) + return dataset + + +def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Prepare DPO batch: reorganize batch for training with DP-safe interleaving. + + Args: + batch: List of rows, each with 'positive' and 'negative' InputFeatures + and other fields (question, etc.) + + Returns: + List interleaved as [pos_1, neg_1, pos_2, neg_2, ...] to ensure each DP + worker gets complete positive/negative pairs after slicing. + Each item contains all original fields plus the InputFeature fields. + """ + result = [] + + for row in batch: + # Get base fields (excluding positive/negative) + base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')} + + # Positive sample: merge base fields with positive InputFeature + pos_sample = {**base_fields, **row['positive']} + # Negative sample: merge base fields with negative InputFeature + neg_sample = {**base_fields, **row['negative']} + + # Interleave: [pos, neg] per pair for DP-safe slicing + result.append(pos_sample) + result.append(neg_sample) + + return result + + +# ── Loss Factory ────────────────────────────────────────────────────────────── + +def create_loss(loss_type: str, beta: float, sft_weight: float = 0.0, reference_free: bool = False): + """Create the appropriate loss function based on configuration.""" + if loss_type == 'simpo': + return SimPOLoss(beta=beta, gamma=0.5) + elif loss_type == 'orpo': + return ORPOLoss(lambda_orpo=beta) + elif loss_type == 'cpo': + return CPOLoss(beta=beta, bc_coef=1.0) + else: + # Standard DPO variants: sigmoid, hinge, ipo + return DPOLoss( + beta=beta, + loss_type=loss_type, + reference_free=reference_free, + sft_weight=sft_weight, + ) + + +# ── Main Training Loop ──────────────────────────────────────────────────────── + +def main(): + # Set up device groups + device_groups = [ + DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + DeviceGroup(name='reference', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'), + ] + + policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups) + + # ── DataLoader Setup ────────────────────────────────────────────────────── + dataloader = DataLoader( + dataset=create_dpo_dataset, + batch_size=BATCH_SIZE, + min_batch_size=BATCH_SIZE, + device_mesh=policy_mesh, + ) + length = len(dataloader) + + # ── Policy Model Setup ──────────────────────────────────────────────────── + lora_config = LoraConfig( + target_modules='all-linear', + r=16, + lora_alpha=32, + lora_dropout=0.05, + ) + + policy_model = TransformersModel( + model_id=MODEL_ID, + device_mesh=policy_mesh, + remote_group='policy', + ) + MAX_STEPS = len(dataloader) + policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01) + policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1) + + # Determine if we need reference model based on loss type + reference_free = LOSS_TYPE in ['simpo', 'orpo', 'cpo'] + + # Set up loss function and metrics + loss_fn = create_loss(LOSS_TYPE, DPO_BETA, sft_weight=SFT_WEIGHT, reference_free=False) + policy_model.set_loss(loss_fn) + policy_model.add_metric(DPOMetric, beta=DPO_BETA) + policy_model.set_processor(InputProcessor) + policy_model.set_template('Template', model_id=MODEL_ID) + + # ── Reference Model Setup ───────────────────────────────────────────────── + ref_model = None + if not reference_free: + ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, dp_size=REF_MODEL_GPUS) + ref_model = TransformersModel( + model_id=MODEL_ID, + device_mesh=ref_mesh, + remote_group='reference', + ) + ref_model.set_processor(InputProcessor) + ref_model.set_template('Template', model_id=MODEL_ID) + logger.info('Reference model initialized for DPO training') + else: + logger.info(f'Training without reference model (loss_type={LOSS_TYPE})') + + optim_step = 0 + logger.info(get_device_placement()) + logger.info(f'Starting DPO training: loss_type={LOSS_TYPE}, beta={DPO_BETA}') + + # ── Training Loop ───────────────────────────────────────────────────────── + for batch in dataloader: + if optim_step >= MAX_STEPS: + break + + # batch is List[Dict] with 'positive' and 'negative' keys + dpo_batch = prepare_dpo_batch(batch) + + # Get reference outputs (lazy - not collected to driver) + ref_outputs = None + if ref_model is not None: + ref_outputs = ref_model.forward_only(inputs=dpo_batch) + + # Forward-backward pass with DPO loss + # ref_outputs is passed to loss which extracts logps internally + policy_model.forward_backward( + inputs=dpo_batch, + ref_outputs=ref_outputs, + ) + + # Gradient clipping and optimizer step + policy_model.clip_grad_and_step() + optim_step += 1 + + # Logging + if optim_step % 10 == 0: + metrics = policy_model.calculate_metric(is_training=True) + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metrics}') + + # Checkpointing + if optim_step % SAVE_STEPS == 0: + policy_model.save(f'dpo-checkpoint-{optim_step}') + + # ── Save Final Checkpoint ───────────────────────────────────────────────── + logger.info(f'Training completed. Total steps: {optim_step}') + policy_model.save('dpo-final-checkpoint') + + +if __name__ == '__main__': + main() diff --git a/cookbook/rl/dpo.sh b/cookbook/rl/dpo.sh new file mode 100644 index 00000000..65206839 --- /dev/null +++ b/cookbook/rl/dpo.sh @@ -0,0 +1,84 @@ +#!/bin/bash +# DPO Training Script for Ray Mode +# +# This script launches DPO (Direct Preference Optimization) training using Ray +# for distributed training across multiple GPUs. +# +# Usage: +# ./dpo.sh # Default settings (8 GPUs: 4 policy + 4 ref) +# ./dpo.sh simpo # Use SimPO (no reference model needed) +# ./dpo.sh orpo # Use ORPO (no reference model needed) +# +# Environment variables can be set to customize training: +# MODEL_ID - Model to train (default: ms://Qwen/Qwen3.5-4B) +# DATASET_ID - Preference dataset (default: UltraFeedback) +# MODEL_GPUS - GPUs for policy model (default: 4) +# REF_MODEL_GPUS - GPUs for reference model (default: 4) +# USE_REFERENCE_MODEL - Use reference model (default: 1) +# BATCH_SIZE - Global batch size (default: 8) +# MAX_STEPS - Training steps (default: 1000) +# LR - Learning rate (default: 5e-6) +# DPO_BETA - DPO beta parameter (default: 0.1) +# LOSS_TYPE - Loss variant: sigmoid/hinge/ipo/simpo/orpo/cpo (default: sigmoid) + +set -e + +# Parse command line argument for loss type +LOSS_TYPE_ARG=${1:-sigmoid} + +# Set default environment variables if not already set +export MODEL_ID=${MODEL_ID:-"ms://Qwen/Qwen3.5-4B"} +export DATASET_ID=${DATASET_ID:-"ms://argilla/ultrafeedback-binarized-preferences-cleaned"} +export MODEL_GPUS=${MODEL_GPUS:-4} +export BATCH_SIZE=${BATCH_SIZE:-8} +export MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-2} +export MAX_STEPS=${MAX_STEPS:-1000} +export LR=${LR:-5e-6} +export DPO_BETA=${DPO_BETA:-0.1} +export SAVE_STEPS=${SAVE_STEPS:-100} +export MAX_LENGTH=${MAX_LENGTH:-2048} + +# Set loss type from argument or environment +export LOSS_TYPE=${LOSS_TYPE:-$LOSS_TYPE_ARG} + +# Reference-free losses don't need reference model +if [[ "$LOSS_TYPE" == "simpo" || "$LOSS_TYPE" == "orpo" || "$LOSS_TYPE" == "cpo" ]]; then + export USE_REFERENCE_MODEL=${USE_REFERENCE_MODEL:-0} + export REF_MODEL_GPUS=${REF_MODEL_GPUS:-0} + echo "Using $LOSS_TYPE loss (reference-free)" +else + export USE_REFERENCE_MODEL=${USE_REFERENCE_MODEL:-1} + export REF_MODEL_GPUS=${REF_MODEL_GPUS:-4} + echo "Using $LOSS_TYPE loss with reference model" +fi + +# Calculate total GPUs +if [[ "$USE_REFERENCE_MODEL" == "1" && "$REF_MODEL_GPUS" -gt 0 ]]; then + TOTAL_GPUS=$((MODEL_GPUS + REF_MODEL_GPUS)) +else + TOTAL_GPUS=$MODEL_GPUS +fi + +echo "==========================================" +echo "DPO Training Configuration" +echo "==========================================" +echo "Model: $MODEL_ID" +echo "Dataset: $DATASET_ID" +echo "Loss Type: $LOSS_TYPE" +echo "DPO Beta: $DPO_BETA" +echo "Policy GPUs: $MODEL_GPUS" +echo "Reference GPUs: $REF_MODEL_GPUS" +echo "Total GPUs: $TOTAL_GPUS" +echo "Batch Size: $BATCH_SIZE" +echo "Micro Batch Size: $MICRO_BATCH_SIZE" +echo "Max Steps: $MAX_STEPS" +echo "Learning Rate: $LR" +echo "Max Length: $MAX_LENGTH" +echo "Save Steps: $SAVE_STEPS" +echo "==========================================" + +# Get script directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Run training +python "$SCRIPT_DIR/dpo.py" diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 10d75df6..a8c8bb87 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -10,7 +10,7 @@ from twinkle.preprocessor import SelfCognitionProcessor # Construct a device_mesh, dp=2 -device_mesh = DeviceMesh.from_sizes(dp_size=2) +device_mesh = DeviceMesh.from_sizes(dp_size=8) # use torchrun mode twinkle.initialize(mode='local', global_device_mesh=device_mesh) diff --git a/docs/source_en/Components/Data Format/Trajectory.md b/docs/source_en/Components/Data Format/Trajectory.md index d0c14aec..0efc6b6e 100644 --- a/docs/source_en/Components/Data Format/Trajectory.md +++ b/docs/source_en/Components/Data Format/Trajectory.md @@ -5,12 +5,14 @@ The raw data structure input to Template after dataset ETL is `Trajectory` (traj ```python class Trajectory(TypedDict, total=False): messages: List[Message] - extend_message: List[Tuple[str, List[Message]]] tools: List[Tool] + user_data: List[Tuple[str, Any]] ``` - messages: A list of Message messages, representing the multi-turn conversations actually conducted by the model, usually alternating between `user` and `assistant`. -- extend_message: In training such as DPO and PPO, unusable trajectories or low-score trajectories are usually needed, which will be placed in extend_message - tools: A list of all available tools for the model in this call +- user_data: User-defined data, such as labels in KTO training + +For preference alignment training like DPO, preprocessors return `{'positive': List[Trajectory], 'negative': List[Trajectory]}` format. Trajectory is the standard interface for all dataset preprocessing outputs and template inputs in Twinkle. The format conversion goes from the original dataset to Trajectory, and then to InputFeature. diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md" index f7ed4f12..5281999c 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md" @@ -5,12 +5,14 @@ ```python class Trajectory(TypedDict, total=False): messages: List[Message] - extend_message: List[Tuple[str, List[Message]]] tools: List[Tool] + user_data: List[Tuple[str, Any]] ``` - messages: Message消息的列表,代表模型实际进行的多轮对话,通常是`user`和`assistant`交替出现。 -- extend_message: 在DPO、PPO等训练中通常需要不可用轨迹,或低分轨迹,该轨迹会放在extend_message中 - tools: 模型在本次调用中的所有可用工具列表 +- user_data: 用户自定义数据,如KTO训练中的label + +对于DPO等偏好对齐训练,预处理器返回`{'positive': List[Trajectory], 'negative': List[Trajectory]}`格式。 Trajectory是twinkle中所有数据集预处理输出,模板输入的标准接口。格式转换为由原始数据集转换为Trajectory,再到InputFeature。 diff --git a/src/twinkle/data_format/trajectory.py b/src/twinkle/data_format/trajectory.py index c7742d75..51a21fc5 100644 --- a/src/twinkle/data_format/trajectory.py +++ b/src/twinkle/data_format/trajectory.py @@ -13,6 +13,5 @@ class Trajectory(TypedDict, total=False): messages: List[Message] - extend_message: List[Tuple[str, List[Message]]] tools: List[Tool] user_data: List[Tuple[str, Any]] diff --git a/src/twinkle/dataset/base.py b/src/twinkle/dataset/base.py index 00adc984..fc14a030 100644 --- a/src/twinkle/dataset/base.py +++ b/src/twinkle/dataset/base.py @@ -87,7 +87,7 @@ def encode(self, add_generation_prompt: bool = False, **kwargs): with processing_lock('dataset'): # use a default lock because encode is to all datasets self.dataset = self.dataset.map(encode_fn, - **kwargs).filter(lambda batch: [len(x) > 0 for x in batch['input_ids']], + **kwargs).filter(lambda batch: [True] * len(next(iter(batch.values()))) if 'input_ids' not in batch else [len(x) > 0 for x in batch['input_ids']], **kwargs) @remote_function() diff --git a/src/twinkle/loss/__init__.py b/src/twinkle/loss/__init__.py index 7870f5a4..4e4d0e82 100644 --- a/src/twinkle/loss/__init__.py +++ b/src/twinkle/loss/__init__.py @@ -2,6 +2,7 @@ from .base import Loss from .chunked_cross_entropy import ChunkedCrossEntropyLoss from .cross_entropy import CrossEntropyLoss +from .dpo import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss from .gkd import GKDLoss from .grpo import BNPOLoss, CISPOLoss, DRGRPOLoss, GRPOLoss, GSPOLoss, SAPOLoss from .mse import MSELoss @@ -19,4 +20,9 @@ 'cispo': CISPOLoss, 'bnpo': BNPOLoss, 'dr_grpo': DRGRPOLoss, + # DPO family losses + 'dpo': DPOLoss, + 'simpo': SimPOLoss, + 'cpo': CPOLoss, + 'orpo': ORPOLoss, } diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py new file mode 100644 index 00000000..ce533053 --- /dev/null +++ b/src/twinkle/loss/dpo.py @@ -0,0 +1,559 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +DPO (Direct Preference Optimization) Loss Implementation. + +Reference: + "Direct Preference Optimization: Your Language Model is Secretly a Reward Model" + (https://arxiv.org/abs/2305.18290) +""" +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +from twinkle.data_format import LossOutput +from twinkle.utils.torch_utils import selective_log_softmax +from twinkle.loss.base import Loss + +if TYPE_CHECKING: + import torch + + +class PreferenceLossBase(Loss): + """Base class for preference optimization losses with shared utilities.""" + + def __init__(self, ignore_index: int = -100): + self.ignore_index = ignore_index + + def _compute_logps_from_logits( + self, + logits: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute per-token log probabilities from logits. + + Args: + logits: [batch, seq_len, vocab_size] model logits + labels: [batch, seq_len] target token ids + + Returns: + logps: [batch, seq_len] per-token log probabilities + """ + loss_mask = (labels != self.ignore_index).bool() + masked_labels = labels.clone() + masked_labels[~loss_mask] = 0 + return selective_log_softmax(logits, masked_labels) + + def _compute_sequence_logps( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute sequence-level log probabilities by summing valid token logps.""" + loss_mask = (labels != self.ignore_index).float() + return (per_token_logps * loss_mask).sum(dim=-1) + + def _compute_avg_logps( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute length-normalized (average) log probabilities.""" + loss_mask = (labels != self.ignore_index).float() + seq_lengths = loss_mask.sum(dim=-1).clamp(min=1) + return (per_token_logps * loss_mask).sum(dim=-1) / seq_lengths + + def _compute_nll_loss( + self, + per_token_logps: 'torch.Tensor', + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute negative log likelihood loss.""" + loss_mask = (labels != self.ignore_index).float() + return -(per_token_logps * loss_mask).sum() / loss_mask.sum().clamp(min=1) + + def _get_logps_from_outputs( + self, + outputs: Dict, + labels: 'torch.Tensor', + ) -> 'torch.Tensor': + """Extract or compute log probabilities from model outputs.""" + logps = outputs.get('logps') + if logps is None: + logits = outputs.get('logits') + assert logits is not None, "outputs must contain 'logps' or 'logits'" + if logits.shape[1] != labels.shape[1]: + logits = logits[:, -labels.shape[1]:] + logps = self._compute_logps_from_logits(logits, labels) + return logps + + def _split_chosen_rejected( + self, + tensor: 'torch.Tensor', + ) -> tuple: + """Split interleaved tensor into chosen and rejected. + + Input format: [pos_1, neg_1, pos_2, neg_2, ...] (interleaved for DP-safe slicing) + Output: (chosen [pos_1, pos_2, ...], rejected [neg_1, neg_2, ...]) + """ + # Even indices = chosen (positive), odd indices = rejected (negative) + return tensor[0::2], tensor[1::2] + + +class DPOLoss(PreferenceLossBase): + """Direct Preference Optimization (DPO) Loss. + + DPO directly optimizes the policy using preference data without explicit reward modeling. + The loss function is derived from the Bradley-Terry preference model: + + L_DPO = -log(σ(β * (log π(y_w|x)/π_ref(y_w|x) - log π(y_l|x)/π_ref(y_l|x)))) + + where: + - y_w is the preferred (chosen) response + - y_l is the dispreferred (rejected) response + - β is the temperature parameter controlling deviation from reference + - π is the current policy + - π_ref is the reference policy (frozen) + + Args: + beta: Temperature parameter controlling how much to deviate from ref policy (default: 0.1). + label_smoothing: Label smoothing parameter for soft labels (default: 0.0). + ignore_index: Index to ignore in labels (default: -100). + loss_type: Type of DPO loss variant ('sigmoid', 'hinge', 'ipo', 'kto_pair') (default: 'sigmoid'). + reference_free: Whether to use reference-free DPO (default: False). + sft_weight: Weight for SFT loss on chosen responses to prevent likelihood displacement (default: 0.0). + """ + + def __init__( + self, + beta: float = 0.1, + label_smoothing: float = 0.0, + ignore_index: int = -100, + loss_type: str = 'sigmoid', + reference_free: bool = False, + sft_weight: float = 0.0, + **kwargs, + ): + super().__init__(ignore_index=ignore_index) + self.beta = beta + self.label_smoothing = label_smoothing + self.loss_type = loss_type + self.reference_free = reference_free + self.sft_weight = sft_weight + + def _align_logps( + self, + logps: 'torch.Tensor', + target_shape: tuple, + device: 'torch.device', + dtype: 'torch.dtype', + ) -> 'torch.Tensor': + """Align log probabilities to target shape. + + Args: + logps: Input log probabilities tensor + target_shape: Target (batch, seq_len) shape + device: Target device + dtype: Target dtype + + Returns: + Aligned tensor of shape target_shape + """ + import torch + + if not torch.is_tensor(logps): + raise TypeError(f'Expected torch.Tensor, got {type(logps)}') + + if logps.dim() == 1: + logps = logps.unsqueeze(0) + + if logps.shape == target_shape: + return logps.to(device=device, dtype=dtype) + + # Handle tensor with different sequence length + if logps.dim() == 2 and logps.shape[0] == target_shape[0]: + batch_size, target_seq_len = target_shape + src_seq_len = logps.shape[1] + logps = logps.to(device=device, dtype=dtype) + if src_seq_len > target_seq_len: + # Truncate right (keep left part) - may happen in Ray result merging + return logps[:, :target_seq_len] + else: + raise ValueError( + f'ref_logps seq_len ({src_seq_len}) < target seq_len ({target_seq_len}). ' + f'This should not happen when both models process the same batch.' + ) + + raise ValueError( + f'Cannot align ref_logps shape {logps.shape} to target shape {target_shape}' + ) + + def _compute_dpo_loss( + self, + policy_chosen_logps: 'torch.Tensor', + policy_rejected_logps: 'torch.Tensor', + reference_chosen_logps: 'torch.Tensor', + reference_rejected_logps: 'torch.Tensor', + ) -> 'torch.Tensor': + """Compute the DPO loss. + + Args: + policy_chosen_logps: [batch/2] log probs of chosen under current policy + policy_rejected_logps: [batch/2] log probs of rejected under current policy + reference_chosen_logps: [batch/2] log probs of chosen under reference policy + reference_rejected_logps: [batch/2] log probs of rejected under reference policy + + Returns: + loss: Scalar DPO loss + """ + import torch + import torch.nn.functional as F + + # Compute log ratios + if self.reference_free: + # Reference-free: only use policy log probs + chosen_logratios = policy_chosen_logps + rejected_logratios = policy_rejected_logps + else: + chosen_logratios = policy_chosen_logps - reference_chosen_logps + rejected_logratios = policy_rejected_logps - reference_rejected_logps + + # Compute preference margin + logits = self.beta * (chosen_logratios - rejected_logratios) + + if self.loss_type == 'sigmoid': + # Standard DPO loss: -log(sigmoid(beta * margin)) + losses = -F.logsigmoid(logits) + elif self.loss_type == 'hinge': + # Hinge loss variant + losses = torch.relu(1 - logits) + elif self.loss_type == 'ipo': + # IPO (Identity Preference Optimization) loss + # Reference: "A General Theoretical Paradigm to Understand Learning from Human Feedback" + losses = (logits - 1 / (2 * self.beta)) ** 2 + elif self.loss_type == 'kto_pair': + # KTO pair loss (simplified version) + chosen_logratios_scaled = self.beta * chosen_logratios + rejected_logratios_scaled = self.beta * rejected_logratios + chosen_losses = 1 - F.sigmoid(chosen_logratios_scaled) + rejected_losses = F.sigmoid(rejected_logratios_scaled) + losses = chosen_losses + rejected_losses + else: + raise ValueError(f"Unknown loss_type: {self.loss_type}") + + # Apply label smoothing if specified + if self.label_smoothing > 0: + # Soft labels: (1 - eps) * loss_chosen + eps * loss_rejected + smooth_losses = -F.logsigmoid(-logits) # Loss for flipped preference + losses = (1 - self.label_smoothing) * losses + self.label_smoothing * smooth_losses + + return losses.mean() + + def __call__( + self, + inputs: Dict, + outputs: Dict, + *, + ref_outputs: Optional[Dict] = None, + ref_logps: Optional[Union['torch.Tensor', List[List[float]]]] = None, + ref_chosen_logps: Optional['torch.Tensor'] = None, + ref_rejected_logps: Optional['torch.Tensor'] = None, + **kwargs, + ) -> LossOutput: + """Compute DPO loss. + + The inputs should contain concatenated chosen and rejected examples: + - First half of batch: chosen responses + - Second half of batch: rejected responses + + Args: + inputs: Dict containing 'input_ids' and 'labels' [batch, seq_len]. + Batch should be organized as [chosen_1, ..., chosen_n, rejected_1, ..., rejected_n] + outputs: Dict containing either: + - 'logps': [batch, seq_len] pre-computed log probs, OR + - 'logits': [batch, seq_len, vocab] from which logps will be computed + ref_outputs: Dict from reference model forward, containing 'logps'. + ref_logps: [batch, seq_len] or List[List[float]] reference model log probs. + Can also be provided as separate ref_chosen_logps and ref_rejected_logps. + ref_chosen_logps: [batch/2] pre-computed reference log probs for chosen. + ref_rejected_logps: [batch/2] pre-computed reference log probs for rejected. + **kwargs: Additional arguments. + + Returns: + LossOutput with DPO loss and metrics. + """ + import torch + + # Extract ref_logps from ref_outputs if provided + if ref_outputs is not None and ref_logps is None: + ref_logps = ref_outputs.get('logps') + + labels = inputs.get('labels') + assert labels is not None, "inputs must contain 'labels'" + if not torch.is_tensor(labels): + labels = torch.as_tensor(labels) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + batch_size = labels.shape[0] + assert batch_size % 2 == 0, "Batch size must be even (chosen + rejected pairs)" + + # Get log probabilities from outputs + logps = self._get_logps_from_outputs(outputs, labels) + device = logps.device + dtype = logps.dtype + + # Split into chosen and rejected + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + chosen_logps, rejected_logps = self._split_chosen_rejected(logps) + + # Compute sequence-level log probs for policy + policy_chosen_logps = self._compute_sequence_logps(chosen_logps, chosen_labels) + policy_rejected_logps = self._compute_sequence_logps(rejected_logps, rejected_labels) + + # Handle reference log probs + if ref_chosen_logps is not None and ref_rejected_logps is not None: + # Pre-computed sequence-level reference log probs provided + reference_chosen_logps = ref_chosen_logps.to(device=device, dtype=dtype) + reference_rejected_logps = ref_rejected_logps.to(device=device, dtype=dtype) + elif ref_logps is not None: + # Per-token reference log probs provided, need to align and sum + ref_logps_aligned = self._align_logps( + ref_logps, labels.shape, device, dtype + ) + ref_chosen, ref_rejected = self._split_chosen_rejected(ref_logps_aligned) + reference_chosen_logps = self._compute_sequence_logps(ref_chosen, chosen_labels) + reference_rejected_logps = self._compute_sequence_logps(ref_rejected, rejected_labels) + elif self.reference_free: + # Reference-free mode: no reference model needed + reference_chosen_logps = torch.zeros_like(policy_chosen_logps) + reference_rejected_logps = torch.zeros_like(policy_rejected_logps) + else: + raise ValueError( + "ref_logps or (ref_chosen_logps, ref_rejected_logps) must be provided " + "unless reference_free=True" + ) + + # Compute DPO loss + dpo_loss = self._compute_dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + + # Add SFT loss on chosen responses to prevent likelihood displacement + if self.sft_weight > 0: + sft_loss = self._compute_nll_loss(chosen_logps, chosen_labels) + loss = dpo_loss + self.sft_weight * sft_loss + else: + loss = dpo_loss + + # Return 0 to skip gradient normalization by num_tokens + # DPO loss is already per-sample mean, unlike SFT which sums per-token loss + # When num_tokens=0, normalize_and_clip_grad_norm defaults to 1 (no division) + return LossOutput(loss=loss, num_tokens=0) + + +class SimPOLoss(PreferenceLossBase): + """SimPO (Simple Preference Optimization) Loss. + + SimPO is a simpler variant of DPO that doesn't require a reference model. + It uses length-normalized log probabilities. + + Reference: + "SimPO: Simple Preference Optimization with a Reference-Free Reward" + (https://arxiv.org/abs/2405.14734) + + Args: + beta: Temperature parameter (default: 2.5). + gamma: Target reward margin (default: 0.5). + ignore_index: Index to ignore in labels (default: -100). + """ + + def __init__( + self, + beta: float = 2.5, + gamma: float = 0.5, + ignore_index: int = -100, + **kwargs, + ): + super().__init__(ignore_index=ignore_index) + self.beta = beta + self.gamma = gamma + + def __call__( + self, + inputs: Dict, + outputs: Dict, + **kwargs, + ) -> LossOutput: + """Compute SimPO loss.""" + import torch + import torch.nn.functional as F + + labels = inputs.get('labels') + assert labels is not None, "inputs must contain 'labels'" + if not torch.is_tensor(labels): + labels = torch.as_tensor(labels) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + assert labels.shape[0] % 2 == 0, "Batch size must be even (chosen + rejected pairs)" + + # Get log probabilities + logps = self._get_logps_from_outputs(outputs, labels) + + # Split into chosen and rejected + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + chosen_logps, rejected_logps = self._split_chosen_rejected(logps) + + # Compute length-normalized log probs + chosen_rewards = self._compute_avg_logps(chosen_logps, chosen_labels) + rejected_rewards = self._compute_avg_logps(rejected_logps, rejected_labels) + + # SimPO loss: -log(sigmoid(beta * (r_w - r_l) - gamma)) + logits = self.beta * (chosen_rewards - rejected_rewards) - self.gamma + loss = -F.logsigmoid(logits).mean() + + return LossOutput(loss=loss, num_tokens=0) + + +class CPOLoss(PreferenceLossBase): + """CPO (Contrastive Preference Optimization) Loss. + + CPO adds a behavior cloning term to preference optimization. + + Reference: + "Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation" + (https://arxiv.org/abs/2401.08417) + + Args: + beta: Temperature parameter for preference (default: 0.1). + bc_coef: Behavior cloning coefficient (default: 1.0). + ignore_index: Index to ignore in labels (default: -100). + """ + + def __init__( + self, + beta: float = 0.1, + bc_coef: float = 1.0, + ignore_index: int = -100, + **kwargs, + ): + super().__init__(ignore_index=ignore_index) + self.beta = beta + self.bc_coef = bc_coef + + def __call__( + self, + inputs: Dict, + outputs: Dict, + **kwargs, + ) -> LossOutput: + """Compute CPO loss.""" + import torch + import torch.nn.functional as F + + labels = inputs.get('labels') + assert labels is not None, "inputs must contain 'labels'" + if not torch.is_tensor(labels): + labels = torch.as_tensor(labels) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + assert labels.shape[0] % 2 == 0, "Batch size must be even" + + # Get log probabilities + logps = self._get_logps_from_outputs(outputs, labels) + + # Split into chosen and rejected + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + chosen_logps, rejected_logps = self._split_chosen_rejected(logps) + + # Compute sequence-level log probs + chosen_seq_logps = self._compute_sequence_logps(chosen_logps, chosen_labels) + rejected_seq_logps = self._compute_sequence_logps(rejected_logps, rejected_labels) + + # Preference loss (reference-free DPO) + logits = self.beta * (chosen_seq_logps - rejected_seq_logps) + preference_loss = -F.logsigmoid(logits).mean() + + # Behavior cloning loss on chosen + bc_loss = self._compute_nll_loss(chosen_logps, chosen_labels) + + # Combined loss + loss = preference_loss + self.bc_coef * bc_loss + + return LossOutput(loss=loss, num_tokens=0) + + +class ORPOLoss(PreferenceLossBase): + """ORPO (Odds Ratio Preference Optimization) Loss. + + ORPO combines SFT and preference alignment in a single objective using odds ratios. + + Reference: + "ORPO: Monolithic Preference Optimization without Reference Model" + (https://arxiv.org/abs/2403.07691) + + Args: + lambda_orpo: Weight for the odds ratio term (default: 0.1). + ignore_index: Index to ignore in labels (default: -100). + """ + + def __init__( + self, + lambda_orpo: float = 0.1, + ignore_index: int = -100, + **kwargs, + ): + super().__init__(ignore_index=ignore_index) + self.lambda_orpo = lambda_orpo + + def __call__( + self, + inputs: Dict, + outputs: Dict, + **kwargs, + ) -> LossOutput: + """Compute ORPO loss.""" + import torch + import torch.nn.functional as F + + labels = inputs.get('labels') + assert labels is not None, "inputs must contain 'labels'" + if not torch.is_tensor(labels): + labels = torch.as_tensor(labels) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + assert labels.shape[0] % 2 == 0, "Batch size must be even" + + # Get log probabilities + logps = self._get_logps_from_outputs(outputs, labels) + + # Split into chosen and rejected + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + chosen_logps, rejected_logps = self._split_chosen_rejected(logps) + + # SFT loss on chosen + sft_loss = self._compute_nll_loss(chosen_logps, chosen_labels) + + # Compute average log probs for odds ratio + chosen_avg_logps = self._compute_avg_logps(chosen_logps, chosen_labels) + rejected_avg_logps = self._compute_avg_logps(rejected_logps, rejected_labels) + + # Odds ratio: log(odds_chosen / odds_rejected) + # log_odds = log(p/(1-p)) = log(p) - log(1-p) + # Use numerically stable computation + prob_chosen = torch.exp(chosen_avg_logps).clamp(min=1e-7, max=1-1e-7) + prob_rejected = torch.exp(rejected_avg_logps).clamp(min=1e-7, max=1-1e-7) + log_odds_chosen = torch.log(prob_chosen) - torch.log(1 - prob_chosen) + log_odds_rejected = torch.log(prob_rejected) - torch.log(1 - prob_rejected) + + # ORPO odds ratio loss + odds_ratio = log_odds_chosen - log_odds_rejected + orpo_loss = -F.logsigmoid(odds_ratio).mean() + + # Combined loss + loss = sft_loss + self.lambda_orpo * orpo_loss + + return LossOutput(loss=loss, num_tokens=0) diff --git a/src/twinkle/metric/__init__.py b/src/twinkle/metric/__init__.py index 739c7a0d..59d5bbeb 100644 --- a/src/twinkle/metric/__init__.py +++ b/src/twinkle/metric/__init__.py @@ -2,5 +2,6 @@ from .accuracy import Accuracy from .base import Metric from .completion_and_reward import CompletionRewardMetric +from .dpo import DPOMetric from .loss import LossMetric from .train_metric import TrainMetric diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py new file mode 100644 index 00000000..c3f3e6cf --- /dev/null +++ b/src/twinkle/metric/dpo.py @@ -0,0 +1,177 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""DPO-specific metrics for preference optimization training.""" +from typing import List, Union + +from twinkle.data_format import InputFeature, ModelOutput +from .base import Metric + + +class DPOMetric(Metric): + """Metrics for DPO (Direct Preference Optimization) training. + + Computes TRL-style metrics: + - logps/chosen: Average sequence-level log prob of chosen responses + - logps/rejected: Average sequence-level log prob of rejected responses + - rewards/chosen: β * (policy_chosen - ref_chosen) + - rewards/rejected: β * (policy_rejected - ref_rejected) + - rewards/margins: chosen_reward - rejected_reward + - rewards/accuracies: Percentage where chosen_reward > rejected_reward + + Args: + device_mesh: The device mesh + process_group: The process group to collect data from + ignore_index: Label index to ignore (default: -100) + beta: DPO beta parameter for reward scaling (default: 0.1) + """ + + def __init__(self, device_mesh, process_group, ignore_index: int = -100, beta: float = 0.1, **kwargs): + super().__init__(device_mesh, process_group, **kwargs) + self.ignore_index = ignore_index + self.beta = beta + self.reset() + + def _compute_sequence_logps(self, per_token_logps, labels): + """Compute sequence-level log probs by summing valid token logps.""" + import torch + loss_mask = (labels != self.ignore_index).float() + return (per_token_logps * loss_mask).sum(dim=-1) + + def _split_chosen_rejected(self, tensor): + """Split interleaved tensor into chosen and rejected. + + Input format: [pos_1, neg_1, pos_2, neg_2, ...] (interleaved for DP-safe slicing) + Output: (chosen [pos_1, pos_2, ...], rejected [neg_1, neg_2, ...]) + """ + return tensor[0::2], tensor[1::2] + + def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs): + """Accumulate DPO metrics from model outputs. + + Expects: + - outputs['logps']: [batch, seq_len] per-token log probabilities + - inputs['labels']: [batch, seq_len] labels with ignore_index for non-target tokens + - kwargs['ref_outputs']: Optional reference model outputs with 'logps' + """ + import torch + + logps = outputs.get('logps') + if logps is None: + return + + # Get labels from inputs + if isinstance(inputs, list): + # Stack labels from list of inputs + labels_list = [torch.as_tensor(inp['labels']) for inp in inputs] + max_len = max(l.shape[0] for l in labels_list) + padded = [] + for l in labels_list: + if l.shape[0] < max_len: + pad = torch.full((max_len - l.shape[0],), self.ignore_index, dtype=l.dtype) + l = torch.cat([pad, l]) + padded.append(l) + labels = torch.stack(padded) + else: + labels = torch.as_tensor(inputs['labels']) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + # Ensure logps and labels have same device + if logps.device != labels.device: + labels = labels.to(logps.device) + + # Align sequence lengths if needed (truncate right) + if logps.shape[1] != labels.shape[1]: + min_len = min(logps.shape[1], labels.shape[1]) + logps = logps[:, :min_len] + labels = labels[:, :min_len] + + # Compute sequence-level logps + seq_logps = self._compute_sequence_logps(logps, labels) + + # Split into chosen and rejected (interleaved format) + chosen_logps, rejected_logps = self._split_chosen_rejected(seq_logps) + chosen_labels, rejected_labels = self._split_chosen_rejected(labels) + + # Accumulate policy logps + self.total_chosen_logps += chosen_logps.sum().item() + self.total_rejected_logps += rejected_logps.sum().item() + + # Compute rewards if ref_outputs available + ref_outputs = kwargs.get('ref_outputs') + if ref_outputs is not None: + ref_logps = ref_outputs.get('logps') + if ref_logps is not None: + # Align ref_logps + if ref_logps.device != labels.device: + ref_logps = ref_logps.to(labels.device) + if ref_logps.shape[1] != labels.shape[1]: + min_len = min(ref_logps.shape[1], labels.shape[1]) + ref_logps = ref_logps[:, :min_len] + + ref_seq_logps = self._compute_sequence_logps(ref_logps, labels) + ref_chosen_logps, ref_rejected_logps = self._split_chosen_rejected(ref_seq_logps) + + # Compute rewards: β * (policy - ref) + chosen_rewards = self.beta * (chosen_logps - ref_chosen_logps) + rejected_rewards = self.beta * (rejected_logps - ref_rejected_logps) + + self.total_chosen_rewards += chosen_rewards.sum().item() + self.total_rejected_rewards += rejected_rewards.sum().item() + margins = chosen_rewards - rejected_rewards + self.total_reward_margin += margins.sum().item() + self.total_reward_correct += (margins > 0).sum().item() + self.has_rewards = True + + self.total_count += chosen_logps.shape[0] + + def reset(self): + """Reset all accumulated values.""" + self.total_chosen_logps = 0.0 + self.total_rejected_logps = 0.0 + self.total_chosen_rewards = 0.0 + self.total_rejected_rewards = 0.0 + self.total_reward_margin = 0.0 + self.total_reward_correct = 0 + self.total_count = 0 + self.has_rewards = False + + def calculate(self): + """Calculate and return aggregated metrics.""" + local_results = [{ + 'chosen_logps': self.total_chosen_logps, + 'rejected_logps': self.total_rejected_logps, + 'chosen_rewards': self.total_chosen_rewards, + 'rejected_rewards': self.total_rejected_rewards, + 'reward_margin': self.total_reward_margin, + 'reward_correct': self.total_reward_correct, + 'count': self.total_count, + 'has_rewards': self.has_rewards, + }] + all_results = self.gather_results(local_results) + + total_chosen_logps = sum(r['chosen_logps'] for r in all_results) + total_rejected_logps = sum(r['rejected_logps'] for r in all_results) + total_chosen_rewards = sum(r['chosen_rewards'] for r in all_results) + total_rejected_rewards = sum(r['rejected_rewards'] for r in all_results) + total_reward_margin = sum(r['reward_margin'] for r in all_results) + total_reward_correct = sum(r['reward_correct'] for r in all_results) + total_count = sum(r['count'] for r in all_results) + has_rewards = any(r['has_rewards'] for r in all_results) + + self.reset() + + if total_count == 0: + return {} + + results = { + 'logps/chosen': f'{total_chosen_logps / total_count:.2f}', + 'logps/rejected': f'{total_rejected_logps / total_count:.2f}', + } + + if has_rewards: + results['rewards/chosen'] = f'{total_chosen_rewards / total_count:.4f}' + results['rewards/rejected'] = f'{total_rejected_rewards / total_count:.4f}' + results['rewards/margins'] = f'{total_reward_margin / total_count:.4f}' + results['rewards/accuracies'] = f'{total_reward_correct / total_count * 100:.1f}%' + + return results diff --git a/src/twinkle/metric/loss.py b/src/twinkle/metric/loss.py index 52f50fdd..8f4ad0c9 100644 --- a/src/twinkle/metric/loss.py +++ b/src/twinkle/metric/loss.py @@ -52,7 +52,6 @@ def calculate(self): 'grad_norm': self.grad_norm, 'num_tokens': self.num_tokens }] - all_results = self.gather_results(local_results) total_loss = sum(r['loss'] for r in all_results) @@ -61,8 +60,10 @@ def calculate(self): num_tokens = sum(r['num_tokens'] for r in all_results) if num_tokens > 0: avg_loss = total_loss / num_tokens - else: + elif total_count > 0: avg_loss = total_loss / total_count + else: + avg_loss = 0.0 self.reset() results = {} if avg_loss is not None: diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 4b37973c..7c131210 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -245,7 +245,7 @@ def __init__( def _construct_default_optimizer_group(self): return MegatronOptimizerGroup( - loss_instance=CrossEntropyLoss(), + loss_instance=CrossEntropyLoss(reduction='sum'), template=Template(self.tokenizer_id), processor=InputProcessor(self.device_mesh, framework='megatron'), _device_mesh=self.device_mesh, @@ -479,6 +479,12 @@ def post_loss_function(output_tensor, inputs, logps): losses = result['loss'] counts = result['num_tokens'] if not counts: + # Later will gather this value, so it becomes: + # 1. SUM loss: gather_sum(local_num_tokens) = global_num_tokens + # 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps ) = gradient_accumulation_steps * world_size + # Then, grad will divided by this value: + # 1. SUM loss: (global_sum_grad) / (global_num_tokens) = global_sum_grad/global_num_tokens + # 2. PER TOKEN MEAN loss: (gather_sum(per_token_grad * gradient_accumulation_steps)) / (gradient_accumulation_steps * world_size ) = avg_per_token_grad counts = torch.tensor(1, device=losses.device) return self.strategy.reduce_loss(losses, counts, output_tensor, logps) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 520aaf9f..a6be74e8 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -121,6 +121,8 @@ def accumulate_metrics(self, is_training): metrics = self.train_metrics else: metrics = self.eval_metrics + # Get stored forward_kwargs from previous forward + forward_kwargs = getattr(self, 'forward_kwargs', None) or {} if len(metrics) > 0 and self.inputs is not None and self.outputs is not None: for metric in metrics: metric.accumulate( @@ -130,7 +132,8 @@ def accumulate_metrics(self, is_training): step=self.cur_step - 1, gradient_accumulation_steps=self.gradient_accumulation_steps, grad_norm=self._last_grad_norm, - loss_reduction=getattr(self.loss_instance, 'reduction', 'mean')) + loss_reduction=getattr(self.loss_instance, 'reduction', 'mean'), + **forward_kwargs) def calculate_metrics(self, is_training): self.accumulate_metrics(is_training) @@ -405,6 +408,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec inputs['labels'] = labels optimizer_config.inputs = inputs optimizer_config.outputs = outputs + optimizer_config.forward_kwargs = kwargs # Store for next metric accumulation optimizer_config.loss_value = outputs.get('aux_loss', 0) if labels is not None: loss_mask = (labels != -100).bool() @@ -496,7 +500,18 @@ def calculate_loss(self, **kwargs): loss_value = result['loss'] counts = result['num_tokens'] if not counts: - counts = torch.tensor(0, device=loss_value.device) + counts = torch.tensor(1, device=loss_value.device) + # Later will gather this value, so it becomes: + # 1. SUM loss: gather_sum(local_num_tokens / dp_world_size) = global_num_tokens / dp_world_size + # 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps / dp_world_size ) = gradient_accumulation_steps + # Then, grad will divided by this value: + # 1. SUM loss: gather_mean(local_sum_grad) / (global_num_tokens / dp_world_size) + # = (global_sum_grad / dp_world_size) / (global_num_tokens / dp_world_size) + # = global_sum_grad/global_num_tokens + # 2. PER TOKEN MEAN loss: gather_mean(per_token_grad * gradient_accumulation_steps) / gradient_accumulation_steps + # = (global_per_token_grad * gradient_accumulation_steps / dp_world_size ) / gradient_accumulation_steps + # = global_per_token_grad / dp_world_size = avg_per_token_grad + counts = counts / self.device_mesh.data_world_size optimizer_config = self.optimizer_group[adapter_name] optimizer_config.num_tokens += counts.item() if self.sp_strategy is not None and 'labels' in inputs: @@ -1086,6 +1101,7 @@ def set_grad_scaler(self, **kwargs): grad_scaler_config.update(kwargs) optimizer_config.scaler = GradScaler(**grad_scaler_config) + @remote_function() def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs): """Add an eval metric diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py index 13b52d99..6d9f6dd7 100644 --- a/src/twinkle/preprocessor/__init__.py +++ b/src/twinkle/preprocessor/__init__.py @@ -1,4 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import DataFilter, Preprocessor +from .dpo import (DPOProcessor, EmojiDPOProcessor, HHRLHFProcessor, IntelOrcaDPOProcessor, ShareGPTDPOProcessor, + UltraFeedbackKTOProcessor, UltraFeedbackProcessor) from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor, GSM8KProcessor, SelfCognitionProcessor) diff --git a/src/twinkle/preprocessor/dpo.py b/src/twinkle/preprocessor/dpo.py new file mode 100644 index 00000000..0a03c4ad --- /dev/null +++ b/src/twinkle/preprocessor/dpo.py @@ -0,0 +1,504 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +DPO (Direct Preference Optimization) Data Preprocessors. + +These preprocessors convert various preference dataset formats into the standard +format required by Twinkle for DPO training. + +DPO output format: + - positive: Trajectory - chosen response trajectory + - negative: Trajectory - rejected response trajectory +""" +from typing import Any, Dict, List, Optional, Union + +from twinkle.data_format import Message, Trajectory +from .base import Preprocessor + + +class DPOProcessor(Preprocessor): + """Generic DPO preference data preprocessor. + + Converts preference data with chosen/rejected pairs into positive/negative Trajectories. + Supports multiple common dataset formats. + + Expected input format (one of): + 1. {'prompt': str, 'chosen': str, 'rejected': str} + 2. {'prompt': str, 'chosen': List[Message], 'rejected': List[Message]} + 3. {'messages': List[Message], 'chosen': str, 'rejected': str} + 4. {'chosen': List[Message], 'rejected': List[Message]} (full conversations) + + Output format: + - positive: Trajectory with chosen response + - negative: Trajectory with rejected response + + Args: + system: Optional system prompt to prepend. + chosen_key: Key for chosen response (default: 'chosen'). + rejected_key: Key for rejected response (default: 'rejected'). + prompt_key: Key for prompt/question (default: 'prompt'). + messages_key: Key for conversation messages (default: 'messages'). + """ + + def __init__( + self, + system: Optional[str] = None, + chosen_key: str = 'chosen', + rejected_key: str = 'rejected', + prompt_key: str = 'prompt', + messages_key: str = 'messages', + ): + self.system = system + self.chosen_key = chosen_key + self.rejected_key = rejected_key + self.prompt_key = prompt_key + self.messages_key = messages_key + + def _parse_response(self, response: Union[str, List[Dict], List[Message]]) -> List[Message]: + """Parse response into list of Messages.""" + if isinstance(response, str): + return [Message(role='assistant', content=response)] + elif isinstance(response, list): + messages = [] + for msg in response: + if isinstance(msg, Message): + messages.append(msg) + elif isinstance(msg, dict): + messages.append(Message(role=msg.get('role', 'assistant'), content=msg.get('content', ''))) + return messages + return [Message(role='assistant', content=str(response))] + + def _build_prompt_messages(self, row: Dict[str, Any]) -> List[Message]: + """Build prompt messages from row data.""" + messages = [] + + # Add system message if provided + if self.system: + messages.append(Message(role='system', content=self.system)) + + # Check for messages field (conversation format) + if self.messages_key in row and row[self.messages_key]: + raw_messages = row[self.messages_key] + for msg in raw_messages: + if isinstance(msg, Message): + messages.append(msg) + elif isinstance(msg, dict): + messages.append(Message(role=msg.get('role'), content=msg.get('content', ''))) + return messages + + # Check for prompt field + if self.prompt_key in row and row[self.prompt_key]: + prompt = row[self.prompt_key] + if isinstance(prompt, str): + messages.append(Message(role='user', content=prompt)) + elif isinstance(prompt, list): + for msg in prompt: + if isinstance(msg, Message): + messages.append(msg) + elif isinstance(msg, dict): + messages.append(Message(role=msg.get('role'), content=msg.get('content', ''))) + + return messages + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + """Process a single row into positive/negative Trajectories. + + Returns: + Dict with 'positive' and 'negative' Trajectory. + """ + # Build prompt messages + prompt_messages = self._build_prompt_messages(row) + + # Get chosen response + chosen_raw = row.get(self.chosen_key, '') + chosen_response = self._parse_response(chosen_raw) + + # Get rejected response + rejected_raw = row.get(self.rejected_key, '') + rejected_response = self._parse_response(rejected_raw) + + # Build full message lists + chosen_messages = prompt_messages + chosen_response + rejected_messages = prompt_messages + rejected_response + + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + """Process batched data into DPO format.""" + rows = self.map_col_to_row(rows) + results = [self.preprocess(row) for row in rows] + # Collect all positive and negative trajectories + positive_list = [r['positive'] for r in results] + negative_list = [r['negative'] for r in results] + return { + 'positive': positive_list, + 'negative': negative_list, + } + + +class HHRLHFProcessor(Preprocessor): + """Preprocessor for Anthropic HH-RLHF dataset format. + + HH-RLHF format: + {'chosen': "Human: ... Assistant: ...", 'rejected': "Human: ... Assistant: ..."} + + The conversations use "Human:" and "Assistant:" prefixes. + """ + + def __init__(self, system: Optional[str] = None): + self.system = system + + def _parse_hh_conversation(self, text: str) -> List[Message]: + """Parse HH-RLHF style conversation text into Messages.""" + messages = [] + + if self.system: + messages.append(Message(role='system', content=self.system)) + + # Split by Human/Assistant markers + parts = text.split('\n\nHuman: ') + for i, part in enumerate(parts): + if i == 0 and not part.startswith('Human: '): + if part.strip(): + if part.startswith('Human: '): + part = part[7:] + messages.append(Message(role='user', content=part.strip())) + continue + + # Split Human and Assistant parts + if '\n\nAssistant: ' in part: + human_part, assistant_part = part.split('\n\nAssistant: ', 1) + messages.append(Message(role='user', content=human_part.strip())) + messages.append(Message(role='assistant', content=assistant_part.strip())) + else: + messages.append(Message(role='user', content=part.strip())) + + return messages + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + """Process HH-RLHF format row.""" + chosen_text = row.get('chosen', '') + rejected_text = row.get('rejected', '') + + chosen_messages = self._parse_hh_conversation(chosen_text) + rejected_messages = self._parse_hh_conversation(rejected_text) + + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + results = [self.preprocess(row) for row in rows] + return { + 'positive': [r['positive'] for r in results], + 'negative': [r['negative'] for r in results], + } + + +class UltraFeedbackProcessor(Preprocessor): + """Preprocessor for UltraFeedback dataset format. + + UltraFeedback format: + { + 'instruction': str, + 'completions': [ + {'response': str, 'overall_score': float, ...}, + ... + ] + } + + Selects highest and lowest scored completions as chosen/rejected. + """ + + def __init__( + self, + system: Optional[str] = None, + instruction_key: str = 'instruction', + completions_key: str = 'completions', + response_key: str = 'response', + score_key: str = 'overall_score', + ): + self.system = system + self.instruction_key = instruction_key + self.completions_key = completions_key + self.response_key = response_key + self.score_key = score_key + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Trajectory]]: + """Process UltraFeedback format row.""" + instruction = row.get(self.instruction_key, '') + completions = row.get(self.completions_key, []) + + if len(completions) < 2: + return None + + # Sort by score + scored_completions = [ + (c.get(self.score_key, 0), c.get(self.response_key, '')) + for c in completions + if c.get(self.response_key) + ] + + if len(scored_completions) < 2: + return None + + scored_completions.sort(key=lambda x: x[0], reverse=True) + chosen_response = scored_completions[0][1] + rejected_response = scored_completions[-1][1] + + # Build messages + prompt_messages = [] + if self.system: + prompt_messages.append(Message(role='system', content=self.system)) + prompt_messages.append(Message(role='user', content=instruction)) + + chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_response)] + rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_response)] + + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + results = [self.preprocess(row) for row in rows] + results = [r for r in results if r is not None] + if not results: + return {} + return { + 'positive': [r['positive'] for r in results], + 'negative': [r['negative'] for r in results], + } + + +class ShareGPTDPOProcessor(Preprocessor): + """Preprocessor for ShareGPT-style DPO datasets. + + Expected format: + { + 'conversations': [ + {'from': 'human', 'value': '...'}, + {'from': 'gpt', 'value': '...'}, + ... + ], + 'chosen': {'from': 'gpt', 'value': '...'}, + 'rejected': {'from': 'gpt', 'value': '...'} + } + """ + + ROLE_MAPPING = { + 'human': 'user', + 'gpt': 'assistant', + 'system': 'system', + 'user': 'user', + 'assistant': 'assistant', + } + + def __init__(self, system: Optional[str] = None): + self.system = system + + def _parse_sharegpt_message(self, msg: Dict) -> Message: + """Parse ShareGPT format message.""" + role = self.ROLE_MAPPING.get(msg.get('from', ''), 'user') + content = msg.get('value', '') or msg.get('content', '') + return Message(role=role, content=content) + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + """Process ShareGPT DPO format row.""" + conversations = row.get('conversations', []) + + # Build prompt messages + prompt_messages = [] + if self.system: + prompt_messages.append(Message(role='system', content=self.system)) + + for msg in conversations: + prompt_messages.append(self._parse_sharegpt_message(msg)) + + # Remove last message if it's assistant (will be replaced) + if prompt_messages and prompt_messages[-1]['role'] == 'assistant': + prompt_messages = prompt_messages[:-1] + + # Get chosen and rejected + chosen_msg = row.get('chosen', {}) + rejected_msg = row.get('rejected', {}) + + if isinstance(chosen_msg, dict): + chosen_content = chosen_msg.get('value', '') or chosen_msg.get('content', '') + else: + chosen_content = str(chosen_msg) + + if isinstance(rejected_msg, dict): + rejected_content = rejected_msg.get('value', '') or rejected_msg.get('content', '') + else: + rejected_content = str(rejected_msg) + + chosen_messages = prompt_messages + [Message(role='assistant', content=chosen_content)] + rejected_messages = prompt_messages + [Message(role='assistant', content=rejected_content)] + + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + results = [self.preprocess(row) for row in rows] + return { + 'positive': [r['positive'] for r in results], + 'negative': [r['negative'] for r in results], + } + + +class IntelOrcaDPOProcessor(Preprocessor): + """Preprocessor for Intel ORCA DPO dataset format. + + Expected format: + { + 'system': str, + 'question': str, + 'chosen': str, + 'rejected': str + } + """ + + def __init__(self, default_system: Optional[str] = None): + self.default_system = default_system + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + """Process Intel ORCA DPO format row.""" + system = row.get('system', self.default_system) + question = row.get('question', '') + chosen = row.get('chosen', '') + rejected = row.get('rejected', '') + + prompt_messages = [] + if system: + prompt_messages.append(Message(role='system', content=system)) + prompt_messages.append(Message(role='user', content=question)) + + chosen_messages = prompt_messages + [Message(role='assistant', content=chosen)] + rejected_messages = prompt_messages + [Message(role='assistant', content=rejected)] + + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + results = [self.preprocess(row) for row in rows] + return { + 'positive': [r['positive'] for r in results], + 'negative': [r['negative'] for r in results], + } + + +class EmojiDPOProcessor(Preprocessor): + """Preprocessor for shareAI/DPO-zh-en-emoji dataset format. + + Dataset format: + { + 'prompt': str, + 'answer_zh': str, # chosen response (Chinese) + 'answer_en': str, # rejected response (English) + } + + Output format: + - positive: Trajectory with chosen (answer_zh) + - negative: Trajectory with rejected (answer_en) + + Args: + system: Optional system prompt. + chosen_key: Key for chosen response (default: 'answer_zh'). + rejected_key: Key for rejected response (default: 'answer_en'). + prompt_key: Key for prompt (default: 'prompt'). + """ + + def __init__( + self, + system: Optional[str] = None, + chosen_key: str = 'answer_zh', + rejected_key: str = 'answer_en', + prompt_key: str = 'prompt', + ): + self.system = system + self.chosen_key = chosen_key + self.rejected_key = rejected_key + self.prompt_key = prompt_key + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Trajectory]: + """Process a single row.""" + prompt = row.get(self.prompt_key, '') + chosen = row.get(self.chosen_key, '') + rejected = row.get(self.rejected_key, '') + + prompt_messages = [] + if self.system: + prompt_messages.append(Message(role='system', content=self.system)) + prompt_messages.append(Message(role='user', content=prompt)) + + chosen_messages = prompt_messages + [Message(role='assistant', content=chosen)] + rejected_messages = prompt_messages + [Message(role='assistant', content=rejected)] + + return { + 'positive': Trajectory(messages=chosen_messages), + 'negative': Trajectory(messages=rejected_messages), + } + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + results = [self.preprocess(row) for row in rows] + return { + 'positive': [r['positive'] for r in results], + 'negative': [r['negative'] for r in results], + } + + +class UltraFeedbackKTOProcessor(Preprocessor): + """Preprocessor for ultrafeedback-binarized-preferences-cleaned-kto dataset. + + Dataset format: + { + 'prompt': str, + 'completion': str, + 'label': bool, # True for chosen, False for rejected + } + + For KTO training, we need (prompt, completion, label) format. + The label is stored in user_data. + + Args: + system: Optional system prompt. + """ + + def __init__(self, system: Optional[str] = None): + self.system = system + + def preprocess(self, row: Dict[str, Any]) -> Trajectory: + """Process a single row for KTO.""" + prompt = row.get('prompt', '') + completion = row.get('completion', '') + label = row.get('label', True) + + messages = [] + if self.system: + messages.append(Message(role='system', content=self.system)) + messages.append(Message(role='user', content=prompt)) + messages.append(Message(role='assistant', content=completion)) + + return Trajectory( + messages=messages, + user_data=[('kto_label', label)] + ) + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 167a459d..a3e53056 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -179,9 +179,6 @@ def _add_default_system(self, trajectory: Trajectory) -> List[Trajectory]: if self.use_chat_template and self.default_system: if trajectory['messages'][0]['role'] == 'user': trajectory['messages'].insert(0, Message(role='system', content=self.default_system)) - for (_, messages) in trajectory.get('extend_message', []): - if messages and messages[0]['role'] == 'user': - messages.insert(0, Message(role='system', content=self.default_system)) return [trajectory] def _to_standard_reasoning_content(self, trajectory: Trajectory) -> List[Trajectory]: @@ -206,33 +203,47 @@ def _extract_reasoning_content(messages: list[Message]) -> List[Message]: return result trajectory['messages'] = _extract_reasoning_content(trajectory['messages']) - extra_messages = trajectory.get('extend_message', []) - if extra_messages: - result = [] - for key, extra_message in trajectory.get('extend_message', []): - result.append((key, _extract_reasoning_content(extra_message))) - trajectory['extend_message'] = result return [trajectory] + def _truncate_feature(self, feature: InputFeature, strategy: str) -> InputFeature: + """Truncate input_ids and labels in a single InputFeature.""" + length = len(feature['input_ids']) + if length <= self.max_length: + return feature + if strategy == 'raise': + raise ValueError(f'Input length {length} exceeds max_length {self.max_length}') + result = dict(feature) + if strategy == 'left': + result['input_ids'] = result['input_ids'][-self.max_length:] + if 'labels' in result: + result['labels'] = result['labels'][-self.max_length:] + elif strategy == 'right': + result['input_ids'] = result['input_ids'][:self.max_length] + if 'labels' in result: + result['labels'] = result['labels'][:self.max_length] + return InputFeature(**result) + def _check_max_length(self, input_feature: InputFeature) -> List[InputFeature]: - if self.max_length and len(input_feature['input_ids']) > self.max_length: - if self.truncation_strategy == 'raise': - raise ValueError(f'An input message(length: {len(input_feature["input_ids"])} ' - f'exceeds the maximum length({self.max_length})') - elif self.truncation_strategy == 'left': - return [InputFeature(**{key: value[-self.max_length:] for key, value in input_feature.items()})] - elif self.truncation_strategy == 'right': - return [InputFeature(**{key: value[:self.max_length] for key, value in input_feature.items()})] - else: # split - result = [] - total_length = len(input_feature['input_ids']) - for start in range(0, total_length, self.max_length): - end = min(start + self.max_length, total_length) - result.append(InputFeature(**{key: value[start:end] for key, value in input_feature.items()})) - return result - else: + if not self.max_length: return [input_feature] + strategy = self.truncation_strategy + + # Split strategy + if strategy == 'split': + results = [] + for start in range(0, len(input_feature['input_ids']), self.max_length): + end = min(start + self.max_length, len(input_feature['input_ids'])) + feat = dict(input_feature) + feat['input_ids'] = feat['input_ids'][start:end] + if 'labels' in feat: + feat['labels'] = feat['labels'][start:end] + results.append(InputFeature(**feat)) + return results + + # left/right/raise + return [self._truncate_feature(input_feature, strategy)] + def _add_attention_fields(self, input_feature: InputFeature) -> List[InputFeature]: input_ids = input_feature['input_ids'] input_feature['attention_mask'] = np.ones_like(input_ids) @@ -244,8 +255,8 @@ def _roll_labels(self, input_feature: InputFeature) -> List[InputFeature]: input_feature['labels'] = np.roll(input_feature['labels'], -1, axis=-1) return [input_feature] - def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: - messages = trajectory['messages'] + def _process_mm_messages(self, messages: List) -> List: + """Process multimodal content in a list of messages.""" new_messages = [] for message in messages: message = copy(message) @@ -265,8 +276,10 @@ def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: new_messages.append( transfer_to_standard_message(message, self.image_placeholder, self.video_placeholder, self.audio_placeholder, self.is_mm)) + return new_messages - trajectory['messages'] = new_messages + def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: + trajectory['messages'] = self._process_mm_messages(trajectory['messages']) return [trajectory] def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bool = False, **kwargs): @@ -283,7 +296,8 @@ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bo **kwargs) return inputs - def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> InputFeature: + def _encode_messages(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> InputFeature: + """Encode a single trajectory's messages into InputFeature.""" if self.use_chat_template: if add_generation_prompt: # For inference: just get input_ids with generation prompt, no labels needed @@ -306,11 +320,17 @@ def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> input_ids = self.tokenizer.encode(text) encoded = {} labels = deepcopy(input_ids) - return InputFeature( + + input_feature = InputFeature( input_ids=np.array(input_ids), labels=np.array(labels), **encoded, ) + trajectory.update(input_feature) + return trajectory + + def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> InputFeature: + return self._encode_messages(trajectory, add_generation_prompt) @staticmethod def map_col_to_row(trajectories: Dict[str, Any]): @@ -338,18 +358,69 @@ def map_row_to_col(rows: List[Union[Dict[str, Any], InputFeature]]) -> Dict[str, return columns - def batch_encode(self, - trajectories: Union[Dict[str, Any], List[Trajectory]], - add_generation_prompt: bool = False) -> List[InputFeature]: - output = [] + def _is_trajectory(self, obj: Any) -> bool: + """Check if an object is a Trajectory (has 'messages' key).""" + return isinstance(obj, Mapping) and 'messages' in obj + + def _get_trajectory_keys(self, columnar: Mapping) -> List[str]: + """Get keys whose values are lists of Trajectories in columnar format.""" + keys = [] + for k, v in columnar.items(): + if isinstance(v, list) and v and self._is_trajectory(v[0]): + keys.append(k) + return keys + + def batch_encode( + self, + trajectories: Union[Dict[str, Any], List[Trajectory]], + add_generation_prompt: bool = False, + ) -> Union[Dict[str, Any], List[InputFeature]]: + """Encode trajectories into InputFeatures. + + Args: + trajectories: Either List[Trajectory] or columnar Dict[str, List]. + For DPO, columnar format with 'positive'/'negative' keys containing + List[Trajectory] is supported. + add_generation_prompt: Whether to add generation prompt. + + Returns: + List[InputFeature] or columnar Dict[str, List[InputFeature]]. + """ _transfer = False + if isinstance(trajectories, Mapping): _transfer = True - trajectories = self.map_col_to_row(trajectories) + # Check if it has trajectory list columns (DPO format) + traj_keys = self._get_trajectory_keys(trajectories) + if traj_keys: + # DPO format: encode each trajectory list separately, keep other columns + result = {} + for key in trajectories: + if key in traj_keys: + # Encode this trajectory list + result[key] = self.batch_encode( + trajectories[key], add_generation_prompt=add_generation_prompt + ) + else: + # Keep non-trajectory columns as-is + result[key] = trajectories[key] + return result + else: + # Standard columnar format + trajectories = self.map_col_to_row(trajectories) + + # Process List[Trajectory] trajectories = self._invoke_pre_pipeline(trajectories) - for trajectory in trajectories: - output.append(self.encode(trajectory, add_generation_prompt=add_generation_prompt)) + + # Use thread pool for parallel encoding + from concurrent.futures import ThreadPoolExecutor + from functools import partial + encode_fn = partial(self.encode, add_generation_prompt=add_generation_prompt) + with ThreadPoolExecutor() as executor: + output = list(executor.map(encode_fn, trajectories)) + output = self._invoke_post_pipeline(output) + if _transfer: output = self.map_row_to_col(output) return output