Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion benchmarks/python/comparative/bench_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,17 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits):


quant_matmul = {
"quant_matmul_32_1": partial(_quant_matmul, transpose=False, group_size=32, bits=1),
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
"quant_matmul_64_1": partial(_quant_matmul, transpose=False, group_size=64, bits=1),
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
"quant_matmul_128_1": partial(
_quant_matmul, transpose=False, group_size=128, bits=1
),
"quant_matmul_128_2": partial(
_quant_matmul, transpose=False, group_size=128, bits=2
),
Expand All @@ -87,6 +92,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits):
"quant_matmul_128_8": partial(
_quant_matmul, transpose=False, group_size=128, bits=8
),
"quant_matmul_t_32_1": partial(
_quant_matmul, transpose=True, group_size=32, bits=1
),
"quant_matmul_t_32_2": partial(
_quant_matmul, transpose=True, group_size=32, bits=2
),
Expand All @@ -96,6 +104,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits):
"quant_matmul_t_32_8": partial(
_quant_matmul, transpose=True, group_size=32, bits=8
),
"quant_matmul_t_64_1": partial(
_quant_matmul, transpose=True, group_size=64, bits=1
),
"quant_matmul_t_64_2": partial(
_quant_matmul, transpose=True, group_size=64, bits=2
),
Expand All @@ -105,6 +116,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits):
"quant_matmul_t_64_8": partial(
_quant_matmul, transpose=True, group_size=64, bits=8
),
"quant_matmul_t_128_1": partial(
_quant_matmul, transpose=True, group_size=128, bits=1
),
"quant_matmul_t_128_2": partial(
_quant_matmul, transpose=True, group_size=128, bits=2
),
Expand Down Expand Up @@ -420,7 +434,22 @@ def selu(x):
print(bench(matmul, *xs))

elif args.benchmark.startswith("quant_matmul"):
print(bench(quant_matmul[args.benchmark], *xs))
# Parse group_size and bits from the benchmark name, e.g.
# "quant_matmul_128_4" or "quant_matmul_t_128_4"
fn = quant_matmul[args.benchmark]
gs = fn.keywords["group_size"]
bits = fn.keywords["bits"]
transpose = fn.keywords["transpose"]

# xs[0] = activation x, xs[1] = original (float) weight matrix
# Quantize the weight internally so the caller only needs:
# --size MxK --size NxK (transpose=True) or --size MxK --size KxN
w_float = xs[1].astype(mx.float16)
w_q, scales, biases = mx.quantize(w_float, group_size=gs, bits=bits)
mx.eval(w_q, scales, biases)
x_input = xs[0].astype(mx.float16)
mx.eval(x_input)
print(bench(_quant_matmul, x_input, w_q, scales, biases, transpose, gs, bits))

elif args.benchmark == "linear":
if args.fused:
Expand Down
35 changes: 35 additions & 0 deletions benchmarks/python/comparative/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ def compare(args):
print((t_torch - t_mlx) / t_torch, " ".join(args), sep="\t")


def compare_mlx_quant(args_base, bits_list):
"""Compare quantized matmul across bit widths (MLX only, no PyTorch)."""
results = {}
for bits in bits_list:
bench_args = args_base.replace("{bits}", str(bits)).split()
results[bits] = run_or_raise(["python", BENCH_MLX] + bench_args)
baseline = max(results.values())
for bits in bits_list:
speedup = (baseline - results[bits]) / baseline if baseline > 0 else 0
print(f"{speedup:.4f}\t{args_base.replace('{bits}', str(bits))}")


def compare_mlx_dtypes(args, dt1, dt2):
t_mlx_dt1 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt1])
t_mlx_dt2 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt2])
Expand Down Expand Up @@ -282,3 +294,26 @@ def predicate(x):
compare_filtered("topk --size 32768x128 --axis 1")
compare_filtered("topk --size 128x128 --axis 0 --cpu")
compare_filtered("topk --size 128x128 --axis 1 --cpu")

# Quantized matmul ops (MLX only — compare across bit widths)
# qmv path (M=1, token generation, memory-bandwidth bound)
for gs in [64, 128]:
compare_mlx_quant(
f"quant_matmul_t_{gs}_{{bits}} --size 1x4096 --size 4096x4096",
[1, 2, 4, 8],
)
compare_mlx_quant(
f"quant_matmul_t_{gs}_{{bits}} --size 1x4096 --size 11008x4096",
[1, 2, 4, 8],
)
# qmm path (prompt processing, more compute bound)
for gs in [64, 128]:
for M in [32, 512]:
compare_mlx_quant(
f"quant_matmul_t_{gs}_{{bits}} --size {M}x4096 --size 4096x4096",
[1, 2, 4, 8],
)
compare_mlx_quant(
f"quant_matmul_t_{gs}_{{bits}} --size {M}x4096 --size 11008x4096",
[1, 2, 4, 8],
)
34 changes: 24 additions & 10 deletions mlx/backend/cpu/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,10 @@ void _qmm_dispatch_typed(
int bits,
bool transposed_w) {
switch (bits) {
case 1:
_qmm_dispatch_group<T, 1>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 2:
_qmm_dispatch_group<T, 2>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
Expand All @@ -376,7 +380,8 @@ void _qmm_dispatch_typed(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
default:
throw std::invalid_argument("Quantization bits must be 2, 3, 4, 6 or 8.");
throw std::invalid_argument(
"Quantization bits must be 1, 2, 3, 4, 5, 6 or 8.");
}
}

Expand Down Expand Up @@ -1172,15 +1177,24 @@ void quantize(
w_min = std::min(w_min, (float)w[w_idx + j]);
}
bool mask = std::abs(w_min) > std::abs(w_max);
float scale = std::max((w_max - w_min) / n_bins, eps);
scale = mask ? scale : -scale;

float edge = mask ? w_min : w_max;
float q0 = std::rint(edge / scale);
float bias = 0;
if (q0 != 0) {
scale = edge / q0;
bias = edge;
float scale;
float bias;

if (bits == 1) {
// Affine 1-bit: bit 0 -> w_min, bit 1 -> w_max
scale = std::max(w_max - w_min, eps);
bias = w_min;
} else {
scale = std::max((w_max - w_min) / n_bins, eps);
scale = mask ? scale : -scale;

float edge = mask ? w_min : w_max;
float q0 = std::rint(edge / scale);
bias = 0;
if (q0 != 0) {
scale = edge / q0;
bias = edge;
}
}
size_t out_idx = i * int_per_group;
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
Expand Down
Loading