diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp index 9a6837acbb..2608d0ea1b 100644 --- a/mlx/backend/cuda/custom_kernel.cpp +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -223,7 +223,7 @@ CustomKernelFunction cuda_kernel( } return array::make_arrays( - std::move(output_shapes), + output_shapes, std::move(output_dtypes), std::make_shared( s, @@ -236,7 +236,8 @@ CustomKernelFunction cuda_kernel( init_value, std::vector{}, false, - shared_memory), + shared_memory, + output_shapes), std::move(inputs)); }; } @@ -270,7 +271,8 @@ std::vector precompiled_cuda_kernel( init_value, scalars, true, - shared_memory), + shared_memory, + output_shapes), inputs); } diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 6d33ff5007..3eb41302ba 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -306,7 +306,7 @@ CustomKernelFunction metal_kernel( } return array::make_arrays( - std::move(output_shapes), + output_shapes, std::move(output_dtypes), std::make_shared( s, @@ -319,7 +319,8 @@ CustomKernelFunction metal_kernel( init_value, std::vector{}, false, - 0), + 0, + output_shapes), std::move(inputs)); }; } diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4434830875..827a4eab6d 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -375,7 +375,8 @@ class CustomKernel : public Primitive { std::optional init_value, std::vector scalar_arguments, bool is_precompiled, - int shared_memory) + int shared_memory, + std::vector output_shapes = {}) : Primitive(stream), name_(std::move(name)), source_(std::move(source)), @@ -386,7 +387,8 @@ class CustomKernel : public Primitive { init_value_(init_value), scalar_arguments_(std::move(scalar_arguments)), is_precompiled_(is_precompiled), - shared_memory_(shared_memory) {} + shared_memory_(shared_memory), + output_shapes_(std::move(output_shapes)) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -397,6 +399,13 @@ class CustomKernel : public Primitive { override; DEFINE_NAME(CustomKernel); + + std::vector output_shapes(const std::vector&) override { + if (output_shapes_.empty()) + return Primitive::output_shapes({}); + return output_shapes_; + } + auto state() const { return std::make_tuple( name_, @@ -422,6 +431,7 @@ class CustomKernel : public Primitive { std::vector scalar_arguments_; bool is_precompiled_; int shared_memory_; + std::vector output_shapes_; }; } // namespace mlx::core::fast diff --git a/mlx/primitives.h b/mlx/primitives.h index 75fb978dce..313ded3545 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1700,6 +1700,21 @@ class GatherQMM : public UnaryPrimitive { DEFINE_GRADS() DEFINE_NAME(GatherQMM) bool is_equivalent(const Primitive& other) const override; + + // inputs layout: Affine → {x, w, scales, biases, lhs_idx, rhs_idx} + // other → {x, w, scales, lhs_idx, rhs_idx} + std::vector output_shapes(const std::vector& inputs) override { + const auto& x = inputs[0]; + const auto& w = inputs[1]; + const auto& lhs_idx = + (mode_ == QuantizationMode::Affine) ? inputs[4] : inputs[3]; + int w_outer = transpose_ ? w.shape(-2) : w.shape(-1) * 32 / bits_; + auto out_shape = lhs_idx.shape(); + out_shape.push_back(x.shape(-2)); + out_shape.push_back(w_outer); + return {out_shape}; + } + auto state() const { return std::make_tuple( group_size_, bits_, mode_, transpose_, left_sorted_, right_sorted_); diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 20f1145223..fe6decb3d0 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -504,6 +504,63 @@ def ones_fun(x): self.assertEqual(compiled_zero_like(y).shape, y_shape) self.assertEqual(compiled_ones_like(y).shape, y_shape) + def test_shapeless_compile_custom_kernel(self): + # CustomKernel must implement output_shapes() so shapeless compile can + # re-trace without throwing "CustomKernel cannot infer output shapes". + if not mx.metal.is_available(): + return + + kernel = mx.fast.metal_kernel( + name="copy_kernel", + input_names=["inp"], + output_names=["out"], + source="out[thread_position_in_grid.x] = inp[thread_position_in_grid.x];", + ) + + def fn(x): + return kernel( + inputs=[x], + grid=(x.size, 1, 1), + threadgroup=(min(x.size, 256), 1, 1), + output_shapes=[x.shape], + output_dtypes=[x.dtype], + stream=mx.gpu, + )[0] + + cfn = mx.compile(fn, shapeless=True) + + x = mx.ones((4,), dtype=mx.float32) + self.assertTrue(mx.array_equal(cfn(x), x)) + + # Different shape — must reuse compiled graph without throwing. + x = mx.ones((8,), dtype=mx.float32) + self.assertTrue(mx.array_equal(cfn(x), x)) + + def test_shapeless_compile_gather_qmm(self): + # GatherQMM must implement output_shapes() so shapeless compile can + # re-trace without throwing "GatherQMM cannot infer output shapes". + K, N, num_experts = 64, 32, 4 + + w = mx.random.normal((num_experts, N, K)) + qw, s, b = mx.quantize(w) + mx.eval(qw, s, b) + + # Keep inputs outside fn so RandomBits doesn't enter the compiled graph. + x4 = mx.ones((4, K)) + x8 = mx.ones((8, K)) + idx4 = mx.array([0, 1, 2, 3]) + idx8 = mx.array([0, 0, 1, 1, 2, 2, 3, 3]) + + def fn(x, lhs_indices): + return mx.gather_qmm(x, qw, s, b, lhs_indices=lhs_indices, transpose=True) + + cfn = mx.compile(fn, shapeless=True) + + self.assertEqual(cfn(x4, idx4).shape, fn(x4, idx4).shape) + + # Different M — must reuse compiled graph without throwing. + self.assertEqual(cfn(x8, idx8).shape, fn(x8, idx8).shape) + def test_compile_with_constant(self): # Test float @partial(mx.compile)