Skip to content

[BUG] quantized_matmul produces wrong results for GQA expand_dims broadcasting when M < vector_limit #3480

@AirRunner

Description

@AirRunner

Describe the bug

mx.quantized_matmul produces severely wrong results when the weight tensor has a stride-0 batch dimension (from mx.expand_dims, used for GQA broadcasting) and M is small enough to trigger the qmv kernel. The error is orders of magnitude larger than the shape-dependence issue described in #3473 (max_diff up to 12.46 in real inference vs ~2e-5 for the proj case), large enough to corrupt softmax distributions and cause LLM hallucinations.

In GQA (Grouped Query Attention) with a quantized KV cache, the standard pattern is:

# queries: (B, n_kv_heads, n_repeats, M, D)
# q_keys:  (B, n_kv_heads, N, D//pack_factor)  — quantized tuple
qk_e = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
# qk_e[0] shape: (B, n_kv_heads, 1, N, D//pack_factor), stride-0 on dim -3
scores = mx.quantized_matmul(queries, *qk_e, transpose=True, ...)

ensure_row_contiguous_matrix (called at the top of QuantizedMatmul::eval_gpu) only checks the last two dimensions for row-contiguity. A weight tensor with stride-0 on a batch dimension passes through without a copy, retaining its zero stride. That stride is then forwarded verbatim to the Metal kernel via add_strides_and_shapes. When M < vector_limit, the dispatch routes to qmv, which does not handle stride-0 batch dimensions correctly, producing wrong output for large N (error appears at N > ~2048 and grows with cache size). The qmm path handles stride-0 correctly.

This is a distinct bug from the shape-dependence issue in #3473: the problem here is not fp32 non-associativity across different reduction trees, but incorrect memory access patterns in qmv when batch strides are zero.

To Reproduce

import mlx.core as mx
from mlx.utils import tree_map

B, n_kv_heads, n_repeats, D = 1, 4, 4, 256
n_q_heads = n_kv_heads * n_repeats
key_bits, value_bits, group_size = 8, 4, 64

mx.random.seed(42)

def run(N, M):
    keys_f = mx.random.normal((B, n_kv_heads, N, D)).astype(mx.float16)
    values_f = mx.random.normal((B, n_kv_heads, N, D)).astype(mx.float16)
    q_keys = mx.quantize(keys_f, group_size=group_size, bits=key_bits)
    q_values = mx.quantize(values_f, group_size=group_size, bits=value_bits)
    queries = (mx.random.normal((B, n_q_heads, M, D)) * D**-0.5).astype(mx.float16)
    mx.eval(q_keys, q_values, queries)

    # reference: dequantize then float matmul
    keys_dq = mx.dequantize(*q_keys, group_size=group_size, bits=key_bits)
    values_dq = mx.dequantize(*q_values, group_size=group_size, bits=value_bits)
    qr = queries.reshape(B, n_kv_heads, n_repeats, M, D)
    s_ref = mx.softmax(qr @ keys_dq[:,:,None,:,:].transpose(0,1,2,4,3), axis=-1)
    out_ref = (s_ref @ values_dq[:,:,None,:,:]).reshape(B, n_q_heads, M, D)

    # quantized_matmul with expand_dims broadcast (GQA)
    qk_e = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
    qv_e = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
    s_qmm = mx.softmax(mx.quantized_matmul(qr, *qk_e, transpose=True, group_size=group_size, bits=key_bits), axis=-1)
    out_qmm = mx.quantized_matmul(s_qmm, *qv_e, transpose=False, group_size=group_size, bits=value_bits).reshape(B, n_q_heads, M, D)

    mx.eval(out_ref, out_qmm)
    diff = mx.max(mx.abs(out_ref.astype(mx.float32) - out_qmm.astype(mx.float32))).item()
    print(f"N={N:5d} M={M}: max_diff={diff:.6f}")

for N in [512, 2048, 4096, 7358]:
    for M in [1, 2]:
        run(N, M)

Output on M4 Pro:

N=  512 M=1: max_diff=0.000244   # within quantization noise
N=  512 M=2: max_diff=0.000244   # ok, qmm path
N= 2048 M=1: max_diff=0.000122
N= 2048 M=2: max_diff=0.140656   # qmv triggered, error grows
N= 4096 M=1: max_diff=0.000092
N= 4096 M=2: max_diff=0.118958   # wrong
N= 7358 M=1: max_diff=0.000061
N= 7358 M=2: max_diff=0.079422   # wrong, mean_diff ~0.009

M=1 is always correct (qmv handles the non-broadcast single-token decode case). The bug triggers for M >= 2 combined with stride-0 batch dims and N > ~2048.

Expected behavior

quantized_matmul with a stride-0 weight tensor (GQA expand_dims broadcast) should produce the same result as dequantizing the weights and using standard float matmul, regardless of M.

Desktop:

  • OS Version: macOS 26.4.1
  • Version: 0.31.1 and 0.31.2

Additional context

In actual inference (real model, N=7358), I measured max_diff up to 12.46 between the qmv and dequantize-then-float-matmul paths in certain attention layers. The practical symptom is severely degraded generation: repetition loops, wrong task execution. The issue only manifests when a quantized KV cache is used (routing through quantized_scaled_dot_product_attention in mlx-lm) combined with a small-M pass such as a 2-token verification step. Without KV quantization, mx.fast.scaled_dot_product_attention is used and the bug does not trigger.

Proposed fix

After ensure_row_contiguous_matrix, check for zero strides on the batch dimensions of w before the kernel dispatch. If any are found, route to qmm, which already handles stride-0 batch dims correctly (confirmed by the reproducer: max_diff within quantization noise at M=2, N=512, where qmm is selected):

// In QuantizedMatmul::eval_gpu, after ensure_row_contiguous_matrix calls:
bool has_broadcast_batch = false;
for (int i = 0; i + 2 < w.ndim(); i++) {
    if (w.strides()[i] == 0) { has_broadcast_batch = true; break; }
}
if (has_broadcast_batch) {
    qmm(x, w, scales, biases, out, transpose_, group_size_, bits_, M, N, K, d, s, mode);
    return;
}

This is complementary to the MLX_NUMERICAL_STRICT_MODE opt-in in #3473, which addresses the shape-dependence issue for contiguous tensors. The stride-0 check fires automatically with no env var and adds negligible overhead. Standard M=1 decode with contiguous weights is completely unaffected.

It may also be worth checking whether GatherQMM::eval_gpu is exposed to the same issue for MoE models with GQA.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions