Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 263 additions & 0 deletions cookbook/rl/dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
"""DPO (Direct Preference Optimization) Training via Ray.

Off-policy preference alignment: trains the model to prefer chosen responses
over rejected responses using preference data, without explicit reward modeling.

Pipeline:
1. Load preference dataset with chosen/rejected pairs.
2. Encode positive and negative separately.
3. Compute reference model log probabilities (frozen).
4. Train policy model using DPO loss.

Architecture (Ray):
┌─────────────────────────────────────────────────────────────────┐
│ Driver (CPU) │
│ dataloader ──► batched preference pairs │
│ ref_model.forward_only() ──► reference log probs │
│ policy_model.forward_backward() ──► DPO loss + gradient │
└─────────────────────────────────────────────────────────────────┘
│ │ │
DataLoader RefModel (frozen) PolicyModel (trainable)
(ref GPUs) (policy GPUs)

DPO data format (after preprocessing):
- positive: List[Trajectory] - chosen responses
- negative: List[Trajectory] - rejected responses

For SimPO/ORPO variants that don't require a reference model,
set REF_MODEL_GPUS=0 to skip reference model computation.

Environment variables (all optional):
MODEL_ID – (default: ms://Qwen/Qwen3.5-4B)
DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji)
MODEL_GPUS – GPUs for policy model (default: 4)
REF_MODEL_GPUS – GPUs for reference model (default: 4, 0 to disable)
BATCH_SIZE – global batch size (preference pairs) (default: 8)
MICRO_BATCH_SIZE – per-device micro batch size (default: 2)
Comment on lines +35 to +36
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default BATCH_SIZE and MICRO_BATCH_SIZE values specified in the docstring comments here (BATCH_SIZE default: 8, MICRO_BATCH_SIZE default: 2) are inconsistent with their actual default values set in the code (lines 76-77, BATCH_SIZE: 4, MICRO_BATCH_SIZE: 4). Please update the docstring to reflect the correct default values.

MAX_STEPS – total optimization steps (default: 1000)
LR – learning rate (default: 5e-6)
DPO_BETA – DPO temperature parameter (default: 0.1)
LOSS_TYPE – DPO variant (sigmoid/hinge/ipo/simpo/orpo/cpo) (default: sigmoid)
SAVE_STEPS – checkpoint save interval (default: 100)
MAX_LENGTH – max sequence length (default: 2048)

Dataset field mapping (for custom datasets):
PROMPT_KEY – key for prompt field (default: 'prompt')
CHOSEN_KEY – key for chosen response (default: 'answer_zh')
REJECTED_KEY – key for rejected response (default: 'answer_en')
Comment on lines +44 to +47
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring mentions PROMPT_KEY, CHOSEN_KEY, and REJECTED_KEY as configurable environment variables for custom datasets. However, the EmojiDPOProcessor used in create_dpo_dataset (lines 94-97) does not read these values from environment variables or accept them as init_args. This creates a discrepancy between the documented configurability and the actual implementation, making these environment variables ineffective for EmojiDPOProcessor.

SYSTEM_PROMPT – system prompt to prepend (default: 'You are a helpful assistant.')
"""

import os
from typing import Any, Dict, List, Optional

from peft import LoraConfig

import twinkle
from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger
from twinkle.data_format import Trajectory
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.loss import CPOLoss, DPOLoss, ORPOLoss, SimPOLoss
from twinkle.metric import DPOMetric
from twinkle.model import TransformersModel
from twinkle.preprocessor import EmojiDPOProcessor
from twinkle.processor import InputProcessor

logger = get_logger()

# ── Configuration ─────────────────────────────────────────────────────────────
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct')
DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji')

MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
REF_MODEL_GPUS = int(os.environ.get('REF_MODEL_GPUS', 4))
NUM_GPUS = MODEL_GPUS + REF_MODEL_GPUS

BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) # Number of preference pairs
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 4))
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 4))
LEARNING_RATE = float(os.environ.get('LR', 5e-6)) # TRL default for DPO is 5e-7 to 5e-6
DPO_BETA = float(os.environ.get('DPO_BETA', 0.1))
SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization
LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo, simpo, orpo, cpo
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 200))
MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048))
ADAPTER_NAME = 'default'
SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.')


