Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions mlx/backend/cuda/custom_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ CustomKernelFunction cuda_kernel(
}

return array::make_arrays(
std::move(output_shapes),
output_shapes,
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
Expand All @@ -236,7 +236,8 @@ CustomKernelFunction cuda_kernel(
init_value,
std::vector<ScalarArg>{},
false,
shared_memory),
shared_memory,
output_shapes),
std::move(inputs));
};
}
Expand Down Expand Up @@ -270,7 +271,8 @@ std::vector<array> precompiled_cuda_kernel(
init_value,
scalars,
true,
shared_memory),
shared_memory,
output_shapes),
inputs);
}

Expand Down
5 changes: 3 additions & 2 deletions mlx/backend/metal/custom_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ CustomKernelFunction metal_kernel(
}

return array::make_arrays(
std::move(output_shapes),
output_shapes,
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
Expand All @@ -319,7 +319,8 @@ CustomKernelFunction metal_kernel(
init_value,
std::vector<ScalarArg>{},
false,
0),
0,
output_shapes),
std::move(inputs));
};
}
Expand Down
14 changes: 12 additions & 2 deletions mlx/fast_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ class CustomKernel : public Primitive {
std::optional<float> init_value,
std::vector<ScalarArg> scalar_arguments,
bool is_precompiled,
int shared_memory)
int shared_memory,
std::vector<Shape> output_shapes = {})
: Primitive(stream),
name_(std::move(name)),
source_(std::move(source)),
Expand All @@ -386,7 +387,8 @@ class CustomKernel : public Primitive {
init_value_(init_value),
scalar_arguments_(std::move(scalar_arguments)),
is_precompiled_(is_precompiled),
shared_memory_(shared_memory) {}
shared_memory_(shared_memory),
output_shapes_(std::move(output_shapes)) {}

void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
Expand All @@ -397,6 +399,13 @@ class CustomKernel : public Primitive {
override;

DEFINE_NAME(CustomKernel);

std::vector<Shape> output_shapes(const std::vector<array>&) override {
if (output_shapes_.empty())
return Primitive::output_shapes({});
return output_shapes_;
}

auto state() const {
return std::make_tuple(
name_,
Expand All @@ -422,6 +431,7 @@ class CustomKernel : public Primitive {
std::vector<ScalarArg> scalar_arguments_;
bool is_precompiled_;
int shared_memory_;
std::vector<Shape> output_shapes_;
};

} // namespace mlx::core::fast
15 changes: 15 additions & 0 deletions mlx/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -1700,6 +1700,21 @@ class GatherQMM : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_NAME(GatherQMM)
bool is_equivalent(const Primitive& other) const override;

// inputs layout: Affine → {x, w, scales, biases, lhs_idx, rhs_idx}
// other → {x, w, scales, lhs_idx, rhs_idx}
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
const auto& x = inputs[0];
const auto& w = inputs[1];
const auto& lhs_idx =
(mode_ == QuantizationMode::Affine) ? inputs[4] : inputs[3];
int w_outer = transpose_ ? w.shape(-2) : w.shape(-1) * 32 / bits_;
auto out_shape = lhs_idx.shape();
out_shape.push_back(x.shape(-2));
out_shape.push_back(w_outer);
return {out_shape};
}

auto state() const {
return std::make_tuple(
group_size_, bits_, mode_, transpose_, left_sorted_, right_sorted_);
Expand Down
57 changes: 57 additions & 0 deletions python/tests/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,63 @@ def ones_fun(x):
self.assertEqual(compiled_zero_like(y).shape, y_shape)
self.assertEqual(compiled_ones_like(y).shape, y_shape)

def test_shapeless_compile_custom_kernel(self):
# CustomKernel must implement output_shapes() so shapeless compile can
# re-trace without throwing "CustomKernel cannot infer output shapes".
if not mx.metal.is_available():
return

kernel = mx.fast.metal_kernel(
name="copy_kernel",
input_names=["inp"],
output_names=["out"],
source="out[thread_position_in_grid.x] = inp[thread_position_in_grid.x];",
)

def fn(x):
return kernel(
inputs=[x],
grid=(x.size, 1, 1),
threadgroup=(min(x.size, 256), 1, 1),
output_shapes=[x.shape],
output_dtypes=[x.dtype],
stream=mx.gpu,
)[0]

cfn = mx.compile(fn, shapeless=True)

x = mx.ones((4,), dtype=mx.float32)
self.assertTrue(mx.array_equal(cfn(x), x))

# Different shape — must reuse compiled graph without throwing.
x = mx.ones((8,), dtype=mx.float32)
self.assertTrue(mx.array_equal(cfn(x), x))

def test_shapeless_compile_gather_qmm(self):
# GatherQMM must implement output_shapes() so shapeless compile can
# re-trace without throwing "GatherQMM cannot infer output shapes".
K, N, num_experts = 64, 32, 4

w = mx.random.normal((num_experts, N, K))
qw, s, b = mx.quantize(w)
mx.eval(qw, s, b)

# Keep inputs outside fn so RandomBits doesn't enter the compiled graph.
x4 = mx.ones((4, K))
x8 = mx.ones((8, K))
idx4 = mx.array([0, 1, 2, 3])
idx8 = mx.array([0, 0, 1, 1, 2, 2, 3, 3])

def fn(x, lhs_indices):
return mx.gather_qmm(x, qw, s, b, lhs_indices=lhs_indices, transpose=True)

cfn = mx.compile(fn, shapeless=True)

self.assertEqual(cfn(x4, idx4).shape, fn(x4, idx4).shape)

# Different M — must reuse compiled graph without throwing.
self.assertEqual(cfn(x8, idx8).shape, fn(x8, idx8).shape)

def test_compile_with_constant(self):
# Test float
@partial(mx.compile)
Expand Down