Tracking issue for our ongoing fight to push MFU above the ~8-13% range
where current runs sit. Captures: (a) the profiler data that shapes
the search, (b) every experiment we've actually fired (with results,
not just "tried"), and (c) ideas queued but not yet tested.
Where we are now
| run |
model |
TPU |
BS |
ctx |
grad_ckpt |
shuffle |
MFU |
train-full-lmq-v2-200M-bs128-emd-do-8k-tpu16  |
200M |
v6e-16 |
128 |
8192 |
on |
off |
~8.7% |
train-full-v3-200M-bs128-emd-do-8k-tpu16-shuf1k  |
200M |
v6e-16 |
128 |
8192 |
on |
win=1024 |
~8% |
train-full-lmq-v2-1B-bs256-emd-do-16k-tpu32  |
1B |
v6e-32 |
256 |
8192 |
on |
off |
~13% |
tomat-train-cont-from-4711-1B-bs256-emd-do-12k-tpu32  |
1B |
v6e-32 |
256 |
8192 |
on |
off |
~13% |
Numbers are wall-time MFU computed against the full padded sequence,
not just the useful (atoms+positions+density) prefix.
What the profiler says is bottlenecking us
200M / v6e-16 / BS=128 / ctx=8192 / grad_ckpt=on (XLA HLO breakdown):
| op |
% cycles |
| attention |
47% |
| element-wise fusion |
28% |
| custom-call (tensorstore / memcpy / collectives) |
21% |
| inter-chip comm |
5.5% |
This is HBM-bandwidth-bound, not comm-bound: the chip spends most of
its time moving activations through attention scratch memory and
fused element-wise ops, not exchanging data with peers. A bigger
model (more compute / byte) is the cleanest lever.
1B / v6e-32 / BS=256 confirms the thesis: attn drops 47% → 17%,
matmul rises 21% → 40%, comm drops 5.5% → 2.1%, MFU 8.7% → 13%. We
are not comm-limited even at 4× scale.
(Memory: see mfu-200M-v6e-16-bottleneck.md /
mfu-1B-v6e-32-profile.md.)
Experiments tried (with what we actually saw)
1. grad_ckpt=0 to recover recompute time (200M / v6e-16)
Goal: skip backward-pass recompute, get back the ~30% MFU usually
hidden behind grad_ckpt=on.
Result: OOM at every BS / ctx combo on v6e-16.
| date |
BS |
ctx |
HBM used / cap |
overshoot |
| 2026-04-29 |
128 |
8192 |
OOM at compile (no exact number) |
— |
| 2026-05-06 |
128 |
4608 |
71.87 GB / 31.25 GB |
+40.6 GB (2.3×) |
| 2026-05-06 |
64 |
4608 |
36.01 GB / 31.25 GB |
+4.76 GB (1.15×) |
Activation memory scales near-linearly with BS × ctx per chip; no
fixed-overhead gap to exploit. Top XLA allocations at OOM were
f32[12,8,4608,1024] broadcasts (~1.7 GB each, lots of siblings).
To fit on v6e-16, we'd need either (a) BS≤32, which is 4× lower
throughput per step than the BS=128 baseline and unlikely to be
recovered by saved recompute time, or (b) ctx<~2048, which would
require a fresh tokenize at much smaller patches.
2. BS↑ on top of grad_ckpt=0 (200M / v6e-16)
Folded into above: at BS=128 we OOM by 2.3×, so BS=256 / 512 isn't
even a candidate. Not retried.
3. Block-shuffled training (200M / v6e-16)
Goal: kill the train/val gap from M=64 patches/material × BS=128
batches drawing only ~2 mats per batch when no shuffle.
Result: gap collapsed (val − train ≈ 0.01% vs ~0.21% no-shuffle),
but absolute NMAE didn't move much (~2.0-2.2% at step 7k either
way). MFU didn't change — this was a generalization fix, not a
throughput fix.
BlockShuffleConfig(io_block_size=64, window_blocks=1024, perm_type=feistel) — 65k-row buffer (~512 MB) shuffled in-memory.
4. v3 tokenizer at P=19 / M=64 (200M / v6e-16)
Goal: not directly an MFU experiment, but shorter preamble + more
density tokens / sequence = more useful tokens per step at the same
MFU. The compute still goes into the same padded ctx=8192.
Result: MFU unchanged (~8%) as expected. Useful-tokens-per-step
went up; throughput per useful token went down. NMAE landed in the
same 2.0-2.2% range at step 7k.
5. Phase 2 — v3-p15 / ctx=4608 / grad_ckpt=0 (200M / v6e-16)
Goal: shorter ctx (P=15 instead of 19) cuts activation memory ~44%
and is tile-aligned (4608 = 18 × 256 = 36 × 128); thought it might
unlock grad_ckpt=0.
Result: see experiment 1 above — OOM'd at both BS=128 and BS=64.
ctx=4608 wasn't enough headroom.
train-full-v3-p15 + val-full-v3-p15 are tokenized and on GCS (~32
GB) — reusable for future no-grad-ckpt attempts on hardware with
more per-chip HBM (v6e-32+).
6. 200M cont/extended runs (v6e-{8,16})
Resume from step-7999 of an existing run, etc. Not MFU experiments
per se, but they pin down "this MFU sticks across LR schedule
changes / longer training". MFU stayed in the same band.
7. Modal-vs-TPU bakeoff
Compared 200M training on Modal H100×8 vs marin v6e-16, looking for
any sign one architecture was a better fit. Inconclusive on MFU
delta (separate FLOPs accounting), but both confirmed our scaling
expectations.
Ideas queued (haven't fired yet)
High-confidence
- 1B + v3 + shuffle + (TBD MFU experiments) on v6e-32. Bigger
per-chip HBM means the no-grad-ckpt / BS↑ knobs that died on v6e-16
may actually fit. We've never run 1B at v3 yet. Most likely path
to a real MFU win.
- Re-tokenize at ctx ≈ 2048 (P=12 or P=11). Half the activation
memory of ctx=4608. Should let no-grad-ckpt fit at BS=128 on
v6e-16. Tradeoff: less useful density per sequence, but the
point of this experiment is MFU, not throughput-per-sequence.
Medium-confidence
- Sliding-window / local attention. 47% of cycles are in attn at
ctx=8192 — if O(N²) attn is the bottleneck, dropping to a
fixed-window cuts that proportionally. Levanter has a sliding
window flag (use_sliding_window); never enabled. Risk: NMAE
regression if global context matters for density prediction.
- Flash-attention block size sweep. Unset by default in our
configs. Could matter for the 47% attn cycles.
- Higher BS on v6e-32 with grad_ckpt=on. 1B/BS=256 is what we've
tried; BS=512 might add a few MFU points if collectives don't
saturate. Cheapest of the v6e-32 experiments.
Low-confidence / parking
- Replace
attn_backend=None with explicit splash / pallas backend.
use_qk_norm / hybrid_norm toggles on the LM config. Not
obvious how they'd interact with MFU.
- Switch to Pythia-style or other architectures. Out of scope
unless we want a side-by-side, but Levanter supports a few.
Open questions
- Why does 1B at v6e-32 cap at 13% MFU when matmul cycles are 40%?
We'd expect closer to 30-40% on a dense matmul-bound regime.
Suspect non-matmul overhead is still meaningful.
- v6e is bf16 / f32 mixed-precision only — would v6e Pallas /
TPU f8 land any wins? (Probably not at this model size; just
noting.)
- Is the wall-time MFU calculation including or excluding profiler /
eval / checkpoint cycles? Need to double-check the formula in
train_tomat_tpu.py.
Next-step recommendation
Pick one of:
- 1B + v3 + shuffle on v6e-32 as the next big run-of-record;
uses the v3-p15 (or v3 default) tokenized data + scales the model
beyond where MFU is HBM-bound.
- Re-tokenize at ctx ≈ 2048 and retry no-grad-ckpt on v6e-16,
purely as an MFU diagnostic — it'll either confirm or
permanently rule out grad_ckpt=0 as a knob on this size of TPU.
Tracking issue for our ongoing fight to push MFU above the ~8-13% range
where current runs sit. Captures: (a) the profiler data that shapes
the search, (b) every experiment we've actually fired (with results,
not just "tried"), and (c) ideas queued but not yet tested.
Where we are now
train-full-lmq-v2-200M-bs128-emd-do-8k-tpu16train-full-v3-200M-bs128-emd-do-8k-tpu16-shuf1ktrain-full-lmq-v2-1B-bs256-emd-do-16k-tpu32tomat-train-cont-from-4711-1B-bs256-emd-do-12k-tpu32Numbers are wall-time MFU computed against the full padded sequence,
not just the useful (atoms+positions+density) prefix.
What the profiler says is bottlenecking us
200M / v6e-16 / BS=128 / ctx=8192 / grad_ckpt=on (XLA HLO breakdown):
This is HBM-bandwidth-bound, not comm-bound: the chip spends most of
its time moving activations through attention scratch memory and
fused element-wise ops, not exchanging data with peers. A bigger
model (more compute / byte) is the cleanest lever.
1B / v6e-32 / BS=256 confirms the thesis: attn drops 47% → 17%,
matmul rises 21% → 40%, comm drops 5.5% → 2.1%, MFU 8.7% → 13%. We
are not comm-limited even at 4× scale.
(Memory: see
mfu-200M-v6e-16-bottleneck.md/mfu-1B-v6e-32-profile.md.)Experiments tried (with what we actually saw)
1.
grad_ckpt=0to recover recompute time (200M / v6e-16)Goal: skip backward-pass recompute, get back the ~30% MFU usually
hidden behind grad_ckpt=on.
Result: OOM at every BS / ctx combo on v6e-16.
Activation memory scales near-linearly with
BS × ctxper chip; nofixed-overhead gap to exploit. Top XLA allocations at OOM were
f32[12,8,4608,1024]broadcasts (~1.7 GB each, lots of siblings).To fit on v6e-16, we'd need either (a) BS≤32, which is 4× lower
throughput per step than the BS=128 baseline and unlikely to be
recovered by saved recompute time, or (b) ctx<~2048, which would
require a fresh tokenize at much smaller patches.
2. BS↑ on top of
grad_ckpt=0(200M / v6e-16)Folded into above: at BS=128 we OOM by 2.3×, so BS=256 / 512 isn't
even a candidate. Not retried.
3. Block-shuffled training (200M / v6e-16)
Goal: kill the train/val gap from M=64 patches/material × BS=128
batches drawing only ~2 mats per batch when no shuffle.
Result: gap collapsed (val − train ≈ 0.01% vs ~0.21% no-shuffle),
but absolute NMAE didn't move much (~2.0-2.2% at step 7k either
way). MFU didn't change — this was a generalization fix, not a
throughput fix.
BlockShuffleConfig(io_block_size=64, window_blocks=1024, perm_type=feistel)— 65k-row buffer (~512 MB) shuffled in-memory.4. v3 tokenizer at P=19 / M=64 (200M / v6e-16)
Goal: not directly an MFU experiment, but shorter preamble + more
density tokens / sequence = more useful tokens per step at the same
MFU. The compute still goes into the same padded ctx=8192.
Result: MFU unchanged (~8%) as expected. Useful-tokens-per-step
went up; throughput per useful token went down. NMAE landed in the
same 2.0-2.2% range at step 7k.
5. Phase 2 — v3-p15 / ctx=4608 /
grad_ckpt=0(200M / v6e-16)Goal: shorter ctx (P=15 instead of 19) cuts activation memory ~44%
and is tile-aligned (4608 = 18 × 256 = 36 × 128); thought it might
unlock
grad_ckpt=0.Result: see experiment 1 above — OOM'd at both BS=128 and BS=64.
ctx=4608 wasn't enough headroom.
train-full-v3-p15 + val-full-v3-p15 are tokenized and on GCS (~32
GB) — reusable for future no-grad-ckpt attempts on hardware with
more per-chip HBM (v6e-32+).
6. 200M cont/extended runs (v6e-{8,16})
Resume from step-7999 of an existing run, etc. Not MFU experiments
per se, but they pin down "this MFU sticks across LR schedule
changes / longer training". MFU stayed in the same band.
7. Modal-vs-TPU bakeoff
Compared 200M training on Modal H100×8 vs marin v6e-16, looking for
any sign one architecture was a better fit. Inconclusive on MFU
delta (separate FLOPs accounting), but both confirmed our scaling
expectations.
Ideas queued (haven't fired yet)
High-confidence
per-chip HBM means the no-grad-ckpt / BS↑ knobs that died on v6e-16
may actually fit. We've never run 1B at v3 yet. Most likely path
to a real MFU win.
memory of ctx=4608. Should let no-grad-ckpt fit at BS=128 on
v6e-16. Tradeoff: less useful density per sequence, but the
point of this experiment is MFU, not throughput-per-sequence.
Medium-confidence
ctx=8192 — if O(N²) attn is the bottleneck, dropping to a
fixed-window cuts that proportionally. Levanter has a sliding
window flag (
use_sliding_window); never enabled. Risk: NMAEregression if global context matters for density prediction.
configs. Could matter for the 47% attn cycles.
tried; BS=512 might add a few MFU points if collectives don't
saturate. Cheapest of the v6e-32 experiments.
Low-confidence / parking
attn_backend=Nonewith explicit splash / pallas backend.use_qk_norm/hybrid_normtoggles on the LM config. Notobvious how they'd interact with MFU.
unless we want a side-by-side, but Levanter supports a few.
Open questions
We'd expect closer to 30-40% on a dense matmul-bound regime.
Suspect non-matmul overhead is still meaningful.
TPU f8 land any wins? (Probably not at this model size; just
noting.)
eval / checkpoint cycles? Need to double-check the formula in
train_tomat_tpu.py.Next-step recommendation
Pick one of:
uses the v3-p15 (or v3 default) tokenized data + scales the model
beyond where MFU is HBM-bound.
purely as an MFU diagnostic — it'll either confirm or
permanently rule out grad_ckpt=0 as a knob on this size of TPU.