Skip to content

Commit b00b75f

Browse files
k-artemsfc-gh-truwaseloadams
authored
[AMD][ROCm] Improve support of AMD (#7448)
The patch delivers several fixes for building issues for CUDA part of DeepSpeed library. Percentage of passed unit tests improved(tested on RDNA hardware, gfx110x and gfx12x) Before: collected 5298 items / 15 skipped 2773 failed, 862 passed, 1665 skipped, 13 errors After: collected 5851 items / 11 skipped 4187 failed, 1373 passed, 292 skipped, 10 errors Regarding testing of **fp_quantizer(DS_BUILD_FP_QUANTIZER)** via `tests/unit/ops/fp_quantizer/test_fp_quant.py`, this test depends on QPyTorch which should be patched before run on AMD, please apply Tiiiger/QPyTorch#71 --------- Signed-off-by: Artem Kuzmitckii <artem.kuzmitckii@amd.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
1 parent 0e77020 commit b00b75f

9 files changed

Lines changed: 135 additions & 8 deletions

File tree

csrc/fp_quantizer/fp_quantize.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// DeepSpeed Team
55

66
#include <stdexcept>
7-
#include "context.h"
7+
#include "fp_context.h"
88
#include "fp_quantize.h"
99
#include "memory_access_utils.h"
1010
#include "reduction_utils.h"

csrc/fp_quantizer/includes/fp_quantize.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,18 @@
99
#include <stdint.h>
1010

1111
#include <cuda_fp16.h>
12-
13-
#ifdef BF16_AVAILABLE
12+
// Note: BF16 support on AMD but we have to exclude here cuda_bf16.h (which turn to
13+
// <hip/hip_bfloat16.h> after hipifying), because this header is pulled into .cpp translation units
14+
// that are compiled by a host-only compiler, which triggers build errors. Added forward declaration
15+
// instead, see code block below
16+
#if defined(BF16_AVAILABLE)
17+
#if !defined(__HIP_PLATFORM_AMD__)
1418
#include <cuda_bf16.h>
19+
#else
20+
struct __hip_bfloat16;
21+
#endif
1522
#endif
23+
1624
#include <cuda_runtime_api.h>
1725
#include <stdio.h>
1826

csrc/includes/conversion_utils.h

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,42 +363,74 @@ DS_D_INLINE __nv_bfloat16 to(float val)
363363
template <>
364364
DS_D_INLINE __nv_bfloat16 to(int64_t val)
365365
{
366+
#ifdef __HIP_PLATFORM_AMD__
367+
return __double2bfloat16(__ll2double_rn(val));
368+
#else
366369
return __ll2bfloat16_rn(val);
370+
#endif
367371
}
368372
template <>
369373
DS_D_INLINE __nv_bfloat16 to(int32_t val)
370374
{
375+
#ifdef __HIP_PLATFORM_AMD__
376+
return __float2bfloat16(__int2float_rn(val));
377+
#else
371378
return __int2bfloat16_rn(val);
379+
#endif
372380
}
373381
template <>
374382
DS_D_INLINE __nv_bfloat16 to(int16_t val)
375383
{
384+
#ifdef __HIP_PLATFORM_AMD__
385+
return __float2bfloat16(__int2float_rn(val));
386+
#else
376387
return __short2bfloat16_rn(val);
388+
#endif
377389
}
378390
template <>
379391
DS_D_INLINE __nv_bfloat16 to(int8_t val)
380392
{
393+
#ifdef __HIP_PLATFORM_AMD__
394+
return __float2bfloat16(__int2float_rn(val));
395+
#else
381396
return __int2bfloat16_rn(val);
397+
#endif
382398
}
383399
template <>
384400
DS_D_INLINE __nv_bfloat16 to(uint64_t val)
385401
{
402+
#ifdef __HIP_PLATFORM_AMD__
403+
return __double2bfloat16(__ull2double_rn(val));
404+
#else
386405
return __ull2bfloat16_rn(val);
406+
#endif
387407
}
388408
template <>
389409
DS_D_INLINE __nv_bfloat16 to(uint32_t val)
390410
{
411+
#ifdef __HIP_PLATFORM_AMD__
412+
return __float2bfloat16(__uint2float_rn(val));
413+
#else
391414
return __uint2bfloat16_rn(val);
415+
#endif
392416
}
393417
template <>
394418
DS_D_INLINE __nv_bfloat16 to(uint16_t val)
395419
{
420+
#ifdef __HIP_PLATFORM_AMD__
421+
return __float2bfloat16(__uint2float_rn(val));
422+
#else
396423
return __ushort2bfloat16_rn(val);
424+
#endif
397425
}
398426
template <>
399427
DS_D_INLINE __nv_bfloat16 to(uint8_t val)
400428
{
429+
#ifdef __HIP_PLATFORM_AMD__
430+
return __float2bfloat16(__uint2float_rn(val));
431+
#else
401432
return __uint2bfloat16_rn(val);
433+
#endif
402434
}
403435
#endif
404436

@@ -412,7 +444,11 @@ DS_D_INLINE __nv_bfloat162 to(float2 val)
412444
template <>
413445
DS_D_INLINE __nv_bfloat162 to(float val)
414446
{
447+
#ifdef __HIP_PLATFORM_AMD__
448+
return __bfloat162bfloat162(__float2bfloat16(val));
449+
#else
415450
return __float2bfloat162_rn(val);
451+
#endif
416452
}
417453
template <>
418454
DS_D_INLINE __nv_bfloat162 to(__half2 val)
@@ -444,7 +480,11 @@ DS_D_INLINE int64_t to(__half val)
444480
template <>
445481
DS_D_INLINE int64_t to(__nv_bfloat16 val)
446482
{
483+
#ifdef __HIP_PLATFORM_AMD__
484+
return __float2ll_rn(__bfloat162float(val));
485+
#else
447486
return __bfloat162ll_rn(val);
487+
#endif
448488
}
449489
#endif
450490

@@ -471,7 +511,11 @@ DS_D_INLINE int32_t to(__half val)
471511
template <>
472512
DS_D_INLINE int32_t to(__nv_bfloat16 val)
473513
{
514+
#ifdef __HIP_PLATFORM_AMD__
515+
return __float2int_rn(__bfloat162float(val));
516+
#else
474517
return __bfloat162int_rn(val);
518+
#endif
475519
}
476520
#endif
477521

@@ -498,7 +542,11 @@ DS_D_INLINE int16_t to(__half val)
498542
template <>
499543
DS_D_INLINE int16_t to(__nv_bfloat16 val)
500544
{
545+
#ifdef __HIP_PLATFORM_AMD__
546+
return __float2int_rn(__bfloat162float(val));
547+
#else
501548
return __bfloat162int_rn(val);
549+
#endif
502550
}
503551
#endif
504552

@@ -525,7 +573,11 @@ DS_D_INLINE int8_t to(__half val)
525573
template <>
526574
DS_D_INLINE int8_t to(__nv_bfloat16 val)
527575
{
576+
#ifdef __HIP_PLATFORM_AMD__
577+
return __float2int_rn(__bfloat162float(val));
578+
#else
528579
return __bfloat162int_rn(val);
580+
#endif
529581
}
530582
#endif
531583

@@ -552,7 +604,11 @@ DS_D_INLINE uint64_t to(__half val)
552604
template <>
553605
DS_D_INLINE uint64_t to(__nv_bfloat16 val)
554606
{
607+
#ifdef __HIP_PLATFORM_AMD__
608+
return __float2ull_rn(__bfloat162float(val));
609+
#else
555610
return __bfloat162ull_rn(val);
611+
#endif
556612
}
557613
#endif
558614

@@ -579,7 +635,11 @@ DS_D_INLINE uint32_t to(__half val)
579635
template <>
580636
DS_D_INLINE uint32_t to(__nv_bfloat16 val)
581637
{
638+
#ifdef __HIP_PLATFORM_AMD__
639+
return __float2uint_rn(__bfloat162float(val));
640+
#else
582641
return __bfloat162uint_rn(val);
642+
#endif
583643
}
584644
#endif
585645

@@ -606,7 +666,11 @@ DS_D_INLINE uint16_t to(__half val)
606666
template <>
607667
DS_D_INLINE uint16_t to(__nv_bfloat16 val)
608668
{
669+
#ifdef __HIP_PLATFORM_AMD__
670+
return __float2uint_rn(__bfloat162float(val));
671+
#else
609672
return __bfloat162uint_rn(val);
673+
#endif
610674
}
611675
#endif
612676

@@ -633,7 +697,11 @@ DS_D_INLINE uint8_t to(__half val)
633697
template <>
634698
DS_D_INLINE uint8_t to(__nv_bfloat16 val)
635699
{
700+
#ifdef __HIP_PLATFORM_AMD__
701+
return __float2uint_rn(__bfloat162float(val));
702+
#else
636703
return __bfloat162uint_rn(val);
704+
#endif
637705
}
638706
#endif
639707

csrc/includes/ds_kernel_utils.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,21 @@ used throughout the codebase.
1313
#include <cuda.h>
1414
#include <cuda_fp16.h>
1515

16-
#ifdef BF16_AVAILABLE
16+
// Note: BF16 support on AMD but we have to exclude here cuda_bf16.h (which turn to
17+
// <hip/hip_bfloat16.h> after hipifying), because this header is pulled into .cpp translation units
18+
// that are compiled by a host-only compiler, which triggers build errors. Added forward declaration
19+
// instead, see code block below
20+
#if defined(BF16_AVAILABLE) && !defined(__HIP_PLATFORM_AMD__)
1721
#include <cuda_bf16.h>
1822
#endif
1923

2024
#define DS_HD_INLINE __host__ __device__ __forceinline__
2125
#define DS_D_INLINE __device__ __forceinline__
2226

2327
#ifdef __HIP_PLATFORM_AMD__
24-
28+
#if BF16_AVAILABLE
29+
struct __hip_bfloat16;
30+
#endif
2531
// constexpr variant of warpSize for templating
2632
constexpr int hw_warp_size = ROCM_WAVEFRONT_SIZE;
2733
#define HALF_PRECISION_AVAILABLE = 1

csrc/includes/reduction_utils.h

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
#include "ds_kernel_utils.h"
1010
#include "memory_access_utils.h"
1111

12+
#if defined(BF16_AVAILABLE) && defined(__HIP_PLATFORM_AMD__)
13+
#include <hip/hip_bfloat16.h>
14+
#endif
15+
1216
namespace cg = cooperative_groups;
1317

1418
namespace reduce {
@@ -374,7 +378,11 @@ DS_D_INLINE __half init<ROpType::Max>()
374378
template <>
375379
DS_D_INLINE __nv_bfloat16 init<ROpType::Max>()
376380
{
381+
#ifdef __HIP_PLATFORM_AMD__
382+
constexpr __hip_bfloat16_raw neg_inf = {0xFF80};
383+
#else
377384
constexpr __nv_bfloat16_raw neg_inf = {0xFF80};
385+
#endif
378386
return __nv_bfloat16(neg_inf);
379387
}
380388
#endif
@@ -573,6 +581,24 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
573581
}
574582
}
575583

