From dee600c418f9a13544b478493da13231c85584eb Mon Sep 17 00:00:00 2001 From: sstamenk Date: Wed, 4 Mar 2026 12:07:13 +0000 Subject: [PATCH 1/8] Replace compile-time warp size with runtime query in host code Add bnb_host_warp_size() that queries hipDeviceGetAttribute at runtime with per-device caching (up to 32 GPUs), replacing the compile-time BNB_WARP_SIZE macro in host-side dispatch. This fixes incorrect defaulting to warp size 64 on RDNA and kernel dispatch with proper parameters. --- csrc/common.cuh | 9 +++++++-- csrc/ops.cu | 20 ++++++++++++++++++-- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/csrc/common.cuh b/csrc/common.cuh index 61bef3c27..cf92f93e0 100644 --- a/csrc/common.cuh +++ b/csrc/common.cuh @@ -7,11 +7,16 @@ // Warp size #if BNB_HIP -// CDNA (gfx9xx) = 64, RDNA = 32. +// CDNA (gfx9xx) = 64, RDNA (gfx10xx/gfx11xx/gfx12xx) = 32. +// __AMDGCN_WAVEFRONT_SIZE is not defined by all compiler versions (removed since ROCm 7.0), +// so fall back to architecture-family macros when it is absent. +// This is a macro that is defined by the compiler during each device-code pass and as such should only be used inside kernels. #ifdef __AMDGCN_WAVEFRONT_SIZE #define BNB_WARP_SIZE __AMDGCN_WAVEFRONT_SIZE +#elif defined(__GFX9__) +#define BNB_WARP_SIZE 64 // CDNA #else -#define BNB_WARP_SIZE 64 // Safe default for HIP (matches CDNA) +#define BNB_WARP_SIZE 32 // RDNA and other #endif #else #define BNB_WARP_SIZE 32 diff --git a/csrc/ops.cu b/csrc/ops.cu index ef13678e4..3262a607e 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -10,6 +10,23 @@ #define ERR_NOT_IMPLEMENTED 100 +#if BNB_HIP +#include +static int bnb_host_warp_size() { + constexpr int MAX_DEVICES = 32; + static int cache[MAX_DEVICES] = {}; + int dev; + (void)hipGetDevice(&dev); + if (dev < 0 || dev >= MAX_DEVICES) return 64; + if (cache[dev] == 0) + (void)hipDeviceGetAttribute(&cache[dev], hipDeviceAttributeWarpSize, dev); + return cache[dev]; +} +#else +static constexpr int bnb_host_warp_size() { return 32; } +#endif + + using std::cout; using std::endl; @@ -407,8 +424,7 @@ void gemm_4bit_inference_naive( int num_blocks = (m + 3) / 4; #if BNB_HIP - // On 64-wide warp architectures, each warp processes 2 rows instead of 4 - if (BNB_WARP_SIZE == 64) { + if (bnb_host_warp_size() == 64) { num_blocks = (m + 1) / 2; } #endif From 83892a5546ae9d1da378ce25fec53787a27f4be2 Mon Sep 17 00:00:00 2001 From: sstamenk Date: Wed, 4 Mar 2026 13:59:48 +0100 Subject: [PATCH 2/8] Fix kernel dispatching for RDNA --- csrc/ops.cu | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/csrc/ops.cu b/csrc/ops.cu index 3262a607e..a0662c627 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -52,10 +52,16 @@ void quantizeBlockwise( kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if (blocksize == 64) { #if BNB_HIP - // On HIP with 64-wide warps (CDNA), use specialized kernel for 4-bit types if constexpr (DATA_TYPE > 0) { - kQuantizeBlockwiseSmall - <<<(num_blocks + 1) / 2, 64>>>(code, A, absmax, out, rand, rand_offset, n); + if (bnb_host_warp_size() == 64) { + // CDNA: kQuantizeBlockwiseSmall is compiled with THREADS=64 + kQuantizeBlockwiseSmall + <<<(num_blocks + 1) / 2, 64>>>(code, A, absmax, out, rand, rand_offset, n); + } else { + // RDNA: standard kernel (same as CUDA path) + kQuantizeBlockwise + <<>>(code, A, absmax, out, rand, rand_offset, n); + } } else { kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); } From 6d51838d7f8ff5af93e0fb365b1ede3ea2882810 Mon Sep 17 00:00:00 2001 From: Strahinja Stamenkovic Date: Wed, 4 Mar 2026 17:55:07 +0100 Subject: [PATCH 3/8] Fix linting issues --- csrc/ops.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/ops.cu b/csrc/ops.cu index a0662c627..2dda3394a 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -17,7 +17,8 @@ static int bnb_host_warp_size() { static int cache[MAX_DEVICES] = {}; int dev; (void)hipGetDevice(&dev); - if (dev < 0 || dev >= MAX_DEVICES) return 64; + if (dev < 0 || dev >= MAX_DEVICES) + return 64; if (cache[dev] == 0) (void)hipDeviceGetAttribute(&cache[dev], hipDeviceAttributeWarpSize, dev); return cache[dev]; From 947f64efee4c3c2b4dddce0e8c211e7823f435c5 Mon Sep 17 00:00:00 2001 From: Strahinja Stamenkovic Date: Wed, 4 Mar 2026 19:17:12 +0100 Subject: [PATCH 4/8] Fix linting issues --- csrc/common.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/common.cuh b/csrc/common.cuh index cf92f93e0..03d934270 100644 --- a/csrc/common.cuh +++ b/csrc/common.cuh @@ -10,7 +10,8 @@ // CDNA (gfx9xx) = 64, RDNA (gfx10xx/gfx11xx/gfx12xx) = 32. // __AMDGCN_WAVEFRONT_SIZE is not defined by all compiler versions (removed since ROCm 7.0), // so fall back to architecture-family macros when it is absent. -// This is a macro that is defined by the compiler during each device-code pass and as such should only be used inside kernels. +// This is a macro that is defined by the compiler during each device-code pass and as such +// should only be used inside kernels. #ifdef __AMDGCN_WAVEFRONT_SIZE #define BNB_WARP_SIZE __AMDGCN_WAVEFRONT_SIZE #elif defined(__GFX9__) From 0c94504914207c0572c1a9022bcbd6c059fae23f Mon Sep 17 00:00:00 2001 From: Strahinja Stamenkovic Date: Wed, 4 Mar 2026 19:17:46 +0100 Subject: [PATCH 5/8] Fix linting issues --- csrc/ops.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/ops.cu b/csrc/ops.cu index 2dda3394a..bcde04ab6 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -12,6 +12,7 @@ #if BNB_HIP #include + static int bnb_host_warp_size() { constexpr int MAX_DEVICES = 32; static int cache[MAX_DEVICES] = {}; @@ -27,7 +28,6 @@ static int bnb_host_warp_size() { static constexpr int bnb_host_warp_size() { return 32; } #endif - using std::cout; using std::endl; From 72e4e39aa1e3b8eb262b2d483ac602fc10dfab41 Mon Sep 17 00:00:00 2001 From: sstamenk Date: Wed, 4 Mar 2026 22:18:04 +0100 Subject: [PATCH 6/8] Revert device array caching and instead only do device 0 --- csrc/ops.cu | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/csrc/ops.cu b/csrc/ops.cu index bcde04ab6..dc4638e6d 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -13,16 +13,14 @@ #if BNB_HIP #include +// NOTE: This queries device 0 once and caches the result. On mixed RDNA+CDNA +// systems (warp size 32 vs 64) this will return the wrong value for whichever +// device doesn't match device 0. static int bnb_host_warp_size() { - constexpr int MAX_DEVICES = 32; - static int cache[MAX_DEVICES] = {}; - int dev; - (void)hipGetDevice(&dev); - if (dev < 0 || dev >= MAX_DEVICES) - return 64; - if (cache[dev] == 0) - (void)hipDeviceGetAttribute(&cache[dev], hipDeviceAttributeWarpSize, dev); - return cache[dev]; + static int warp_size = 0; + if (warp_size == 0) + (void)hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, 0); + return warp_size; } #else static constexpr int bnb_host_warp_size() { return 32; } From fa09fb15259821842ec92d7e2d82a4368dfc7f90 Mon Sep 17 00:00:00 2001 From: sstamenk Date: Thu, 5 Mar 2026 13:36:12 +0100 Subject: [PATCH 7/8] Use atomics to avoid a race condition --- csrc/ops.cu | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/csrc/ops.cu b/csrc/ops.cu index dc4638e6d..2dfbbf086 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -12,15 +12,19 @@ #if BNB_HIP #include +#include // NOTE: This queries device 0 once and caches the result. On mixed RDNA+CDNA // systems (warp size 32 vs 64) this will return the wrong value for whichever // device doesn't match device 0. static int bnb_host_warp_size() { - static int warp_size = 0; - if (warp_size == 0) - (void)hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, 0); - return warp_size; + static std::atomic warp_size{0}; + int ws = warp_size.load(std::memory_order_relaxed); + if (ws == 0) { + (void)hipDeviceGetAttribute(&ws, hipDeviceAttributeWarpSize, 0); + warp_size.store(ws, std::memory_order_relaxed); + } + return ws; } #else static constexpr int bnb_host_warp_size() { return 32; } From 30bbc4f849cb99630eb91bf3868aee7e4ae901af Mon Sep 17 00:00:00 2001 From: Strahinja Stamenkovic Date: Thu, 5 Mar 2026 13:41:24 +0100 Subject: [PATCH 8/8] Fix linting issues --- csrc/ops.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/ops.cu b/csrc/ops.cu index 2dfbbf086..c1f8e65bc 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -11,8 +11,8 @@ #define ERR_NOT_IMPLEMENTED 100 #if BNB_HIP -#include #include +#include // NOTE: This queries device 0 once and caches the result. On mixed RDNA+CDNA // systems (warp size 32 vs 64) this will return the wrong value for whichever