feat(compile): CustomKernel and GatherQMM implement output_shapes for shapeless compile#3485
Open
dexwritescode wants to merge 2 commits intoml-explore:mainfrom
Open
feat(compile): CustomKernel and GatherQMM implement output_shapes for shapeless compile#3485dexwritescode wants to merge 2 commits intoml-explore:mainfrom
dexwritescode wants to merge 2 commits intoml-explore:mainfrom
Conversation
`mx::compile(shapeless=true)` calls `Primitive::output_shapes()` on
every node when re-tracing a compiled function with changed input
shapes. `CustomKernel` never implemented this override, so any
compiled function containing a `metal_kernel` / `custom_kernel` call
would throw:
[Primitive::output_shapes] CustomKernel cannot infer output shapes
The output shapes are already provided by the caller at creation time
via `metal_kernel()(inputs, output_shapes, ...)` and passed to
`array::make_arrays`. They just weren't stored on the primitive.
Fix: add an optional `output_shapes` parameter to the `CustomKernel`
constructor (default `{}` for backward compatibility), store it in a
new `output_shapes_` member, and override `output_shapes()` to return
it. If the field is empty (legacy construction path), fall through to
the base-class throw as before.
Update both Metal and CUDA call sites to copy the shapes before
`std::move`-ing them into `array::make_arrays` and pass the copy to
the constructor.
output_shapes() is called on every primitive during shapeless=true retracing. GatherQMM was missing this override, causing compile to throw when any graph containing gather_qmm was retraced. The output shape is fully inferrable from inputs and stored fields: out_shape = lhs_indices.shape() + [x.shape(-2), w_outer_dims] where w_outer_dims = transpose ? w.shape(-2) : w.shape(-1)*32/bits. Input layout differs by mode: Affine has biases at index 3 (pushing indices to 4/5); other modes have indices at 3/4.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
mx::compile(shapeless=true)callsPrimitive::output_shapes()on everynode when re-tracing a compiled function after input shapes change. Two
primitives were missing this override, causing compiled functions that
contain them to throw at runtime:
This makes it impossible to use
mx::compileon models that combine customMetal kernels with gather-quantized-matmul — for example, hybrid SSM+attention
MoE models (like Qwen3 MoE) where the SSM step uses a custom Metal kernel and
the MoE routing uses
gather_qmm.Fix
CustomKernel (
mlx/fast_primitives.h,mlx/backend/metal/custom_kernel.cpp,mlx/backend/cuda/custom_kernel.cpp)The output shapes are already provided by the caller at creation time via
metal_kernel()(inputs, output_shapes, ...)and passed toarray::make_arrays.They just weren't stored on the primitive.
Add an optional
output_shapesconstructor parameter (default{}— backwardcompatible), store in
output_shapes_member, overrideoutput_shapes()toreturn it. Falls through to the base-class throw when empty (legacy path).
GatherQMM (
mlx/primitives.h)The output shape is fully inferrable from the stored fields and input shapes:
where
w_outer_dims = transpose ? w.shape(-2) : w.shape(-1) * 32 / bits.Input layout differs by quantization mode: Affine mode has biases at index 3,
pushing lhs_indices to index 4; other modes have lhs_indices at index 3.
Testing
Verified by enabling
mx::compile(shapeless=true)on a 94-layer hybridSSM+attention MoE model (Qwen3.6-35B-A3B-4bit) where the GatedDeltaNet SSM
step uses a custom Metal kernel and the MoE routing uses
gather_qmm.Previously crashed on re-trace; with this fix the compiled graph is reused
correctly across decode steps.