diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu index d01bc429f2..7edff90c91 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu @@ -82,10 +82,24 @@ struct Vec4Type { using type = float2; }; +// Use bfloat16_4 instead of float2 to distinguish BF16 from FP16 at the +// store call site: both would otherwise resolve to float2 and incorrectly +// use __float22half2_rn (FP16) instead of __float2bfloat16_rn (BF16). +// bfloat16_4 is only available on ROCm or CUDA >= 11 / SM80+; fall back +// to float2 (broken but unchanged) on legacy platforms. +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +template <> +struct Vec4Type { + using type = bfloat16_4; +}; +#else template <> struct Vec4Type { using type = float2; }; +#endif template <> struct Vec4Type { diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/vec4acc.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/vec4acc.cuh index 16c0d665fa..0dcfd7b6ef 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/vec4acc.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/vec4acc.cuh @@ -10,6 +10,7 @@ #include #include "fbgemm_gpu/utils/cuda_prelude.cuh" +#include "fbgemm_gpu/utils/float.cuh" namespace fbgemm_gpu { @@ -69,6 +70,14 @@ struct Vec4AccT { *dst = *reinterpret_cast(vals_h); } +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) + DEVICE_INLINE void store_(const float4* src, bfloat16_4* dst) { + *dst = to_bfloat16_4(*src); + } +#endif + DEVICE_INLINE void store(float4* ptr) { this->store_(reinterpret_cast(acc), ptr); } @@ -78,6 +87,15 @@ struct Vec4AccT { this->store_(reinterpret_cast(acc), ptr); } +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) + // Store to bfloat16 + DEVICE_INLINE void store(bfloat16_4* ptr) { + this->store_(reinterpret_cast(acc), ptr); + } +#endif + DEVICE_INLINE void store(uint8_t* ptr) { CUDA_KERNEL_ASSERT(false); } @@ -188,6 +206,14 @@ struct Vec4StepT : Vec4AccT { this->store_(&loaded_vals[idx], ptr); } +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) + DEVICE_INLINE void index_store(uint32_t idx, bfloat16_4* ptr) { + this->store_(reinterpret_cast(&loaded_vals[idx]), ptr); + } +#endif + DEVICE_INLINE void index_store(uint32_t idx, uint8_t* ptr) { CUDA_KERNEL_ASSERT(false); } @@ -213,6 +239,21 @@ struct Vec4StepT : Vec4AccT { this->store_(reinterpret_cast(vals_f), ptr); } +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) + DEVICE_INLINE void + index_weighted_store(uint32_t idx, bfloat16_4* ptr, const float weight) { + const float* vals = reinterpret_cast(&loaded_vals[idx]); + float4 vals_f = make_float4( + __fmul_rn(vals[0], weight), + __fmul_rn(vals[1], weight), + __fmul_rn(vals[2], weight), + __fmul_rn(vals[3], weight)); + this->store_(&vals_f, ptr); + } +#endif + DEVICE_INLINE void index_weighted_store(uint32_t idx, uint8_t* ptr, const float weight) { CUDA_KERNEL_ASSERT(false); @@ -297,6 +338,18 @@ struct Vec4StepT : Vec4AccT { *ptr = *reinterpret_cast(&loaded_vals[idx]); } +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) + DEVICE_INLINE void index_store(uint32_t idx, bfloat16_4* ptr) { + const half2* vals_h = reinterpret_cast(&loaded_vals[idx]); + float2 vals_f[2]; + vals_f[0] = __half22float2(vals_h[0]); + vals_f[1] = __half22float2(vals_h[1]); + this->store_(reinterpret_cast(vals_f), ptr); + } +#endif + DEVICE_INLINE void index_store(uint32_t idx, uint8_t* ptr) { CUDA_KERNEL_ASSERT(false); } @@ -322,6 +375,21 @@ struct Vec4StepT : Vec4AccT { this->store_(reinterpret_cast(vals_f), ptr); } +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) + DEVICE_INLINE void + index_weighted_store(uint32_t idx, bfloat16_4* ptr, const float weight) { + const half* vals = reinterpret_cast(&loaded_vals[idx]); + float4 vals_f = make_float4( + __fmul_rn(__half2float(vals[0]), weight), + __fmul_rn(__half2float(vals[1]), weight), + __fmul_rn(__half2float(vals[2]), weight), + __fmul_rn(__half2float(vals[3]), weight)); + this->store_(&vals_f, ptr); + } +#endif + DEVICE_INLINE void index_weighted_store(uint32_t idx, uint8_t* ptr, const float weight) { CUDA_KERNEL_ASSERT(false); @@ -383,6 +451,19 @@ struct Vec4StepT : Vec4AccT { index_weighted_store(uint32_t idx, uint8_t* ptr, const float weight) { CUDA_KERNEL_ASSERT(false); } + +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) + DEVICE_INLINE void index_store(uint32_t idx, bfloat16_4* ptr) { + CUDA_KERNEL_ASSERT(false); + } + + DEVICE_INLINE void + index_weighted_store(uint32_t idx, bfloat16_4* ptr, const float weight) { + CUDA_KERNEL_ASSERT(false); + } +#endif }; template @@ -447,6 +528,19 @@ struct Vec4StepT : Vec4AccT { const float weight) { CUDA_KERNEL_ASSERT(false); } + +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) + DEVICE_INLINE void index_store(uint32_t idx, bfloat16_4* ptr) { + CUDA_KERNEL_ASSERT(false); + } + + DEVICE_INLINE void + index_weighted_store(uint32_t idx, bfloat16_4* ptr, const float weight) { + CUDA_KERNEL_ASSERT(false); + } +#endif }; template @@ -511,6 +605,19 @@ struct Vec4StepT : Vec4AccT { const float weight) { CUDA_KERNEL_ASSERT(false); } + +#if defined(USE_ROCM) || \ + !(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) + DEVICE_INLINE void index_store(uint32_t idx, bfloat16_4* ptr) { + CUDA_KERNEL_ASSERT(false); + } + + DEVICE_INLINE void + index_weighted_store(uint32_t idx, bfloat16_4* ptr, const float weight) { + CUDA_KERNEL_ASSERT(false); + } +#endif }; } // namespace fbgemm_gpu