Skip to content
23 changes: 7 additions & 16 deletions skyrl-train/skyrl_train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,26 +1073,17 @@ def train_critic_and_policy(self, data: TrainingInputBatch):
"""
Run the training step for the policy and critic models.

For Megatron strategy: uses ppo_train (training loop inside worker)
For FSDP strategy: uses forward_backward + optim_step (training loop in trainer)
Uses forward_backward + optim_step for both FSDP and Megatron strategies.
"""
data.metadata["global_step"] = self.global_step
critic_status = None

if self.cfg.trainer.strategy == "megatron":
# Megatron: training loop inside worker via ppo_train
if self.has_critic:
with Timer("critic_train", self.all_timings):
critic_status = self.dispatch.ppo_train("critic", data)
with Timer("policy_train", self.all_timings):
policy_status = self.dispatch.ppo_train("policy", data)
else:
# FSDP: training loop in trainer via forward_backward + optim_step
if self.has_critic:
with Timer("critic_train", self.all_timings):
critic_status = self._execute_training_step("critic", data)
with Timer("policy_train", self.all_timings):
policy_status = self._execute_training_step("policy", data)
# Unified training interface for both FSDP and Megatron
if self.has_critic:
with Timer("critic_train", self.all_timings):
critic_status = self._execute_training_step("critic", data)
with Timer("policy_train", self.all_timings):
policy_status = self._execute_training_step("policy", data)

# Update metrics
if critic_status is not None:
Expand Down
6 changes: 0 additions & 6 deletions skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,6 @@ def init_model(self, model_path, num_training_steps: int = None):

self._is_lora = self.cfg.trainer.policy.model.lora.rank > 0

# Update per-gpu mini batch size based on device mesh
self._normalize_mini_batch_size()

