Reduce peak memory of large half-precision random.uniform/normal#3439
Reduce peak memory of large half-precision random.uniform/normal#3439dogukanveziroglu wants to merge 5 commits intoml-explore:mainfrom
Conversation
A new primitive that runs the entire uniform RNG pipeline (threefry
hash → fp32 normalize → clip → cast → affine) per-thread in registers
for half-precision GPU outputs. Avoids materializing the fp32
intermediate buffer that the standard bits()/divide()/astype() chain
requires; peak memory drops 3x → 1x of target.
Activation conditions (all required): half-precision dtype (bf16 or
fp16), even total output size, scalar low/high, single key (shape
{2}), GPU stream. Bit-exact with vanilla on the same seed; matches
the rbitsc kernel's interleaved counter layout.
Performance: 15.4x faster on (16384, 16384) bf16 (108 ms → 7 ms)
because the fp32 intermediate no longer transits L2/HBM. Small
shapes (<1 MB) pay a slight kernel-launch overhead — chunked path
threshold ensures the fast path only activates when the win
dominates.
CUDA mirror added in a follow-up commit (untested; algorithmic
transcription of the validated Metal kernel).
Splits large GPU random calls (output ≥ ~512 MB fp32-equivalent) along axis 0 into K independent sub-key chunks, computes each via the existing fp32-then-cast pipeline, and writes into a pre- allocated output via slice_update with eval per chunk. Per-chunk fp32 transients are freed between iterations; peak drops from 3x to ~1+2/K of target (1.09x at K=33 on the canary shape). Heuristic: K = ceil(fp32_bytes / 256 MB), clamped to [4, 256]. Profiled in path-c/19-K-isolation.md: theory matches measurement within 5% at K ≥ 32; allocator overhead at small K (2-16) adds 17-30% but amortizes away. Sub-key derivation via random::split is cryptographically independent and seed-deterministic. Same seed produces same chunked output across runs, but the bit pattern differs from vanilla (which uses one key for the whole shape). Same trade-off class as PR ml-explore#904; statistical quality preserved per-chunk (chunked unique-value count ≥ vanilla baseline). Activation rule (all required): GPU stream, scalar lo/hi, single key, fp32-equiv output size ≥ 512 MB, axis-0 dim ≥ 4. Falls back to vanilla path for everything else (small shapes, multi-key, broadcast bounds, CPU). normal() uses the same chunked pipeline when target dtype is bf16/fp16/fp32. Resolves OOM on (46341, 46341) bf16 normal: vanilla aborts at 12.88 GB peak, chunked completes at 4.69 GB. Tolerates up to ~11 GB of concurrent allocations on M4 16 GB before swap kicks in (path-c/21-active-ballast.md).
CUDA mirror of the Metal RandomUniform kernel (same threefry counter mapping, same per-thread fp32-then-cast in registers, same output dtype templating). Marked untested in code: no NVIDIA hardware on this branch's CI; algorithmic equivalence to the validated Metal kernel verified by inspection. TestRandomChunked: 8 tests targeting the chunked path (shapes ≥ 1 GB so chunking activates). Each test uses 5σ/√N statistical tolerance for distribution stats (not hand-tuned); seed reproducibility test confirms deterministic output; odd-first-dim test exercises chunk-remainder handling; unique-bit test asserts ≥ 2000 distinct bf16 values per million samples (PR ml-explore#2361 quality floor). Brings test_random.py coverage from 14 to 22 tests; full pytest remains 696 passed / 4 skipped / 9283 subtests on M4.
The chunked dispatch in mlx/random.cpp had two correctness gaps
discovered by an adversarial drawback sweep against vanilla:
1. fp32 chunking is strictly worse than vanilla. Vanilla fp32
uniform/normal already operate at ~1x output peak (the
intermediate IS the target dtype), so chunking adds K-fold
sub-key derivation + slice_update overhead with zero memory
benefit. Measured ~25% latency regression and ~25% higher
peak memory at 12K^2+ shapes. Restrict chunkable_dtype to
{bfloat16, float16}.
2. Both the fused RandomUniform primitive and the chunked path
are illegal inside mx.compile / mx.vmap / mx.grad: the fused
primitive throws on RandomUniform::vmap, and the chunked
path's per-chunk eval() is rejected by the tracer. Gate
both dispatches on !detail::in_tracing() so any transform
falls back to the vanilla pipeline (which uses RandomBits,
DEFINE_VMAP()-supported).
Headline canary unchanged: (46341, 46341) bf16 normal still
peaks at 4.7 GB on M4-16GB.
Pre-PR cleanup pass: remove internal investigation references
("Variant D1/D4", "Phase X", *.md filenames, drawback Phase
references) from comments, compress the chunked-path docstrings,
and tighten throw messages to drop implementation detail leakage.
Also:
- mlx/random.cpp: replace if-cascade clamps in pick_chunk_count
with std::clamp; mark chunked_fp32_then_cast and pick_chunk_count
static; drop the redundant inner key-shape check (single_key
already guarantees Shape{2}); inline single-use bool 'even'.
- mlx/backend/metal/kernels/random.metal: collapse the two-output
per-thread block to one expression each; drop the "Step 4"
reference.
- mlx/backend/metal/primitives.cpp: drop the 7-line debugging
postmortem about constant-buffer packing.
- .gitignore: drop the path-c-only .venv / python/mlx/lib entries.
No behavior change. 22/22 random tests + 708 full pytest pass;
canary (46341, 46341) bf16 normal still peaks at 4.69 GB.
Net diff: -69 lines.
|
@angeloskath could you give some thoughts please, thank you. |
zcbenz
left a comment
There was a problem hiding this comment.
Wouldn't mx.compile fuse the ops?
That was actually the first thing that I tried but it gave me the same memory peak with 3x more latency. The problem is the two compile fusion regression for |
|
I think we should try to make |
|
Ok, I will try to make |
Hi guys, so I've been messing around with
mlx.randomfor the last few days initially just trying to figure out whymx.random.normal((46341, 46341), dtype=mx.bfloat16)was crashing on my M4 16GB. I realised it uses too much memory then it supposed to do. I tried to make some changes about the calculation. Opening as a draft because I want your gut check before I polish anything further. I might try to climb a vertical flat wall but I am not sure :D. If what I made is dumb pls tell and guide me I would love to get some critisim.Heads up: I didn't tried on nvidia GPU yet I will try it soon...
What
mx.random.normal((46341, 46341), dtype=mx.bfloat16)aborts with"Insufficient Memory" on a 16 GB Apple-Silicon device. The standard
bits → divide → cast → minimum → mul → addchain holds threefp32-sized buffers at the same time — about 12.88 GB of peak for a
4.3 GB output.
This PR adds two narrow GPU paths to fix it. The headline canary
goes from abort (12.88 GB) to success at 4.69 GB peak on
M4-16GB.
How
Two new dispatch paths in
mlx/random.cpp:1. Fused
RandomUniformMetal kernel for half-precision uniformwhen bounds are scalar, total size is even, and a single key is in
flight. Computes
bits → divide → clip → cast → affineper threadin registers, so no fp32 intermediate ever lands in global memory.
The output is bit-identical to vanilla, peak drops from ~3x to
1x, and it's 5–10× faster on large shapes.
2. Chunked path that splits the existing fp32 pipeline along
axis 0 into K independent sub-keys (
K = ⌈bits_bytes / 256MB⌉,clamped to [4, 256]). Triggers when the fp32-equivalent size is
≥ 512 MB. Peak drops to about
(1 + 2/K) × output. Bytes differfrom vanilla because sub-keys differ, each chunk still does
fp32-then-cast.
Both paths are skipped when
detail::in_tracing()is true, soanything inside
vmap/compile/vjp/jvpfalls back tothe vanilla pipeline. fp32 is excluded from the chunked path on
purpose — vanilla fp32 already runs at ~1x output peak (the
intermediate IS the target dtype), so chunking only adds overhead.
Dispatch summary
Numbers (M4-16GB, 30 reps, fresh subprocess)
Half-precision uniform (fused kernel):
Half-precision normal (chunked):
fp32 (deliberately unchanged):
Quality
(mean/var/min/max identical for 10 tested seeds at 16384²).
vanilla; inter-chunk Pearson correlation < 0.011 at all chunk
boundaries; unique bf16 value count 2315 at 100K samples
identical seeds, same peak memory, ~2% better steady-state tok/s.
Trade-offs
Real but small:
slower because the fused kernel's launch overhead dominates.
Crossover above ~1K² gives the 5–90× speedups. Easy to add a
min-size guard if you'd prefer.
+9–11% slower in exchange for the 42–60% peak cut.
uniform path (the kernel emits two outputs per thread).
Tests
python/tests/test_random.py(8 new inTestRandomChunkedwith mathematically derived 5σ/√N tolerances)
Files
mlx/random.cpp— dispatch + chunking helpermlx/primitives.{h,cpp}—RandomUniformclassmlx/backend/metal/kernels/random.metal—runiformc<T>kernelmlx/backend/metal/primitives.cpp—RandomUniform::eval_gpumlx/backend/cpu/primitives.cpp—RandomUniform::eval_cpustubmlx/backend/cuda/random.cu— CUDA mirrorpython/tests/test_random.py—TestRandomChunkedRepro the canary
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes