Skip to content

Conversation

@justinvyu
Copy link
Contributor

@justinvyu justinvyu commented Jan 23, 2026

Summary

The previous implementation for ppo policy loss reduction had a "mean of means" bias — when computing token-mean loss across micro-batches and workers with varying token counts, the naive averaging gave incorrect results where:

  • microbatches with fewer tokens are weighted more heavily (since we take a mean across microbatches within a minibatch)
    • Micro-batch 1: 100 tokens, average loss = 0.5, micro-batch 2: 900 tokens, average loss = 0.3
    • -> Naive mean: (0.5 + 0.3) / 2 = 0.4, Correct token-mean: (100×0.5 + 900×0.3) / 1000 = 0.32
  • worker minibatches with fewer tokens were weighted more heavily (since DDP all-reduce takes a mean across minibatches)
    • Same example as above, but the average is across workers instead.

After this PR, ppo_policy_loss used within forward_backward now just sums the per-token loss for all sequences and relies on the advantages passed in by the user to handle the loss normalization.

This aligns with Tinker semantics:

Notice that for all objectives we sum the token-level losses over the sequence length unlike some other loss implementations. If you would like to explore different aggregation schemes, you can include that in the advantage tensor computation.

Example for loss_reduction="token_mean":

  • Move the 1/num_minibatch_tokens normalization into the advantage: loss = sum( -advantage_i * ratio_i for i in range(num_minibatch_tokens) ) / num_minibatch_tokens
  • -> sum( -(advantage_i / num_minibatch_tokens) * ratio_i for i in range(num_minibatch_tokens) )

DDP all-reduce

DDP/FSDP defaults to a mean all-reduce for gradients across workers. This PR counteracts this by multiplying by the DP world size.

Additional details

This was the first attempt: #909

This method was to track total tokens and then do one big normalization at the optim_step in order to get an average per-token loss. But, we decided to align with Tinker's way of just summing up the loss at the end, and pushing any loss normalization to the user's advantage calculation.

The benefit is that users have full control of customizing their loss reduction strategy, rather than having it happen in our opaque forward_backward, optim_step implementation which would require some configuration argument that diverges from tinker's API. For example, we would need to add a config somewhere to determine how to average/sum the loss:

client.forward_backward(...)
client.optim_step(..., loss_reduction="token_mean")  # no longer tinker compatible

Follow-up work

The ppo_critic_loss has the same problem but is not as important as the policy loss.

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Comment on lines +787 to +789
for param in self.model.parameters():
if param.grad is not None:
param.grad.mul_(self.strategy.world_size)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could do this at the advantage computation level, but i thought it was a bit weird to have ddp all-reduce implementation details there so i separated it to be here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant