Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,24 @@ struct Vec4Type<at::Half> {
using type = float2;
};

// Use bfloat16_4 instead of float2 to distinguish BF16 from FP16 at the
// store call site: both would otherwise resolve to float2 and incorrectly
// use __float22half2_rn (FP16) instead of __float2bfloat16_rn (BF16).
// bfloat16_4 is only available on ROCm or CUDA >= 11 / SM80+; fall back
// to float2 (broken but unchanged) on legacy platforms.
#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
template <>
struct Vec4Type<at::BFloat16> {
using type = bfloat16_4;
};
#else
template <>
struct Vec4Type<at::BFloat16> {
using type = float2;
};
#endif

template <>
struct Vec4Type<uint8_t> {
Expand Down
107 changes: 107 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/vec4acc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <ATen/ATen.h>
#include "fbgemm_gpu/utils/cuda_prelude.cuh"
#include "fbgemm_gpu/utils/float.cuh"

namespace fbgemm_gpu {

Expand Down Expand Up @@ -69,6 +70,14 @@ struct Vec4AccT {
*dst = *reinterpret_cast<float2*>(vals_h);
}

#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
DEVICE_INLINE void store_(const float4* src, bfloat16_4* dst) {
*dst = to_bfloat16_4(*src);
}
#endif

DEVICE_INLINE void store(float4* ptr) {
this->store_(reinterpret_cast<float4*>(acc), ptr);
}
Expand All @@ -78,6 +87,15 @@ struct Vec4AccT {
this->store_(reinterpret_cast<const float4*>(acc), ptr);
}

#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
// Store to bfloat16
DEVICE_INLINE void store(bfloat16_4* ptr) {
this->store_(reinterpret_cast<const float4*>(acc), ptr);
}
#endif

DEVICE_INLINE void store(uint8_t* ptr) {
CUDA_KERNEL_ASSERT(false);
}
Expand Down Expand Up @@ -188,6 +206,14 @@ struct Vec4StepT<STEP, float> : Vec4AccT {
this->store_(&loaded_vals[idx], ptr);
}

#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
DEVICE_INLINE void index_store(uint32_t idx, bfloat16_4* ptr) {
this->store_(reinterpret_cast<const float4*>(&loaded_vals[idx]), ptr);
}
#endif

DEVICE_INLINE void index_store(uint32_t idx, uint8_t* ptr) {
CUDA_KERNEL_ASSERT(false);
}
Expand All @@ -213,6 +239,21 @@ struct Vec4StepT<STEP, float> : Vec4AccT {
this->store_(reinterpret_cast<float4*>(vals_f), ptr);
}

#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
DEVICE_INLINE void
index_weighted_store(uint32_t idx, bfloat16_4* ptr, const float weight) {
const float* vals = reinterpret_cast<const float*>(&loaded_vals[idx]);
float4 vals_f = make_float4(
__fmul_rn(vals[0], weight),
__fmul_rn(vals[1], weight),
__fmul_rn(vals[2], weight),
__fmul_rn(vals[3], weight));
this->store_(&vals_f, ptr);
}
#endif

DEVICE_INLINE void
index_weighted_store(uint32_t idx, uint8_t* ptr, const float weight) {
CUDA_KERNEL_ASSERT(false);
Expand Down Expand Up @@ -297,6 +338,18 @@ struct Vec4StepT<STEP, at::Half> : Vec4AccT {
*ptr = *reinterpret_cast<float2*>(&loaded_vals[idx]);
}

#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
DEVICE_INLINE void index_store(uint32_t idx, bfloat16_4* ptr) {
const half2* vals_h = reinterpret_cast<const half2*>(&loaded_vals[idx]);
float2 vals_f[2];
vals_f[0] = __half22float2(vals_h[0]);
vals_f[1] = __half22float2(vals_h[1]);
this->store_(reinterpret_cast<const float4*>(vals_f), ptr);
}
#endif

DEVICE_INLINE void index_store(uint32_t idx, uint8_t* ptr) {
CUDA_KERNEL_ASSERT(false);
}
Expand All @@ -322,6 +375,21 @@ struct Vec4StepT<STEP, at::Half> : Vec4AccT {
this->store_(reinterpret_cast<float4*>(vals_f), ptr);
}

#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
DEVICE_INLINE void
index_weighted_store(uint32_t idx, bfloat16_4* ptr, const float weight) {
const half* vals = reinterpret_cast<const half*>(&loaded_vals[idx]);
float4 vals_f = make_float4(
__fmul_rn(__half2float(vals[0]), weight),
__fmul_rn(__half2float(vals[1]), weight),
__fmul_rn(__half2float(vals[2]), weight),
__fmul_rn(__half2float(vals[3]), weight));
this->store_(&vals_f, ptr);
}
#endif

DEVICE_INLINE void
index_weighted_store(uint32_t idx, uint8_t* ptr, const float weight) {
CUDA_KERNEL_ASSERT(false);
Expand Down Expand Up @@ -383,6 +451,19 @@ struct Vec4StepT<STEP, uint8_t> : Vec4AccT {
index_weighted_store(uint32_t idx, uint8_t* ptr, const float weight) {
CUDA_KERNEL_ASSERT(false);
}

#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
DEVICE_INLINE void index_store(uint32_t idx, bfloat16_4* ptr) {
CUDA_KERNEL_ASSERT(false);
}

DEVICE_INLINE void
index_weighted_store(uint32_t idx, bfloat16_4* ptr, const float weight) {
CUDA_KERNEL_ASSERT(false);
}
#endif
};

template <uint32_t STEP>
Expand Down Expand Up @@ -447,6 +528,19 @@ struct Vec4StepT<STEP, c10::Float8_e4m3fn> : Vec4AccT {
const float weight) {
CUDA_KERNEL_ASSERT(false);
}

#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
DEVICE_INLINE void index_store(uint32_t idx, bfloat16_4* ptr) {
CUDA_KERNEL_ASSERT(false);
}

DEVICE_INLINE void
index_weighted_store(uint32_t idx, bfloat16_4* ptr, const float weight) {
CUDA_KERNEL_ASSERT(false);
}
#endif
};

template <uint32_t STEP>
Expand Down Expand Up @@ -511,6 +605,19 @@ struct Vec4StepT<STEP, c10::Float8_e4m3fnuz> : Vec4AccT {
const float weight) {
CUDA_KERNEL_ASSERT(false);
}

#if defined(USE_ROCM) || \
!(((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
DEVICE_INLINE void index_store(uint32_t idx, bfloat16_4* ptr) {
CUDA_KERNEL_ASSERT(false);
}

DEVICE_INLINE void
index_weighted_store(uint32_t idx, bfloat16_4* ptr, const float weight) {
CUDA_KERNEL_ASSERT(false);
}
#endif
};

} // namespace fbgemm_gpu