diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 4e6ba04f8e..0d4ac5cbd1 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -72,12 +72,17 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits): quant_matmul = { + "quant_matmul_32_1": partial(_quant_matmul, transpose=False, group_size=32, bits=1), "quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2), "quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4), "quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8), + "quant_matmul_64_1": partial(_quant_matmul, transpose=False, group_size=64, bits=1), "quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2), "quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4), "quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8), + "quant_matmul_128_1": partial( + _quant_matmul, transpose=False, group_size=128, bits=1 + ), "quant_matmul_128_2": partial( _quant_matmul, transpose=False, group_size=128, bits=2 ), @@ -87,6 +92,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits): "quant_matmul_128_8": partial( _quant_matmul, transpose=False, group_size=128, bits=8 ), + "quant_matmul_t_32_1": partial( + _quant_matmul, transpose=True, group_size=32, bits=1 + ), "quant_matmul_t_32_2": partial( _quant_matmul, transpose=True, group_size=32, bits=2 ), @@ -96,6 +104,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits): "quant_matmul_t_32_8": partial( _quant_matmul, transpose=True, group_size=32, bits=8 ), + "quant_matmul_t_64_1": partial( + _quant_matmul, transpose=True, group_size=64, bits=1 + ), "quant_matmul_t_64_2": partial( _quant_matmul, transpose=True, group_size=64, bits=2 ), @@ -105,6 +116,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits): "quant_matmul_t_64_8": partial( _quant_matmul, transpose=True, group_size=64, bits=8 ), + "quant_matmul_t_128_1": partial( + _quant_matmul, transpose=True, group_size=128, bits=1 + ), "quant_matmul_t_128_2": partial( _quant_matmul, transpose=True, group_size=128, bits=2 ), diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index c0f1a3c315..73eac098ac 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -351,6 +351,10 @@ void _qmm_dispatch_typed( int bits, bool transposed_w) { switch (bits) { + case 1: + _qmm_dispatch_group( + result, x, w, scales, biases, M, N, K, group_size, transposed_w); + break; case 2: _qmm_dispatch_group( result, x, w, scales, biases, M, N, K, group_size, transposed_w); @@ -376,7 +380,8 @@ void _qmm_dispatch_typed( result, x, w, scales, biases, M, N, K, group_size, transposed_w); break; default: - throw std::invalid_argument("Quantization bits must be 2, 3, 4, 6 or 8."); + throw std::invalid_argument( + "Quantization bits must be 1, 2, 3, 4, 5, 6 or 8."); } } @@ -1172,15 +1177,24 @@ void quantize( w_min = std::min(w_min, (float)w[w_idx + j]); } bool mask = std::abs(w_min) > std::abs(w_max); - float scale = std::max((w_max - w_min) / n_bins, eps); - scale = mask ? scale : -scale; - - float edge = mask ? w_min : w_max; - float q0 = std::rint(edge / scale); - float bias = 0; - if (q0 != 0) { - scale = edge / q0; - bias = edge; + float scale; + float bias; + + if (bits == 1) { + // Affine 1-bit: bit 0 -> w_min, bit 1 -> w_max + scale = std::max(w_max - w_min, eps); + bias = w_min; + } else { + scale = std::max((w_max - w_min) / n_bins, eps); + scale = mask ? scale : -scale; + + float edge = mask ? w_min : w_max; + float q0 = std::rint(edge / scale); + bias = 0; + if (q0 != 0) { + scale = edge / q0; + bias = edge; + } } size_t out_idx = i * int_per_group; for (int j = 0; j < int_per_group / bytes_per_pack; ++j) { diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index d4a5d28bcb..608111f81a 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -28,13 +28,28 @@ inline constexpr short get_bytes_per_pack() { template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); U sum = 0; - if (bits == 2) { + if (bits == 1) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; + x_thread[i + 4] = x[i + 4]; + x_thread[i + 5] = x[i + 5]; + x_thread[i + 6] = x[i + 6]; + x_thread[i + 7] = x[i + 7]; + } + } + + else if (bits == 2) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; @@ -107,13 +122,28 @@ inline U load_vector(const device T* x, thread U* x_thread) { template inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); U sum = 0; - if (bits == 2) { + if (bits == 1) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; + x_thread[i + 4] = x[i + 4]; + x_thread[i + 5] = x[i + 5]; + x_thread[i + 6] = x[i + 6]; + x_thread[i + 7] = x[i + 7]; + } + } + + else if (bits == 2) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; @@ -196,13 +226,27 @@ inline U qdot( U bias, U sum) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); U accum = 0; - if (bits == 2) { + if (bits == 1) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t wb = w[i]; + accum += select(U(0), x_thread[8 * i], bool(wb & 0x01)); + accum += select(U(0), x_thread[8 * i + 1], bool(wb & 0x02)); + accum += select(U(0), x_thread[8 * i + 2], bool(wb & 0x04)); + accum += select(U(0), x_thread[8 * i + 3], bool(wb & 0x08)); + accum += select(U(0), x_thread[8 * i + 4], bool(wb & 0x10)); + accum += select(U(0), x_thread[8 * i + 5], bool(wb & 0x20)); + accum += select(U(0), x_thread[8 * i + 6], bool(wb & 0x40)); + accum += select(U(0), x_thread[8 * i + 7], bool(wb & 0x80)); + } + } + + else if (bits == 2) { for (int i = 0; i < (values_per_thread / 4); i++) { accum += (x_thread[4 * i] * (w[i] & 0x03) + @@ -298,13 +342,27 @@ inline U qdot_safe( U sum, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); U accum = 0; - if (bits == 2) { + if (bits == 1) { + for (int i = 0; i < (N / 8); i++) { + uint8_t wb = w[i]; + accum += select(U(0), x_thread[8 * i], bool(wb & 0x01)); + accum += select(U(0), x_thread[8 * i + 1], bool(wb & 0x02)); + accum += select(U(0), x_thread[8 * i + 2], bool(wb & 0x04)); + accum += select(U(0), x_thread[8 * i + 3], bool(wb & 0x08)); + accum += select(U(0), x_thread[8 * i + 4], bool(wb & 0x10)); + accum += select(U(0), x_thread[8 * i + 5], bool(wb & 0x20)); + accum += select(U(0), x_thread[8 * i + 6], bool(wb & 0x40)); + accum += select(U(0), x_thread[8 * i + 7], bool(wb & 0x80)); + } + } + + else if (bits == 2) { for (int i = 0; i < (N / 4); i++) { accum += (x_thread[4 * i] * (w[i] & 0x03) + @@ -395,11 +453,25 @@ template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); + + if (bits == 1) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t wb = w[i]; + result[8 * i] += x * (select(U(0), scale, bool(wb & 0x01)) + bias); + result[8 * i + 1] += x * (select(U(0), scale, bool(wb & 0x02)) + bias); + result[8 * i + 2] += x * (select(U(0), scale, bool(wb & 0x04)) + bias); + result[8 * i + 3] += x * (select(U(0), scale, bool(wb & 0x08)) + bias); + result[8 * i + 4] += x * (select(U(0), scale, bool(wb & 0x10)) + bias); + result[8 * i + 5] += x * (select(U(0), scale, bool(wb & 0x20)) + bias); + result[8 * i + 6] += x * (select(U(0), scale, bool(wb & 0x40)) + bias); + result[8 * i + 7] += x * (select(U(0), scale, bool(wb & 0x80)) + bias); + } + } - if (bits == 2) { + else if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; for (int i = 0; i < (values_per_thread / 4); i++) { result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); @@ -484,11 +556,33 @@ template inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); - if (bits == 2) { + if (bits == 1) { + U s[8] = { + scale, + scale / static_cast(2.0f), + scale / static_cast(4.0f), + scale / static_cast(8.0f), + scale / static_cast(16.0f), + scale / static_cast(32.0f), + scale / static_cast(64.0f), + scale / static_cast(128.0f)}; + for (int i = 0; i < (N / 8); i++) { + w_local[8 * i] = s[0] * (w[i] & 0x01) + bias; + w_local[8 * i + 1] = s[1] * (w[i] & 0x02) + bias; + w_local[8 * i + 2] = s[2] * (w[i] & 0x04) + bias; + w_local[8 * i + 3] = s[3] * (w[i] & 0x08) + bias; + w_local[8 * i + 4] = s[4] * (w[i] & 0x10) + bias; + w_local[8 * i + 5] = s[5] * (w[i] & 0x20) + bias; + w_local[8 * i + 6] = s[6] * (w[i] & 0x40) + bias; + w_local[8 * i + 7] = s[7] * (w[i] & 0x80) + bias; + } + } + + else if (bits == 2) { U s[4] = { scale, scale / static_cast(4.0f), @@ -577,9 +671,9 @@ struct QuantizedBlockLoader { group_size % BCOLS == 0, "The group size should be divisible by the columns"); static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); MLX_MTL_CONST short pack_factor = get_pack_factor(); MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); @@ -786,7 +880,9 @@ METAL_FUNC void qmv_fast_impl( x += tid.x * in_vec_size + simd_lid * values_per_thread; y += tid.x * out_vec_size + out_row; - for (int k = 0; k < in_vec_size; k += block_size) { + const int aligned_end = (in_vec_size / block_size) * block_size; + + for (int k = 0; k < aligned_end; k += block_size) { U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { @@ -805,6 +901,27 @@ METAL_FUNC void qmv_fast_impl( x += block_size; } + // Tail block: handles K that is a multiple of 512 (dispatch gate) but not + // of block_size. Only 1-bit has block_size > 512 (it uses 2048), so for + // >=2-bit this is a no-op. Out-of-bounds lanes skip the qdot call entirely; + // their result[row] stays at 0 and contributes 0 to the simd_sum below. + if (aligned_end < in_vec_size) { + bool in_bounds = (aligned_end + simd_lid * values_per_thread) < in_vec_size; + if (in_bounds) { + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot(wl, x_thread, s, b, sum); + } + } + } + for (int row = 0; row < results_per_simdgroup; row++) { result[row] = simd_sum(result[row]); if (simd_lid == 0) { @@ -2472,14 +2589,23 @@ template w_min = simd_min(w_min); w_max = simd_max(w_max); - float scale = max((w_max - w_min) / n_bins, eps); - bool side = abs(w_min) > abs(w_max); - scale = side ? scale : -scale; - float edge = side ? w_min : w_max; - float q0 = round(edge / scale); - bool at_zero = q0 == 0.0f; - scale = at_zero ? scale : edge / q0; - float bias = at_zero ? 0 : edge; + float scale; + float bias; + + if (bits == 1) { + // Affine 1-bit: bit 0 -> w_min, bit 1 -> w_max + scale = max(w_max - w_min, eps); + bias = w_min; + } else { + scale = max((w_max - w_min) / n_bins, eps); + bool side = abs(w_min) > abs(w_max); + scale = side ? scale : -scale; + float edge = side ? w_min : w_max; + float q0 = round(edge / scale); + bool at_zero = q0 == 0.0f; + scale = at_zero ? scale : edge / q0; + bias = at_zero ? 0 : edge; + } // Write out the scales and biases size_t gindex = in_index / group_size; @@ -2583,7 +2709,9 @@ template #pragma clang loop unroll(full) for (int i = 0; i < pack_factor; i++) { uint8_t d; - if (bits == 2) { + if (bits == 1) { + d = (val >> i) & 0x01; + } else if (bits == 2) { d = (val >> (bits * i)) & 0x03; } else if (bits == 4) { d = (val >> (bits * i)) & 0x0f; diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index d632cfbda5..2266ade5b2 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -148,6 +148,7 @@ instantiate_quantized_types(32, bits) #define instantiate_quantized_all() \ + instantiate_quantized_groups(1) \ instantiate_quantized_groups(2) \ instantiate_quantized_groups(3) \ instantiate_quantized_groups(4) \ diff --git a/mlx/backend/metal/kernels/quantized_nax.h b/mlx/backend/metal/kernels/quantized_nax.h index 8814fafa33..3f8c5d2495 100644 --- a/mlx/backend/metal/kernels/quantized_nax.h +++ b/mlx/backend/metal/kernels/quantized_nax.h @@ -31,13 +31,28 @@ inline constexpr short get_bytes_per_pack() { template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); U sum = 0; - if (bits == 2) { + if (bits == 1) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 2.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 8.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 32.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 128.0f; + } + } + + else if (bits == 2) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; @@ -110,13 +125,28 @@ inline U load_vector(const device T* x, thread U* x_thread) { template inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); U sum = 0; - if (bits == 2) { + if (bits == 1) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 2.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 8.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 32.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 128.0f; + } + } + + else if (bits == 2) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; @@ -199,13 +229,27 @@ inline U qdot( U bias, U sum) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); U accum = 0; - if (bits == 2) { + if (bits == 1) { + for (int i = 0; i < (values_per_thread / 8); i++) { + accum += + (x_thread[8 * i] * (w[i] & 0x01) + + x_thread[8 * i + 1] * (w[i] & 0x02) + + x_thread[8 * i + 2] * (w[i] & 0x04) + + x_thread[8 * i + 3] * (w[i] & 0x08) + + x_thread[8 * i + 4] * (w[i] & 0x10) + + x_thread[8 * i + 5] * (w[i] & 0x20) + + x_thread[8 * i + 6] * (w[i] & 0x40) + + x_thread[8 * i + 7] * (w[i] & 0x80)); + } + } + + else if (bits == 2) { for (int i = 0; i < (values_per_thread / 4); i++) { accum += (x_thread[4 * i] * (w[i] & 0x03) + @@ -301,13 +345,27 @@ inline U qdot_safe( U sum, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); U accum = 0; - if (bits == 2) { + if (bits == 1) { + for (int i = 0; i < (N / 8); i++) { + accum += + (x_thread[8 * i] * (w[i] & 0x01) + + x_thread[8 * i + 1] * (w[i] & 0x02) + + x_thread[8 * i + 2] * (w[i] & 0x04) + + x_thread[8 * i + 3] * (w[i] & 0x08) + + x_thread[8 * i + 4] * (w[i] & 0x10) + + x_thread[8 * i + 5] * (w[i] & 0x20) + + x_thread[8 * i + 6] * (w[i] & 0x40) + + x_thread[8 * i + 7] * (w[i] & 0x80)); + } + } + + else if (bits == 2) { for (int i = 0; i < (N / 4); i++) { accum += (x_thread[4 * i] * (w[i] & 0x03) + @@ -398,11 +456,33 @@ template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); - if (bits == 2) { + if (bits == 1) { + U s[8] = { + scale, + scale / 2.0f, + scale / 4.0f, + scale / 8.0f, + scale / 16.0f, + scale / 32.0f, + scale / 64.0f, + scale / 128.0f}; + for (int i = 0; i < (values_per_thread / 8); i++) { + result[8 * i] += x * (s[0] * (w[i] & 0x01) + bias); + result[8 * i + 1] += x * (s[1] * (w[i] & 0x02) + bias); + result[8 * i + 2] += x * (s[2] * (w[i] & 0x04) + bias); + result[8 * i + 3] += x * (s[3] * (w[i] & 0x08) + bias); + result[8 * i + 4] += x * (s[4] * (w[i] & 0x10) + bias); + result[8 * i + 5] += x * (s[5] * (w[i] & 0x20) + bias); + result[8 * i + 6] += x * (s[6] * (w[i] & 0x40) + bias); + result[8 * i + 7] += x * (s[7] * (w[i] & 0x80) + bias); + } + } + + else if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; for (int i = 0; i < (values_per_thread / 4); i++) { result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); @@ -487,11 +567,33 @@ template inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); + + if (bits == 1) { + U s[8] = { + scale, + scale / static_cast(2.0f), + scale / static_cast(4.0f), + scale / static_cast(8.0f), + scale / static_cast(16.0f), + scale / static_cast(32.0f), + scale / static_cast(64.0f), + scale / static_cast(128.0f)}; + for (int i = 0; i < (N / 8); i++) { + w_local[8 * i] = s[0] * (w[i] & 0x01) + bias; + w_local[8 * i + 1] = s[1] * (w[i] & 0x02) + bias; + w_local[8 * i + 2] = s[2] * (w[i] & 0x04) + bias; + w_local[8 * i + 3] = s[3] * (w[i] & 0x08) + bias; + w_local[8 * i + 4] = s[4] * (w[i] & 0x10) + bias; + w_local[8 * i + 5] = s[5] * (w[i] & 0x20) + bias; + w_local[8 * i + 6] = s[6] * (w[i] & 0x40) + bias; + w_local[8 * i + 7] = s[7] * (w[i] & 0x80) + bias; + } + } - if (bits == 2) { + else if (bits == 2) { U s[4] = { scale, scale / static_cast(4.0f), @@ -580,9 +682,9 @@ struct QuantizedBlockLoader { group_size % BCOLS == 0, "The group size should be divisible by the columns"); static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); MLX_MTL_CONST short pack_factor = get_pack_factor(); MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); @@ -715,9 +817,9 @@ struct QuantizedBlockLoader< BCOLS % group_size == 0, "The group size should be divisible by the columns"); static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || + bits == 6 || bits == 8, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); MLX_MTL_CONST short pack_factor = get_pack_factor(); MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); diff --git a/mlx/backend/metal/kernels/quantized_nax.metal b/mlx/backend/metal/kernels/quantized_nax.metal index 5a9c9fb874..fb68e535a9 100644 --- a/mlx/backend/metal/kernels/quantized_nax.metal +++ b/mlx/backend/metal/kernels/quantized_nax.metal @@ -96,6 +96,7 @@ instantiate_quantized_types(32, bits) #define instantiate_quantized_all() \ + instantiate_quantized_groups(1) \ instantiate_quantized_groups(2) \ instantiate_quantized_groups(3) \ instantiate_quantized_groups(4) \ diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index c8d5a31cb4..f8b52fe481 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -1377,7 +1377,7 @@ void dispatch_qmv( const Stream& s, const std::string& mode) { // It is a qmv with a small inner dimension so route to qmv_quad kernel - if ((K == 128 || K == 64) && is_power_of_2(bits)) { + if ((K == 128 || (K == 64 && bits >= 2)) && is_power_of_2(bits)) { qmv_quad(x, w, scales, biases, out, group_size, bits, M, N, K, d, s, mode); return; } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index defcc2f6e0..b569888975 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4742,10 +4742,10 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { throw std::invalid_argument(msg.str()); } - if (bits < 2 || bits > 8 || bits == 7) { + if (bits < 1 || bits > 8 || bits == 7) { std::ostringstream msg; msg << "[quantize] The requested number of bits " << bits - << " is not supported. The supported bits are 2, 3, 4, 5, 6 and 8."; + << " is not supported. The supported bits are 1, 2, 3, 4, 5, 6 and 8."; throw std::invalid_argument(msg.str()); } @@ -4766,14 +4766,22 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { w_max = astype(w_max, float32, s); w_min = astype(w_min, float32, s); - array mask = greater(abs(w_min, s), abs(w_max, s), s); - array scales = - maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); - scales = where(mask, scales, negative(scales, s), s); - array edge = where(mask, w_min, w_max, s); - array q0 = round(divide(edge, scales, s), s); - scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales); - array biases = where(equal(q0, zero, s), zero, edge, s); + array scales(0, float32); + array biases(0, float32); + + if (bits == 1) { + // Affine 1-bit: bit 0 -> w_min, bit 1 -> w_max + scales = maximum(subtract(w_max, w_min, s), eps, s); + biases = w_min; + } else { + array mask = greater(abs(w_min, s), abs(w_max, s), s); + scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); + scales = where(mask, scales, negative(scales, s), s); + array edge = where(mask, w_min, w_max, s); + array q0 = round(divide(edge, scales, s), s); + scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales); + biases = where(equal(q0, zero, s), zero, edge, s); + } packed_w = pack_and_quantize(packed_w, scales, biases, bits, s); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 9bbe0ff751..775bdb012e 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4400,14 +4400,14 @@ void init_ops(nb::module_& m) { .. table:: Quantization modes - ====== ====================== ========================== ============= ===== - mode group size bits scale type bias - ====== ====================== ========================== ============= ===== - affine 32, 64\ :sup:`*`, 128 2, 3, 4\ :sup:`*`, 5, 6, 8 same as input yes - mxfp4 32\ :sup:`*` 4\ :sup:`*` e8m0 no - mxfp8 32\ :sup:`*` 8\ :sup:`*` e8m0 no - nvfp4 16\ :sup:`*` 4\ :sup:`*` e4m3 no - ====== ====================== ========================== ============= ===== + ====== ====================== ============================== ============= ===== + mode group size bits scale type bias + ====== ====================== ============================== ============= ===== + affine 32, 64\ :sup:`*`, 128 1, 2, 3, 4\ :sup:`*`, 5, 6, 8 same as input yes + mxfp4 32\ :sup:`*` 4\ :sup:`*` e8m0 no + mxfp8 32\ :sup:`*` 8\ :sup:`*` e8m0 no + nvfp4 16\ :sup:`*` 4\ :sup:`*` e4m3 no + ====== ====================== ============================== ============= ===== :sup:`*` indicates the default value when unspecified. diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 31a657d713..361b5c0bed 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -19,4 +19,6 @@ "TestQuantized.test_gather_qmm", "TestQuantized.test_gather_qmm_sorted", "TestQuantized.test_gather_qmm_grad", + # 1-bit quantization NYI on CUDA + "TestQuantized.test_1bit_quantize_dequantize", } diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index a7472e9920..6049c71ff8 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -11,7 +11,7 @@ class TestQuantized(mlx_tests.MLXTestCase): def test_quantize_dequantize(self): w = mx.random.normal(shape=(128, 512)) for gs in [32, 64, 128]: - for b in [2, 3, 5, 6, 4, 8]: + for b in [1, 2, 3, 5, 6, 4, 8]: with self.subTest(gs=gs, b=b): w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b) w_hat = mx.dequantize(w_q, scales, biases, gs, b) @@ -22,7 +22,7 @@ def test_quantize_dequantize(self): # test quantize/dequantize 0s a = mx.zeros((256, 512)) for gs in [32, 64, 128]: - for b in [2, 3, 4, 5, 6, 8]: + for b in [1, 2, 3, 4, 5, 6, 8]: w_q, scales, biases = mx.quantize(a, gs, b) a_hat = mx.dequantize(w_q, scales, biases, gs, b) self.assertTrue(mx.all(a_hat == 0)) @@ -173,6 +173,96 @@ def test_nvfp4_quantize_dequantize(self): ) self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5)) + def test_1bit_quantize_dequantize(self): + """Test 1-bit affine quantization.""" + + # Symmetric binary weights {-0.5, +0.5} should round-trip perfectly + # (affine formula gives scale=1.0, bias=-0.5) + for gs in [32, 64, 128]: + with self.subTest(gs=gs, case="pack_symmetric_weights"): + signs = (mx.random.uniform(shape=(128, 512)) > 0.5).astype(mx.float32) + w = signs * 1.0 - (1 - signs) * 1.0 # {-1.0, +1.0} + w = w * 0.5 # {-0.5, +0.5} + + w_q, scales, biases = mx.quantize(w, group_size=gs, bits=1) + w_hat = mx.dequantize(w_q, scales, biases, gs, 1) + + self.assertLess((w - w_hat).abs().max(), 1e-5) + + # Asymmetric binary weights {0.1, 0.9} should round-trip perfectly + # (affine formula gives scale=0.8, bias=0.1) + for gs in [32, 64, 128]: + with self.subTest(gs=gs, case="pack_asymmetric_weights"): + bits = (mx.random.uniform(shape=(128, 512)) > 0.5).astype(mx.float32) + w = bits * 0.9 + (1 - bits) * 0.1 # {0.1, 0.9} + + w_q, scales, biases = mx.quantize(w, group_size=gs, bits=1) + w_hat = mx.dequantize(w_q, scales, biases, gs, 1) + + self.assertLess((w - w_hat).abs().max(), 1e-5) + + # Verify dequantized values are exactly {bias, bias + scale} + w = mx.random.normal(shape=(64, 256)) + for gs in [32, 64, 128]: + with self.subTest(gs=gs, case="dequant_values"): + w_q, scales, biases = mx.quantize(w, group_size=gs, bits=1) + w_hat = mx.dequantize(w_q, scales, biases, gs, 1) + + for i in range(scales.shape[0]): + for j in range(scales.shape[1]): + s = scales[i, j].item() + b = biases[i, j].item() + row_start = j * gs + row_end = row_start + gs + vals = w_hat[i, row_start:row_end] + mx.eval(vals) + for v in vals.tolist(): + self.assertTrue( + abs(v - b) < 1e-5 or abs(v - (b + s)) < 1e-5, + f"Value {v} not in {{bias={b}, bias+scale={b+s}}}", + ) + + # 1-bit quantize/dequantize zeros — scale floors to eps, bias=0 + a = mx.zeros((256, 512)) + for gs in [32, 64, 128]: + w_q, scales, biases = mx.quantize(a, gs, 1) + a_hat = mx.dequantize(w_q, scales, biases, gs, 1) + self.assertLess(a_hat.abs().max(), 1e-5) + + # Quantized matmul with symmetric binary weights + key = mx.random.key(42) + k1, k2 = mx.random.split(key) + for gs in [32, 64, 128]: + with self.subTest(gs=gs, case="quantized_matmul_symmetric"): + x = mx.random.normal(shape=(4, 256), key=k1) + signs = (mx.random.uniform(shape=(128, 256), key=k2) > 0.5).astype( + mx.float32 + ) + w = signs * 0.3 - (1 - signs) * 0.3 # {-0.3, +0.3} + + w_q, scales, biases = mx.quantize(w, gs, 1) + w_hat = mx.dequantize(w_q, scales, biases, gs, 1) + y_q = mx.quantized_matmul(x, w_q, scales, biases, True, gs, 1) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-5) + + # Quantized matmul with asymmetric binary weights + for gs in [32, 64, 128]: + with self.subTest(gs=gs, case="quantized_matmul_asymmetric"): + x = mx.random.normal(shape=(4, 256), key=k1) + bits = (mx.random.uniform(shape=(128, 256), key=k2) > 0.5).astype( + mx.float32 + ) + w = bits * 0.7 + (1 - bits) * 0.1 # {0.1, 0.7} + + w_q, scales, biases = mx.quantize(w, gs, 1) + w_hat = mx.dequantize(w_q, scales, biases, gs, 1) + y_q = mx.quantized_matmul(x, w_q, scales, biases, True, gs, 1) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-5) + def test_qqmv(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) @@ -211,7 +301,7 @@ def test_qmm(self): dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32 tests = product( [128, 64, 32], # group_size - [2, 4, 8], # bits + [1, 2, 4, 8], # bits [8, 32, 33, 64], # M [128, 256], # N [128, 256], # K