Skip to content

[BUG] mx.fast.metal_kernel: use-after-free when multiple custom kernels compose in lazy graph #3347

@Ziqiao-git

Description

@Ziqiao-git

Describe the bug

When multiple mx.fast.metal_kernel calls are composed in a single lazy evaluation graph, Metal API Validation reports:

failed assertion `command buffer references deallocated object which previously existed at address 0x...'

Without validation enabled, this manifests as non-deterministic NaN outputs (~2% per forward pass, amplified to ~100% in autoregressive generation).

The root cause is commandBufferWithUnretainedReferences() in Device::get_command_buffer() (device.cpp:392). Metal does not retain GPU buffers referenced by the command buffer, so intermediate arrays whose Python references are dropped before the command buffer finishes have their backing buffers deallocated while still in use by the GPU.

To Reproduce

# Requires: pip install paroquant
# Run with: MTL_DEBUG_LAYER=1 python repro.py
import mlx.core as mx
from paroquant.inference.backends.mlx.load import load
from paroquant.inference.backends.mlx.modules import RotateQuantizedLinear
from mlx_lm.models.cache import make_prompt_cache

model, proc, _ = load("z-lab/Qwen3.5-4B-PARO", force_text=True)
for _, m in model.named_modules():
    if isinstance(m, RotateQuantizedLinear):
        m._force_eval = False

inner = model.language_model.model
h = inner.embed_tokens(mx.array([[0]]))
mx.eval(h)

cache = make_prompt_cache(model.language_model)
out = inner.layers[0](h, mask=None, cache=cache[0])
mx.eval(out)
# MTL_DEBUG_LAYER=1: crashes with "command buffer references deallocated object"
# Without: non-deterministic NaN

Expected behavior

mx.fast.metal_kernel outputs should be correct regardless of lazy graph depth.

Desktop (please complete the following information):

  • OS Version: macOS 15.x
  • Hardware: Apple M3 Max 48GB
  • MLX Version: 0.31.1

Additional context

Proposed fix — one-line change in mlx/backend/metal/device.cpp:392:

-    stream.buffer = stream.queue->commandBufferWithUnretainedReferences();
+    stream.buffer = stream.queue->commandBuffer();

commandBuffer() uses retained references — Metal automatically retains all buffers until the command buffer completes.

Performance impact (M3 Max 48GB):

Model Unretained (current) Retained (fix)
Qwen3-8B-4bit (standard model, no custom kernel) 77.0 tok/s 76.9 tok/s
Qwen3.5-35B-A3B MoE (custom kernel + workaround) 21 tok/s 48 tok/s

No regression on standard models. Models using mx.fast.metal_kernel see 2.3x speedup because the mx.eval() workaround after every custom kernel call is no longer needed.

Branch with the fix: https://github.com/Ziqiao-git/mlx/tree/fix/retained-command-buffer

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions