portable: accumulate in fp32 for Half/BFloat16 in softmax, log_softmax, mean, and sum#20090
Open
vacu9708 wants to merge 2 commits into
Open
portable: accumulate in fp32 for Half/BFloat16 in softmax, log_softmax, mean, and sum#20090vacu9708 wants to merge 2 commits into
vacu9708 wants to merge 2 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20090
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
88fb15d to
3fb0012
Compare
…tmax Problem: Softmax and log_softmax accumulated exp(x - max) in the tensor dtype. For BFloat16, the running sum saturates around 256 — adding 1.0 stops changing the total — so a uniform softmax over N=512 elements outputs ~1/256 instead of 1/512. Changes: Accumulate the exp-sum in float for Half/BFloat16 by threading an ACC type through the map-reduce calls. Loads and stores remain in the tensor dtype. Continues the fp32-accumulation work in pytorch#19117.
Problem: The fast-path and generic reduction loops in mean.out and sum.IntList_out accumulated the running sum in the tensor dtype. For BFloat16, the sum saturates around 256, so a mean over N=512 all-ones elements gives 0.5 instead of 1.0, and summing 512 all-ones elements gives 256 instead of 512. Changes: Accumulate in float for Half/BFloat16 by promoting the loop accumulator to ACC in both the fast path and the generic path. The final result is cast back to the tensor dtype on store. Continues the fp32-accumulation work in pytorch#19117.
3fb0012 to
d56aa5a
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
softmax, log_softmax, mean, and sum all accumulate their reduction in the input dtype. For BFloat16, that sum saturates around 256. Once it gets there, adding 1.0 rounds away and the total gets stuck. A uniform softmax over 512 elements in BFloat16 gives
~1/256per output instead of1/512.Why FP32 accumulation is needed
BFloat16 has the same exponent width as Float32, so it has a similar range. However, it has far fewer fraction bits, which makes its representable spacing much coarser as values grow.
BFloat16Float32, but coarse spacingFloat32For BFloat16, the gap between consecutive representable values (i.e, the smallest step size) increases at each power-of-two range:
[128, 256)1128, 129, 130, ..., 255[256, 512)2256, 258, 260, ..., 510As a result, once a BFloat16 running sum reaches
256, adding1.0no longer changes the value:256 + 1257256257is not representable and rounds back to256(according to IEEE 754; round-to-nearest-even)This directly affects all four ops for large inputs. For a softmax over 512 zeros, each
exp(0)contributes1.0, so the denominator should be512. If the BFloat16 accumulation gets stuck at256, the output becomes approximately1/256instead of the correct1/512.5125121/512512~256~1/256ATen accumulates reductions in float for Half/BFloat16 (via
acc_type). This PR does the same, following the pattern already established inop_grid_sampler_2d(#19117).Tests