Add ReduceType.weighted_mean for weighted metric reduction#604
Add ReduceType.weighted_mean for weighted metric reduction#604finbarrtimbers wants to merge 3 commits intomainfrom
Conversation
Enables recording metrics with per-rank weights so the reduction computes sum(value * weight) / sum(weight) instead of a simple mean. This is useful when ranks process different amounts of data (e.g. different token counts). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 612bfdebf2
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if not isinstance(weight, torch.Tensor): | ||
| weight = torch.tensor(weight) | ||
| else: | ||
| weight = get_local_tensor(weight.detach()).float() | ||
| value = value * weight |
There was a problem hiding this comment.
Move weight to metric device before multiplication
record_metric() converts non-tensor weights with torch.tensor(weight), which creates a CPU tensor, and then immediately multiplies it with value. If value is on CUDA (common in training) and weight is passed as a Python scalar (the new API allows this), this path raises a device-mismatch runtime error when logging weighted metrics. This makes ReduceType.weighted_mean unusable in typical GPU runs unless callers manually wrap weights as tensors on the right device.
Useful? React with 👍 / 👎.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
ReduceType.weighted_meanwhich computessum(value * weight) / sum(weight)across ranksweightparameter toTrainer.record_metric()andTrainModule.record_metric()Test plan
reduce_metricswithweighted_meanweighted_meancase to existing distributedreduce_metricstest🤖 Generated with Claude Code