Skip to content

portable: accumulate in fp32 for Half/BFloat16 in softmax, log_softmax, mean, and sum#20090

Open
vacu9708 wants to merge 2 commits into
pytorch:mainfrom
vacu9708:fp32-accumulation-bfloat16
Open

portable: accumulate in fp32 for Half/BFloat16 in softmax, log_softmax, mean, and sum#20090
vacu9708 wants to merge 2 commits into
pytorch:mainfrom
vacu9708:fp32-accumulation-bfloat16

Conversation

@vacu9708

@vacu9708 vacu9708 commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Motivation

softmax, log_softmax, mean, and sum all accumulate their reduction in the input dtype. For BFloat16, that sum saturates around 256. Once it gets there, adding 1.0 rounds away and the total gets stuck. A uniform softmax over 512 elements in BFloat16 gives ~1/256 per output instead of 1/512.

Why FP32 accumulation is needed

BFloat16 has the same exponent width as Float32, so it has a similar range. However, it has far fewer fraction bits, which makes its representable spacing much coarser as values grow.

Type Exponent bits Fraction bits Practical effect
BFloat16 8 7 Similar range to Float32, but coarse spacing
Float32 8 23 Similar range, much finer spacing

For BFloat16, the gap between consecutive representable values (i.e, the smallest step size) increases at each power-of-two range:

Range BFloat16 step size Representable examples
[128, 256) 1 128, 129, 130, ..., 255
[256, 512) 2 256, 258, 260, ..., 510

As a result, once a BFloat16 running sum reaches 256, adding 1.0 no longer changes the value:

Operation Exact result BFloat16 result Reason
256 + 1 257 256 257 is not representable and rounds back to 256 (according to IEEE 754; round-to-nearest-even)

This directly affects all four ops for large inputs. For a softmax over 512 zeros, each exp(0) contributes 1.0, so the denominator should be 512. If the BFloat16 accumulation gets stuck at 256, the output becomes approximately 1/256 instead of the correct 1/512.

Case Expected denominator BFloat16 accumulated denominator Output
Correct accumulation 512 512 1/512
BFloat16 accumulation 512 ~256 ~1/256

ATen accumulates reductions in float for Half/BFloat16 (via acc_type). This PR does the same, following the pattern already established in op_grid_sampler_2d (#19117).

Tests

$ cmake --build cmake-out --target portable_kernels_test -j$(nproc)
[100%] Built target portable_kernels_test

# Post-fix — new tests:
[ OK ] OpSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat
[ OK ] OpLogSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat
[ OK ] OpMeanOutTest.BFloat16LargeDimAccumulatesInFloat
[ OK ] OpSumOutTest.BFloat16LargeDimAccumulatesInFloat

# Pre-fix (reverted op files):
[ FAILED ] OpSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat
[ FAILED ] OpLogSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat
[ FAILED ] OpMeanOutTest.BFloat16LargeDimAccumulatesInFloat
[ FAILED ] OpSumOutTest.BFloat16LargeDimAccumulatesInFloat

$ lintrunner op_softmax.cpp op_log_softmax.cpp op_mean.cpp op_sum.cpp \
             op_softmax_test.cpp op_log_softmax_test.cpp op_mean_test.cpp op_sum_test.cpp
ok  No lint issues.

@vacu9708 vacu9708 requested a review from manuelcandales as a code owner June 8, 2026 04:09
@pytorch-bot

pytorch-bot Bot commented Jun 8, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20090

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 8, 2026
@github-actions

github-actions Bot commented Jun 8, 2026

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@vacu9708 vacu9708 force-pushed the fp32-accumulation-bfloat16 branch from 88fb15d to 3fb0012 Compare June 8, 2026 04:13
vacu9708 added 2 commits June 8, 2026 15:01
…tmax

Problem:
Softmax and log_softmax accumulated exp(x - max) in the tensor dtype.
For BFloat16, the running sum saturates around 256 — adding 1.0 stops
changing the total — so a uniform softmax over N=512 elements outputs
~1/256 instead of 1/512.

Changes:
Accumulate the exp-sum in float for Half/BFloat16 by threading an ACC
type through the map-reduce calls. Loads and stores remain in the tensor
dtype.

Continues the fp32-accumulation work in pytorch#19117.
Problem:
The fast-path and generic reduction loops in mean.out and sum.IntList_out
accumulated the running sum in the tensor dtype. For BFloat16, the sum
saturates around 256, so a mean over N=512 all-ones elements gives 0.5
instead of 1.0, and summing 512 all-ones elements gives 256 instead of
512.

Changes:
Accumulate in float for Half/BFloat16 by promoting the loop accumulator
to ACC in both the fast path and the generic path. The final result is
cast back to the tensor dtype on store.

Continues the fp32-accumulation work in pytorch#19117.
@vacu9708 vacu9708 force-pushed the fp32-accumulation-bfloat16 branch from 3fb0012 to d56aa5a Compare June 8, 2026 06:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants