Skip to content

[skyrl-train] Add SFT support via forward_backward(loss_fn="cross_entropy")#961

Merged
tyler-griggs merged 8 commits intomainfrom
tyler/sft-support
Jan 27, 2026
Merged

[skyrl-train] Add SFT support via forward_backward(loss_fn="cross_entropy")#961
tyler-griggs merged 8 commits intomainfrom
tyler/sft-support

Conversation

@tyler-griggs
Copy link
Copy Markdown
Member

@tyler-griggs tyler-griggs commented Jan 26, 2026

Summary

Enables SFT using the Tinker-compatible API:

metrics = dispatch.forward_backward("policy", batch, loss_fn="cross_entropy")

Key Changes

  • Add loss_fn parameter to forward_backward() (overrides config's policy_loss_type)
  • Implement cross_entropy loss in PolicyLossRegistry
  • Return per-sequence loss_fn_outputs for Tinker API compatibility:
    metrics["loss_fn_outputs"] = [
        {"logprobs": [...], "elementwise_loss": [...]},  # sequence 1
        {"logprobs": [...], "elementwise_loss": [...]},  # sequence 2
        ...
    ]
  • Make action_log_probs optional in Experience (SFT batches don't have rollout log probs)
  • Update MeshDispatch to pass through kwargs to worker methods

@vercel
Copy link
Copy Markdown

vercel Bot commented Jan 26, 2026

The latest updates on your projects. Learn more about Vercel for GitHub.

Project Deployment Review Updated (UTC)
skyrl-docs Ready Ready Preview, Comment Jan 27, 2026 0:12am

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces supervised fine-tuning (SFT) support, which is a great addition to the training framework. The changes are well-structured, and the new SFT example is clear and helpful. The implementation correctly adds a cross_entropy loss function and adapts the forward_backward path to support it, including returning per-token outputs for Tinker API compatibility. I've identified one high-severity issue regarding an in-place modification of the shared configuration object, which could lead to unexpected side effects, and a related medium-severity style issue. Overall, this is a solid contribution that significantly enhances the framework's capabilities.

Comment thread skyrl-train/skyrl_train/workers/worker.py
…ropy")

Enables supervised fine-tuning using the Tinker-compatible API.

Changes:
- ppo_utils.py: Add CROSS_ENTROPY loss type and cross_entropy_loss() function
- worker.py: Add SFT code path that returns per-token logprobs and elementwise_loss
- worker_dispatch.py: Add loss_fn and loss_fn_config params to forward_backward()
- dispatch.py: Update MeshDispatch to pass through kwargs (loss_fn, loss_fn_config)
- replay_buffer.py: Make action_log_probs optional in Experience
- worker_utils.py: Use .get() for optional fields; handle non-scalar metrics

New:
- examples/sft/: Minimal SFT example demonstrating the API

This enables PR #871 (SkyRL-train backend for Tinker) to return proper
per-token values instead of placeholder data.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- validate_dispatch_args now accepts data as positional or keyword arg
- worker_dispatch only passes loss_fn/loss_fn_config when non-None
  (critic worker doesn't accept these params)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…sion error

loss_fn_outputs is a list of dicts (per-sequence data for Tinker API),
not a tensor/scalar. Extract before all_reduce and add back after.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
With DP>1, each rank returns loss_fn_outputs for its data chunk.
Previously only statuses[0] was returned, dropping other ranks' outputs.
Now concatenate all loss_fn_outputs in rank order.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tinker expects variable-length arrays that align with input weights,
not padded to batch max. Use loss_mask to determine valid length
per sample and slice accordingly.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Verifies:
- loss_fn="cross_entropy" returns loss_fn_outputs
- Each DP rank returns outputs for its data chunk
- Output structure has logprobs and elementwise_loss keys
- Arrays are trimmed to valid length

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@tyler-griggs tyler-griggs merged commit 3683ceb into main Jan 27, 2026
4 of 5 checks passed
erictang000 added a commit that referenced this pull request Jan 31, 2026
…ptim_step (#901)

## Summary
- Add `forward_backward()` and `optim_step()` methods to
`MegatronPolicyWorkerBase` to match FSDP worker interface
- Update trainer to use unified interface for both Megatron and FSDP
strategies (removes strategy branching)
- Mark `ppo_train()` as deprecated (kept for backward compatibility)
- Update `test_megatron_worker.py` to use the new interface
- Add `get_lr` and `set_lr` to the megatron worker to be in line with
behavior from #978
- Add SFT behavior form #961, allowing the megatron backend to be used
with the TX SkyRL-Train integration

This brings Megatron up to parity with FSDP following the refactoring in
PR #859.

## Test plan
- [x] Run `test_megatron_worker.py` to verify forward_backward +
optim_step works correctly
- [x] Verify metrics match between Megatron and FSDP implementations

Co-Authored-By: Eric Tang <erictang000@gmail.com>

---------

Co-authored-by: Eric Tang <erictang000@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant