diff --git a/.clang-format b/.clang-format index 131e4e6fc..6f3f6fc45 100644 --- a/.clang-format +++ b/.clang-format @@ -22,6 +22,7 @@ StatementMacros: - 'MAKE_OptimizerStatic8bit1StateBlockwise' - 'MAKE_OptimizerStatic8bit2StateBlockwise' - 'MAKE_kQuantizeBlockwise' + - 'MAKE_kQuantizeBlockwiseSmall' - 'MAKE_BLOCKWISE8' - 'MAKE_ELEMENTWISE_FUNC' - 'CMAKE_ELEMENTWISE_FUNC' diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index ab0ffc309..7799645db 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -8,7 +8,7 @@ from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr from ..._ops import register_kernel -from ...cextension import ROCM_WARP_SIZE_64, lib +from ...cextension import lib @register_kernel("bitsandbytes::int8_linear_matmul", "cuda") @@ -212,10 +212,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor A = A.contiguous() torch._check_is_size(blocksize) - if ROCM_WARP_SIZE_64: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") @@ -271,10 +268,7 @@ def _dequantize_blockwise_impl( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ) -> None: A = A.contiguous() - if ROCM_WARP_SIZE_64: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check( @@ -306,10 +300,7 @@ def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: A = A.contiguous() - if ROCM_WARP_SIZE_64: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) torch._check(quant_type in ["fp4", "nf4"]) torch._check( @@ -389,10 +380,7 @@ def _dequantize_4bit_impl( out: torch.Tensor, ) -> None: A = A.contiguous() - if ROCM_WARP_SIZE_64: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) torch._check(quant_type in ["fp4", "nf4"]) torch._check( diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index c3cec7281..373a91875 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -14,7 +14,6 @@ get_cuda_specs, get_cuda_version_tuple, get_rocm_gpu_arch, - get_rocm_warpsize, ) logger = logging.getLogger(__name__) @@ -317,7 +316,6 @@ def get_native_library() -> BNBNativeLibrary: ROCM_GPU_ARCH = get_rocm_gpu_arch() -ROCM_WARP_SIZE_64 = True if get_rocm_warpsize() == 64 else False HIP_ENVIRONMENT = False BNB_BACKEND = "CPU" diff --git a/csrc/kernels.cu b/csrc/kernels.cu index cff242316..0d313c8d7 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -374,69 +374,72 @@ __global__ void kQuantizeBlockwise( } } -// Unified small-blocksize kernel for 4-bit quantization -// Processes 2 blocks of BNB_WARP_SIZE values per thread block -// On CUDA (warp=32): blocksize=32, 32 threads, WarpReduce<16> -// On HIP (warp=64): blocksize=64, 64 threads, WarpReduce<32> -// On HIP (warp=32): blocksize=32, 32 threads, WarpReduce<16> -template +// Small-blocksize kernel for 4-bit quantization, parameterized on quantization +// block size (QBLOCK_SIZE). Always launches exactly BNB_WARP_SIZE threads so +// every lane in the wavefront is productive. Multiple quantization blocks are +// packed into one wavefront when QBLOCK_SIZE < BNB_WARP_SIZE * NUM_PER_TH: +// +// CDNA (64), QBLOCK_SIZE=32 -> 4 quant blocks per wavefront +// CDNA (64), QBLOCK_SIZE=64 -> 2 quant blocks per wavefront +// CUDA/RDNA (32), QBLOCK_SIZE=32 -> 2 quant blocks per wavefront +// +// Uses logical-warp WarpReduce so each quantization block's +// threads reduce independently via warp shuffles. +template __global__ void kQuantizeBlockwiseSmall( float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, const int rand_offset, const int n ) { - constexpr int BLOCK_SIZE = BNB_WARP_SIZE; // Size of each quantization block - constexpr int NUM_PER_TH = 2; // Values per thread (for 4-bit packing) - constexpr int THREADS = BNB_WARP_SIZE; // Total threads (one full warp) - constexpr int THREADS_PER_BLOCK = BNB_WARP_SIZE / 2; // Half-warp per quantization block + static_assert(QBLOCK_SIZE <= BNB_WARP_SIZE * 2, "QBLOCK_SIZE too large for one warp"); - const int base_idx = blockIdx.x * BLOCK_SIZE * 2; // 2 blocks per thread block + constexpr int NUM_PER_TH = 2; + constexpr int THREADS = BNB_WARP_SIZE; + constexpr int THREADS_PER_QB = QBLOCK_SIZE / NUM_PER_TH; + constexpr int NUM_QB = THREADS / THREADS_PER_QB; + constexpr int TOTAL_VALUES = QBLOCK_SIZE * NUM_QB; + + const int base_idx = blockIdx.x * TOTAL_VALUES; T vals[NUM_PER_TH]; - unsigned char qvals[NUM_PER_TH / 2]; // For 4-bit: 2 values per byte + unsigned char qvals[NUM_PER_TH / 2]; float local_abs_max = 0.0f; - const int block_id = threadIdx.x / THREADS_PER_BLOCK; // 0 for threads 0-15, 1 for threads 16-31 - const int local_thread_id = threadIdx.x % THREADS_PER_BLOCK; // Thread ID within the block (0-15) + const int qb_id = threadIdx.x / THREADS_PER_QB; + const int local_tid = threadIdx.x % THREADS_PER_QB; typedef bnb_cub::BlockLoad LoadT; typedef bnb_cub::BlockStore StoreChar; - typedef bnb_cub::WarpReduce - WarpReduce; // Half-warp logical reduction: each half reduces independently + typedef bnb_cub::WarpReduce WarpReduce; __shared__ typename LoadT::TempStorage loadt; __shared__ typename StoreChar::TempStorage storec; - __shared__ typename WarpReduce::TempStorage warp_reduce[2]; // One per logical warp - __shared__ float smem_absmax_value[2]; + __shared__ typename WarpReduce::TempStorage warp_reduce[NUM_QB]; + __shared__ float smem_absmax_value[NUM_QB]; - const int i = base_idx + block_id * BLOCK_SIZE; - // Use a flag instead of early return: BlockLoad/BlockStore/__syncthreads are cooperative - // operations that require ALL 32 threads to participate - const bool block_valid = (i < n); + const int qi = base_idx + qb_id * QBLOCK_SIZE; + const bool qb_valid = (qi < n); - // All 32 threads participate in the load (out-of-bounds threads get 0.0f) __syncthreads(); - LoadT(loadt).Load(&(A[base_idx]), vals, min(BLOCK_SIZE * 2, n - base_idx), (T)0.0f); + LoadT(loadt).Load(&(A[base_idx]), vals, min(TOTAL_VALUES, n - base_idx), (T)0.0f); - // Each thread computes max of its values local_abs_max = -FLT_MAX; #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); - // Reduce within each logical warp of 16 threads independently - local_abs_max = WarpReduce(warp_reduce[block_id]).Reduce(local_abs_max, BNB_MAX_OP); + local_abs_max = WarpReduce(warp_reduce[qb_id]).Reduce(local_abs_max, BNB_MAX_OP); - if (local_thread_id == 0) { - if (block_valid) { - smem_absmax_value[block_id] = 1.0f / local_abs_max; - absmax[blockIdx.x * 2 + block_id] = local_abs_max; + if (local_tid == 0) { + if (qb_valid) { + smem_absmax_value[qb_id] = 1.0f / local_abs_max; + absmax[blockIdx.x * NUM_QB + qb_id] = local_abs_max; } else { - smem_absmax_value[block_id] = 0.0f; + smem_absmax_value[qb_id] = 0.0f; } } __syncthreads(); - local_abs_max = smem_absmax_value[block_id]; + local_abs_max = smem_absmax_value[qb_id]; switch (DATA_TYPE) { case FP4: @@ -455,9 +458,8 @@ __global__ void kQuantizeBlockwiseSmall( break; } - // All 32 threads participate in the store (valid_items limits the actual writes) __syncthreads(); - StoreChar(storec).Store(&(out[base_idx / 2]), qvals, min((BLOCK_SIZE * 2 + 1) / 2, (n - base_idx + 1) / 2)); + StoreChar(storec).Store(&(out[base_idx / 2]), qvals, min((TOTAL_VALUES + 1) / 2, (n - base_idx + 1) / 2)); } template @@ -1446,15 +1448,15 @@ __global__ void kgemm_4bit_inference_naive( ) { // per threadblock: - // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] - // 4 warps -> 4 loads per iter - // 1x32 * 32x4 -> 1x4 outputs per thread block + // load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps] + // THREADS/BNB_WARP_SIZE warps -> that many loads per iter + // 1xwarp_size * warp_size x warps -> 1 x warps outputs per thread block typedef bnb_cub::WarpReduce WarpReduce; - __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32]; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / BNB_WARP_SIZE]; - const int warp_idx = threadIdx.x / 32; - const int warp_lane = threadIdx.x % 32; - const int row_B = (THREADS / 32) * blockIdx.x + warp_idx; + const int warp_idx = threadIdx.x / BNB_WARP_SIZE; + const int warp_lane = threadIdx.x % BNB_WARP_SIZE; + const int row_B = (THREADS / BNB_WARP_SIZE) * blockIdx.x + warp_idx; const int offset_B = ldb * row_B; const int num_values_8bit = num_values_4bit / 2; float local_C = 0.0f; @@ -1473,7 +1475,7 @@ __global__ void kgemm_4bit_inference_naive( // A: [1, K] // B: [N, K] - for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) { + for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE * num_values_4bit) { const int inner_idx_halved = inner_idx / 2; // Since blocksize will always be a power-of-2, we avoid more expensive @@ -1766,26 +1768,32 @@ MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(bnb_bfloat16, 128, 2, 0, NF4) MAKE_kQuantizeBlockwise(bnb_bfloat16, 64, 2, 0, NF4) -// Template instantiations for blocksize=32 specialized kernel (4-bit only) -#define MAKE_kQuantizeBlockwiseSmall(dtype, data_type_name) \ - template __global__ void kQuantizeBlockwiseSmall( \ +// Template instantiations for kQuantizeBlockwiseSmall (4-bit only) +#define MAKE_kQuantizeBlockwiseSmall(dtype, qblock_size, data_type_name) \ + template __global__ void kQuantizeBlockwiseSmall( \ float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \ const int rand_offset, const int n \ ); -// FP4 instantiations for blocksize=32 -MAKE_kQuantizeBlockwiseSmall(half, FP4) MAKE_kQuantizeBlockwiseSmall(float, FP4) MAKE_kQuantizeBlockwiseSmall( - bnb_bfloat16, FP4 -) - - // NF4 instantiations for blocksize=32 - MAKE_kQuantizeBlockwiseSmall(half, NF4) MAKE_kQuantizeBlockwiseSmall(float, NF4) MAKE_kQuantizeBlockwiseSmall( - bnb_bfloat16, NF4 - ) - - template __global__ void kDequantizeBlockwise( - float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n - ); +// QBLOCK_SIZE=32 instantiations +MAKE_kQuantizeBlockwiseSmall(half, 32, FP4) +MAKE_kQuantizeBlockwiseSmall(float, 32, FP4) +MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 32, FP4) +MAKE_kQuantizeBlockwiseSmall(half, 32, NF4) +MAKE_kQuantizeBlockwiseSmall(float, 32, NF4) +MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 32, NF4) + +// QBLOCK_SIZE=64 instantiations (blocksize=64, 4-bit) +MAKE_kQuantizeBlockwiseSmall(half, 64, FP4) +MAKE_kQuantizeBlockwiseSmall(float, 64, FP4) +MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 64, FP4) +MAKE_kQuantizeBlockwiseSmall(half, 64, NF4) +MAKE_kQuantizeBlockwiseSmall(float, 64, NF4) +MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 64, NF4) + +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n +); template __global__ void kDequantizeBlockwise( float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n ); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 6ac6732fc..dc511661b 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -14,7 +14,7 @@ __global__ void kQuantizeBlockwise( float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, const int rand_offset, const int n ); -template +template __global__ void kQuantizeBlockwiseSmall( float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, const int rand_offset, const int n diff --git a/csrc/ops.cu b/csrc/ops.cu index c1f8e65bc..e76834785 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -54,30 +54,20 @@ void quantizeBlockwise( else if (blocksize == 128) kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if (blocksize == 64) { -#if BNB_HIP if constexpr (DATA_TYPE > 0) { - if (bnb_host_warp_size() == 64) { - // CDNA: kQuantizeBlockwiseSmall is compiled with THREADS=64 - kQuantizeBlockwiseSmall - <<<(num_blocks + 1) / 2, 64>>>(code, A, absmax, out, rand, rand_offset, n); - } else { - // RDNA: standard kernel (same as CUDA path) - kQuantizeBlockwise - <<>>(code, A, absmax, out, rand, rand_offset, n); - } + const int ws = bnb_host_warp_size(); + const int num_qb = ws / (64 / 2); + int grid = (num_blocks + num_qb - 1) / num_qb; + kQuantizeBlockwiseSmall<<>>(code, A, absmax, out, rand, rand_offset, n); } else { kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); } -#else - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); -#endif } else if (blocksize == 32) { - // For 4-bit: use specialized kernel that processes 2 blocks per warp - // Each CUDA block handles 2 quantization blocks, so divide num_blocks by 2 if constexpr (DATA_TYPE > 0) { - int num_blocks_adjusted = (num_blocks + 1) / 2; - kQuantizeBlockwiseSmall - <<>>(code, A, absmax, out, rand, rand_offset, n); + const int ws = bnb_host_warp_size(); + const int num_qb = ws / (32 / 2); + int grid = (num_blocks + num_qb - 1) / num_qb; + kQuantizeBlockwiseSmall<<>>(code, A, absmax, out, rand, rand_offset, n); } } diff --git a/tests/test_functional.py b/tests/test_functional.py index dfa8ceb75..5c8af1a05 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -10,7 +10,6 @@ import bitsandbytes as bnb from bitsandbytes import functional as F -from bitsandbytes.cextension import ROCM_WARP_SIZE_64 from tests.helpers import ( BOOLEAN_TUPLES, TRUE_FALSE, @@ -96,7 +95,7 @@ class Test8BitBlockwiseQuantizeFunctional: @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) @pytest.mark.parametrize( "blocksize", - [4096, 2048, 1024, 512, 256, 128, 64] if not ROCM_WARP_SIZE_64 else [4096, 2048, 1024, 512, 256, 128], + [4096, 2048, 1024, 512, 256, 128, 64], ) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): @@ -509,7 +508,6 @@ def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim): @pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim")) @pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim")) @pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) - @pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet") def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose): def min_max(x): maxA = torch.amax(x, dim=2, keepdim=True) @@ -844,7 +842,7 @@ class TestQuantize4BitFunctional: @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize( "blocksize", - [32, 64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [64, 128, 256, 512, 1024, 2048, 4096], + [32, 64, 128, 256, 512, 1024, 2048, 4096], ) def test_4bit_quant(self, device, dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): @@ -927,9 +925,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize( - "blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128], ids=id_formatter("blocksize") - ) + @pytest.mark.parametrize("blocksize", [32, 64, 128], ids=id_formatter("blocksize")) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype) def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): @@ -966,9 +962,7 @@ def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype): @pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device") @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize( - "blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128], ids=id_formatter("blocksize") - ) + @pytest.mark.parametrize("blocksize", [32, 64, 128], ids=id_formatter("blocksize")) def test_4bit_quant_large(self, device, dtype, quant_type, blocksize): """ Test that we can successfully quantize a large tensor. Note that the following limitations apply: @@ -1028,9 +1022,6 @@ def test_bench_4bit_dequant(self, quant_type): # torch.cuda.synchronize() # print((time.time()-t0)/iters*1e6) - @pytest.mark.skipif( - ROCM_WARP_SIZE_64, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" - ) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") @pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) @@ -1185,7 +1176,6 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, double_quant, kind): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) - @pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet") def test_gemv_eye_4bit(self, device, storage_type, dtype): if device == "hpu" and not is_supported_on_hpu(storage_type, dtype): pytest.skip("This configuration is not supported on HPU.") diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 1585ea389..d43656b63 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -11,7 +11,6 @@ import torch import bitsandbytes as bnb -from bitsandbytes.cextension import ROCM_WARP_SIZE_64 from tests.helpers import ( TRUE_FALSE, describe_dtype, @@ -195,7 +194,7 @@ def test_linear_serialization( @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128]) +@pytest.mark.parametrize("blocksize", [32, 64, 128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_copy_param(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): @@ -286,7 +285,7 @@ def test_quant_storage_shard_roundtrip(device, quant_type, quant_storage): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128]) +@pytest.mark.parametrize("blocksize", [32, 64, 128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): @@ -315,7 +314,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128]) +@pytest.mark.parametrize("blocksize", [32, 64, 128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 2460099ae..dc4ff4741 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -10,7 +10,6 @@ import torch import bitsandbytes as bnb -from bitsandbytes.cextension import ROCM_WARP_SIZE_64 from bitsandbytes.nn.modules import Linear8bitLt from tests.helpers import ( TRUE_FALSE, @@ -238,7 +237,6 @@ def test_linear8bit_serialization(linear8bit): @pytest.mark.skipif( torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10" ) -@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet") def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): if device == "cuda" and platform.system() == "Windows": pytest.skip("Triton is not officially supported on Windows") diff --git a/tests/test_ops.py b/tests/test_ops.py index 005084c52..1dbeb0a53 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -4,7 +4,6 @@ import torch import bitsandbytes -from bitsandbytes.cextension import ROCM_WARP_SIZE_64 from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, is_supported_on_hpu # torch.library.opcheck is only available in torch 2.4 and later. @@ -102,7 +101,7 @@ def test_int8_scaled_mm(self, device, dtype, has_bias): class TestInt8BlockwiseQuantOps: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): if device == "cpu": if dtype != torch.float32: @@ -126,7 +125,7 @@ def test_quantize_blockwise(self, device, dtype, blocksize): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): if device == "cpu" and dtype != torch.float32: pytest.skip("CPU implementation is only available for float32") @@ -152,7 +151,7 @@ class Test4bitBlockwiseQuantOps: @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512]) def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.") @@ -202,7 +201,7 @@ def test_quantize_4bit_not_divisible_by_blocksize(self, device, dtype, quant_typ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512]) def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.") @@ -236,8 +235,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) - @pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet") + @pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512]) def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.") diff --git a/tests/test_parametrize.py b/tests/test_parametrize.py index 1055e16d3..50260123b 100644 --- a/tests/test_parametrize.py +++ b/tests/test_parametrize.py @@ -3,7 +3,6 @@ import torch.nn as nn from bitsandbytes import functional as F -from bitsandbytes.cextension import ROCM_WARP_SIZE_64 from bitsandbytes.nn.parametrize import ( Bnb4bitParametrization, replace_parameter_4bit, @@ -336,7 +335,7 @@ def test_multiple_parameters(device, dtype): @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize( "blocksize", - [64, 128, 256] if not ROCM_WARP_SIZE_64 else [128, 256], + [64, 128, 256], ) def test_different_blocksizes(device, dtype, blocksize): """Test parametrization with different block sizes to verify flexibility."""