Skip to content

Fix bug in momentum type declaration in HIP TBE kernel#147

Open
aryaman-gupta wants to merge 5 commits intoaryaman/upstreamfrom
aryaman/hip-tbe-momentum-fix
Open

Fix bug in momentum type declaration in HIP TBE kernel#147
aryaman-gupta wants to merge 5 commits intoaryaman/upstreamfrom
aryaman/hip-tbe-momentum-fix

Conversation

@aryaman-gupta
Copy link

This PR fixes a bug in the HIP TBE kernel which defined the data type of p_momentum as cache_t. When cache_t is half, this declares p_momentum as half, which is incorrect and inconsistent with the Python code (https://github.com/ROCm/FBGEMM/blob/aryaman/hip-tbe-momentum-fix/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py#L1232-L1240). Before pytorch@ea2a302 was merged, cache_t was always float, so this bug did not show up.

Here, the datatype for p_momentum is determined using PyTorch's acc_type, which returns float when cache_t is half. This is also how the CUDA codepath appears to handle the datatype for the momentum buffer (references: https://github.com/ROCm/FBGEMM/blob/aryaman/hip-tbe-momentum-fix/fbgemm_gpu/codegen/genscript/optimizer_args.py#L956-L961, https://github.com/ROCm/FBGEMM/blob/aryaman/hip-tbe-momentum-fix/fbgemm_gpu/codegen/genscript/optimizers.py#L266)

@aryaman-gupta aryaman-gupta marked this pull request as ready for review March 19, 2026 22:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants