Describe the bug
mx.quantized_matmul produces severely wrong results when the weight tensor has a stride-0 batch dimension (from mx.expand_dims, used for GQA broadcasting) and M is small enough to trigger the qmv kernel. The error is orders of magnitude larger than the shape-dependence issue described in #3473 (max_diff up to 12.46 in real inference vs ~2e-5 for the proj case), large enough to corrupt softmax distributions and cause LLM hallucinations.
In GQA (Grouped Query Attention) with a quantized KV cache, the standard pattern is:
# queries: (B, n_kv_heads, n_repeats, M, D)
# q_keys: (B, n_kv_heads, N, D//pack_factor) — quantized tuple
qk_e = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
# qk_e[0] shape: (B, n_kv_heads, 1, N, D//pack_factor), stride-0 on dim -3
scores = mx.quantized_matmul(queries, *qk_e, transpose=True, ...)
ensure_row_contiguous_matrix (called at the top of QuantizedMatmul::eval_gpu) only checks the last two dimensions for row-contiguity. A weight tensor with stride-0 on a batch dimension passes through without a copy, retaining its zero stride. That stride is then forwarded verbatim to the Metal kernel via add_strides_and_shapes. When M < vector_limit, the dispatch routes to qmv, which does not handle stride-0 batch dimensions correctly, producing wrong output for large N (error appears at N > ~2048 and grows with cache size). The qmm path handles stride-0 correctly.
This is a distinct bug from the shape-dependence issue in #3473: the problem here is not fp32 non-associativity across different reduction trees, but incorrect memory access patterns in qmv when batch strides are zero.
To Reproduce
import mlx.core as mx
from mlx.utils import tree_map
B, n_kv_heads, n_repeats, D = 1, 4, 4, 256
n_q_heads = n_kv_heads * n_repeats
key_bits, value_bits, group_size = 8, 4, 64
mx.random.seed(42)
def run(N, M):
keys_f = mx.random.normal((B, n_kv_heads, N, D)).astype(mx.float16)
values_f = mx.random.normal((B, n_kv_heads, N, D)).astype(mx.float16)
q_keys = mx.quantize(keys_f, group_size=group_size, bits=key_bits)
q_values = mx.quantize(values_f, group_size=group_size, bits=value_bits)
queries = (mx.random.normal((B, n_q_heads, M, D)) * D**-0.5).astype(mx.float16)
mx.eval(q_keys, q_values, queries)
# reference: dequantize then float matmul
keys_dq = mx.dequantize(*q_keys, group_size=group_size, bits=key_bits)
values_dq = mx.dequantize(*q_values, group_size=group_size, bits=value_bits)
qr = queries.reshape(B, n_kv_heads, n_repeats, M, D)
s_ref = mx.softmax(qr @ keys_dq[:,:,None,:,:].transpose(0,1,2,4,3), axis=-1)
out_ref = (s_ref @ values_dq[:,:,None,:,:]).reshape(B, n_q_heads, M, D)
# quantized_matmul with expand_dims broadcast (GQA)
qk_e = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
qv_e = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
s_qmm = mx.softmax(mx.quantized_matmul(qr, *qk_e, transpose=True, group_size=group_size, bits=key_bits), axis=-1)
out_qmm = mx.quantized_matmul(s_qmm, *qv_e, transpose=False, group_size=group_size, bits=value_bits).reshape(B, n_q_heads, M, D)
mx.eval(out_ref, out_qmm)
diff = mx.max(mx.abs(out_ref.astype(mx.float32) - out_qmm.astype(mx.float32))).item()
print(f"N={N:5d} M={M}: max_diff={diff:.6f}")
for N in [512, 2048, 4096, 7358]:
for M in [1, 2]:
run(N, M)
Output on M4 Pro:
N= 512 M=1: max_diff=0.000244 # within quantization noise
N= 512 M=2: max_diff=0.000244 # ok, qmm path
N= 2048 M=1: max_diff=0.000122
N= 2048 M=2: max_diff=0.140656 # qmv triggered, error grows
N= 4096 M=1: max_diff=0.000092
N= 4096 M=2: max_diff=0.118958 # wrong
N= 7358 M=1: max_diff=0.000061
N= 7358 M=2: max_diff=0.079422 # wrong, mean_diff ~0.009
M=1 is always correct (qmv handles the non-broadcast single-token decode case). The bug triggers for M >= 2 combined with stride-0 batch dims and N > ~2048.
Expected behavior
quantized_matmul with a stride-0 weight tensor (GQA expand_dims broadcast) should produce the same result as dequantizing the weights and using standard float matmul, regardless of M.
Desktop:
- OS Version: macOS 26.4.1
- Version: 0.31.1 and 0.31.2
Additional context
In actual inference (real model, N=7358), I measured max_diff up to 12.46 between the qmv and dequantize-then-float-matmul paths in certain attention layers. The practical symptom is severely degraded generation: repetition loops, wrong task execution. The issue only manifests when a quantized KV cache is used (routing through quantized_scaled_dot_product_attention in mlx-lm) combined with a small-M pass such as a 2-token verification step. Without KV quantization, mx.fast.scaled_dot_product_attention is used and the bug does not trigger.
Proposed fix
After ensure_row_contiguous_matrix, check for zero strides on the batch dimensions of w before the kernel dispatch. If any are found, route to qmm, which already handles stride-0 batch dims correctly (confirmed by the reproducer: max_diff within quantization noise at M=2, N=512, where qmm is selected):
// In QuantizedMatmul::eval_gpu, after ensure_row_contiguous_matrix calls:
bool has_broadcast_batch = false;
for (int i = 0; i + 2 < w.ndim(); i++) {
if (w.strides()[i] == 0) { has_broadcast_batch = true; break; }
}
if (has_broadcast_batch) {
qmm(x, w, scales, biases, out, transpose_, group_size_, bits_, M, N, K, d, s, mode);
return;
}
This is complementary to the MLX_NUMERICAL_STRICT_MODE opt-in in #3473, which addresses the shape-dependence issue for contiguous tensors. The stride-0 check fires automatically with no env var and adds negligible overhead. Standard M=1 decode with contiguous weights is completely unaffected.
It may also be worth checking whether GatherQMM::eval_gpu is exposed to the same issue for MoE models with GQA.
Describe the bug
mx.quantized_matmulproduces severely wrong results when the weight tensor has a stride-0 batch dimension (frommx.expand_dims, used for GQA broadcasting) and M is small enough to trigger theqmvkernel. The error is orders of magnitude larger than the shape-dependence issue described in #3473 (max_diff up to 12.46 in real inference vs ~2e-5 for the proj case), large enough to corrupt softmax distributions and cause LLM hallucinations.In GQA (Grouped Query Attention) with a quantized KV cache, the standard pattern is:
ensure_row_contiguous_matrix(called at the top ofQuantizedMatmul::eval_gpu) only checks the last two dimensions for row-contiguity. A weight tensor with stride-0 on a batch dimension passes through without a copy, retaining its zero stride. That stride is then forwarded verbatim to the Metal kernel viaadd_strides_and_shapes. When M <vector_limit, the dispatch routes toqmv, which does not handle stride-0 batch dimensions correctly, producing wrong output for large N (error appears at N > ~2048 and grows with cache size). Theqmmpath handles stride-0 correctly.This is a distinct bug from the shape-dependence issue in #3473: the problem here is not fp32 non-associativity across different reduction trees, but incorrect memory access patterns in
qmvwhen batch strides are zero.To Reproduce
Output on M4 Pro:
M=1 is always correct (qmv handles the non-broadcast single-token decode case). The bug triggers for M >= 2 combined with stride-0 batch dims and N > ~2048.
Expected behavior
quantized_matmulwith a stride-0 weight tensor (GQAexpand_dimsbroadcast) should produce the same result as dequantizing the weights and using standard float matmul, regardless of M.Desktop:
Additional context
In actual inference (real model, N=7358), I measured
max_diffup to 12.46 between theqmvand dequantize-then-float-matmul paths in certain attention layers. The practical symptom is severely degraded generation: repetition loops, wrong task execution. The issue only manifests when a quantized KV cache is used (routing throughquantized_scaled_dot_product_attentionin mlx-lm) combined with a small-M pass such as a 2-token verification step. Without KV quantization,mx.fast.scaled_dot_product_attentionis used and the bug does not trigger.Proposed fix
After
ensure_row_contiguous_matrix, check for zero strides on the batch dimensions ofwbefore the kernel dispatch. If any are found, route toqmm, which already handles stride-0 batch dims correctly (confirmed by the reproducer: max_diff within quantization noise at M=2, N=512, whereqmmis selected):This is complementary to the
MLX_NUMERICAL_STRICT_MODEopt-in in #3473, which addresses the shape-dependence issue for contiguous tensors. The stride-0 check fires automatically with no env var and adds negligible overhead. Standard M=1 decode with contiguous weights is completely unaffected.It may also be worth checking whether
GatherQMM::eval_gpuis exposed to the same issue for MoE models with GQA.