Skip to content

MFU on tomat training: bottleneck profile, experiments tried, ideas to try #4

@ryan-williams

Description

@ryan-williams

Tracking issue for our ongoing fight to push MFU above the ~8-13% range
where current runs sit. Captures: (a) the profiler data that shapes
the search, (b) every experiment we've actually fired (with results,
not just "tried"), and (c) ideas queued but not yet tested.

Where we are now

run model TPU BS ctx grad_ckpt shuffle MFU
train-full-lmq-v2-200M-bs128-emd-do-8k-tpu16 200M v6e-16 128 8192 on off ~8.7%
train-full-v3-200M-bs128-emd-do-8k-tpu16-shuf1k 200M v6e-16 128 8192 on win=1024 ~8%
train-full-lmq-v2-1B-bs256-emd-do-16k-tpu32 1B v6e-32 256 8192 on off ~13%
tomat-train-cont-from-4711-1B-bs256-emd-do-12k-tpu32 1B v6e-32 256 8192 on off ~13%

Numbers are wall-time MFU computed against the full padded sequence,
not just the useful (atoms+positions+density) prefix.

What the profiler says is bottlenecking us

200M / v6e-16 / BS=128 / ctx=8192 / grad_ckpt=on (XLA HLO breakdown):

op % cycles
attention 47%
element-wise fusion 28%
custom-call (tensorstore / memcpy / collectives) 21%
inter-chip comm 5.5%

This is HBM-bandwidth-bound, not comm-bound: the chip spends most of
its time moving activations through attention scratch memory and
fused element-wise ops, not exchanging data with peers. A bigger
model (more compute / byte) is the cleanest lever.

1B / v6e-32 / BS=256 confirms the thesis: attn drops 47% → 17%,
matmul rises 21% → 40%, comm drops 5.5% → 2.1%, MFU 8.7% → 13%. We
are not comm-limited even at 4× scale.

(Memory: see mfu-200M-v6e-16-bottleneck.md /
mfu-1B-v6e-32-profile.md.)

Experiments tried (with what we actually saw)

1. grad_ckpt=0 to recover recompute time (200M / v6e-16)

Goal: skip backward-pass recompute, get back the ~30% MFU usually
hidden behind grad_ckpt=on.

Result: OOM at every BS / ctx combo on v6e-16.

date BS ctx HBM used / cap overshoot
2026-04-29 128 8192 OOM at compile (no exact number)
2026-05-06 128 4608 71.87 GB / 31.25 GB +40.6 GB (2.3×)
2026-05-06 64 4608 36.01 GB / 31.25 GB +4.76 GB (1.15×)

Activation memory scales near-linearly with BS × ctx per chip; no
fixed-overhead gap to exploit. Top XLA allocations at OOM were
f32[12,8,4608,1024] broadcasts (~1.7 GB each, lots of siblings).

To fit on v6e-16, we'd need either (a) BS≤32, which is 4× lower
throughput per step than the BS=128 baseline and unlikely to be
recovered by saved recompute time, or (b) ctx<~2048, which would
require a fresh tokenize at much smaller patches.

2. BS↑ on top of grad_ckpt=0 (200M / v6e-16)

Folded into above: at BS=128 we OOM by 2.3×, so BS=256 / 512 isn't
even a candidate. Not retried.

3. Block-shuffled training (200M / v6e-16)

Goal: kill the train/val gap from M=64 patches/material × BS=128
batches drawing only ~2 mats per batch when no shuffle.

Result: gap collapsed (val − train ≈ 0.01% vs ~0.21% no-shuffle),
but absolute NMAE didn't move much
(~2.0-2.2% at step 7k either
way). MFU didn't change — this was a generalization fix, not a
throughput fix.

BlockShuffleConfig(io_block_size=64, window_blocks=1024, perm_type=feistel) — 65k-row buffer (~512 MB) shuffled in-memory.

4. v3 tokenizer at P=19 / M=64 (200M / v6e-16)

Goal: not directly an MFU experiment, but shorter preamble + more
density tokens / sequence = more useful tokens per step at the same
MFU. The compute still goes into the same padded ctx=8192.

Result: MFU unchanged (~8%) as expected. Useful-tokens-per-step
went up; throughput per useful token went down. NMAE landed in the
same 2.0-2.2% range at step 7k.

5. Phase 2 — v3-p15 / ctx=4608 / grad_ckpt=0 (200M / v6e-16)

Goal: shorter ctx (P=15 instead of 19) cuts activation memory ~44%
and is tile-aligned (4608 = 18 × 256 = 36 × 128); thought it might
unlock grad_ckpt=0.

Result: see experiment 1 above — OOM'd at both BS=128 and BS=64.
ctx=4608 wasn't enough headroom.

train-full-v3-p15 + val-full-v3-p15 are tokenized and on GCS (~32
GB) — reusable for future no-grad-ckpt attempts on hardware with
more per-chip HBM (v6e-32+).

6. 200M cont/extended runs (v6e-{8,16})

Resume from step-7999 of an existing run, etc. Not MFU experiments
per se, but they pin down "this MFU sticks across LR schedule
changes / longer training". MFU stayed in the same band.

7. Modal-vs-TPU bakeoff

Compared 200M training on Modal H100×8 vs marin v6e-16, looking for
any sign one architecture was a better fit. Inconclusive on MFU
delta (separate FLOPs accounting), but both confirmed our scaling
expectations.

Ideas queued (haven't fired yet)

High-confidence

  • 1B + v3 + shuffle + (TBD MFU experiments) on v6e-32. Bigger
    per-chip HBM means the no-grad-ckpt / BS↑ knobs that died on v6e-16
    may actually fit. We've never run 1B at v3 yet. Most likely path
    to a real MFU win.
  • Re-tokenize at ctx ≈ 2048 (P=12 or P=11). Half the activation
    memory of ctx=4608. Should let no-grad-ckpt fit at BS=128 on
    v6e-16. Tradeoff: less useful density per sequence, but the
    point of this experiment is MFU, not throughput-per-sequence.

Medium-confidence

  • Sliding-window / local attention. 47% of cycles are in attn at
    ctx=8192 — if O(N²) attn is the bottleneck, dropping to a
    fixed-window cuts that proportionally. Levanter has a sliding
    window flag (use_sliding_window); never enabled. Risk: NMAE
    regression if global context matters for density prediction.
  • Flash-attention block size sweep. Unset by default in our
    configs. Could matter for the 47% attn cycles.
  • Higher BS on v6e-32 with grad_ckpt=on. 1B/BS=256 is what we've
    tried; BS=512 might add a few MFU points if collectives don't
    saturate. Cheapest of the v6e-32 experiments.

Low-confidence / parking

  • Replace attn_backend=None with explicit splash / pallas backend.
  • use_qk_norm / hybrid_norm toggles on the LM config. Not
    obvious how they'd interact with MFU.
  • Switch to Pythia-style or other architectures. Out of scope
    unless we want a side-by-side, but Levanter supports a few.

Open questions

  • Why does 1B at v6e-32 cap at 13% MFU when matmul cycles are 40%?
    We'd expect closer to 30-40% on a dense matmul-bound regime.
    Suspect non-matmul overhead is still meaningful.
  • v6e is bf16 / f32 mixed-precision only — would v6e Pallas /
    TPU f8 land any wins? (Probably not at this model size; just
    noting.)
  • Is the wall-time MFU calculation including or excluding profiler /
    eval / checkpoint cycles? Need to double-check the formula in
    train_tomat_tpu.py.

Next-step recommendation

Pick one of:

  1. 1B + v3 + shuffle on v6e-32 as the next big run-of-record;
    uses the v3-p15 (or v3 default) tokenized data + scales the model
    beyond where MFU is HBM-bound.
  2. Re-tokenize at ctx ≈ 2048 and retry no-grad-ckpt on v6e-16,
    purely as an MFU diagnostic — it'll either confirm or
    permanently rule out grad_ckpt=0 as a knob on this size of TPU.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions