Skip to content

Add ReduceType.weighted_mean for weighted metric reduction#604

Open
finbarrtimbers wants to merge 3 commits intomainfrom
finbarr/weighted-mean-reduce-type
Open

Add ReduceType.weighted_mean for weighted metric reduction#604
finbarrtimbers wants to merge 3 commits intomainfrom
finbarr/weighted-mean-reduce-type

Conversation

@finbarrtimbers
Copy link
Copy Markdown
Contributor

Summary

  • Adds ReduceType.weighted_mean which computes sum(value * weight) / sum(weight) across ranks
  • Useful when ranks process different amounts of data (e.g. different token counts per rank)
  • Adds weight parameter to Trainer.record_metric() and TrainModule.record_metric()

Test plan

  • Added non-distributed unit test for reduce_metrics with weighted_mean
  • Added weighted_mean case to existing distributed reduce_metrics test
  • Linting and style checks pass

🤖 Generated with Claude Code

finbarrtimbers and others added 2 commits February 11, 2026 14:14
Enables recording metrics with per-rank weights so the reduction computes
sum(value * weight) / sum(weight) instead of a simple mean. This is useful
when ranks process different amounts of data (e.g. different token counts).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 612bfdebf2

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +1020 to +1024
if not isinstance(weight, torch.Tensor):
weight = torch.tensor(weight)
else:
weight = get_local_tensor(weight.detach()).float()
value = value * weight
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Move weight to metric device before multiplication

record_metric() converts non-tensor weights with torch.tensor(weight), which creates a CPU tensor, and then immediately multiplies it with value. If value is on CUDA (common in training) and weight is passed as a Python scalar (the new API allows this), this path raises a device-mismatch runtime error when logging weighted metrics. This makes ReduceType.weighted_mean unusable in typical GPU runs unless callers manually wrap weights as tensors on the right device.

Useful? React with 👍 / 👎.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant