Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ StatementMacros:
- 'MAKE_OptimizerStatic8bit1StateBlockwise'
- 'MAKE_OptimizerStatic8bit2StateBlockwise'
- 'MAKE_kQuantizeBlockwise'
- 'MAKE_kQuantizeBlockwiseSmall'
- 'MAKE_BLOCKWISE8'
- 'MAKE_ELEMENTWISE_FUNC'
- 'CMAKE_ELEMENTWISE_FUNC'
Expand Down
22 changes: 5 additions & 17 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr

from ..._ops import register_kernel
from ...cextension import ROCM_WARP_SIZE_64, lib
from ...cextension import lib


@register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
Expand Down Expand Up @@ -212,10 +212,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
A = A.contiguous()
torch._check_is_size(blocksize)

if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")

Expand Down Expand Up @@ -271,10 +268,7 @@ def _dequantize_blockwise_impl(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
) -> None:
A = A.contiguous()
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(
Expand Down Expand Up @@ -306,10 +300,7 @@ def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
A = A.contiguous()
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

torch._check(quant_type in ["fp4", "nf4"])
torch._check(
Expand Down Expand Up @@ -389,10 +380,7 @@ def _dequantize_4bit_impl(
out: torch.Tensor,
) -> None:
A = A.contiguous()
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

torch._check(quant_type in ["fp4", "nf4"])
torch._check(
Expand Down
2 changes: 0 additions & 2 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
get_cuda_specs,
get_cuda_version_tuple,
get_rocm_gpu_arch,
get_rocm_warpsize,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -317,7 +316,6 @@ def get_native_library() -> BNBNativeLibrary:


ROCM_GPU_ARCH = get_rocm_gpu_arch()
ROCM_WARP_SIZE_64 = True if get_rocm_warpsize() == 64 else False

HIP_ENVIRONMENT = False
BNB_BACKEND = "CPU"
Expand Down
126 changes: 67 additions & 59 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -374,69 +374,72 @@ __global__ void kQuantizeBlockwise(
}
}

// Unified small-blocksize kernel for 4-bit quantization
// Processes 2 blocks of BNB_WARP_SIZE values per thread block
// On CUDA (warp=32): blocksize=32, 32 threads, WarpReduce<16>
// On HIP (warp=64): blocksize=64, 64 threads, WarpReduce<32>
// On HIP (warp=32): blocksize=32, 32 threads, WarpReduce<16>
template <typename T, int DATA_TYPE>
// Small-blocksize kernel for 4-bit quantization, parameterized on quantization
// block size (QBLOCK_SIZE). Always launches exactly BNB_WARP_SIZE threads so
// every lane in the wavefront is productive. Multiple quantization blocks are
// packed into one wavefront when QBLOCK_SIZE < BNB_WARP_SIZE * NUM_PER_TH:
//
// CDNA (64), QBLOCK_SIZE=32 -> 4 quant blocks per wavefront
// CDNA (64), QBLOCK_SIZE=64 -> 2 quant blocks per wavefront
// CUDA/RDNA (32), QBLOCK_SIZE=32 -> 2 quant blocks per wavefront
//
// Uses logical-warp WarpReduce<THREADS_PER_QB> so each quantization block's
// threads reduce independently via warp shuffles.
template <typename T, int QBLOCK_SIZE, int DATA_TYPE>
__global__ void kQuantizeBlockwiseSmall(
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
) {
constexpr int BLOCK_SIZE = BNB_WARP_SIZE; // Size of each quantization block
constexpr int NUM_PER_TH = 2; // Values per thread (for 4-bit packing)
constexpr int THREADS = BNB_WARP_SIZE; // Total threads (one full warp)
constexpr int THREADS_PER_BLOCK = BNB_WARP_SIZE / 2; // Half-warp per quantization block
static_assert(QBLOCK_SIZE <= BNB_WARP_SIZE * 2, "QBLOCK_SIZE too large for one warp");

const int base_idx = blockIdx.x * BLOCK_SIZE * 2; // 2 blocks per thread block
constexpr int NUM_PER_TH = 2;
constexpr int THREADS = BNB_WARP_SIZE;
constexpr int THREADS_PER_QB = QBLOCK_SIZE / NUM_PER_TH;
constexpr int NUM_QB = THREADS / THREADS_PER_QB;
constexpr int TOTAL_VALUES = QBLOCK_SIZE * NUM_QB;

const int base_idx = blockIdx.x * TOTAL_VALUES;

T vals[NUM_PER_TH];
unsigned char qvals[NUM_PER_TH / 2]; // For 4-bit: 2 values per byte
unsigned char qvals[NUM_PER_TH / 2];
float local_abs_max = 0.0f;

const int block_id = threadIdx.x / THREADS_PER_BLOCK; // 0 for threads 0-15, 1 for threads 16-31
const int local_thread_id = threadIdx.x % THREADS_PER_BLOCK; // Thread ID within the block (0-15)
const int qb_id = threadIdx.x / THREADS_PER_QB;
const int local_tid = threadIdx.x % THREADS_PER_QB;

typedef bnb_cub::BlockLoad<T, THREADS, NUM_PER_TH, bnb_cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef bnb_cub::BlockStore<unsigned char, THREADS, NUM_PER_TH / 2, bnb_cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef bnb_cub::WarpReduce<float, THREADS_PER_BLOCK>
WarpReduce; // Half-warp logical reduction: each half reduces independently
typedef bnb_cub::WarpReduce<float, THREADS_PER_QB> WarpReduce;

__shared__ typename LoadT::TempStorage loadt;
__shared__ typename StoreChar::TempStorage storec;
__shared__ typename WarpReduce::TempStorage warp_reduce[2]; // One per logical warp
__shared__ float smem_absmax_value[2];
__shared__ typename WarpReduce::TempStorage warp_reduce[NUM_QB];
__shared__ float smem_absmax_value[NUM_QB];

const int i = base_idx + block_id * BLOCK_SIZE;
// Use a flag instead of early return: BlockLoad/BlockStore/__syncthreads are cooperative
// operations that require ALL 32 threads to participate
const bool block_valid = (i < n);
const int qi = base_idx + qb_id * QBLOCK_SIZE;
const bool qb_valid = (qi < n);

// All 32 threads participate in the load (out-of-bounds threads get 0.0f)
__syncthreads();
LoadT(loadt).Load(&(A[base_idx]), vals, min(BLOCK_SIZE * 2, n - base_idx), (T)0.0f);
LoadT(loadt).Load(&(A[base_idx]), vals, min(TOTAL_VALUES, n - base_idx), (T)0.0f);

// Each thread computes max of its values
local_abs_max = -FLT_MAX;
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++)
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));

// Reduce within each logical warp of 16 threads independently
local_abs_max = WarpReduce(warp_reduce[block_id]).Reduce(local_abs_max, BNB_MAX_OP);
local_abs_max = WarpReduce(warp_reduce[qb_id]).Reduce(local_abs_max, BNB_MAX_OP);

if (local_thread_id == 0) {
if (block_valid) {
smem_absmax_value[block_id] = 1.0f / local_abs_max;
absmax[blockIdx.x * 2 + block_id] = local_abs_max;
if (local_tid == 0) {
if (qb_valid) {
smem_absmax_value[qb_id] = 1.0f / local_abs_max;
absmax[blockIdx.x * NUM_QB + qb_id] = local_abs_max;
} else {
smem_absmax_value[block_id] = 0.0f;
smem_absmax_value[qb_id] = 0.0f;
}
}
__syncthreads();

local_abs_max = smem_absmax_value[block_id];
local_abs_max = smem_absmax_value[qb_id];

switch (DATA_TYPE) {
case FP4:
Expand All @@ -455,9 +458,8 @@ __global__ void kQuantizeBlockwiseSmall(
break;
}

// All 32 threads participate in the store (valid_items limits the actual writes)
__syncthreads();
StoreChar(storec).Store(&(out[base_idx / 2]), qvals, min((BLOCK_SIZE * 2 + 1) / 2, (n - base_idx + 1) / 2));
StoreChar(storec).Store(&(out[base_idx / 2]), qvals, min((TOTAL_VALUES + 1) / 2, (n - base_idx + 1) / 2));
}

template <typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
Expand Down Expand Up @@ -1446,15 +1448,15 @@ __global__ void kgemm_4bit_inference_naive(
) {

// per threadblock:
// load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps]
// 4 warps -> 4 loads per iter
// 1x32 * 32x4 -> 1x4 outputs per thread block
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
// THREADS/BNB_WARP_SIZE warps -> that many loads per iter
// 1xwarp_size * warp_size x warps -> 1 x warps outputs per thread block
typedef bnb_cub::WarpReduce<float> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32];
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS / BNB_WARP_SIZE];

const int warp_idx = threadIdx.x / 32;
const int warp_lane = threadIdx.x % 32;
const int row_B = (THREADS / 32) * blockIdx.x + warp_idx;
const int warp_idx = threadIdx.x / BNB_WARP_SIZE;
const int warp_lane = threadIdx.x % BNB_WARP_SIZE;
const int row_B = (THREADS / BNB_WARP_SIZE) * blockIdx.x + warp_idx;
const int offset_B = ldb * row_B;
const int num_values_8bit = num_values_4bit / 2;
float local_C = 0.0f;
Expand All @@ -1473,7 +1475,7 @@ __global__ void kgemm_4bit_inference_naive(

// A: [1, K]
// B: [N, K]
for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) {
for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE * num_values_4bit) {
const int inner_idx_halved = inner_idx / 2;

// Since blocksize will always be a power-of-2, we avoid more expensive
Expand Down Expand Up @@ -1766,26 +1768,32 @@ MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(bnb_bfloat16, 128, 2, 0, NF4)
MAKE_kQuantizeBlockwise(bnb_bfloat16, 64, 2, 0, NF4)

// Template instantiations for blocksize=32 specialized kernel (4-bit only)
#define MAKE_kQuantizeBlockwiseSmall(dtype, data_type_name) \
template __global__ void kQuantizeBlockwiseSmall<dtype, data_type_name>( \
// Template instantiations for kQuantizeBlockwiseSmall (4-bit only)
#define MAKE_kQuantizeBlockwiseSmall(dtype, qblock_size, data_type_name) \
template __global__ void kQuantizeBlockwiseSmall<dtype, qblock_size, data_type_name>( \
float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \
const int rand_offset, const int n \
);

// FP4 instantiations for blocksize=32
MAKE_kQuantizeBlockwiseSmall(half, FP4) MAKE_kQuantizeBlockwiseSmall(float, FP4) MAKE_kQuantizeBlockwiseSmall(
bnb_bfloat16, FP4
)

// NF4 instantiations for blocksize=32
MAKE_kQuantizeBlockwiseSmall(half, NF4) MAKE_kQuantizeBlockwiseSmall(float, NF4) MAKE_kQuantizeBlockwiseSmall(
bnb_bfloat16, NF4
)

template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(
float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
);
// QBLOCK_SIZE=32 instantiations
MAKE_kQuantizeBlockwiseSmall(half, 32, FP4)
MAKE_kQuantizeBlockwiseSmall(float, 32, FP4)
MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 32, FP4)
MAKE_kQuantizeBlockwiseSmall(half, 32, NF4)
MAKE_kQuantizeBlockwiseSmall(float, 32, NF4)
MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 32, NF4)

// QBLOCK_SIZE=64 instantiations (blocksize=64, 4-bit)
MAKE_kQuantizeBlockwiseSmall(half, 64, FP4)
MAKE_kQuantizeBlockwiseSmall(float, 64, FP4)
MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 64, FP4)
MAKE_kQuantizeBlockwiseSmall(half, 64, NF4)
MAKE_kQuantizeBlockwiseSmall(float, 64, NF4)
MAKE_kQuantizeBlockwiseSmall(bnb_bfloat16, 64, NF4)

template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(
float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(
float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
);
Expand Down
2 changes: 1 addition & 1 deletion csrc/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ __global__ void kQuantizeBlockwise(
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
);
template <typename T, int DATA_TYPE>
template <typename T, int QBLOCK_SIZE, int DATA_TYPE>
__global__ void kQuantizeBlockwiseSmall(
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
Expand Down
26 changes: 8 additions & 18 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,30 +54,20 @@ void quantizeBlockwise(
else if (blocksize == 128)
kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
else if (blocksize == 64) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would mean that for (CUDA , bs = 64, 4-bit dtypes), we would use the kQuantizeBlockwiseSmall, while it should use directly kQuantizeBlockwise.
IMO, on CUDA, blocksize=64 should still use the regular kQuantizeBlockwise kernel. Wdyt @matthewdouglas

Copy link
Contributor Author

@sstamenk sstamenk Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I did this intentionally to minimize the kernel selection logic since I believe there shouldn't be much of a performance impact between the 2 but I haven't tested this explicitly. Alternatively, for bs = 64 we can use kQuantizeBlockwiseSmall for CDNA (ws = 64) and the standard kQuantizeBlockwise on CUDA/RDNA (ws = 32).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, me too! Your way of writing it is definitely better for code readiness. But I’m not sure how much impact there’d be going from kQuantizeBlockwise to kQuantizeBlockwiseSmall. I think it’s negligible, but I’m not entirely sure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do some benchmarking/profiling for the overhead between kQuantizeBlockwise and kQuantizeBlockwiseSmall for RDNA and CUDA with batch_size == 64 to confirm.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, can you do that please ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are quite the same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewdouglas It's basically does 200 iterations of quantize_4bit and captures the time taken with torch.cuda.Event

"""
Quick A/B benchmark for blocksize=64 quantization kernels.

    # 1. Build baseline branch
    git checkout fix/rocm-runtime-warp-size
    cmake -S . -B build -DCOMPUTE_BACKEND=hip -DBNB_ROCM_ARCH=gfx1201
    cmake --build build -j$(nproc) && pip install -e .

    # 2. Run benchmark, results saved to std_results.json
    python benchmarks/bench_quick.py --save std_results.json

    # 3. Build candidate branch
    git checkout fix/rocm-cdna-blocksize-32-64
    cmake -S . -B build -DCOMPUTE_BACKEND=hip -DBNB_ROCM_ARCH=gfx1201
    cmake --build build -j$(nproc) && pip install -e .

    # 4. Run benchmark, compare against baseline
    python benchmarks/bench_quick.py --save small_results.json --compare std_results.json
"""

import argparse
import json
import sys

import torch
from bitsandbytes.functional import quantize_4bit

SIZES = [1_048_576, 4_194_304, 16_777_216, 67_108_864]
QUANT_TYPES = ["nf4", "fp4"]
DTYPES = [torch.float16, torch.bfloat16, torch.float32]
DTYPE_NAMES = {torch.float16: "f16", torch.bfloat16: "bf16", torch.float32: "f32"}
BLOCKSIZE = 64
WARMUP = 50
ITERS = 200


def bench_one(n, dtype, quant_type, device):
    A = torch.randn(n, dtype=dtype, device=device)

    for _ in range(WARMUP):
        quantize_4bit(A, blocksize=BLOCKSIZE, quant_type=quant_type)
    torch.cuda.synchronize(device)

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record(stream=torch.cuda.current_stream(device))
    for _ in range(ITERS):
        quantize_4bit(A, blocksize=BLOCKSIZE, quant_type=quant_type)
    end.record(stream=torch.cuda.current_stream(device))
    torch.cuda.synchronize(device)

    time_us = start.elapsed_time(end) / ITERS * 1000

    num_absmax = (n + BLOCKSIZE - 1) // BLOCKSIZE
    total_bytes = n * A.element_size() + n // 2 + num_absmax * 4
    bw = (total_bytes / 1e9) / (time_us / 1e6)

    return time_us, bw


def print_header(device):
    props = torch.cuda.get_device_properties(device)
    arch = getattr(props, "gcnArchName", "N/A")
    ws = getattr(props, "warp_size", 32)
    print(f"Device: {props.name} ({arch}, warp_size={ws})")
    print(f"Config: blocksize={BLOCKSIZE}, warmup={WARMUP}, iters={ITERS}")
    print()


def run_all(device):
    results = {}
    total = len(QUANT_TYPES) * len(DTYPES) * len(SIZES)
    done = 0

    for qt in QUANT_TYPES:
        for dt in DTYPES:
            for n in SIZES:
                done += 1
                tag = f"{qt}_{DTYPE_NAMES[dt]}_{n}"
                time_us, bw = bench_one(n, dt, qt, device)
                results[tag] = {"time_us": round(time_us, 2), "bw_gbs": round(bw, 1)}
                print(f"  [{done}/{total}] {qt} {DTYPE_NAMES[dt]:>3s} N={n:>12,d}  {time_us:8.1f} us  {bw:7.1f} GB/s")

    return results


def print_comparison(base, cand):
    print()
    hdr = f"{'quant':>5s} {'dtype':>5s} {'N':>12s} | {'Base (us)':>10s} {'Cand (us)':>10s} {'Speedup':>8s} | {'Base GB/s':>10s} {'Cand GB/s':>10s}"
    print(hdr)
    print("-" * len(hdr))

    for qt in QUANT_TYPES:
        for dt in DTYPES:
            for n in SIZES:
                tag = f"{qt}_{DTYPE_NAMES[dt]}_{n}"
                b = base.get(tag)
                c = cand.get(tag)
                if not b or not c:
                    continue
                speedup = b["time_us"] / c["time_us"]
                marker = "**" if speedup >= 1.10 else ""
                print(
                    f"{qt:>5s} {DTYPE_NAMES[dt]:>5s} {n:>12,d} | "
                    f"{b['time_us']:>10.1f} {c['time_us']:>10.1f} {speedup:>7.2f}x{marker} | "
                    f"{b['bw_gbs']:>10.1f} {c['bw_gbs']:>10.1f}"
                )
        print()


def main():
    parser = argparse.ArgumentParser(description="Quick A/B benchmark for blocksize=64 quantization")
    parser.add_argument("--device", type=int, default=0, help="GPU device index (default 0; use HIP_VISIBLE_DEVICES to isolate)")
    parser.add_argument("--save", type=str, help="Save results to JSON file")
    parser.add_argument("--compare", type=str, help="Compare against a previously saved JSON (baseline)")
    args = parser.parse_args()

    device = torch.device(f"cuda:{args.device}")
    torch.cuda.set_device(device)
    print_header(device)

    print("Running benchmarks...")
    results = run_all(device)

    if args.save:
        with open(args.save, "w") as f:
            json.dump(results, f, indent=2)
        print(f"\nResults saved to {args.save}")

    if args.compare:
        with open(args.compare) as f:
            baseline = json.load(f)
        print_comparison(baseline, results)


if __name__ == "__main__":
    main()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Abdennacer-Badaoui Same here, on an RTX 5090 the results are within margin of error:

quant dtype            N |  Base (us)  Cand (us)  Speedup |  Base GB/s  Cand GB/s
---------------------------------------------------------------------------------
  nf4   f16    1,048,576 |       32.5       32.4    1.00x |       82.8       83.0
  nf4   f16    4,194,304 |       55.0       54.9    1.00x |      195.6      195.7
  nf4   f16   16,777,216 |      148.7      148.5    1.00x |      289.1      289.5
  nf4   f16   67,108,864 |      522.4      522.0    1.00x |      329.2      329.4
  nf4  bf16    1,048,576 |       32.3       32.1    1.01x |       83.1       83.6
  nf4  bf16    4,194,304 |       55.0       54.8    1.00x |      195.3      196.1
  nf4  bf16   16,777,216 |      149.0      148.5    1.00x |      288.5      289.4
  nf4  bf16   67,108,864 |      522.9      521.5    1.00x |      328.8      329.7
  nf4   f32    1,048,576 |       32.3       32.2    1.00x |      148.2      148.4
  nf4   f32    4,194,304 |       54.9       54.8    1.00x |      348.7      348.9
  nf4   f32   16,777,216 |      149.0      148.9    1.00x |      513.9      514.1
  nf4   f32   67,108,864 |      523.6      523.1    1.00x |      584.8      585.3

  fp4   f16    1,048,576 |       32.3       32.3    1.00x |       83.2       83.3
  fp4   f16    4,194,304 |       54.9       54.8    1.00x |      195.7      196.0
  fp4   f16   16,777,216 |      148.9      148.8    1.00x |      288.8      288.8
  fp4   f16   67,108,864 |      522.9      522.8    1.00x |      328.9      328.9
  fp4  bf16    1,048,576 |       32.1       32.1    1.00x |       83.6       83.8
  fp4  bf16    4,194,304 |       55.0       54.7    1.01x |      195.4      196.4
  fp4  bf16   16,777,216 |      148.8      149.0    1.00x |      288.9      288.5
  fp4  bf16   67,108,864 |      522.8      522.6    1.00x |      328.9      329.1
  fp4   f32    1,048,576 |       32.2       32.1    1.00x |      148.6      149.1
  fp4   f32    4,194,304 |       54.9       56.4    0.97x |      348.7      339.3
  fp4   f32   16,777,216 |      149.0      148.9    1.00x |      513.9      514.1
  fp4   f32   67,108,864 |      523.1      524.2    1.00x |      585.3      584.1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewdouglas Data with larger N:

Quantize

quant dtype N Time Base (us) Time Cand (us) BW Base (GB/s) BW Cand (GB/s) Speedup
nf4 f16 268,435,456 3034.3 2280.2 226.7 301.7 1.33x
nf4 f16 536,870,912 6015.5 4511.3 228.7 305.0 1.33x
nf4 f16 1,073,741,824 12055.1 8995.6 228.2 305.9 1.34x
nf4 bf16 268,435,456 3007.6 2290.6 228.7 300.3 1.31x
nf4 bf16 536,870,912 5999.8 4534.6 229.3 303.4 1.32x
nf4 bf16 1,073,741,824 11999.4 9037.4 229.3 304.5 1.33x
nf4 f32 268,435,456 3155.2 2410.3 388.2 508.1 1.31x
nf4 f32 536,870,912 6273.9 4788.1 390.4 511.6 1.31x
nf4 f32 1,073,741,824 12502.6 9478.2 391.8 516.9 1.32x
fp4 f16 268,435,456 2855.7 2135.2 240.9 322.2 1.34x
fp4 f16 536,870,912 5664.7 4230.4 242.9 325.2 1.34x
fp4 f16 1,073,741,824 11328.2 8413.6 242.9 327.0 1.35x
fp4 bf16 268,435,456 2833.5 2137.4 242.8 321.8 1.33x
fp4 bf16 536,870,912 5622.8 4237.9 244.7 324.6 1.33x
fp4 bf16 1,073,741,824 11252.3 8431.7 244.5 326.3 1.33x
fp4 f32 268,435,456 2978.5 2279.3 411.2 537.3 1.31x
fp4 f32 536,870,912 5934.0 4515.9 412.8 542.4 1.31x
fp4 f32 1,073,741,824 11817.6 8945.3 414.5 547.7 1.32x

Dequantize (same kernel on both branches)

quant dtype N Time Base (us) Time Cand (us) BW Base (GB/s) BW Cand (GB/s) Speedup
nf4 f16 268,435,456 1146.1 1145.7 600.2 600.4 1.00x
nf4 f16 536,870,912 2301.8 2300.3 597.7 598.1 1.00x
nf4 f16 1,073,741,824 4597.3 4598.4 598.5 598.4 1.00x
nf4 bf16 268,435,456 1147.4 1147.6 599.5 599.4 1.00x
nf4 bf16 536,870,912 2283.5 2283.9 602.5 602.3 1.00x
nf4 bf16 1,073,741,824 4551.5 4552.6 604.5 604.4 1.00x
nf4 f32 268,435,456 2070.8 2070.9 591.4 591.4 1.00x
nf4 f32 536,870,912 4134.7 4134.0 592.4 592.5 1.00x
nf4 f32 1,073,741,824 8265.9 8265.8 592.7 592.7 1.00x
fp4 f16 268,435,456 1146.4 1146.6 600.0 599.9 1.00x
fp4 f16 536,870,912 2298.0 2295.0 598.7 599.4 1.00x
fp4 f16 1,073,741,824 4588.7 4587.6 599.6 599.8 1.00x
fp4 bf16 268,435,456 1151.1 1151.9 597.6 597.2 1.00x
fp4 bf16 536,870,912 2290.7 2292.0 600.6 600.2 1.00x
fp4 bf16 1,073,741,824 4567.6 4569.5 602.4 602.1 1.00x
fp4 f32 268,435,456 2070.5 2070.6 591.5 591.5 1.00x
fp4 f32 536,870,912 4134.4 4137.3 592.5 592.0 1.00x
fp4 f32 1,073,741,824 8267.7 8265.2 592.5 592.7 1.00x

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!
LGTM! I think we can keep your changes. Thanks for working on this!

#if BNB_HIP
if constexpr (DATA_TYPE > 0) {
if (bnb_host_warp_size() == 64) {
// CDNA: kQuantizeBlockwiseSmall is compiled with THREADS=64
kQuantizeBlockwiseSmall<T, DATA_TYPE>
<<<(num_blocks + 1) / 2, 64>>>(code, A, absmax, out, rand, rand_offset, n);
} else {
// RDNA: standard kernel (same as CUDA path)
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE>
<<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
}
const int ws = bnb_host_warp_size();
const int num_qb = ws / (64 / 2);
int grid = (num_blocks + num_qb - 1) / num_qb;
kQuantizeBlockwiseSmall<T, 64, DATA_TYPE><<<grid, ws>>>(code, A, absmax, out, rand, rand_offset, n);
} else {
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Observation here that kQuantizeBlockwise in combination with General8bit data type only utilizes half of the warp since it launches 32 threads on CDNA which has 64.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can discuss this further in another PR. For now, let’s focus on merging this one.

}
#else
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
#endif
} else if (blocksize == 32) {
// For 4-bit: use specialized kernel that processes 2 blocks per warp
// Each CUDA block handles 2 quantization blocks, so divide num_blocks by 2
if constexpr (DATA_TYPE > 0) {
int num_blocks_adjusted = (num_blocks + 1) / 2;
kQuantizeBlockwiseSmall<T, DATA_TYPE>
<<<num_blocks_adjusted, 32>>>(code, A, absmax, out, rand, rand_offset, n);
const int ws = bnb_host_warp_size();
const int num_qb = ws / (32 / 2);
int grid = (num_blocks + num_qb - 1) / num_qb;
kQuantizeBlockwiseSmall<T, 32, DATA_TYPE><<<grid, ws>>>(code, A, absmax, out, rand, rand_offset, n);
}
}

Expand Down
Loading