[skyrl-train] Add SFT support via forward_backward(loss_fn="cross_entropy")#961
[skyrl-train] Add SFT support via forward_backward(loss_fn="cross_entropy")#961tyler-griggs merged 8 commits intomainfrom
Conversation
|
The latest updates on your projects. Learn more about Vercel for GitHub.
|
There was a problem hiding this comment.
Code Review
This pull request introduces supervised fine-tuning (SFT) support, which is a great addition to the training framework. The changes are well-structured, and the new SFT example is clear and helpful. The implementation correctly adds a cross_entropy loss function and adapts the forward_backward path to support it, including returning per-token outputs for Tinker API compatibility. I've identified one high-severity issue regarding an in-place modification of the shared configuration object, which could lead to unexpected side effects, and a related medium-severity style issue. Overall, this is a solid contribution that significantly enhances the framework's capabilities.
b259251 to
25f8632
Compare
25f8632 to
ef12c2e
Compare
…ropy") Enables supervised fine-tuning using the Tinker-compatible API. Changes: - ppo_utils.py: Add CROSS_ENTROPY loss type and cross_entropy_loss() function - worker.py: Add SFT code path that returns per-token logprobs and elementwise_loss - worker_dispatch.py: Add loss_fn and loss_fn_config params to forward_backward() - dispatch.py: Update MeshDispatch to pass through kwargs (loss_fn, loss_fn_config) - replay_buffer.py: Make action_log_probs optional in Experience - worker_utils.py: Use .get() for optional fields; handle non-scalar metrics New: - examples/sft/: Minimal SFT example demonstrating the API This enables PR #871 (SkyRL-train backend for Tinker) to return proper per-token values instead of placeholder data. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
ef12c2e to
4cec033
Compare
- validate_dispatch_args now accepts data as positional or keyword arg - worker_dispatch only passes loss_fn/loss_fn_config when non-None (critic worker doesn't accept these params) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…sion error loss_fn_outputs is a list of dicts (per-sequence data for Tinker API), not a tensor/scalar. Extract before all_reduce and add back after. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
With DP>1, each rank returns loss_fn_outputs for its data chunk. Previously only statuses[0] was returned, dropping other ranks' outputs. Now concatenate all loss_fn_outputs in rank order. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tinker expects variable-length arrays that align with input weights, not padded to batch max. Use loss_mask to determine valid length per sample and slice accordingly. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Verifies: - loss_fn="cross_entropy" returns loss_fn_outputs - Each DP rank returns outputs for its data chunk - Output structure has logprobs and elementwise_loss keys - Arrays are trimmed to valid length Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ptim_step (#901) ## Summary - Add `forward_backward()` and `optim_step()` methods to `MegatronPolicyWorkerBase` to match FSDP worker interface - Update trainer to use unified interface for both Megatron and FSDP strategies (removes strategy branching) - Mark `ppo_train()` as deprecated (kept for backward compatibility) - Update `test_megatron_worker.py` to use the new interface - Add `get_lr` and `set_lr` to the megatron worker to be in line with behavior from #978 - Add SFT behavior form #961, allowing the megatron backend to be used with the TX SkyRL-Train integration This brings Megatron up to parity with FSDP following the refactoring in PR #859. ## Test plan - [x] Run `test_megatron_worker.py` to verify forward_backward + optim_step works correctly - [x] Verify metrics match between Megatron and FSDP implementations Co-Authored-By: Eric Tang <erictang000@gmail.com> --------- Co-authored-by: Eric Tang <erictang000@gmail.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Summary
Enables SFT using the Tinker-compatible API:
Key Changes
loss_fnparameter toforward_backward()(overrides config'spolicy_loss_type)cross_entropyloss inPolicyLossRegistryloss_fn_outputsfor Tinker API compatibility:action_log_probsoptional inExperience(SFT batches don't have rollout log probs)MeshDispatchto pass through kwargs to worker methods