Skip to content

feat(compile): CustomKernel and GatherQMM implement output_shapes for shapeless compile#3485

Open
dexwritescode wants to merge 2 commits intoml-explore:mainfrom
dexwritescode:fix/custom-kernel-output-shapes
Open

feat(compile): CustomKernel and GatherQMM implement output_shapes for shapeless compile#3485
dexwritescode wants to merge 2 commits intoml-explore:mainfrom
dexwritescode:fix/custom-kernel-output-shapes

Conversation

@dexwritescode
Copy link
Copy Markdown

@dexwritescode dexwritescode commented May 5, 2026

Problem

mx::compile(shapeless=true) calls Primitive::output_shapes() on every
node when re-tracing a compiled function after input shapes change. Two
primitives were missing this override, causing compiled functions that
contain them to throw at runtime:

[Primitive::output_shapes] CustomKernel cannot infer output shapes
[Primitive::output_shapes] GatherQMM cannot infer output shapes

This makes it impossible to use mx::compile on models that combine custom
Metal kernels with gather-quantized-matmul — for example, hybrid SSM+attention
MoE models (like Qwen3 MoE) where the SSM step uses a custom Metal kernel and
the MoE routing uses gather_qmm.

Fix

CustomKernel (mlx/fast_primitives.h, mlx/backend/metal/custom_kernel.cpp, mlx/backend/cuda/custom_kernel.cpp)

The output shapes are already provided by the caller at creation time via
metal_kernel()(inputs, output_shapes, ...) and passed to array::make_arrays.
They just weren't stored on the primitive.

Add an optional output_shapes constructor parameter (default {} — backward
compatible), store in output_shapes_ member, override output_shapes() to
return it. Falls through to the base-class throw when empty (legacy path).

GatherQMM (mlx/primitives.h)

The output shape is fully inferrable from the stored fields and input shapes:

out_shape = lhs_indices.shape() + [x.shape(-2), w_outer_dims]

where w_outer_dims = transpose ? w.shape(-2) : w.shape(-1) * 32 / bits.

Input layout differs by quantization mode: Affine mode has biases at index 3,
pushing lhs_indices to index 4; other modes have lhs_indices at index 3.

Testing

Verified by enabling mx::compile(shapeless=true) on a 94-layer hybrid
SSM+attention MoE model (Qwen3.6-35B-A3B-4bit) where the GatedDeltaNet SSM
step uses a custom Metal kernel and the MoE routing uses gather_qmm.
Previously crashed on re-trace; with this fix the compiled graph is reused
correctly across decode steps.

`mx::compile(shapeless=true)` calls `Primitive::output_shapes()` on
every node when re-tracing a compiled function with changed input
shapes. `CustomKernel` never implemented this override, so any
compiled function containing a `metal_kernel` / `custom_kernel` call
would throw:

  [Primitive::output_shapes] CustomKernel cannot infer output shapes

The output shapes are already provided by the caller at creation time
via `metal_kernel()(inputs, output_shapes, ...)` and passed to
`array::make_arrays`. They just weren't stored on the primitive.

Fix: add an optional `output_shapes` parameter to the `CustomKernel`
constructor (default `{}` for backward compatibility), store it in a
new `output_shapes_` member, and override `output_shapes()` to return
it. If the field is empty (legacy construction path), fall through to
the base-class throw as before.

Update both Metal and CUDA call sites to copy the shapes before
`std::move`-ing them into `array::make_arrays` and pass the copy to
the constructor.
output_shapes() is called on every primitive during shapeless=true
retracing. GatherQMM was missing this override, causing compile to
throw when any graph containing gather_qmm was retraced.

The output shape is fully inferrable from inputs and stored fields:
  out_shape = lhs_indices.shape() + [x.shape(-2), w_outer_dims]
where w_outer_dims = transpose ? w.shape(-2) : w.shape(-1)*32/bits.

Input layout differs by mode: Affine has biases at index 3 (pushing
indices to 4/5); other modes have indices at 3/4.
@dexwritescode dexwritescode changed the title feat(compile): CustomKernel stores and returns output shapes feat(compile): CustomKernel and GatherQMM implement output_shapes for shapeless compile May 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant