From 784203f0db24a1effe1b98462d301e9d756857fc Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Mon, 23 Feb 2026 19:30:09 -0800 Subject: [PATCH 1/6] Add 1-bit affine quantization support --- benchmarks/python/comparative/bench_mlx.py | 31 ++- benchmarks/python/comparative/compare.py | 35 ++++ mlx/backend/cpu/quantized.cpp | 34 +++- mlx/backend/metal/kernels/quantized.h | 177 ++++++++++++++---- mlx/backend/metal/kernels/quantized.metal | 1 + mlx/backend/metal/kernels/quantized_nax.h | 162 +++++++++++++--- mlx/backend/metal/kernels/quantized_nax.metal | 1 + mlx/ops.cpp | 28 ++- python/src/ops.cpp | 16 +- python/tests/cuda_skip.py | 1 + python/tests/test_quantized.py | 96 +++++++++- 11 files changed, 484 insertions(+), 98 deletions(-) diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 4e6ba04f8e..1f12064ebe 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -72,12 +72,17 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits): quant_matmul = { + "quant_matmul_32_1": partial(_quant_matmul, transpose=False, group_size=32, bits=1), "quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2), "quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4), "quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8), + "quant_matmul_64_1": partial(_quant_matmul, transpose=False, group_size=64, bits=1), "quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2), "quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4), "quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8), + "quant_matmul_128_1": partial( + _quant_matmul, transpose=False, group_size=128, bits=1 + ), "quant_matmul_128_2": partial( _quant_matmul, transpose=False, group_size=128, bits=2 ), @@ -87,6 +92,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits): "quant_matmul_128_8": partial( _quant_matmul, transpose=False, group_size=128, bits=8 ), + "quant_matmul_t_32_1": partial( + _quant_matmul, transpose=True, group_size=32, bits=1 + ), "quant_matmul_t_32_2": partial( _quant_matmul, transpose=True, group_size=32, bits=2 ), @@ -96,6 +104,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits): "quant_matmul_t_32_8": partial( _quant_matmul, transpose=True, group_size=32, bits=8 ), + "quant_matmul_t_64_1": partial( + _quant_matmul, transpose=True, group_size=64, bits=1 + ), "quant_matmul_t_64_2": partial( _quant_matmul, transpose=True, group_size=64, bits=2 ), @@ -105,6 +116,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits): "quant_matmul_t_64_8": partial( _quant_matmul, transpose=True, group_size=64, bits=8 ), + "quant_matmul_t_128_1": partial( + _quant_matmul, transpose=True, group_size=128, bits=1 + ), "quant_matmul_t_128_2": partial( _quant_matmul, transpose=True, group_size=128, bits=2 ), @@ -420,7 +434,22 @@ def selu(x): print(bench(matmul, *xs)) elif args.benchmark.startswith("quant_matmul"): - print(bench(quant_matmul[args.benchmark], *xs)) + # Parse group_size and bits from the benchmark name, e.g. + # "quant_matmul_128_4" or "quant_matmul_t_128_4" + fn = quant_matmul[args.benchmark] + gs = fn.keywords["group_size"] + bits = fn.keywords["bits"] + transpose = fn.keywords["transpose"] + + # xs[0] = activation x, xs[1] = original (float) weight matrix + # Quantize the weight internally so the caller only needs: + # --size MxK --size NxK (transpose=True) or --size MxK --size KxN + w_float = xs[1].astype(mx.float16) + w_q, scales, biases = mx.quantize(w_float, group_size=gs, bits=bits) + mx.eval(w_q, scales, biases) + x_input = xs[0].astype(mx.float16) + mx.eval(x_input) + print(bench(_quant_matmul, x_input, w_q, scales, biases, transpose, gs, bits)) elif args.benchmark == "linear": if args.fused: diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index 68b4a5bd32..abfaddb268 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -28,6 +28,18 @@ def compare(args): print((t_torch - t_mlx) / t_torch, " ".join(args), sep="\t") +def compare_mlx_quant(args_base, bits_list): + """Compare quantized matmul across bit widths (MLX only, no PyTorch).""" + results = {} + for bits in bits_list: + bench_args = args_base.replace("{bits}", str(bits)).split() + results[bits] = run_or_raise(["python", BENCH_MLX] + bench_args) + baseline = max(results.values()) + for bits in bits_list: + speedup = (baseline - results[bits]) / baseline if baseline > 0 else 0 + print(f"{speedup:.4f}\t{args_base.replace('{bits}', str(bits))}") + + def compare_mlx_dtypes(args, dt1, dt2): t_mlx_dt1 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt1]) t_mlx_dt2 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt2]) @@ -282,3 +294,26 @@ def predicate(x): compare_filtered("topk --size 32768x128 --axis 1") compare_filtered("topk --size 128x128 --axis 0 --cpu") compare_filtered("topk --size 128x128 --axis 1 --cpu") + + # Quantized matmul ops (MLX only — compare across bit widths) + # qmv path (M=1, token generation, memory-bandwidth bound) + for gs in [64, 128]: + compare_mlx_quant( + f"quant_matmul_t_{gs}_{{bits}} --size 1x4096 --size 4096x4096", + [1, 2, 4, 8], + ) + compare_mlx_quant( + f"quant_matmul_t_{gs}_{{bits}} --size 1x4096 --size 11008x4096", + [1, 2, 4, 8], + ) + # qmm path (prompt processing, more compute bound) + for gs in [64, 128]: + for M in [32, 512]: + compare_mlx_quant( + f"quant_matmul_t_{gs}_{{bits}} --size {M}x4096 --size 4096x4096", + [1, 2, 4, 8], + ) + compare_mlx_quant( + f"quant_matmul_t_{gs}_{{bits}} --size {M}x4096 --size 11008x4096", + [1, 2, 4, 8], + ) diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index 3c3643a320..d2cb4e8797 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -359,6 +359,10 @@ void _qmm_dispatch_typed( int bits, bool transposed_w) { switch (bits) { + case 1: + _qmm_dispatch_group( + result, x, w, scales, biases, M, N, K, group_size, transposed_w); + break; case 2: _qmm_dispatch_group( result, x, w, scales, biases, M, N, K, group_size, transposed_w); @@ -384,7 +388,8 @@ void _qmm_dispatch_typed( result, x, w, scales, biases, M, N, K, group_size, transposed_w); break; default: - throw std::invalid_argument("Quantization bits must be 2, 3, 4, 6 or 8."); + throw std::invalid_argument( + "Quantization bits must be 1, 2, 3, 4, 5, 6 or 8."); } } @@ -1180,15 +1185,24 @@ void quantize( w_min = std::min(w_min, (float)w[w_idx + j]); } bool mask = std::abs(w_min) > std::abs(w_max); - float scale = std::max((w_max - w_min) / n_bins, eps); - scale = mask ? scale : -scale; - - float edge = mask ? w_min : w_max; - float q0 = std::rint(edge / scale); - float bias = 0; - if (q0 != 0) { - scale = edge / q0; - bias = edge; + float scale; + float bias; + + if (bits == 1) { + // Affine 1-bit: bit 0 -> w_min, bit 1 -> w_max + scale = std::max(w_max - w_min, eps); + bias = w_min; + } else { + scale = std::max((w_max - w_min) / n_bins, eps); + scale = mask ? scale : -scale; + + float edge = mask ? w_min : w_max; + float q0 = std::rint(edge / scale); + bias = 0; + if (q0 != 0) { + scale = edge / q0; + bias = edge; + } } size_t out_idx = i * int_per_group; for (int j = 0; j < int_per_group / bytes_per_pack; ++j) { diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 5ac4c6e165..76c614ca93 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -28,13 +28,28 @@ inline constexpr short get_bytes_per_pack() { template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); U sum = 0; - if (bits == 2) { + if (bits == 1) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; + x_thread[i + 4] = x[i + 4]; + x_thread[i + 5] = x[i + 5]; + x_thread[i + 6] = x[i + 6]; + x_thread[i + 7] = x[i + 7]; + } + } + + else if (bits == 2) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; @@ -107,13 +122,28 @@ inline U load_vector(const device T* x, thread U* x_thread) { template inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); U sum = 0; - if (bits == 2) { + if (bits == 1) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; + x_thread[i + 4] = x[i + 4]; + x_thread[i + 5] = x[i + 5]; + x_thread[i + 6] = x[i + 6]; + x_thread[i + 7] = x[i + 7]; + } + } + + else if (bits == 2) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; @@ -196,13 +226,27 @@ inline U qdot( U bias, U sum) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); U accum = 0; - if (bits == 2) { + if (bits == 1) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t wb = w[i]; + accum += select(U(0), x_thread[8 * i], bool(wb & 0x01)); + accum += select(U(0), x_thread[8 * i + 1], bool(wb & 0x02)); + accum += select(U(0), x_thread[8 * i + 2], bool(wb & 0x04)); + accum += select(U(0), x_thread[8 * i + 3], bool(wb & 0x08)); + accum += select(U(0), x_thread[8 * i + 4], bool(wb & 0x10)); + accum += select(U(0), x_thread[8 * i + 5], bool(wb & 0x20)); + accum += select(U(0), x_thread[8 * i + 6], bool(wb & 0x40)); + accum += select(U(0), x_thread[8 * i + 7], bool(wb & 0x80)); + } + } + + else if (bits == 2) { for (int i = 0; i < (values_per_thread / 4); i++) { accum += (x_thread[4 * i] * (w[i] & 0x03) + @@ -298,13 +342,27 @@ inline U qdot_safe( U sum, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); U accum = 0; - if (bits == 2) { + if (bits == 1) { + for (int i = 0; i < (N / 8); i++) { + uint8_t wb = w[i]; + accum += select(U(0), x_thread[8 * i], bool(wb & 0x01)); + accum += select(U(0), x_thread[8 * i + 1], bool(wb & 0x02)); + accum += select(U(0), x_thread[8 * i + 2], bool(wb & 0x04)); + accum += select(U(0), x_thread[8 * i + 3], bool(wb & 0x08)); + accum += select(U(0), x_thread[8 * i + 4], bool(wb & 0x10)); + accum += select(U(0), x_thread[8 * i + 5], bool(wb & 0x20)); + accum += select(U(0), x_thread[8 * i + 6], bool(wb & 0x40)); + accum += select(U(0), x_thread[8 * i + 7], bool(wb & 0x80)); + } + } + + else if (bits == 2) { for (int i = 0; i < (N / 4); i++) { accum += (x_thread[4 * i] * (w[i] & 0x03) + @@ -395,11 +453,25 @@ template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); - if (bits == 2) { + if (bits == 1) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t wb = w[i]; + result[8 * i] += x * (select(U(0), scale, bool(wb & 0x01)) + bias); + result[8 * i + 1] += x * (select(U(0), scale, bool(wb & 0x02)) + bias); + result[8 * i + 2] += x * (select(U(0), scale, bool(wb & 0x04)) + bias); + result[8 * i + 3] += x * (select(U(0), scale, bool(wb & 0x08)) + bias); + result[8 * i + 4] += x * (select(U(0), scale, bool(wb & 0x10)) + bias); + result[8 * i + 5] += x * (select(U(0), scale, bool(wb & 0x20)) + bias); + result[8 * i + 6] += x * (select(U(0), scale, bool(wb & 0x40)) + bias); + result[8 * i + 7] += x * (select(U(0), scale, bool(wb & 0x80)) + bias); + } + } + + else if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; for (int i = 0; i < (values_per_thread / 4); i++) { result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); @@ -484,11 +556,33 @@ template inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); + + if (bits == 1) { + U s[8] = { + scale, + scale / static_cast(2.0f), + scale / static_cast(4.0f), + scale / static_cast(8.0f), + scale / static_cast(16.0f), + scale / static_cast(32.0f), + scale / static_cast(64.0f), + scale / static_cast(128.0f)}; + for (int i = 0; i < (N / 8); i++) { + w_local[8 * i] = s[0] * (w[i] & 0x01) + bias; + w_local[8 * i + 1] = s[1] * (w[i] & 0x02) + bias; + w_local[8 * i + 2] = s[2] * (w[i] & 0x04) + bias; + w_local[8 * i + 3] = s[3] * (w[i] & 0x08) + bias; + w_local[8 * i + 4] = s[4] * (w[i] & 0x10) + bias; + w_local[8 * i + 5] = s[5] * (w[i] & 0x20) + bias; + w_local[8 * i + 6] = s[6] * (w[i] & 0x40) + bias; + w_local[8 * i + 7] = s[7] * (w[i] & 0x80) + bias; + } + } - if (bits == 2) { + else if (bits == 2) { U s[4] = { scale, scale / static_cast(4.0f), @@ -577,9 +671,9 @@ struct QuantizedBlockLoader { group_size % BCOLS == 0, "The group size should be divisible by the columns"); static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); MLX_MTL_CONST short pack_factor = get_pack_factor(); MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); @@ -2384,14 +2478,23 @@ template w_min = simd_min(w_min); w_max = simd_max(w_max); - float scale = max((w_max - w_min) / n_bins, eps); - bool side = abs(w_min) > abs(w_max); - scale = side ? scale : -scale; - float edge = side ? w_min : w_max; - float q0 = round(edge / scale); - bool at_zero = q0 == 0.0f; - scale = at_zero ? scale : edge / q0; - float bias = at_zero ? 0 : edge; + float scale; + float bias; + + if (bits == 1) { + // Affine 1-bit: bit 0 -> w_min, bit 1 -> w_max + scale = max(w_max - w_min, eps); + bias = w_min; + } else { + scale = max((w_max - w_min) / n_bins, eps); + bool side = abs(w_min) > abs(w_max); + scale = side ? scale : -scale; + float edge = side ? w_min : w_max; + float q0 = round(edge / scale); + bool at_zero = q0 == 0.0f; + scale = at_zero ? scale : edge / q0; + bias = at_zero ? 0 : edge; + } // Write out the scales and biases size_t gindex = in_index / group_size; @@ -2495,7 +2598,9 @@ template #pragma clang loop unroll(full) for (int i = 0; i < pack_factor; i++) { uint8_t d; - if (bits == 2) { + if (bits == 1) { + d = (val >> i) & 0x01; + } else if (bits == 2) { d = (val >> (bits * i)) & 0x03; } else if (bits == 4) { d = (val >> (bits * i)) & 0x0f; diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index f734b9bcee..255de89188 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -134,6 +134,7 @@ instantiate_quantized_types(32, bits) #define instantiate_quantized_all() \ + instantiate_quantized_groups(1) \ instantiate_quantized_groups(2) \ instantiate_quantized_groups(3) \ instantiate_quantized_groups(4) \ diff --git a/mlx/backend/metal/kernels/quantized_nax.h b/mlx/backend/metal/kernels/quantized_nax.h index c26ff646bb..63e864b7df 100644 --- a/mlx/backend/metal/kernels/quantized_nax.h +++ b/mlx/backend/metal/kernels/quantized_nax.h @@ -31,13 +31,28 @@ inline constexpr short get_bytes_per_pack() { template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); U sum = 0; - if (bits == 2) { + if (bits == 1) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 2.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 8.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 32.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 128.0f; + } + } + + else if (bits == 2) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; @@ -110,13 +125,28 @@ inline U load_vector(const device T* x, thread U* x_thread) { template inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); U sum = 0; - if (bits == 2) { + if (bits == 1) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 2.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 8.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 32.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 128.0f; + } + } + + else if (bits == 2) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; @@ -199,13 +229,27 @@ inline U qdot( U bias, U sum) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); U accum = 0; - if (bits == 2) { + if (bits == 1) { + for (int i = 0; i < (values_per_thread / 8); i++) { + accum += + (x_thread[8 * i] * (w[i] & 0x01) + + x_thread[8 * i + 1] * (w[i] & 0x02) + + x_thread[8 * i + 2] * (w[i] & 0x04) + + x_thread[8 * i + 3] * (w[i] & 0x08) + + x_thread[8 * i + 4] * (w[i] & 0x10) + + x_thread[8 * i + 5] * (w[i] & 0x20) + + x_thread[8 * i + 6] * (w[i] & 0x40) + + x_thread[8 * i + 7] * (w[i] & 0x80)); + } + } + + else if (bits == 2) { for (int i = 0; i < (values_per_thread / 4); i++) { accum += (x_thread[4 * i] * (w[i] & 0x03) + @@ -301,13 +345,27 @@ inline U qdot_safe( U sum, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); U accum = 0; - if (bits == 2) { + if (bits == 1) { + for (int i = 0; i < (N / 8); i++) { + accum += + (x_thread[8 * i] * (w[i] & 0x01) + + x_thread[8 * i + 1] * (w[i] & 0x02) + + x_thread[8 * i + 2] * (w[i] & 0x04) + + x_thread[8 * i + 3] * (w[i] & 0x08) + + x_thread[8 * i + 4] * (w[i] & 0x10) + + x_thread[8 * i + 5] * (w[i] & 0x20) + + x_thread[8 * i + 6] * (w[i] & 0x40) + + x_thread[8 * i + 7] * (w[i] & 0x80)); + } + } + + else if (bits == 2) { for (int i = 0; i < (N / 4); i++) { accum += (x_thread[4 * i] * (w[i] & 0x03) + @@ -398,11 +456,33 @@ template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); - if (bits == 2) { + if (bits == 1) { + U s[8] = { + scale, + scale / 2.0f, + scale / 4.0f, + scale / 8.0f, + scale / 16.0f, + scale / 32.0f, + scale / 64.0f, + scale / 128.0f}; + for (int i = 0; i < (values_per_thread / 8); i++) { + result[8 * i] += x * (s[0] * (w[i] & 0x01) + bias); + result[8 * i + 1] += x * (s[1] * (w[i] & 0x02) + bias); + result[8 * i + 2] += x * (s[2] * (w[i] & 0x04) + bias); + result[8 * i + 3] += x * (s[3] * (w[i] & 0x08) + bias); + result[8 * i + 4] += x * (s[4] * (w[i] & 0x10) + bias); + result[8 * i + 5] += x * (s[5] * (w[i] & 0x20) + bias); + result[8 * i + 6] += x * (s[6] * (w[i] & 0x40) + bias); + result[8 * i + 7] += x * (s[7] * (w[i] & 0x80) + bias); + } + } + + else if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; for (int i = 0; i < (values_per_thread / 4); i++) { result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); @@ -487,11 +567,33 @@ template inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); + + if (bits == 1) { + U s[8] = { + scale, + scale / static_cast(2.0f), + scale / static_cast(4.0f), + scale / static_cast(8.0f), + scale / static_cast(16.0f), + scale / static_cast(32.0f), + scale / static_cast(64.0f), + scale / static_cast(128.0f)}; + for (int i = 0; i < (N / 8); i++) { + w_local[8 * i] = s[0] * (w[i] & 0x01) + bias; + w_local[8 * i + 1] = s[1] * (w[i] & 0x02) + bias; + w_local[8 * i + 2] = s[2] * (w[i] & 0x04) + bias; + w_local[8 * i + 3] = s[3] * (w[i] & 0x08) + bias; + w_local[8 * i + 4] = s[4] * (w[i] & 0x10) + bias; + w_local[8 * i + 5] = s[5] * (w[i] & 0x20) + bias; + w_local[8 * i + 6] = s[6] * (w[i] & 0x40) + bias; + w_local[8 * i + 7] = s[7] * (w[i] & 0x80) + bias; + } + } - if (bits == 2) { + else if (bits == 2) { U s[4] = { scale, scale / static_cast(4.0f), @@ -580,9 +682,9 @@ struct QuantizedBlockLoader { group_size % BCOLS == 0, "The group size should be divisible by the columns"); static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); MLX_MTL_CONST short pack_factor = get_pack_factor(); MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); @@ -715,9 +817,9 @@ struct QuantizedBlockLoader< BCOLS % group_size == 0, "The group size should be divisible by the columns"); static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); MLX_MTL_CONST short pack_factor = get_pack_factor(); MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); diff --git a/mlx/backend/metal/kernels/quantized_nax.metal b/mlx/backend/metal/kernels/quantized_nax.metal index 5a9c9fb874..fb68e535a9 100644 --- a/mlx/backend/metal/kernels/quantized_nax.metal +++ b/mlx/backend/metal/kernels/quantized_nax.metal @@ -96,6 +96,7 @@ instantiate_quantized_types(32, bits) #define instantiate_quantized_all() \ + instantiate_quantized_groups(1) \ instantiate_quantized_groups(2) \ instantiate_quantized_groups(3) \ instantiate_quantized_groups(4) \ diff --git a/mlx/ops.cpp b/mlx/ops.cpp index deb1c27036..a5a4976361 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4540,10 +4540,10 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { throw std::invalid_argument(msg.str()); } - if (bits < 2 || bits > 8 || bits == 7) { + if (bits < 1 || bits > 8 || bits == 7) { std::ostringstream msg; msg << "[quantize] The requested number of bits " << bits - << " is not supported. The supported bits are 2, 3, 4, 5, 6 and 8."; + << " is not supported. The supported bits are 1, 2, 3, 4, 5, 6 and 8."; throw std::invalid_argument(msg.str()); } @@ -4564,14 +4564,22 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { w_max = astype(w_max, float32, s); w_min = astype(w_min, float32, s); - array mask = greater(abs(w_min, s), abs(w_max, s), s); - array scales = - maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); - scales = where(mask, scales, negative(scales, s), s); - array edge = where(mask, w_min, w_max, s); - array q0 = round(divide(edge, scales, s), s); - scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales); - array biases = where(equal(q0, zero, s), zero, edge, s); + array scales(0, float32); + array biases(0, float32); + + if (bits == 1) { + // Affine 1-bit: bit 0 -> w_min, bit 1 -> w_max + scales = maximum(subtract(w_max, w_min, s), eps, s); + biases = w_min; + } else { + array mask = greater(abs(w_min, s), abs(w_max, s), s); + scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); + scales = where(mask, scales, negative(scales, s), s); + array edge = where(mask, w_min, w_max, s); + array q0 = round(divide(edge, scales, s), s); + scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales); + biases = where(equal(q0, zero, s), zero, edge, s); + } packed_w = pack_and_quantize(packed_w, scales, biases, bits, s); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index dac4b8f7f7..f3909ddea3 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4348,14 +4348,14 @@ void init_ops(nb::module_& m) { .. table:: Quantization modes - ====== ====================== ========================== ============= ===== - mode group size bits scale type bias - ====== ====================== ========================== ============= ===== - affine 32, 64\ :sup:`*`, 128 2, 3, 4\ :sup:`*`, 5, 6, 8 same as input yes - mxfp4 32\ :sup:`*` 4\ :sup:`*` e8m0 no - mxfp8 32\ :sup:`*` 8\ :sup:`*` e8m0 no - nvfp4 16\ :sup:`*` 4\ :sup:`*` e4m3 no - ====== ====================== ========================== ============= ===== + ====== ====================== ============================== ============= ===== + mode group size bits scale type bias + ====== ====================== ============================== ============= ===== + affine 32, 64\ :sup:`*`, 128 1, 2, 3, 4\ :sup:`*`, 5, 6, 8 same as input yes + mxfp4 32\ :sup:`*` 4\ :sup:`*` e8m0 no + mxfp8 32\ :sup:`*` 8\ :sup:`*` e8m0 no + nvfp4 16\ :sup:`*` 4\ :sup:`*` e4m3 no + ====== ====================== ============================== ============= ===== :sup:`*` indicates the default value when unspecified. diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 20793d5c91..3d37d1facc 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -53,6 +53,7 @@ "TestQuantized.test_qmv_small_non_multiples", "TestQuantized.test_small_matrix", "TestQuantized.test_throw", + "TestQuantized.test_1bit_quantize_dequantize", "TestQuantized.test_vjp_scales_biases", "TestExportImport.test_export_quantized_model", # Masked scatter diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index a7472e9920..6049c71ff8 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -11,7 +11,7 @@ class TestQuantized(mlx_tests.MLXTestCase): def test_quantize_dequantize(self): w = mx.random.normal(shape=(128, 512)) for gs in [32, 64, 128]: - for b in [2, 3, 5, 6, 4, 8]: + for b in [1, 2, 3, 5, 6, 4, 8]: with self.subTest(gs=gs, b=b): w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b) w_hat = mx.dequantize(w_q, scales, biases, gs, b) @@ -22,7 +22,7 @@ def test_quantize_dequantize(self): # test quantize/dequantize 0s a = mx.zeros((256, 512)) for gs in [32, 64, 128]: - for b in [2, 3, 4, 5, 6, 8]: + for b in [1, 2, 3, 4, 5, 6, 8]: w_q, scales, biases = mx.quantize(a, gs, b) a_hat = mx.dequantize(w_q, scales, biases, gs, b) self.assertTrue(mx.all(a_hat == 0)) @@ -173,6 +173,96 @@ def test_nvfp4_quantize_dequantize(self): ) self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5)) + def test_1bit_quantize_dequantize(self): + """Test 1-bit affine quantization.""" + + # Symmetric binary weights {-0.5, +0.5} should round-trip perfectly + # (affine formula gives scale=1.0, bias=-0.5) + for gs in [32, 64, 128]: + with self.subTest(gs=gs, case="pack_symmetric_weights"): + signs = (mx.random.uniform(shape=(128, 512)) > 0.5).astype(mx.float32) + w = signs * 1.0 - (1 - signs) * 1.0 # {-1.0, +1.0} + w = w * 0.5 # {-0.5, +0.5} + + w_q, scales, biases = mx.quantize(w, group_size=gs, bits=1) + w_hat = mx.dequantize(w_q, scales, biases, gs, 1) + + self.assertLess((w - w_hat).abs().max(), 1e-5) + + # Asymmetric binary weights {0.1, 0.9} should round-trip perfectly + # (affine formula gives scale=0.8, bias=0.1) + for gs in [32, 64, 128]: + with self.subTest(gs=gs, case="pack_asymmetric_weights"): + bits = (mx.random.uniform(shape=(128, 512)) > 0.5).astype(mx.float32) + w = bits * 0.9 + (1 - bits) * 0.1 # {0.1, 0.9} + + w_q, scales, biases = mx.quantize(w, group_size=gs, bits=1) + w_hat = mx.dequantize(w_q, scales, biases, gs, 1) + + self.assertLess((w - w_hat).abs().max(), 1e-5) + + # Verify dequantized values are exactly {bias, bias + scale} + w = mx.random.normal(shape=(64, 256)) + for gs in [32, 64, 128]: + with self.subTest(gs=gs, case="dequant_values"): + w_q, scales, biases = mx.quantize(w, group_size=gs, bits=1) + w_hat = mx.dequantize(w_q, scales, biases, gs, 1) + + for i in range(scales.shape[0]): + for j in range(scales.shape[1]): + s = scales[i, j].item() + b = biases[i, j].item() + row_start = j * gs + row_end = row_start + gs + vals = w_hat[i, row_start:row_end] + mx.eval(vals) + for v in vals.tolist(): + self.assertTrue( + abs(v - b) < 1e-5 or abs(v - (b + s)) < 1e-5, + f"Value {v} not in {{bias={b}, bias+scale={b+s}}}", + ) + + # 1-bit quantize/dequantize zeros — scale floors to eps, bias=0 + a = mx.zeros((256, 512)) + for gs in [32, 64, 128]: + w_q, scales, biases = mx.quantize(a, gs, 1) + a_hat = mx.dequantize(w_q, scales, biases, gs, 1) + self.assertLess(a_hat.abs().max(), 1e-5) + + # Quantized matmul with symmetric binary weights + key = mx.random.key(42) + k1, k2 = mx.random.split(key) + for gs in [32, 64, 128]: + with self.subTest(gs=gs, case="quantized_matmul_symmetric"): + x = mx.random.normal(shape=(4, 256), key=k1) + signs = (mx.random.uniform(shape=(128, 256), key=k2) > 0.5).astype( + mx.float32 + ) + w = signs * 0.3 - (1 - signs) * 0.3 # {-0.3, +0.3} + + w_q, scales, biases = mx.quantize(w, gs, 1) + w_hat = mx.dequantize(w_q, scales, biases, gs, 1) + y_q = mx.quantized_matmul(x, w_q, scales, biases, True, gs, 1) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-5) + + # Quantized matmul with asymmetric binary weights + for gs in [32, 64, 128]: + with self.subTest(gs=gs, case="quantized_matmul_asymmetric"): + x = mx.random.normal(shape=(4, 256), key=k1) + bits = (mx.random.uniform(shape=(128, 256), key=k2) > 0.5).astype( + mx.float32 + ) + w = bits * 0.7 + (1 - bits) * 0.1 # {0.1, 0.7} + + w_q, scales, biases = mx.quantize(w, gs, 1) + w_hat = mx.dequantize(w_q, scales, biases, gs, 1) + y_q = mx.quantized_matmul(x, w_q, scales, biases, True, gs, 1) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-5) + def test_qqmv(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) @@ -211,7 +301,7 @@ def test_qmm(self): dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32 tests = product( [128, 64, 32], # group_size - [2, 4, 8], # bits + [1, 2, 4, 8], # bits [8, 32, 33, 64], # M [128, 256], # N [128, 256], # K From d155e95634b6ddbda848281377ede921fc2247cc Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Fri, 6 Mar 2026 18:06:31 -0800 Subject: [PATCH 2/6] Guard fast-path Metal kernel dispatch for 1-bit quantization --- mlx/backend/metal/quantized.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index cb67c74f73..c344af84a0 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -1282,7 +1282,7 @@ void dispatch_qmv( const Stream& s, const std::string& mode) { // It is a qmv with a small inner dimension so route to qmv_quad kernel - if ((K == 128 || K == 64) && is_power_of_2(bits)) { + if ((K == 128 || (K == 64 && bits >= 2)) && is_power_of_2(bits)) { qmv_quad(x, w, scales, biases, out, group_size, bits, M, N, K, d, s, mode); return; } From 644a8cd0b358d483e888d8cc7783b45f8920af71 Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Tue, 24 Mar 2026 21:47:04 -0700 Subject: [PATCH 3/6] fix qmv_fast tail iteration for non-aligned K --- mlx/backend/metal/kernels/quantized.h | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 76c614ca93..4812a9cbe2 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -880,7 +880,9 @@ METAL_FUNC void qmv_fast_impl( x += tid.x * in_vec_size + simd_lid * values_per_thread; y += tid.x * out_vec_size + out_row; - for (int k = 0; k < in_vec_size; k += block_size) { + const int aligned_end = (in_vec_size / block_size) * block_size; + + for (int k = 0; k < aligned_end; k += block_size) { U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { @@ -899,6 +901,28 @@ METAL_FUNC void qmv_fast_impl( x += block_size; } + if (aligned_end < in_vec_size) { + bool in_bounds = + (aligned_end + simd_lid * values_per_thread) < in_vec_size; + U sum = 0; + if (in_bounds) { + sum = load_vector(x, x_thread); + } else { + for (int i = 0; i < values_per_thread; i++) + x_thread[i] = 0; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = in_bounds ? (U)sl[0] : (U)0; + U b = in_bounds ? (U)bl[0] : (U)0; + result[row] += qdot(wl, x_thread, s, b, sum); + } + } + for (int row = 0; row < results_per_simdgroup; row++) { result[row] = simd_sum(result[row]); if (simd_lid == 0) { From a386acc2a92408825091c50153e755bf0d911d14 Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Wed, 1 Apr 2026 03:03:52 -0700 Subject: [PATCH 4/6] address few PR comments, plus fix linter format --- benchmarks/python/comparative/bench_mlx.py | 17 +---------- benchmarks/python/comparative/compare.py | 35 ---------------------- mlx/backend/metal/kernels/quantized.h | 3 +- 3 files changed, 2 insertions(+), 53 deletions(-) diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 1f12064ebe..0d4ac5cbd1 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -434,22 +434,7 @@ def selu(x): print(bench(matmul, *xs)) elif args.benchmark.startswith("quant_matmul"): - # Parse group_size and bits from the benchmark name, e.g. - # "quant_matmul_128_4" or "quant_matmul_t_128_4" - fn = quant_matmul[args.benchmark] - gs = fn.keywords["group_size"] - bits = fn.keywords["bits"] - transpose = fn.keywords["transpose"] - - # xs[0] = activation x, xs[1] = original (float) weight matrix - # Quantize the weight internally so the caller only needs: - # --size MxK --size NxK (transpose=True) or --size MxK --size KxN - w_float = xs[1].astype(mx.float16) - w_q, scales, biases = mx.quantize(w_float, group_size=gs, bits=bits) - mx.eval(w_q, scales, biases) - x_input = xs[0].astype(mx.float16) - mx.eval(x_input) - print(bench(_quant_matmul, x_input, w_q, scales, biases, transpose, gs, bits)) + print(bench(quant_matmul[args.benchmark], *xs)) elif args.benchmark == "linear": if args.fused: diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index abfaddb268..68b4a5bd32 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -28,18 +28,6 @@ def compare(args): print((t_torch - t_mlx) / t_torch, " ".join(args), sep="\t") -def compare_mlx_quant(args_base, bits_list): - """Compare quantized matmul across bit widths (MLX only, no PyTorch).""" - results = {} - for bits in bits_list: - bench_args = args_base.replace("{bits}", str(bits)).split() - results[bits] = run_or_raise(["python", BENCH_MLX] + bench_args) - baseline = max(results.values()) - for bits in bits_list: - speedup = (baseline - results[bits]) / baseline if baseline > 0 else 0 - print(f"{speedup:.4f}\t{args_base.replace('{bits}', str(bits))}") - - def compare_mlx_dtypes(args, dt1, dt2): t_mlx_dt1 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt1]) t_mlx_dt2 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt2]) @@ -294,26 +282,3 @@ def predicate(x): compare_filtered("topk --size 32768x128 --axis 1") compare_filtered("topk --size 128x128 --axis 0 --cpu") compare_filtered("topk --size 128x128 --axis 1 --cpu") - - # Quantized matmul ops (MLX only — compare across bit widths) - # qmv path (M=1, token generation, memory-bandwidth bound) - for gs in [64, 128]: - compare_mlx_quant( - f"quant_matmul_t_{gs}_{{bits}} --size 1x4096 --size 4096x4096", - [1, 2, 4, 8], - ) - compare_mlx_quant( - f"quant_matmul_t_{gs}_{{bits}} --size 1x4096 --size 11008x4096", - [1, 2, 4, 8], - ) - # qmm path (prompt processing, more compute bound) - for gs in [64, 128]: - for M in [32, 512]: - compare_mlx_quant( - f"quant_matmul_t_{gs}_{{bits}} --size {M}x4096 --size 4096x4096", - [1, 2, 4, 8], - ) - compare_mlx_quant( - f"quant_matmul_t_{gs}_{{bits}} --size {M}x4096 --size 11008x4096", - [1, 2, 4, 8], - ) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index c7806c6e9c..3aafabd314 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -902,8 +902,7 @@ METAL_FUNC void qmv_fast_impl( } if (aligned_end < in_vec_size) { - bool in_bounds = - (aligned_end + simd_lid * values_per_thread) < in_vec_size; + bool in_bounds = (aligned_end + simd_lid * values_per_thread) < in_vec_size; U sum = 0; if (in_bounds) { sum = load_vector(x, x_thread); From 99514c0ee7471004894932669d793ac50260802b Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Mon, 20 Apr 2026 10:06:07 -0700 Subject: [PATCH 5/6] fix merge conflict --- python/tests/cuda_skip.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 8b50bb3483..361b5c0bed 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,8 +1,4 @@ cuda_skip = { - # Gather matmul NYI - "TestBlas.test_gather_matmul", - "TestBlas.test_gather_matmul_grad", - "TestBlas.test_gather_mm_sorted_vjp", # Lapack ops NYI "TestLinalg.test_cholesky", "TestLinalg.test_cholesky_inv", @@ -23,15 +19,6 @@ "TestQuantized.test_gather_qmm", "TestQuantized.test_gather_qmm_sorted", "TestQuantized.test_gather_qmm_grad", - "TestQuantized.test_non_multiples", - "TestQuantized.test_qmm_shapes", - "TestQuantized.test_fp_qvm", - "TestQuantized.test_qvm", - "TestQuantized.test_qvm_splitk", - "TestQuantized.test_qmv_small_non_multiples", - "TestQuantized.test_small_matrix", - "TestQuantized.test_throw", + # 1-bit quantization NYI on CUDA "TestQuantized.test_1bit_quantize_dequantize", - "TestQuantized.test_vjp_scales_biases", - "TestExportImport.test_export_quantized_model", } From 63eda4333ccebd06fbcbbdc71fc0cd891fd10db5 Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Mon, 20 Apr 2026 15:46:40 -0700 Subject: [PATCH 6/6] copilot suggestion --- mlx/backend/metal/kernels/quantized.h | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 3aafabd314..608111f81a 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -901,24 +901,24 @@ METAL_FUNC void qmv_fast_impl( x += block_size; } + // Tail block: handles K that is a multiple of 512 (dispatch gate) but not + // of block_size. Only 1-bit has block_size > 512 (it uses 2048), so for + // >=2-bit this is a no-op. Out-of-bounds lanes skip the qdot call entirely; + // their result[row] stays at 0 and contributes 0 to the simd_sum below. if (aligned_end < in_vec_size) { bool in_bounds = (aligned_end + simd_lid * values_per_thread) < in_vec_size; - U sum = 0; if (in_bounds) { - sum = load_vector(x, x_thread); - } else { - for (int i = 0; i < values_per_thread; i++) - x_thread[i] = 0; - } + U sum = load_vector(x, x_thread); - for (int row = 0; row < results_per_simdgroup; row++) { - auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; - U s = in_bounds ? (U)sl[0] : (U)0; - U b = in_bounds ? (U)bl[0] : (U)0; - result[row] += qdot(wl, x_thread, s, b, sum); + U s = sl[0]; + U b = bl[0]; + result[row] += qdot(wl, x_thread, s, b, sum); + } } }