From 0b47c468355d9afd6bd4b348a31313ad30d80b44 Mon Sep 17 00:00:00 2001 From: Adam Durham Date: Sun, 26 Apr 2026 12:43:59 -0500 Subject: [PATCH 1/3] feat: MLX_SDPA_BLOCKS env var to override 2-pass vector kernel block count MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 2-pass SDPA vector kernel picks `blocks` (the partial-tile count) from a heuristic over device class + sequence length. The defaults are sensible across the upstream-tested matrix but leave money on the table on combinations the heuristic doesn't anticipate. This adds an `MLX_SDPA_BLOCKS` env var that overrides the heuristic to a positive integer. Unset / non-positive: heuristic unchanged. Empirical example: on a 2-rank M4-Ultra cluster running long-context MoE inference at K~50k, the heuristic picks 1024 but `blocks=88` is +6.5% decode tps with a sharp cliff at 92+, matching the ~352-concurrent-simdgroup capacity (4 kv_heads × 88 ≈ 1.1 dispatch rounds). Different workloads will sit at different optima — letting operators sweep without recompiling MLX is the value. Also adds a regression test (`test_sdpa_blocks_env_override`) that sweeps {16, 64, 256} and asserts numerics match the heuristic-default output, so future changes to the 2-pass dispatch path don't silently break the override. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../metal/scaled_dot_product_attention.cpp | 18 +++++++++++ python/tests/test_fast_sdpa.py | 32 +++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index c79cd51ff0..197b341d5d 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -1,4 +1,5 @@ // Copyright © 2024 Apple Inc. +#include #include #include "mlx/backend/common/compiled.h" @@ -15,6 +16,20 @@ namespace mlx::core::fast { namespace { +// Override the heuristic-chosen `blocks` count for the 2-pass SDPA vector +// kernel via the MLX_SDPA_BLOCKS env var. Returns -1 (no override) when +// unset or non-positive. Useful for tuning the partial-tile count on +// device/workload combinations the heuristic doesn't anticipate. +int sdpa_2pass_blocks_override() { + if (auto* env = std::getenv("MLX_SDPA_BLOCKS")) { + int v = std::atoi(env); + if (v > 0) { + return v; + } + } + return -1; +} + void sdpa_full_self_attention_nax( const Stream& s, metal::Device& d, @@ -474,6 +489,9 @@ void sdpa_vector_2pass( blocks = 32; } } + if (int override = sdpa_2pass_blocks_override(); override > 0) { + blocks = override; + } size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 7bd867084e..743050a67e 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -682,6 +682,38 @@ def test_sdpa_sliced(self): tolerance = {"rtol": 1e-2, "atol": 1e-2} self.assertTrue(mx.allclose(ref, out, **tolerance)) + def test_sdpa_blocks_env_override(self): + # The 2-pass vector kernel chooses `blocks` heuristically from the + # device + sequence length. MLX_SDPA_BLOCKS overrides that choice + # but must not change correctness. + D = 128 + Nq = 4 + Nkv = 1 + N = 8192 # long enough to take the 2-pass path + scale = D**-0.5 + mx.random.seed(0) + q = mx.random.normal(shape=(1, Nq, 1, D)).astype(mx.float16) + k = mx.random.normal(shape=(1, Nkv, N, D)).astype(mx.float16) + v = mx.random.normal(shape=(1, Nkv, N, D)).astype(mx.float16) + + prev = os.environ.pop("MLX_SDPA_BLOCKS", None) + try: + ref = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) + mx.eval(ref) + for blocks in ("16", "64", "256"): + os.environ["MLX_SDPA_BLOCKS"] = blocks + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) + mx.eval(out) + self.assertTrue( + mx.allclose(ref, out, rtol=1e-3, atol=1e-3), + f"MLX_SDPA_BLOCKS={blocks} changed numerics", + ) + finally: + if prev is None: + os.environ.pop("MLX_SDPA_BLOCKS", None) + else: + os.environ["MLX_SDPA_BLOCKS"] = prev + if __name__ == "__main__": mlx_tests.MLXTestRunner(failfast=True) From 4b87e47961a1cbfac6e347fb34a21a67590713e0 Mon Sep 17 00:00:00 2001 From: Adam Durham Date: Sun, 26 Apr 2026 20:52:40 -0500 Subject: [PATCH 2/3] sdpa: address review feedback on MLX_SDPA_BLOCKS - Rename helper away from C++ keyword `override` (sdpa_2pass_blocks_override -> sdpa_2pass_blocks_from_env, and the local `int override` at the call site -> `int blocks_env`). - Use existing env::get_var helper instead of manual std::getenv + std::atoi; drop the now-unused include. - Drop test_sdpa_blocks_env_override per reviewer (it doesn't exercise behavior the existing 2-pass tests don't already cover). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../metal/scaled_dot_product_attention.cpp | 18 ++++------- python/tests/test_fast_sdpa.py | 32 ------------------- 2 files changed, 6 insertions(+), 44 deletions(-) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 197b341d5d..4229247b1d 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -1,5 +1,4 @@ // Copyright © 2024 Apple Inc. -#include #include #include "mlx/backend/common/compiled.h" @@ -17,17 +16,12 @@ namespace mlx::core::fast { namespace { // Override the heuristic-chosen `blocks` count for the 2-pass SDPA vector -// kernel via the MLX_SDPA_BLOCKS env var. Returns -1 (no override) when +// kernel via the MLX_SDPA_BLOCKS env var. Returns 0 (no override) when // unset or non-positive. Useful for tuning the partial-tile count on // device/workload combinations the heuristic doesn't anticipate. -int sdpa_2pass_blocks_override() { - if (auto* env = std::getenv("MLX_SDPA_BLOCKS")) { - int v = std::atoi(env); - if (v > 0) { - return v; - } - } - return -1; +int sdpa_2pass_blocks_from_env() { + int v = env::get_var("MLX_SDPA_BLOCKS", 0); + return v > 0 ? v : 0; } void sdpa_full_self_attention_nax( @@ -489,8 +483,8 @@ void sdpa_vector_2pass( blocks = 32; } } - if (int override = sdpa_2pass_blocks_override(); override > 0) { - blocks = override; + if (int blocks_env = sdpa_2pass_blocks_from_env(); blocks_env > 0) { + blocks = blocks_env; } size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 743050a67e..7bd867084e 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -682,38 +682,6 @@ def test_sdpa_sliced(self): tolerance = {"rtol": 1e-2, "atol": 1e-2} self.assertTrue(mx.allclose(ref, out, **tolerance)) - def test_sdpa_blocks_env_override(self): - # The 2-pass vector kernel chooses `blocks` heuristically from the - # device + sequence length. MLX_SDPA_BLOCKS overrides that choice - # but must not change correctness. - D = 128 - Nq = 4 - Nkv = 1 - N = 8192 # long enough to take the 2-pass path - scale = D**-0.5 - mx.random.seed(0) - q = mx.random.normal(shape=(1, Nq, 1, D)).astype(mx.float16) - k = mx.random.normal(shape=(1, Nkv, N, D)).astype(mx.float16) - v = mx.random.normal(shape=(1, Nkv, N, D)).astype(mx.float16) - - prev = os.environ.pop("MLX_SDPA_BLOCKS", None) - try: - ref = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) - mx.eval(ref) - for blocks in ("16", "64", "256"): - os.environ["MLX_SDPA_BLOCKS"] = blocks - out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) - mx.eval(out) - self.assertTrue( - mx.allclose(ref, out, rtol=1e-3, atol=1e-3), - f"MLX_SDPA_BLOCKS={blocks} changed numerics", - ) - finally: - if prev is None: - os.environ.pop("MLX_SDPA_BLOCKS", None) - else: - os.environ["MLX_SDPA_BLOCKS"] = prev - if __name__ == "__main__": mlx_tests.MLXTestRunner(failfast=True) From a7a77ab6a9e876811d9a3d8788283b215ab55e9f Mon Sep 17 00:00:00 2001 From: Adam Durham Date: Sun, 26 Apr 2026 21:30:28 -0500 Subject: [PATCH 3/3] sdpa: inline MLX_SDPA_BLOCKS env read, drop helper zcbenz noted the helper function was unnecessary. Inline env::get_var("MLX_SDPA_BLOCKS", 0) directly at the call site. --- mlx/backend/metal/scaled_dot_product_attention.cpp | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 4229247b1d..d387a5c08c 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -15,15 +15,6 @@ namespace mlx::core::fast { namespace { -// Override the heuristic-chosen `blocks` count for the 2-pass SDPA vector -// kernel via the MLX_SDPA_BLOCKS env var. Returns 0 (no override) when -// unset or non-positive. Useful for tuning the partial-tile count on -// device/workload combinations the heuristic doesn't anticipate. -int sdpa_2pass_blocks_from_env() { - int v = env::get_var("MLX_SDPA_BLOCKS", 0); - return v > 0 ? v : 0; -} - void sdpa_full_self_attention_nax( const Stream& s, metal::Device& d, @@ -483,7 +474,7 @@ void sdpa_vector_2pass( blocks = 32; } } - if (int blocks_env = sdpa_2pass_blocks_from_env(); blocks_env > 0) { + if (int blocks_env = env::get_var("MLX_SDPA_BLOCKS", 0); blocks_env > 0) { blocks = blocks_env; } size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);