-
Notifications
You must be signed in to change notification settings - Fork 234
[tx] Add experimental SkyRL-train backend that supports SFT #871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new SkyRL-train backend for supervised training. The changes include updating project dependencies in pyproject.toml and adding the new backend implementation in skyrl-tx/tx/tinker/backends/skyrl_train.py. While this is a good starting point for the new backend, my review has identified several issues that need to be addressed. The most critical issue is in the forward_backward method, which is currently a stub and does not perform a backward pass or return actual losses, preventing any training from occurring. Other significant issues include the use of hardcoded paths and hyperparameters, potentially incorrect token padding, and breaking encapsulation by accessing private members of a library class. Addressing these points will be crucial for the backend to be functional and maintainable.
| ray.get([actor.save_checkpoint.remote(output_path) for actor in self._actor_group._actor_handlers]) | ||
|
|
||
| def load_checkpoint(self, checkpoint_path, model_id: str) -> None: | ||
| if model_id != self._model_id: | ||
| raise ValueError(f"Model {model_id} not found") | ||
| ray.get([actor.load_checkpoint.remote(Path(checkpoint_path)) for actor in self._actor_group._actor_handlers]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing the private member _actor_handlers of PPORayActorGroup breaks encapsulation and makes the code dependent on the internal implementation of the skyrl-train library. This could lead to breakages if the library is updated. It would be more robust to use a public API from PPORayActorGroup for this purpose, or request one if it doesn't exist.
|
@pcmoritz is attempting to deploy a commit to the Tyler's projects Team on Vercel. A member of the Team first needs to authorize it. |
…ropy") Enables supervised fine-tuning using the Tinker-compatible API. Changes: - ppo_utils.py: Add CROSS_ENTROPY loss type and cross_entropy_loss() function - worker.py: Add SFT code path that returns per-token logprobs and elementwise_loss - worker_dispatch.py: Add loss_fn and loss_fn_config params to forward_backward() - dispatch.py: Update MeshDispatch to pass through kwargs (loss_fn, loss_fn_config) - replay_buffer.py: Make action_log_probs optional in Experience - worker_utils.py: Use .get() for optional fields; handle non-scalar metrics New: - examples/sft/: Minimal SFT example demonstrating the API This enables PR #871 (SkyRL-train backend for Tinker) to return proper per-token values instead of placeholder data. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ropy") Enables supervised fine-tuning using the Tinker-compatible API. Changes: - ppo_utils.py: Add CROSS_ENTROPY loss type and cross_entropy_loss() function - worker.py: Add SFT code path that returns per-token logprobs and elementwise_loss - worker_dispatch.py: Add loss_fn and loss_fn_config params to forward_backward() - dispatch.py: Update MeshDispatch to pass through kwargs (loss_fn, loss_fn_config) - replay_buffer.py: Make action_log_probs optional in Experience - worker_utils.py: Use .get() for optional fields; handle non-scalar metrics New: - examples/sft/: Minimal SFT example demonstrating the API This enables PR #871 (SkyRL-train backend for Tinker) to return proper per-token values instead of placeholder data. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ropy") Enables supervised fine-tuning using the Tinker-compatible API. Changes: - ppo_utils.py: Add CROSS_ENTROPY loss type and cross_entropy_loss() function - worker.py: Add SFT code path that returns per-token logprobs and elementwise_loss - worker_dispatch.py: Add loss_fn and loss_fn_config params to forward_backward() - dispatch.py: Update MeshDispatch to pass through kwargs (loss_fn, loss_fn_config) - replay_buffer.py: Make action_log_probs optional in Experience - worker_utils.py: Use .get() for optional fields; handle non-scalar metrics New: - examples/sft/: Minimal SFT example demonstrating the API This enables PR #871 (SkyRL-train backend for Tinker) to return proper per-token values instead of placeholder data. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ropy") Enables supervised fine-tuning using the Tinker-compatible API. Changes: - ppo_utils.py: Add CROSS_ENTROPY loss type and cross_entropy_loss() function - worker.py: Add SFT code path that returns per-token logprobs and elementwise_loss - worker_dispatch.py: Add loss_fn and loss_fn_config params to forward_backward() - dispatch.py: Update MeshDispatch to pass through kwargs (loss_fn, loss_fn_config) - replay_buffer.py: Make action_log_probs optional in Experience - worker_utils.py: Use .get() for optional fields; handle non-scalar metrics New: - examples/sft/: Minimal SFT example demonstrating the API This enables PR #871 (SkyRL-train backend for Tinker) to return proper per-token values instead of placeholder data. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ropy") Enables supervised fine-tuning using the Tinker-compatible API. Changes: - ppo_utils.py: Add CROSS_ENTROPY loss type and cross_entropy_loss() function - worker.py: Add SFT code path that returns per-token logprobs and elementwise_loss - worker_dispatch.py: Add loss_fn and loss_fn_config params to forward_backward() - dispatch.py: Update MeshDispatch to pass through kwargs (loss_fn, loss_fn_config) - replay_buffer.py: Make action_log_probs optional in Experience - worker_utils.py: Use .get() for optional fields; handle non-scalar metrics New: - examples/sft/: Minimal SFT example demonstrating the API This enables PR #871 (SkyRL-train backend for Tinker) to return proper per-token values instead of placeholder data. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ropy") Enables supervised fine-tuning using the Tinker-compatible API. Changes: - ppo_utils.py: Add CROSS_ENTROPY loss type and cross_entropy_loss() function - worker.py: Add SFT code path that returns per-token logprobs and elementwise_loss - worker_dispatch.py: Add loss_fn and loss_fn_config params to forward_backward() - dispatch.py: Update MeshDispatch to pass through kwargs (loss_fn, loss_fn_config) - replay_buffer.py: Make action_log_probs optional in Experience - worker_utils.py: Use .get() for optional fields; handle non-scalar metrics New: - examples/sft/: Minimal SFT example demonstrating the API This enables PR #871 (SkyRL-train backend for Tinker) to return proper per-token values instead of placeholder data. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The engine can e.g. by started with
and then you can e.g. run