model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
init_context = get_init_weight_context_manager(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh
Expand Down Expand Up @@ -276,9 +273,6 @@ def init_model(self, model_path, num_training_steps: int = None):
strategy.setup_distributed()
self.strategy = strategy

# Update per-gpu mini batch size based on device mesh
self._normalize_mini_batch_size()

model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
init_context = get_init_weight_context_manager(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh
Expand Down
78 changes: 67 additions & 11 deletions skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import Optional, Callable, List
from typing import Optional, Callable, List, Dict, Any
from functools import partial
import torch
import torch.nn as nn
from omegaconf import OmegaConf

from megatron.core.pipeline_parallel import get_forward_backward_func
import megatron.core.parallel_state as mpu
from megatron.core.distributed import finalize_model_grads

from skyrl_train.distributed.megatron.model_utils import from_parallel_logits_to_logprobs, vocab_parallel_entropy
from skyrl_train.distributed.megatron.megatron_utils import get_model_config
from skyrl_train.utils.ppo_utils import compute_approx_kl, masked_mean
from skyrl_train.utils.ppo_utils import compute_approx_kl, masked_mean, PolicyLossRegistry

from skyrl_train.distributed.megatron.megatron_utils import (
make_batch_generator,
Expand Down Expand Up @@ -171,6 +172,8 @@ def forward_backward_mini_batch(
seq_len: int,
micro_batch_size: int,
temperature: float = 1.0,
loss_fn: Optional[str] = None,
loss_fn_config: Optional[Dict[str, Any]] = None,
) -> List[dict]:
"""
Run forward-backward over a full mini-batch consisting of multiple micro-batches.
Expand All @@ -183,12 +186,27 @@ def forward_backward_mini_batch(
seq_len: Sequence length (tokens) per sample (assumed same across micros after padding).
micro_batch_size: Micro-batch size per forward pass.
temperature: Optional temperature for logits scaling.
loss_fn: Optional loss function name (e.g., "cross_entropy", "ppo").
If provided, overrides the config's policy_loss_type.
loss_fn_config: Optional config overrides for the loss function.

Returns:
List[dict]: one metrics dict per micro-batch in order.
"""
forward_backward_func = get_forward_backward_func()

# Resolve loss function
resolved_loss_name = loss_fn if loss_fn is not None else self.cfg.trainer.algorithm.policy_loss_type
if loss_fn is not None:
current_loss_fn = PolicyLossRegistry.get(loss_fn)
else:
current_loss_fn = self.policy_loss_fn

# Build config for loss function, applying any overrides
loss_config = self.cfg.trainer.algorithm
if loss_fn_config is not None:
loss_config = OmegaConf.merge(loss_config, OmegaConf.create(loss_fn_config))

def loss_func(logits, data):
sequences = data["sequences"]
num_actions = data["num_actions"]
Expand All @@ -197,6 +215,7 @@ def loss_func(logits, data):
advantages = data["advantages"]
loss_mask = data["loss_mask"]
rollout_action_logprobs = data["rollout_action_logprobs"]
action_mask = data.get("action_mask")

tp_grp = mpu.get_tensor_model_parallel_group()
tp_rank = mpu.get_tensor_model_parallel_rank()
Expand All @@ -218,37 +237,74 @@ def loss_func(logits, data):

action_log_probs = token_logprobs[:, -num_actions:]

# policy loss should be calculated based on the selected token logprobs
policy_loss, clip_ratio = self.policy_loss_fn(
policy_loss, clip_ratio = current_loss_fn(
action_log_probs,
old_action_log_probs,
advantages,
config=self.cfg.trainer.algorithm,
config=loss_config,
loss_mask=loss_mask,
rollout_logprobs=rollout_action_logprobs,
)

with torch.set_grad_enabled(self.cfg.trainer.algorithm.use_entropy_loss):
# SFT path: cross_entropy loss (negative log likelihood)
if resolved_loss_name == "cross_entropy":
loss = policy_loss

# Compute elementwise loss for Tinker API (per-token NLL)
with torch.no_grad():
elementwise_loss = -action_log_probs
if loss_mask is not None:
elementwise_loss = elementwise_loss * loss_mask

# Build per-sequence loss_fn_outputs
batch_size = action_log_probs.shape[0]
loss_fn_outputs = []
for i in range(batch_size):
if action_mask is not None:
valid_len = int(action_mask[i].sum().item())
elif loss_mask is not None:
valid_len = int(loss_mask[i].sum().item())
else:
valid_len = action_log_probs.shape[1]

start = max(action_log_probs.shape[1] - valid_len, 0)
loss_fn_outputs.append(
{
"logprobs": action_log_probs[i, start:].detach().cpu().tolist(),
"elementwise_loss": elementwise_loss[i, start:].detach().cpu().tolist(),
}
)

metrics = {
"loss": loss.detach().item(),
"response_length": num_actions,
"loss_fn_outputs": loss_fn_outputs,
}
return loss, metrics

# RL path: add optional KL/entropy terms
# entropy loss
with torch.set_grad_enabled(loss_config.use_entropy_loss):
action_logits = logits[:, -num_actions - 1 : -1, :]
entropy_BS = vocab_parallel_entropy(action_logits)
entropy = masked_mean(entropy_BS, loss_mask)

if self.cfg.trainer.algorithm.use_entropy_loss:
entropy_loss_term = entropy * self.cfg.trainer.algorithm.entropy_loss_coef
if loss_config.use_entropy_loss:
entropy_loss_term = entropy * loss_config.entropy_loss_coef
else:
entropy_loss_term = torch.tensor(0.0)

if self.cfg.trainer.algorithm.use_kl_loss:
if loss_config.use_kl_loss:
kl_loss = compute_approx_kl(
action_log_probs,
base_action_log_probs,
loss_mask=loss_mask,
kl_estimator_type=self.cfg.trainer.algorithm.kl_estimator_type,
kl_estimator_type=loss_config.kl_estimator_type,
)
kl_loss = masked_mean(kl_loss, loss_mask, dim=-1).mean()
else:
kl_loss = torch.tensor(0.0)
kl_loss_term = kl_loss * self.cfg.trainer.algorithm.kl_loss_coef
kl_loss_term = kl_loss * loss_config.kl_loss_coef

loss = policy_loss + kl_loss_term - entropy_loss_term

Expand Down
Loading