From b73a77e2b9e287cb310e95419938a0af1f7ec003 Mon Sep 17 00:00:00 2001 From: Dex Date: Tue, 5 May 2026 15:24:39 -0400 Subject: [PATCH 1/5] feat(compile): CustomKernel stores and returns output shapes `mx::compile(shapeless=true)` calls `Primitive::output_shapes()` on every node when re-tracing a compiled function with changed input shapes. `CustomKernel` never implemented this override, so any compiled function containing a `metal_kernel` / `custom_kernel` call would throw: [Primitive::output_shapes] CustomKernel cannot infer output shapes The output shapes are already provided by the caller at creation time via `metal_kernel()(inputs, output_shapes, ...)` and passed to `array::make_arrays`. They just weren't stored on the primitive. Fix: add an optional `output_shapes` parameter to the `CustomKernel` constructor (default `{}` for backward compatibility), store it in a new `output_shapes_` member, and override `output_shapes()` to return it. If the field is empty (legacy construction path), fall through to the base-class throw as before. Update both Metal and CUDA call sites to copy the shapes before `std::move`-ing them into `array::make_arrays` and pass the copy to the constructor. --- mlx/backend/cuda/custom_kernel.cpp | 7 +++++-- mlx/backend/metal/custom_kernel.cpp | 4 +++- mlx/fast_primitives.h | 14 ++++++++++++-- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp index 9a6837acbb..fdd127e50d 100644 --- a/mlx/backend/cuda/custom_kernel.cpp +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -222,6 +222,7 @@ CustomKernelFunction cuda_kernel( << "```" << std::endl; } + auto output_shapes_copy = output_shapes; return array::make_arrays( std::move(output_shapes), std::move(output_dtypes), @@ -236,7 +237,8 @@ CustomKernelFunction cuda_kernel( init_value, std::vector{}, false, - shared_memory), + shared_memory, + std::move(output_shapes_copy)), std::move(inputs)); }; } @@ -270,7 +272,8 @@ std::vector precompiled_cuda_kernel( init_value, scalars, true, - shared_memory), + shared_memory, + output_shapes), inputs); } diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 6d33ff5007..31e115394a 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -305,6 +305,7 @@ CustomKernelFunction metal_kernel( << "```" << std::endl; } + auto output_shapes_copy = output_shapes; return array::make_arrays( std::move(output_shapes), std::move(output_dtypes), @@ -319,7 +320,8 @@ CustomKernelFunction metal_kernel( init_value, std::vector{}, false, - 0), + 0, + std::move(output_shapes_copy)), std::move(inputs)); }; } diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4434830875..827a4eab6d 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -375,7 +375,8 @@ class CustomKernel : public Primitive { std::optional init_value, std::vector scalar_arguments, bool is_precompiled, - int shared_memory) + int shared_memory, + std::vector output_shapes = {}) : Primitive(stream), name_(std::move(name)), source_(std::move(source)), @@ -386,7 +387,8 @@ class CustomKernel : public Primitive { init_value_(init_value), scalar_arguments_(std::move(scalar_arguments)), is_precompiled_(is_precompiled), - shared_memory_(shared_memory) {} + shared_memory_(shared_memory), + output_shapes_(std::move(output_shapes)) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -397,6 +399,13 @@ class CustomKernel : public Primitive { override; DEFINE_NAME(CustomKernel); + + std::vector output_shapes(const std::vector&) override { + if (output_shapes_.empty()) + return Primitive::output_shapes({}); + return output_shapes_; + } + auto state() const { return std::make_tuple( name_, @@ -422,6 +431,7 @@ class CustomKernel : public Primitive { std::vector scalar_arguments_; bool is_precompiled_; int shared_memory_; + std::vector output_shapes_; }; } // namespace mlx::core::fast From 0b6cedbbf4b945753b5a6ef17e3a5f8df9dacfc3 Mon Sep 17 00:00:00 2001 From: Dex Date: Tue, 5 May 2026 15:48:03 -0400 Subject: [PATCH 2/5] feat(compile): GatherQMM implements output_shapes for shapeless compile output_shapes() is called on every primitive during shapeless=true retracing. GatherQMM was missing this override, causing compile to throw when any graph containing gather_qmm was retraced. The output shape is fully inferrable from inputs and stored fields: out_shape = lhs_indices.shape() + [x.shape(-2), w_outer_dims] where w_outer_dims = transpose ? w.shape(-2) : w.shape(-1)*32/bits. Input layout differs by mode: Affine has biases at index 3 (pushing indices to 4/5); other modes have indices at 3/4. --- mlx/primitives.h | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/mlx/primitives.h b/mlx/primitives.h index 75fb978dce..8525e58253 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1700,6 +1700,21 @@ class GatherQMM : public UnaryPrimitive { DEFINE_GRADS() DEFINE_NAME(GatherQMM) bool is_equivalent(const Primitive& other) const override; + + // inputs layout: Affine → {x, w, scales, biases, lhs_idx, rhs_idx} + // other → {x, w, scales, lhs_idx, rhs_idx} + std::vector output_shapes(const std::vector& inputs) override { + const auto& x = inputs[0]; + const auto& w = inputs[1]; + const auto& lhs_idx = + (mode_ == QuantizationMode::Affine) ? inputs[4] : inputs[3]; + int w_outer = transpose_ ? w.shape(-2) : w.shape(-1) * 32 / bits_; + auto out_shape = lhs_idx.shape(); + out_shape.push_back(x.shape(-2)); + out_shape.push_back(w_outer); + return {out_shape}; + } + auto state() const { return std::make_tuple( group_size_, bits_, mode_, transpose_, left_sorted_, right_sorted_); From fdff69b3b970885e9b6bcc7d2fc6caa8648c1552 Mon Sep 17 00:00:00 2001 From: Dex Date: Wed, 6 May 2026 19:32:45 -0400 Subject: [PATCH 3/5] fix(compile): remove std::move on const& and fix comment alignment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address review feedback from zcbenz: - output_shapes is a const& in the lambda parameter, so std::move(output_shapes) compiles but silently copies rather than moves. Remove the misleading std::move in both metal and cuda backends — make_arrays receives a plain copy. - Fix one extra space in the GatherQMM input layout comment to correctly align lhs_idx under the Affine layout line. --- mlx/backend/cuda/custom_kernel.cpp | 2 +- mlx/backend/metal/custom_kernel.cpp | 2 +- mlx/primitives.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp index fdd127e50d..17b0c488c9 100644 --- a/mlx/backend/cuda/custom_kernel.cpp +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -224,7 +224,7 @@ CustomKernelFunction cuda_kernel( auto output_shapes_copy = output_shapes; return array::make_arrays( - std::move(output_shapes), + output_shapes, std::move(output_dtypes), std::make_shared( s, diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 31e115394a..ffc80ce27e 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -307,7 +307,7 @@ CustomKernelFunction metal_kernel( auto output_shapes_copy = output_shapes; return array::make_arrays( - std::move(output_shapes), + output_shapes, std::move(output_dtypes), std::make_shared( s, diff --git a/mlx/primitives.h b/mlx/primitives.h index 8525e58253..313ded3545 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1702,7 +1702,7 @@ class GatherQMM : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; // inputs layout: Affine → {x, w, scales, biases, lhs_idx, rhs_idx} - // other → {x, w, scales, lhs_idx, rhs_idx} + // other → {x, w, scales, lhs_idx, rhs_idx} std::vector output_shapes(const std::vector& inputs) override { const auto& x = inputs[0]; const auto& w = inputs[1]; From 0f2473fce4432a5237cd2c153ec586f7b4100980 Mon Sep 17 00:00:00 2001 From: Dex Date: Wed, 6 May 2026 19:43:27 -0400 Subject: [PATCH 4/5] test(compile): add shapeless compile tests for CustomKernel and GatherQMM Verify that mx.compile(shapeless=True) correctly re-traces functions containing mx.fast.metal_kernel (CustomKernel) and mx.gather_qmm (GatherQMM) when input shapes change between calls. Both tests fail before the fix with the respective 'cannot infer output shapes' error and pass after output_shapes() is implemented. --- python/tests/test_compile.py | 57 ++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 20f1145223..fe6decb3d0 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -504,6 +504,63 @@ def ones_fun(x): self.assertEqual(compiled_zero_like(y).shape, y_shape) self.assertEqual(compiled_ones_like(y).shape, y_shape) + def test_shapeless_compile_custom_kernel(self): + # CustomKernel must implement output_shapes() so shapeless compile can + # re-trace without throwing "CustomKernel cannot infer output shapes". + if not mx.metal.is_available(): + return + + kernel = mx.fast.metal_kernel( + name="copy_kernel", + input_names=["inp"], + output_names=["out"], + source="out[thread_position_in_grid.x] = inp[thread_position_in_grid.x];", + ) + + def fn(x): + return kernel( + inputs=[x], + grid=(x.size, 1, 1), + threadgroup=(min(x.size, 256), 1, 1), + output_shapes=[x.shape], + output_dtypes=[x.dtype], + stream=mx.gpu, + )[0] + + cfn = mx.compile(fn, shapeless=True) + + x = mx.ones((4,), dtype=mx.float32) + self.assertTrue(mx.array_equal(cfn(x), x)) + + # Different shape — must reuse compiled graph without throwing. + x = mx.ones((8,), dtype=mx.float32) + self.assertTrue(mx.array_equal(cfn(x), x)) + + def test_shapeless_compile_gather_qmm(self): + # GatherQMM must implement output_shapes() so shapeless compile can + # re-trace without throwing "GatherQMM cannot infer output shapes". + K, N, num_experts = 64, 32, 4 + + w = mx.random.normal((num_experts, N, K)) + qw, s, b = mx.quantize(w) + mx.eval(qw, s, b) + + # Keep inputs outside fn so RandomBits doesn't enter the compiled graph. + x4 = mx.ones((4, K)) + x8 = mx.ones((8, K)) + idx4 = mx.array([0, 1, 2, 3]) + idx8 = mx.array([0, 0, 1, 1, 2, 2, 3, 3]) + + def fn(x, lhs_indices): + return mx.gather_qmm(x, qw, s, b, lhs_indices=lhs_indices, transpose=True) + + cfn = mx.compile(fn, shapeless=True) + + self.assertEqual(cfn(x4, idx4).shape, fn(x4, idx4).shape) + + # Different M — must reuse compiled graph without throwing. + self.assertEqual(cfn(x8, idx8).shape, fn(x8, idx8).shape) + def test_compile_with_constant(self): # Test float @partial(mx.compile) From 1109c8ce5b14801f000fef87cb9cb75b858f8ec4 Mon Sep 17 00:00:00 2001 From: Dex Date: Wed, 6 May 2026 22:51:56 -0400 Subject: [PATCH 5/5] refactor(compile): simplify output_shapes copy in CustomKernel Remove the intermediate output_shapes_copy and pass output_shapes directly to the CustomKernel constructor, which takes it by value. --- mlx/backend/cuda/custom_kernel.cpp | 3 +-- mlx/backend/metal/custom_kernel.cpp | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp index 17b0c488c9..2608d0ea1b 100644 --- a/mlx/backend/cuda/custom_kernel.cpp +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -222,7 +222,6 @@ CustomKernelFunction cuda_kernel( << "```" << std::endl; } - auto output_shapes_copy = output_shapes; return array::make_arrays( output_shapes, std::move(output_dtypes), @@ -238,7 +237,7 @@ CustomKernelFunction cuda_kernel( std::vector{}, false, shared_memory, - std::move(output_shapes_copy)), + output_shapes), std::move(inputs)); }; } diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index ffc80ce27e..3eb41302ba 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -305,7 +305,6 @@ CustomKernelFunction metal_kernel( << "```" << std::endl; } - auto output_shapes_copy = output_shapes; return array::make_arrays( output_shapes, std::move(output_dtypes), @@ -321,7 +320,7 @@ CustomKernelFunction metal_kernel( std::vector{}, false, 0, - std::move(output_shapes_copy)), + output_shapes), std::move(inputs)); }; }