def create_dpo_dataset():
"""Create DPO dataset with positive/negative format."""
dataset = Dataset(DatasetMeta(DATASET_ID))
dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH)
dataset.map(
EmojiDPOProcessor,
init_args={
'system': SYSTEM_PROMPT,
}
)
# DPO preprocessor returns {'positive': [...], 'negative': [...]}
# batch_encode handles this format automatically
dataset.encode(load_from_cache_file=True)
return dataset


def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Prepare DPO batch: reorganize batch for training with DP-safe interleaving.

Args:
batch: List of rows, each with 'positive' and 'negative' InputFeatures
and other fields (question, etc.)

Returns:
List interleaved as [pos_1, neg_1, pos_2, neg_2, ...] to ensure each DP
worker gets complete positive/negative pairs after slicing.
Each item contains all original fields plus the InputFeature fields.
"""
result = []

for row in batch:
# Get base fields (excluding positive/negative)
base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')}

# Positive sample: merge base fields with positive InputFeature
pos_sample = {**base_fields, **row['positive']}
# Negative sample: merge base fields with negative InputFeature
neg_sample = {**base_fields, **row['negative']}

# Interleave: [pos, neg] per pair for DP-safe slicing
result.append(pos_sample)
result.append(neg_sample)

return result


# ── Loss Factory ──────────────────────────────────────────────────────────────

def create_loss(loss_type: str, beta: float, sft_weight: float = 0.0, reference_free: bool = False):
"""Create the appropriate loss function based on configuration."""
if loss_type == 'simpo':
return SimPOLoss(beta=beta, gamma=0.5)
elif loss_type == 'orpo':
return ORPOLoss(lambda_orpo=beta)
elif loss_type == 'cpo':
return CPOLoss(beta=beta, bc_coef=1.0)
else:
# Standard DPO variants: sigmoid, hinge, ipo
return DPOLoss(
beta=beta,
loss_type=loss_type,
reference_free=reference_free,
sft_weight=sft_weight,
)


# ── Main Training Loop ────────────────────────────────────────────────────────

def main():
# Set up device groups
device_groups = [
DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
DeviceGroup(name='reference', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'),
]

policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups)

# ── DataLoader Setup ──────────────────────────────────────────────────────
dataloader = DataLoader(
dataset=create_dpo_dataset,
batch_size=BATCH_SIZE,
min_batch_size=BATCH_SIZE,
device_mesh=policy_mesh,
)
length = len(dataloader)

# ── Policy Model Setup ────────────────────────────────────────────────────
lora_config = LoraConfig(
target_modules='all-linear',
r=16,
lora_alpha=32,
lora_dropout=0.05,
)

policy_model = TransformersModel(
model_id=MODEL_ID,
device_mesh=policy_mesh,
remote_group='policy',
)
MAX_STEPS = len(dataloader)
policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01)
policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1)

# Determine if we need reference model based on loss type
reference_free = LOSS_TYPE in ['simpo', 'orpo', 'cpo']

# Set up loss function and metrics
loss_fn = create_loss(LOSS_TYPE, DPO_BETA, sft_weight=SFT_WEIGHT, reference_free=False)
policy_model.set_loss(loss_fn)
policy_model.add_metric(DPOMetric, beta=DPO_BETA)
policy_model.set_processor(InputProcessor)
policy_model.set_template('Template', model_id=MODEL_ID)

# ── Reference Model Setup ─────────────────────────────────────────────────
ref_model = None
if not reference_free:
ref_mesh = DeviceMesh.from_sizes(world_size=REF_MODEL_GPUS, dp_size=REF_MODEL_GPUS)
ref_model = TransformersModel(
model_id=MODEL_ID,
device_mesh=ref_mesh,
remote_group='reference',
)
ref_model.set_processor(InputProcessor)
ref_model.set_template('Template', model_id=MODEL_ID)
logger.info('Reference model initialized for DPO training')
else:
logger.info(f'Training without reference model (loss_type={LOSS_TYPE})')

optim_step = 0
logger.info(get_device_placement())
logger.info(f'Starting DPO training: loss_type={LOSS_TYPE}, beta={DPO_BETA}')

# ── Training Loop ─────────────────────────────────────────────────────────
for batch in dataloader:
if optim_step >= MAX_STEPS:
break

# batch is List[Dict] with 'positive' and 'negative' keys
dpo_batch = prepare_dpo_batch(batch)

# Get reference outputs (lazy - not collected to driver)
ref_outputs = None
if ref_model is not None:
ref_outputs = ref_model.forward_only(inputs=dpo_batch)

# Forward-backward pass with DPO loss
# ref_outputs is passed to loss which extracts logps internally
policy_model.forward_backward(
inputs=dpo_batch,
ref_outputs=ref_outputs,
)

# Gradient clipping and optimizer step
policy_model.clip_grad_and_step()
optim_step += 1

# Logging
if optim_step % 10 == 0:
metrics = policy_model.calculate_metric(is_training=True)
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metrics}')

# Checkpointing
if optim_step % SAVE_STEPS == 0:
policy_model.save(f'dpo-checkpoint-{optim_step}')

# ── Save Final Checkpoint ─────────────────────────────────────────────────
logger.info(f'Training completed. Total steps: {optim_step}')
policy_model.save('dpo-final-checkpoint')


if __name__ == '__main__':
main()
84 changes: 84 additions & 0 deletions cookbook/rl/dpo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#!/bin/bash
# DPO Training Script for Ray Mode
#
# This script launches DPO (Direct Preference Optimization) training using Ray
# for distributed training across multiple GPUs.
#
# Usage:
# ./dpo.sh # Default settings (8 GPUs: 4 policy + 4 ref)
# ./dpo.sh simpo # Use SimPO (no reference model needed)
# ./dpo.sh orpo # Use ORPO (no reference model needed)
#
# Environment variables can be set to customize training:
# MODEL_ID - Model to train (default: ms://Qwen/Qwen3.5-4B)
# DATASET_ID - Preference dataset (default: UltraFeedback)
# MODEL_GPUS - GPUs for policy model (default: 4)
# REF_MODEL_GPUS - GPUs for reference model (default: 4)
# USE_REFERENCE_MODEL - Use reference model (default: 1)
# BATCH_SIZE - Global batch size (default: 8)
# MAX_STEPS - Training steps (default: 1000)
# LR - Learning rate (default: 5e-6)
# DPO_BETA - DPO beta parameter (default: 0.1)
# LOSS_TYPE - Loss variant: sigmoid/hinge/ipo/simpo/orpo/cpo (default: sigmoid)

set -e

# Parse command line argument for loss type
LOSS_TYPE_ARG=${1:-sigmoid}

# Set default environment variables if not already set
export MODEL_ID=${MODEL_ID:-"ms://Qwen/Qwen3.5-4B"}
export DATASET_ID=${DATASET_ID:-"ms://argilla/ultrafeedback-binarized-preferences-cleaned"}
export MODEL_GPUS=${MODEL_GPUS:-4}
export BATCH_SIZE=${BATCH_SIZE:-8}
export MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-2}
export MAX_STEPS=${MAX_STEPS:-1000}
export LR=${LR:-5e-6}
export DPO_BETA=${DPO_BETA:-0.1}
export SAVE_STEPS=${SAVE_STEPS:-100}
export MAX_LENGTH=${MAX_LENGTH:-2048}

# Set loss type from argument or environment
export LOSS_TYPE=${LOSS_TYPE:-$LOSS_TYPE_ARG}

# Reference-free losses don't need reference model
if [[ "$LOSS_TYPE" == "simpo" || "$LOSS_TYPE" == "orpo" || "$LOSS_TYPE" == "cpo" ]]; then
export USE_REFERENCE_MODEL=${USE_REFERENCE_MODEL:-0}
export REF_MODEL_GPUS=${REF_MODEL_GPUS:-0}
echo "Using $LOSS_TYPE loss (reference-free)"
else
export USE_REFERENCE_MODEL=${USE_REFERENCE_MODEL:-1}
export REF_MODEL_GPUS=${REF_MODEL_GPUS:-4}
echo "Using $LOSS_TYPE loss with reference model"
fi

# Calculate total GPUs
if [[ "$USE_REFERENCE_MODEL" == "1" && "$REF_MODEL_GPUS" -gt 0 ]]; then
TOTAL_GPUS=$((MODEL_GPUS + REF_MODEL_GPUS))
else
TOTAL_GPUS=$MODEL_GPUS
fi

echo "=========================================="
echo "DPO Training Configuration"
echo "=========================================="
echo "Model: $MODEL_ID"
echo "Dataset: $DATASET_ID"
echo "Loss Type: $LOSS_TYPE"
echo "DPO Beta: $DPO_BETA"
echo "Policy GPUs: $MODEL_GPUS"
echo "Reference GPUs: $REF_MODEL_GPUS"
echo "Total GPUs: $TOTAL_GPUS"
echo "Batch Size: $BATCH_SIZE"
echo "Micro Batch Size: $MICRO_BATCH_SIZE"
echo "Max Steps: $MAX_STEPS"
echo "Learning Rate: $LR"
echo "Max Length: $MAX_LENGTH"
echo "Save Steps: $SAVE_STEPS"
echo "=========================================="

# Get script directory
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"

# Run training
python "$SCRIPT_DIR/dpo.py"
2 changes: 1 addition & 1 deletion cookbook/transformers/fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from twinkle.preprocessor import SelfCognitionProcessor

# Construct a device_mesh, dp=2
device_mesh = DeviceMesh.from_sizes(dp_size=2)
device_mesh = DeviceMesh.from_sizes(dp_size=8)
# use torchrun mode
twinkle.initialize(mode='local', global_device_mesh=device_mesh)

Expand Down
6 changes: 4 additions & 2 deletions docs/source_en/Components/Data Format/Trajectory.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ The raw data structure input to Template after dataset ETL is `Trajectory` (traj
```python
class Trajectory(TypedDict, total=False):
messages: List[Message]
extend_message: List[Tuple[str, List[Message]]]
tools: List[Tool]
user_data: List[Tuple[str, Any]]
```

- messages: A list of Message messages, representing the multi-turn conversations actually conducted by the model, usually alternating between `user` and `assistant`.
- extend_message: In training such as DPO and PPO, unusable trajectories or low-score trajectories are usually needed, which will be placed in extend_message
- tools: A list of all available tools for the model in this call
- user_data: User-defined data, such as labels in KTO training

For preference alignment training like DPO, preprocessors return `{'positive': List[Trajectory], 'negative': List[Trajectory]}` format.

Trajectory is the standard interface for all dataset preprocessing outputs and template inputs in Twinkle. The format conversion goes from the original dataset to Trajectory, and then to InputFeature.
6 changes: 4 additions & 2 deletions docs/source_zh/组件/数据格式/Trajectory.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
```python
class Trajectory(TypedDict, total=False):
messages: List[Message]
extend_message: List[Tuple[str, List[Message]]]
tools: List[Tool]
user_data: List[Tuple[str, Any]]
```

- messages: Message消息的列表,代表模型实际进行的多轮对话,通常是`user`和`assistant`交替出现。
- extend_message: 在DPO、PPO等训练中通常需要不可用轨迹,或低分轨迹,该轨迹会放在extend_message中
- tools: 模型在本次调用中的所有可用工具列表
- user_data: 用户自定义数据,如KTO训练中的label

对于DPO等偏好对齐训练,预处理器返回`{'positive': List[Trajectory], 'negative': List[Trajectory]}`格式。

Trajectory是twinkle中所有数据集预处理输出,模板输入的标准接口。格式转换为由原始数据集转换为Trajectory,再到InputFeature。
1 change: 0 additions & 1 deletion src/twinkle/data_format/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,5 @@

class Trajectory(TypedDict, total=False):
messages: List[Message]
extend_message: List[Tuple[str, List[Message]]]
tools: List[Tool]
user_data: List[Tuple[str, Any]]
Loading
Loading