Describe the bug
When multiple mx.fast.metal_kernel calls are composed in a single lazy evaluation graph, Metal API Validation reports:
failed assertion `command buffer references deallocated object which previously existed at address 0x...'
Without validation enabled, this manifests as non-deterministic NaN outputs (~2% per forward pass, amplified to ~100% in autoregressive generation).
The root cause is commandBufferWithUnretainedReferences() in Device::get_command_buffer() (device.cpp:392). Metal does not retain GPU buffers referenced by the command buffer, so intermediate arrays whose Python references are dropped before the command buffer finishes have their backing buffers deallocated while still in use by the GPU.
To Reproduce
# Requires: pip install paroquant
# Run with: MTL_DEBUG_LAYER=1 python repro.py
import mlx.core as mx
from paroquant.inference.backends.mlx.load import load
from paroquant.inference.backends.mlx.modules import RotateQuantizedLinear
from mlx_lm.models.cache import make_prompt_cache
model, proc, _ = load("z-lab/Qwen3.5-4B-PARO", force_text=True)
for _, m in model.named_modules():
if isinstance(m, RotateQuantizedLinear):
m._force_eval = False
inner = model.language_model.model
h = inner.embed_tokens(mx.array([[0]]))
mx.eval(h)
cache = make_prompt_cache(model.language_model)
out = inner.layers[0](h, mask=None, cache=cache[0])
mx.eval(out)
# MTL_DEBUG_LAYER=1: crashes with "command buffer references deallocated object"
# Without: non-deterministic NaN
Expected behavior
mx.fast.metal_kernel outputs should be correct regardless of lazy graph depth.
Desktop (please complete the following information):
- OS Version: macOS 15.x
- Hardware: Apple M3 Max 48GB
- MLX Version: 0.31.1
Additional context
Proposed fix — one-line change in mlx/backend/metal/device.cpp:392:
- stream.buffer = stream.queue->commandBufferWithUnretainedReferences();
+ stream.buffer = stream.queue->commandBuffer();
commandBuffer() uses retained references — Metal automatically retains all buffers until the command buffer completes.
Performance impact (M3 Max 48GB):
| Model |
Unretained (current) |
Retained (fix) |
| Qwen3-8B-4bit (standard model, no custom kernel) |
77.0 tok/s |
76.9 tok/s |
| Qwen3.5-35B-A3B MoE (custom kernel + workaround) |
21 tok/s |
48 tok/s |
No regression on standard models. Models using mx.fast.metal_kernel see 2.3x speedup because the mx.eval() workaround after every custom kernel call is no longer needed.
Branch with the fix: https://github.com/Ziqiao-git/mlx/tree/fix/retained-command-buffer
Describe the bug
When multiple
mx.fast.metal_kernelcalls are composed in a single lazy evaluation graph, Metal API Validation reports:Without validation enabled, this manifests as non-deterministic NaN outputs (~2% per forward pass, amplified to ~100% in autoregressive generation).
The root cause is
commandBufferWithUnretainedReferences()inDevice::get_command_buffer()(device.cpp:392). Metal does not retain GPU buffers referenced by the command buffer, so intermediate arrays whose Python references are dropped before the command buffer finishes have their backing buffers deallocated while still in use by the GPU.To Reproduce
Expected behavior
mx.fast.metal_kerneloutputs should be correct regardless of lazy graph depth.Desktop (please complete the following information):
Additional context
Proposed fix — one-line change in
mlx/backend/metal/device.cpp:392:commandBuffer()uses retained references — Metal automatically retains all buffers until the command buffer completes.Performance impact (M3 Max 48GB):
No regression on standard models. Models using
mx.fast.metal_kernelsee 2.3x speedup because themx.eval()workaround after every custom kernel call is no longer needed.Branch with the fix: https://github.com/Ziqiao-git/mlx/tree/fix/retained-command-buffer