Skip to content

Commit ef12c2e

Browse files
tyler-griggsclaude
andcommitted
[skyrl-train] Add SFT support via forward_backward(loss_fn="cross_entropy")
Enables supervised fine-tuning using the Tinker-compatible API. Changes: - ppo_utils.py: Add CROSS_ENTROPY loss type and cross_entropy_loss() function - worker.py: Add SFT code path that returns per-token logprobs and elementwise_loss - worker_dispatch.py: Add loss_fn and loss_fn_config params to forward_backward() - dispatch.py: Update MeshDispatch to pass through kwargs (loss_fn, loss_fn_config) - replay_buffer.py: Make action_log_probs optional in Experience - worker_utils.py: Use .get() for optional fields; handle non-scalar metrics New: - examples/sft/: Minimal SFT example demonstrating the API This enables PR #871 (SkyRL-train backend for Tinker) to return proper per-token values instead of placeholder data. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 5342484 commit ef12c2e

8 files changed

Lines changed: 425 additions & 84 deletions

File tree

skyrl-train/examples/sft/README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# SFT (Supervised Fine-Tuning) Example
2+
3+
This example demonstrates how to use SkyRL's training infrastructure for supervised fine-tuning (SFT).
4+
5+
## Usage
6+
7+
```bash
8+
uv run --isolated --extra vllm python examples/sft/sft_trainer.py
9+
```
10+
11+
## How It Works
12+
13+
1. **Load Dataset**: Uses a small subset of the Alpaca dataset
14+
2. **Tokenize**: Converts instruction/output pairs into token sequences
15+
3. **Create Batch**: Builds a `TrainingInputBatch` with:
16+
- `sequences`: Token IDs (left-padded)
17+
- `attention_mask`: 1 for real tokens, 0 for padding
18+
- `loss_mask`: 1 for response tokens to compute loss on
19+
4. **Train**: Calls `forward_backward(loss_fn="cross_entropy")` for SFT
20+
21+
## Loss Functions
22+
23+
The `loss_fn` parameter supports:
24+
25+
| Loss Function | Use Case |
26+
|--------------|----------|
27+
| `cross_entropy` | Supervised fine-tuning |
28+
| `regular` / `ppo` | PPO with clipping |
29+
| `gspo` | Group Sequence Policy Optimization |
30+
| ... | See `PolicyLossRegistry` for all options |
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
"""
2+
Minimal SFT (Supervised Fine-Tuning) trainer using WorkerDispatch.
3+
4+
This script demonstrates SFT using the same forward_backward interface as RL training,
5+
but with loss_fn="cross_entropy" to compute simple negative log-likelihood loss.
6+
7+
Usage:
8+
# First, make sure you have Ray installed and a GPU available
9+
uv run --isolated --extra vllm python examples/sft/sft_trainer.py
10+
11+
This example:
12+
1. Loads a small subset of the Alpaca dataset
13+
2. Tokenizes examples into prompt + completion format
14+
3. Uses WorkerDispatch.forward_backward(loss_fn="cross_entropy") for SFT
15+
4. Demonstrates the Tinker-compatible API for supervised fine-tuning
16+
"""
17+
18+
import ray
19+
import hydra
20+
import torch
21+
from datasets import load_dataset
22+
from loguru import logger
23+
from omegaconf import DictConfig
24+
from transformers import AutoTokenizer
25+
from tqdm import tqdm
26+
27+
from ray.util.placement_group import placement_group
28+
29+
from skyrl_train.training_batch import TrainingInputBatch
30+
from skyrl_train.entrypoints.main_base import config_dir
31+
from skyrl_train.workers.worker_dispatch import WorkerDispatch
32+
from skyrl_train.workers.worker import PPORayActorGroup
33+
from skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker
34+
from skyrl_train.utils.utils import initialize_ray, validate_cfg
35+
from skyrl_train.utils import get_ray_pg_ready_with_timeout
36+
37+
38+
def get_sft_config() -> DictConfig:
39+
"""Get config with SFT-specific overrides."""
40+
with hydra.initialize_config_dir(config_dir=config_dir):
41+
cfg = hydra.compose(config_name="ppo_base_config")
42+
43+
# Use a small model for testing
44+
cfg.trainer.policy.model.path = "Qwen/Qwen2.5-0.5B-Instruct"
45+
cfg.trainer.placement.policy_num_gpus_per_node = 1
46+
cfg.generator.inference_engine_tensor_parallel_size = 1
47+
cfg.trainer.logger = "console"
48+
cfg.trainer.micro_train_batch_size_per_gpu = 2
49+
50+
validate_cfg(cfg)
51+
return cfg
52+
53+
54+
def tokenize_sft_example(example: dict, tokenizer, max_length: int = 512) -> dict | None:
55+
"""Tokenize a single SFT example (instruction + output).
56+
57+
Returns dict with input_ids, attention_mask, num_actions (response length),
58+
or None if the example was fully truncated.
59+
"""
60+
instruction = example.get("instruction", "")
61+
input_text = example.get("input", "")
62+
output = example.get("output", "")
63+
64+
# Combine instruction and input
65+
if input_text:
66+
prompt = f"{instruction}\n\n{input_text}"
67+
else:
68+
prompt = instruction
69+
70+
# Tokenize prompt and full sequence separately to find boundary
71+
prompt_tokens = tokenizer(prompt, add_special_tokens=True, truncation=True, max_length=max_length)
72+
full_text = f"{prompt}\n\n{output}"
73+
full_tokens = tokenizer(full_text, add_special_tokens=True, truncation=True, max_length=max_length)
74+
75+
prompt_len = len(prompt_tokens["input_ids"])
76+
full_len = len(full_tokens["input_ids"])
77+
num_actions = full_len - prompt_len
78+
79+
# Skip examples where response was fully truncated
80+
if num_actions <= 0:
81+
return None
82+
83+
return {
84+
"input_ids": full_tokens["input_ids"],
85+
"attention_mask": full_tokens["attention_mask"],
86+
"num_actions": num_actions,
87+
}
88+
89+
90+
def collate_sft_batch(examples: list, tokenizer) -> TrainingInputBatch:
91+
"""Collate tokenized examples into a TrainingInputBatch.
92+
93+
Creates the batch format expected by forward_backward with cross_entropy loss:
94+
- sequences: [batch_size, seq_len] - token IDs (left-padded)
95+
- attention_mask: [batch_size, seq_len] - 1 for real tokens, 0 for padding
96+
- loss_mask: [batch_size, num_actions] - 1 for tokens to compute loss on
97+
"""
98+
max_len = max(len(ex["input_ids"]) for ex in examples)
99+
max_num_actions = max(ex["num_actions"] for ex in examples)
100+
101+
sequences = []
102+
attention_masks = []
103+
loss_masks = []
104+
105+
for ex in examples:
106+
pad_len = max_len - len(ex["input_ids"])
107+
# Left-pad sequences (SkyRL convention)
108+
sequences.append([tokenizer.pad_token_id] * pad_len + ex["input_ids"])
109+
attention_masks.append([0] * pad_len + ex["attention_mask"])
110+
# Per-example loss_mask: 0s for padding, 1s only for this example's response tokens
111+
action_pad = max_num_actions - ex["num_actions"]
112+
loss_masks.append([0] * action_pad + [1] * ex["num_actions"])
113+
114+
batch = TrainingInputBatch(
115+
{
116+
"sequences": torch.tensor(sequences, dtype=torch.long),
117+
"attention_mask": torch.tensor(attention_masks, dtype=torch.long),
118+
"loss_mask": torch.tensor(loss_masks, dtype=torch.long),
119+
}
120+
)
121+
batch.metadata = {"response_length": max_num_actions}
122+
return batch
123+
124+
125+
def main():
126+
"""Run a minimal SFT training loop."""
127+
cfg = get_sft_config()
128+
initialize_ray(cfg)
129+
130+
logger.info("Loading tokenizer...")
131+
tokenizer = AutoTokenizer.from_pretrained(cfg.trainer.policy.model.path)
132+
if tokenizer.pad_token is None:
133+
tokenizer.pad_token = tokenizer.eos_token
134+
135+
logger.info("Loading dataset...")
136+
# Use a small subset for demonstration
137+
dataset = load_dataset("yahma/alpaca-cleaned", split="train[:100]")
138+
139+
logger.info("Tokenizing dataset...")
140+
tokenized = [tokenize_sft_example(ex, tokenizer) for ex in dataset]
141+
tokenized = [ex for ex in tokenized if ex is not None] # Filter truncated
142+
logger.info(f"Kept {len(tokenized)} examples after filtering truncated")
143+
144+
logger.info("Initializing policy worker...")
145+
num_gpus = cfg.trainer.placement.policy_num_gpus_per_node
146+
pg = placement_group([{"GPU": num_gpus, "CPU": num_gpus}], strategy="PACK")
147+
get_ray_pg_ready_with_timeout(pg, timeout=30)
148+
149+
actor_group = PPORayActorGroup(
150+
cfg,
151+
num_nodes=1,
152+
num_gpus_per_node=num_gpus,
153+
ray_actor_type=PolicyWorker,
154+
pg=pg,
155+
num_gpus_per_actor=0.75,
156+
colocate_all=False,
157+
sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
158+
)
159+
ray.get(actor_group.async_init_model(cfg.trainer.policy.model.path))
160+
161+
dispatch = WorkerDispatch(cfg, policy_actor_group=actor_group)
162+
163+
# Training loop
164+
batch_size = 4
165+
num_steps = 10
166+
logger.info(f"Starting SFT training for {num_steps} steps...")
167+
168+
for step in tqdm(range(num_steps)):
169+
# Create batch from tokenized examples
170+
start_idx = (step * batch_size) % len(tokenized)
171+
batch_examples = tokenized[start_idx : start_idx + batch_size]
172+
if len(batch_examples) < batch_size:
173+
batch_examples = tokenized[:batch_size] # Wrap around
174+
175+
batch = collate_sft_batch(batch_examples, tokenizer)
176+
177+
# Forward-backward with cross-entropy loss (Tinker API style)
178+
metrics = dispatch.forward_backward("policy", batch, loss_fn="cross_entropy")
179+
180+
# Optimizer step
181+
grad_norm = dispatch.optim_step("policy")
182+
183+
if step % 5 == 0:
184+
loss_val = metrics.get("final_loss", metrics.get("loss", "N/A"))
185+
logger.info(f"Step {step}: loss={loss_val:.4f}, grad_norm={grad_norm}")
186+
187+
logger.info("SFT training complete!")
188+
ray.shutdown()
189+
190+
191+
if __name__ == "__main__":
192+
main()

skyrl-train/skyrl_train/dataset/replay_buffer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class Experience:
5757
"""
5858

5959
sequences: Integer[torch.Tensor, "batch seq_len"]
60-
action_log_probs: Float[torch.Tensor, "batch response_len"]
60+
action_log_probs: Optional[Float[torch.Tensor, "batch response_len"]]
6161
base_action_log_probs: Optional[Float[torch.Tensor, "batch response_len"]]
6262
values: Optional[Float[torch.Tensor, "batch response_len"]]
6363
returns: Optional[Float[torch.Tensor, "batch response_len"]]
@@ -74,7 +74,8 @@ class Experience:
7474
@torch.no_grad()
7575
def to_device(self, device: torch.device) -> None:
7676
self.sequences = to(self.sequences, device)
77-
self.action_log_probs = to(self.action_log_probs, device)
77+
if self.action_log_probs is not None:
78+
self.action_log_probs = to(self.action_log_probs, device)
7879
if self.base_action_log_probs is not None:
7980
self.base_action_log_probs = to(self.base_action_log_probs, device)
8081
if self.values is not None:
@@ -94,7 +95,8 @@ def to_device(self, device: torch.device) -> None:
9495

9596
def pin_memory(self):
9697
self.sequences = pin_memory(self.sequences)
97-
self.action_log_probs = pin_memory(self.action_log_probs)
98+
if self.action_log_probs is not None:
99+
self.action_log_probs = pin_memory(self.action_log_probs)
98100
if self.base_action_log_probs is not None:
99101
self.base_action_log_probs = pin_memory(self.base_action_log_probs)
100102
if self.values is not None:

skyrl-train/skyrl_train/distributed/dispatch.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ class MeshDispatch(Dispatch):
121121
"""
122122

123123
@classmethod
124-
def dispatch(cls, actor_infos: List[ActorInfo], method: str, data: TrainingInputBatch) -> List[ObjectRef]:
124+
def dispatch(
125+
cls, actor_infos: List[ActorInfo], method: str, data: TrainingInputBatch, **kwargs
126+
) -> List[ObjectRef]:
125127
assert len(actor_infos) > 0, "actor_infos must be a non-empty list"
126128
object_refs = []
127129
dp_size = actor_infos[0].rank.dp_size
@@ -134,7 +136,7 @@ def dispatch(cls, actor_infos: List[ActorInfo], method: str, data: TrainingInput
134136
for actor_info in actor_infos:
135137
# index into tensordict to get the correct data to send
136138
data_to_send = data_chunks[actor_info.rank.dp]
137-
object_refs.append(getattr(actor_info.handle, method).remote(data_to_send))
139+
object_refs.append(getattr(actor_info.handle, method).remote(data_to_send, **kwargs))
138140
return object_refs
139141

140142
@classmethod
@@ -159,24 +161,14 @@ def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
159161

160162
@classmethod
161163
def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
162-
sig = inspect.signature(cls.dispatch)
163-
# pass dummy actor_infos and method_name
164-
bound_args = sig.bind([], "dummy", *args, **kwargs)
165-
bound_args.apply_defaults()
166-
167-
# Check if there are any extra arguments
168-
if len(bound_args.arguments) > 3: # data, actor_infos, method_name
169-
# remove actor_infos and method_name - not added by user
170-
bound_args.arguments.pop("actor_infos")
171-
bound_args.arguments.pop("method")
172-
raise ValueError(f"MeshDispatch only accepts 'data' as an argument, got extra args: {bound_args.arguments}")
173-
174-
data = bound_args.arguments.get("data")
164+
# First positional arg must be data (TrainingInputBatch)
165+
if not args:
166+
raise ValueError("MeshDispatch requires 'data' as first positional argument")
167+
data = args[0]
175168
if not isinstance(data, TrainingInputBatch):
176-
raise ValueError(f"For MeshDispatch, `data` entry should be a `TrainingInput`, got {data}")
177-
args = (data,)
178-
kwargs = {}
179-
return args, kwargs
169+
raise ValueError(f"For MeshDispatch, `data` entry should be a `TrainingInputBatch`, got {type(data)}")
170+
# Pass through data as positional arg, and any kwargs (e.g., loss_fn, loss_fn_config)
171+
return (data,), kwargs
180172

181173

182174
class PassThroughDispatch(Dispatch):

skyrl-train/skyrl_train/utils/ppo_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ class PolicyLossType(StrEnum):
471471
CLIP_COV = "clip_cov"
472472
KL_COV = "kl_cov"
473473
SAPO = "sapo"
474+
CROSS_ENTROPY = "cross_entropy"
474475

475476

476477
class PolicyLossRegistry(BaseFunctionRegistry):
@@ -500,6 +501,7 @@ def repopulate_registry(cls):
500501
"clip_cov": [PolicyLossType.CLIP_COV, compute_policy_loss_clip_cov],
501502
"kl_cov": [PolicyLossType.KL_COV, compute_policy_loss_kl_cov],
502503
"sapo": [PolicyLossType.SAPO, sapo_policy_loss],
504+
"cross_entropy": [PolicyLossType.CROSS_ENTROPY, cross_entropy_loss],
503505
}
504506