584+
#if defined(__HIP_PLATFORM_AMD__)
585+
template <int reduce_width, typename T, ROpType... Ops>
586+
DS_D_INLINE void _warp_with_type_conversion(cg::thread_block_tile<hw_warp_size>& warp_arg, T* data)
587+
{
588+
constexpr int elems = sizeof...(Ops);
589+
if constexpr (!(std::is_integral<T>::value || std::is_floating_point<T>::value)) {
590+
float temp_data[elems];
591+
#pragma unroll
592+
for (int i = 0; i < elems; i++) { temp_data[i] = conversion::to<float>(data[i]); }
593+
_warp<float, Ops...>(warp_arg, temp_data);
594+
#pragma unroll
595+
for (int i = 0; i < elems; i++) { data[i] = conversion::to<T>(temp_data[i]); }
596+
} else {
597+
_warp<T, Ops...>(warp_arg, data);
598+
}
599+
}
600+
#endif // defined(__HIP_PLATFORM_AMD__)
601+
576602
/*
577603
Implementation for primary block reduction that serves both `block` and
578604
`partitioned_block`.
@@ -600,7 +626,11 @@ DS_D_INLINE void _block(cg::thread_block& tb,
600626
#endif
601627

602628
// Always perform warp-scope reduction
629+
#ifdef __HIP_PLATFORM_AMD__
630+
_warp_with_type_conversion<hw_warp_size, T, Ops...>(warp_arg, data);
631+
#else
603632
_warp<T, Ops...>(warp_arg, data);
633+
#endif
604634

605635
// If max_warps == 1 let's skip the runtime check
606636
if (total_warps != 1) {
@@ -624,8 +654,11 @@ DS_D_INLINE void _block(cg::thread_block& tb,
624654
} else {
625655
init<Ops...>(data);
626656
}
627-
657+
#ifdef __HIP_PLATFORM_AMD__
658+
_warp_with_type_conversion<total_warps, T, Ops...>(warp_arg, data);
659+
#else
628660
_warp<T, Ops..., total_warps>(warp_arg, data);
661+
#endif
629662

630663
#pragma unroll
631664
for (int i = 0; i < elems; i++) {

op_builder/builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,7 @@ def nvcc_args(self):
779779
'-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR,
780780
'-DROCM_VERSION_MINOR=%s' % ROCM_MINOR
781781
]
782+
self.enable_bf16 = True
782783
else:
783784
try:
784785
nvcc_threads = int(os.getenv("DS_NVCC_THREADS", ""))

op_builder/transformer_inference.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,14 @@ def extra_ldflags(self):
7575

7676
def include_paths(self):
7777
return ['csrc/transformer/inference/includes', 'csrc/includes']
78+
79+
def nvcc_args(self):
80+
args = super().nvcc_args()
81+
"""BF16 is supported on AMD, but including `cuda_bf16.h` (`<hip/hip_bfloat16.h>` after hipification)
82+
in host-only translation units (*.cpp files) fails because GPU-specific builtins are pulled in with the BF16 type.
83+
This cannot be avoided via forward declarations for this transformer_inference extension,
84+
since `pt_binding.cpp` code explicitly requires the BF16 header, so disable it for now.
85+
"""
86+
if self.is_rocm_pytorch():
87+
self.enable_bf16 = False
88+
return args

tests/unit/ops/fp_quantizer/test_fp_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_fp_quant_meta(dtype):
5757

5858
qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=group_size)
5959
qtorch_error = (qtorch_out - x).abs().sum() / x.numel()
60-
ds_error = (x_dequantized - x).abs().sum() / x.numel()
60+
ds_error = (x_dequantized - ds_x).abs().sum() / x.numel()
6161

6262
assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}"
6363

@@ -129,6 +129,6 @@ def test_fp_quant(dtype, q_bits):
129129
qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=quant_config.group_size)
130130

131131
qtorch_error = (qtorch_out - x).abs().sum() / x.numel()
132-
ds_error = (x_dequantized - x).abs().sum() / x.numel()
132+
ds_error = (x_dequantized - ds_x).abs().sum() / x.numel()
133133

134134
assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}"

0 commit comments

Comments
 (0)