diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index c8d5a31cb4..af8df31aec 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -1408,6 +1408,32 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4; auto mode = quantization_mode_to_string(mode_); + + // Numerical strict mode (MLX_NUMERICAL_STRICT_MODE=1): bypass the + // shape-dependent fast paths (qmv, qmm_splitk, qvm_split_k) so output is + // bit-identical regardless of M. Costs ~1.5-2.3x slower decode at M=1 + // (qmv is heavily optimized for the M=1 case) but gives path-independence + // required for prefix-cache reuse, batched-vs-streaming eval comparison, + // and distillation/RLHF teacher-student equality. See + // env::numerical_strict_mode() in mlx/utils.h. + if (env::numerical_strict_mode()) { + qmm(x, + w, + scales, + biases, + out, + transpose_, + group_size_, + bits_, + M, + N, + K, + d, + s, + mode); + return; + } + // It is a matrix matrix product. if (M >= vector_limit) { // Use split-K qmm for small M with transposed weights (non-batched only) @@ -1477,6 +1503,33 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4; auto mode = quantization_mode_to_string(mode_); + // Numerical strict mode (MLX_NUMERICAL_STRICT_MODE=1): bypass the + // shape-dependent fast paths (gather_qmv, gather_qvm, gather_qmm_rhs) so + // GatherQMM output is bit-identical regardless of M. Same justification as + // QuantizedMatmul::eval_gpu — gather_qmm is the reference path that uses + // sequential register-fma accumulation matching qmm. Necessary for + // bit-equivalence in MoE models (Mixtral-MoE, DeepSeek-MoE, etc.). + if (env::numerical_strict_mode()) { + gather_qmm( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + transpose_, + group_size_, + bits_, + M, + N, + K, + d, + s, + mode); + return; + } + // We are walking x in order and w is also in order so we can batch up the // matmuls and reuse reading x and w. // diff --git a/mlx/utils.h b/mlx/utils.h index 7835a97028..99ce09c842 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -173,6 +173,25 @@ inline bool enable_tf32() { return enable_tf32_; } +// When set, QuantizedMatmul forces the no-split qmm path for all 2D shapes, +// bypassing the qmv (M < vector_limit) and qmm_splitk fast paths. This +// guarantees that quantized_matmul output is independent of input shape: +// q_proj(x[:, -L:]) is bit-identical to q_proj(x)[:, -L:] for any L. +// +// The fast paths use parallel reductions across K (simd-butterfly in qmv, +// partition-then-sum in splitk) which produce different fp32 sums than qmm's +// sequential register-level accumulation. Even when both paths use fp32 +// throughout, fp32 is non-associative so the bit patterns differ by ~ULP. +// +// This bites workloads that compare two equivalent paths — prefix-cache reuse, +// batched-vs-streaming eval, distillation/RLHF teacher-student equality. For +// straight inference / training the diff is invisible. Off by default. +inline bool numerical_strict_mode() { + static bool numerical_strict_mode_ = + get_var("MLX_NUMERICAL_STRICT_MODE", 0); + return numerical_strict_mode_; +} + inline int nccl_timeout(int default_value) { static int nccl_timeout = get_var("MLX_NCCL_TIMEOUT", default_value); return nccl_timeout;