Skip to content

Conversation

@tyler-griggs
Copy link
Member

Summary

Enables supervised fine-tuning (SFT) using the Tinker-compatible forward_backward() API.

  • Adds loss_fn and loss_fn_config parameters to forward_backward()
  • Implements cross_entropy loss function
  • Returns per-token logprobs and elementwise_loss for Tinker API compatibility
  • Makes various fields optional to support SFT batches (no advantages, old_log_probs, etc.)

Changes

File Description
ppo_utils.py Add CROSS_ENTROPY loss type and cross_entropy_loss() function
worker.py Add SFT code path that returns per-token outputs
worker_dispatch.py Add loss_fn and loss_fn_config params
dispatch.py Update MeshDispatch to pass through kwargs
replay_buffer.py Make action_log_probs optional in Experience
worker_utils.py Use .get() for optional fields; handle non-scalar metrics
examples/sft/ New SFT example demonstrating the API

Usage

# Tinker-compatible SFT API
metrics = dispatch.forward_backward("policy", batch, loss_fn="cross_entropy")

# Returns per-token outputs:
# - metrics["logprobs"] - per-token log probabilities
# - metrics["elementwise_loss"] - per-token loss values
# - metrics["loss"] - scalar loss

grad_norm = dispatch.optim_step("policy")

Motivation

This PR unblocks #871 (SkyRL-train backend for Tinker), which needs proper per-token outputs instead of placeholder values.

Test plan

  • Run examples/sft/sft_trainer.py to verify SFT training works
  • Verify existing RL training is unaffected (the SFT code path only activates when loss_fn="cross_entropy")

🤖 Generated with Claude Code

tyler-griggs and others added 2 commits January 26, 2026 02:07
- Add SKYRL_LOG_LEVEL to env_vars.py for centralized log level control
- Update configure_ray_worker_logging to use SKYRL_LOG_LEVEL
- Add log_to_driver=False to suppress worker/raylet log forwarding
- Set RAY_BACKEND_LOG_LEVEL=fatal to suppress C++ metrics errors
- Enable verbose logging when SKYRL_LOG_LEVEL=DEBUG

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add SKYRL_LOG_DIR env var for log file location (default: /tmp/skyrl-logs)
- Create new logging module (skyrl_train/utils/logging.py) with:
  - setup_logging(): Routes all logs to file, training progress to stdout
  - configure_worker_logging(): Simplified worker logging setup
- Training progress loggers (trainer, fully_async_trainer, tracking, evaluate)
  go to stdout by default
- Infrastructure loggers (vllm, ray, workers, inference_engines, etc.)
  only go to log file by default
- Set SKYRL_LOG_LEVEL=DEBUG to show all logs on stdout
- Update main_base.py to call setup_logging at startup

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@vercel
Copy link

vercel bot commented Jan 26, 2026

The latest updates on your projects. Learn more about Vercel for GitHub.

Project Deployment Review Updated (UTC)
skyrl-docs Ready Ready Preview, Comment Jan 26, 2026 2:11am

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant new functionality by adding support for Supervised Fine-Tuning (SFT) through a flexible forward_backward API. The changes are well-structured, including a new cross_entropy loss function, updates to data structures to support SFT batches, and a clear example demonstrating the new feature. The logging infrastructure has also been substantially improved, providing better control and clarity in a distributed environment. I've identified a critical issue with incorrect type hints that could lead to a runtime error, along with a few medium-severity suggestions to improve code consistency and style. Overall, this is a solid contribution that expands the framework's capabilities.

I am having trouble creating individual review comments. Click here to see my feedback.

skyrl-train/skyrl_train/utils/ppo_utils.py (886-887)

critical

The type hints for old_log_probs and advantages should be Optional[torch.Tensor]. In the SFT workflow, these values will be None as they are not present in the SFT batch. The current torch.Tensor type hint will lead to a TypeError at runtime when None is passed from the worker.

    old_log_probs: Optional[torch.Tensor],
    advantages: Optional[torch.Tensor],

skyrl-train/skyrl_train/workers/worker.py (741)

medium

This local import should be moved to the top of the file to follow standard Python style guidelines (PEP 8). This improves readability and makes dependencies clear at a glance. You can add OmegaConf to the existing import from omegaconf at the top of the file.

skyrl-train/skyrl_train/workers/worker.py (817)

medium

For consistency, it would be better to use the same key for the final loss metric in both the RL and SFT code paths. The SFT path uses the key "loss", while the RL path uses "final_loss". Unifying this to "loss" would make metric consumption simpler and less error-prone.

                "loss": loss.item(),

skyrl-train/examples/sft/sft_trainer.py (184)

medium

This line handles both final_loss (from RL) and loss (from SFT) due to an inconsistency in the metric names returned by the worker. If the metric name is unified to loss in worker.py as suggested in another comment, this line can be simplified to only get "loss".

            loss_val = metrics.get("loss", "N/A")

@tyler-griggs
Copy link
Member Author

Closing - PR contained unrelated commits. Reopening with clean branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants