-
Notifications
You must be signed in to change notification settings - Fork 234
Add loss_fn parameterization to forward_backward #924
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces parameterization for loss_fn to forward_backward for Tinker API compatibility and refactors weight synchronization into a new save_weights_for_sampler method. The changes are generally positive, simplifying the API and improving test coverage. However, I've identified a critical issue in trainer.py where the critic training path appears to be broken due to incorrect arguments being passed to _execute_training_step. Additionally, there's an inconsistency in a new test file regarding a test utility function call.
| with Timer("critic_train", self.all_timings): | ||
| critic_status = self._execute_training_step("critic", data, "critic") | ||
| with Timer("policy_train", self.all_timings): | ||
| policy_status = self._execute_training_step("policy", data, "policy") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a critical issue with the arguments passed to _execute_training_step for both the critic and policy models. The function signature for _execute_training_step is (self, model: str, data: TrainingInputBatch), but it's being called with three arguments here (e.g., self._execute_training_step("critic", data, "critic")). This will result in a TypeError at runtime.
While the intent seems to be to pass a loss_fn, the implementation appears incomplete. Specifically:
- The signature of
_execute_training_stephasn't been updated to accept a third argument. - Even if it were updated, the critic training path would likely fail. The
loss_fnwould be"critic", which is not handled byPolicyWorkerBase._get_loss_fn, andCriticWorkerBasedoesn't have a comparable method to handle a parameterized loss function.
To fix this, you'll need to update the signature of _execute_training_step and ensure that both policy and critic workers can correctly handle the new loss_fn parameter. For the critic, you might want to pass None as the loss_fn if it's not meant to be parameterized, and handle that case in _execute_training_step.
|
|
||
| # === Step 1: Do a training step === | ||
| dp_size = policy_group.actor_infos[0].rank.dp_size | ||
| dummy_batch = make_dummy_training_batch(batch_size=dp_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call to make_dummy_training_batch here and on line 190 seems inconsistent with changes in other test files. In other files like test_save_load_checkpoint.py and test_training_step.py, the batch_size argument was removed from this call (e.g., make_dummy_training_batch()).
If the signature of make_dummy_training_batch has changed, this could lead to test failures. For consistency across the test suite, please update this call to match the new pattern.
| dummy_batch = make_dummy_training_batch(batch_size=dp_size) | |
| dummy_batch = make_dummy_training_batch() |
- Remove ppo_train() from PolicyWorkerBase and CriticWorkerBase - Workers now use forward_backward() + optim_step() with gradient scaling - Trainer branches on strategy: Megatron uses ppo_train, FSDP uses forward_backward + optim_step - WorkerDispatch forward_backward no longer takes Tinker params (loss_fn, loss_fn_config) - Update tests to use TrainingInputBatch and remove ppo_train tests Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…anges - Megatron: Remove redundant batch_to_experience call (iterator already yields Experience) - test_save_load_model.py: Use TrainingInputBatch, remove extra forward_backward arg - test_worker_offload.py: Use TrainingInputBatch, remove extra forward_backward arg Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
7a0f4c3 to
70fb844
Compare
Summary
loss_fnandloss_fn_configparameters toforward_backward()for Tinker API compatibilityppo_train()from FSDP workers - uses gradient scaling atoptim_stepinsteadChanges
WorkerDispatch (
worker_dispatch.py):loss_fnandloss_fn_configparameters toforward_backward()PolicyWorkerBase (
worker.py):convert_tinker_loss_config()static method to convert Tinker's absolute ratio bounds to SkyRL's offset formatoptim_steptime based on accumulated micro batchesppo_train()path for FSDP workersTests:
test_convert_tinker_loss_configfor Tinker config conversionpass_throughrouting and positional batch parametersTest Plan
test_normalize_mini_batch_size,test_convert_tinker_loss_configpytest tests/gpu/gpu_ci/test_training_step.pyStack
🤖 Generated with Claude Code