[skyrl-train] Fix loss reduction by moving normalization to the advantage computation #925
+46
−16
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
The previous implementation for ppo policy loss reduction had a "mean of means" bias — when computing token-mean loss across micro-batches and workers with varying token counts, the naive averaging gave incorrect results where:
Micro-batch 1: 100 tokens, average loss = 0.5, micro-batch 2: 900 tokens, average loss = 0.3Naive mean: (0.5 + 0.3) / 2 = 0.4, Correct token-mean: (100×0.5 + 900×0.3) / 1000 = 0.32After this PR,
ppo_policy_lossused withinforward_backwardnow just sums the per-token loss for all sequences and relies on the advantages passed in by the user to handle the loss normalization.This aligns with Tinker semantics:
Example for
loss_reduction="token_mean":1/num_minibatch_tokensnormalization into the advantage:loss = sum( -advantage_i * ratio_i for i in range(num_minibatch_tokens) ) / num_minibatch_tokenssum( -(advantage_i / num_minibatch_tokens) * ratio_i for i in range(num_minibatch_tokens) )DDP all-reduce
DDP/FSDP defaults to a mean all-reduce for gradients across workers. This PR counteracts this by multiplying by the DP world size.
Additional details
This was the first attempt: #909
This method was to track total tokens and then do one big normalization at the
optim_stepin order to get an average per-token loss. But, we decided to align with Tinker's way of just summing up the loss at the end, and pushing any loss normalization to the user's advantage calculation.The benefit is that users have full control of customizing their loss reduction strategy, rather than having it happen in our opaque
forward_backward,optim_stepimplementation which would require some configuration argument that diverges from tinker's API. For example, we would need to add a config somewhere to determine how to average/sum the loss:Follow-up work
The
ppo_critic_losshas the same problem but is not as important as the policy loss.