505507
for pl_name, (pl_type, pl_func) in pl_types.items():
@@ -878,6 +880,48 @@ def compute_policy_loss_kl_cov(
878880
return pg_loss, 0.0
879881

880882

883+
@register_policy_loss(PolicyLossType.CROSS_ENTROPY)
884+
def cross_entropy_loss(
885+
log_probs: torch.Tensor,
886+
old_log_probs: torch.Tensor,
887+
advantages: torch.Tensor,
888+
config: DictConfig,
889+
loss_mask: Optional[torch.Tensor] = None,
890+
rollout_logprobs: Optional[torch.Tensor] = None,
891+
) -> Tuple[torch.Tensor, float]:
892+
"""
893+
Cross-entropy loss for supervised fine-tuning (SFT).
894+
895+
This loss function computes the negative log-likelihood of the target tokens,
896+
ignoring the old_log_probs and advantages which are only used for RL.
897+
898+
The loss is computed as: -log_probs * loss_mask, summed over all tokens.
899+
This matches Tinker's cross_entropy semantics where the loss is a simple sum.
900+
901+
Args:
902+
log_probs: Log probabilities from the model for each token
903+
old_log_probs: Ignored (only used for RL losses)
904+
advantages: Ignored (only used for RL losses)
905+
config: Algorithm configuration
906+
loss_mask: Mask indicating which tokens to include in loss (1=include, 0=ignore)
907+
rollout_logprobs: Ignored (only used for RL losses)
908+
909+
Returns:
910+
Tuple of (loss, clip_ratio) where clip_ratio is always 0.0 for SFT
911+
"""
912+
# Simple negative log-likelihood: -log p(token)
913+
elementwise_loss = -log_probs
914+
915+
# Apply loss mask and sum (matching Tinker's SUM reduction semantics)
916+
if loss_mask is not None:
917+
loss = (elementwise_loss * loss_mask).sum()
918+
else:
919+
loss = elementwise_loss.sum()
920+
921+
# No clipping in cross-entropy loss
922+
return loss, 0.0
923+
924+
881925
def reduce_loss(
882926
loss: torch.Tensor,
883927
loss_mask: Optional[torch.Tensor],

0 commit comments

Comments
 (0)