Skip to content

Enable packing + compile for DPO with wasted tokens metric#1466

Merged
finbarrtimbers merged 101 commits intomainfrom
finbarr/profile-dpo
Feb 17, 2026
Merged

Enable packing + compile for DPO with wasted tokens metric#1466
finbarrtimbers merged 101 commits intomainfrom
finbarr/profile-dpo

Conversation

@finbarrtimbers
Copy link
Copy Markdown
Collaborator

@finbarrtimbers finbarrtimbers commented Feb 6, 2026

Summary

  • Enable packing + torch.compile for DPO training to reduce padding waste
  • Fix cu_seq_lens offset bug when concatenating padded chosen/rejected sequences
  • Add tokens_per_second_per_gpu metric to wandb

DPO Packing Performance Results (post MFU-fix)

All runs on finbarr/profile-dpo branch, 2 nodes x 8 H100 GPUs, OLMo2-7B, 16k seq len, activation_memory_budget=0.1.

Batch size comparison (no shard_degree, 16-way FSDP)

Batch Size Token Count MFU (last) TPS TPS/GPU sec/step Experiment
bs=2 57,664 2.46% 9,276 580 6.21s 01KGSYNQ5RYZD1SVWWYDTQNY49
bs=4 58,640 2.51% 9,465 592 6.19s 01KGT1DNDKHZXR1PM2D3ADHD5M
bs=8 431,120 19.33% 69,611 4,351 6.19s 01KGT22ECAJPVPVQP5BRAAA9Q1
bs=16 453,184 19.71% 73,067 4,567 6.19s 01KGTEY8SZ3CWKHBX2KXD7HHAB
bs=32 524,288 22.48% 84,513 5,282 6.19s 01KGTE43M9SBGD9D1Y7Z010EZ3

shard_degree=8 vs no shard_degree (bs=8, full epoch, 57 steps after warmup)

Config MFU TPS TPS/GPU sec/step Token Count Experiment
no shard (16-way FSDP) 8.82% ± 2.54 32,608 ± 9,141 2,038 ± 571 4.212 ± 0.010 233,151 ± 59,605 01KH1S31HR9Y6BRN8MTNY50PV3 / wandb
shard=8, replicas=2 (8-way FSDP) 8.84% ± 2.55 32,670 ± 9,166 2,042 ± 573 4.210 ± 0.009 233,151 ± 59,605 01KH1QG1HYHEYVHRFBXC79MYA3 / wandb

Result: shard_degree=8 and no shard_degree produce identical performance. Over a full epoch (57 steps after warmup), all metrics are within noise. For DPO with 2 nodes, shard_degree has no effect — the default 16-way FSDP is fine.

Binary search: activation_memory_budget with bs=4

Tested both with shard_degree=8 and without shard_degree (16-way FSDP). All budgets > 0.1 OOM on the same 12.25 GiB logits allocation from DPO's concatenated forward pass.

With shard_degree=8 (8-way FSDP):

Budget Result Experiment
0.5 OOM (688 MiB alloc failed) 01KH1JHGY6STAF3GCCJTEPTWQS
0.3 OOM (12.25 GiB alloc failed) 01KH1K8WYQ2PA3Y2NZBG2GKPEF
0.2 OOM (12.25 GiB alloc failed) 01KH1KRZ8J56RM81A0RNK6SQJA

Without shard_degree (16-way FSDP):

Budget Result Experiment
0.5 OOM (688 MiB alloc failed) 01KH1VFZB95VQF38A35HA7HHPV
0.3 OOM (12.25 GiB alloc failed) 01KH1W3HV9DHGV87X8J6AK1GT8
0.2 OOM (12.25 GiB alloc failed) 01KH1WJ86T4Q3D9RNH171ZPY49
0.15 OOM (12.25 GiB alloc failed) 01KH1X0C2YTCR7K4HT5E9RJ95S

Conclusion: The OOM bottleneck is the logits tensor in DPO's concatenated forward pass (~12.25 GiB = 4 seqs × 2 × 16384 tokens × 152064 vocab × 2 bytes). This allocation is independent of activation checkpointing, so increasing activation_memory_budget beyond 0.1 always fails for bs=4 at 16k seq len. The budget only controls how much recomputation vs caching is used for intermediate activations — it cannot reduce the logits tensor size.

Analysis

  • Step time is constant (~6.19s) across all batch sizes (without shard_degree) — the GPU does the same work regardless; the difference is in how many tokens are actually processed per step.
  • bs=2/4 waste ~87% of tokens due to poor packing utilization at 16k seq len (only ~58k of 524k theoretical max tokens).
  • bs=32 hits 100% packing utilization (524,288 = 16 GPUs × 16,384 tokens × 2 chosen+rejected).
  • bs=16 is the sweet spot — 86% packing utilization with minimal sequence truncation (1-4 seqs dropped).
  • shard_degree=8 has no effect on throughput — full-epoch comparison shows identical MFU/TPS/sec_per_step.
  • MFU and TPS are linearly consistent post-fix (MFU/TPS ratio ~0.000265 across all batch sizes), confirming the sequence counting fix works correctly.
  • Note: pre-fix runs (commits before 555d956) reported inflated MFU/TPS because token_count was computed from config (theoretical max) rather than actual tokens processed.

Bug fixes:

  1. Fixed cu_seq_lens offset bug (padding_free_collator.py): When concatenating padded chosen/rejected sequences, the offset for rejected cu_seq_lens now correctly uses the padded length instead of actual sequence length.

  2. Fixed MFU calculation (olmo_core_callbacks.py): Count actual sequences from batch cu_seq_lens instead of deriving from config, which overcounted when packing truncates sequences.

Pre-filter DPO packing (no more mid-stream truncation)

When DPO packing is enabled, the collator now greedily packs as many complete examples as fit within max_seq_length, rather than packing all examples then truncating mid-stream. This prevents data loss from cutting sequences partway through.

Also fixed the reference cache data loader to use batch_size=1 per rank when packing, ensuring no examples are dropped during cache building.

Runs:

  1. Single GPU DPO: Beaker
  2. Multi-node DPO: Beaker

finbarrtimbers and others added 30 commits February 3, 2026 16:20
The prefill_flops() method was counting LM head FLOPs once per sequence,
but olmo-core counts them for each token position. This caused a ~1.6x
discrepancy between perf/mfu_avg and throughput/device/MFU metrics.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add new metrics to PerfCallback:
- perf/total_tokens: Cumulative tokens processed
- perf/data_loading_seconds: Time spent loading batch data
- perf/data_loading_pct: Data loading as % of wall clock time
- perf/wall_clock_per_step: Total time from one step start to next
- perf/step_overhead_pct: % of wall clock time on non-compute

Remove SpeedMonitorCallback from DPO since PerfCallback now provides
equivalent functionality with more detailed metrics.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…fCallback

Profiles what % of DPO training cycle is the actual step vs overhead.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
TP and CP are always 1 at runtime, so remove the dead parameters,
config creation, and mesh application logic. Simplify parallelism_factor
to only use pipeline_parallel_degree.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Breaks down the inter-step overhead to isolate where time is spent
between post_step callbacks and the next training step beginning.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Testing whether reducing GPU→CPU metric sync frequency improves
step overhead.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…Monitor

Adds torch.cuda.synchronize() before timing to distinguish host-side
vs GPU-side step duration. Reports cuda_sync_ms and host_step_time_ms
to identify how much time the GPU is still working after the host
returns from train_batch/optim_step.

Also re-enables SpeedMonitorCallback for olmo-core throughput metrics.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Targeting higher MFU by reducing padding waste. With avg seq length ~258
tokens, 16384 max_seq_length wastes 98.4% compute on padding. Reducing to
2048 and increasing batch_size to 8 keeps memory footprint similar
(batch*seq_len constant) while attention saves memory (O(seq_len^2)).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Test if eliminating 98.4% padding waste via packing outperforms
torch.compile speedup. Packing concatenates sequences with cu_seq_lens
boundaries for flash attention, avoiding padding entirely.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Fix KeyError in concatenated_forward_olmo packing path: the padding-free
collator produces 'concatenated_cu_seq_lens_k' but the code accessed
'concatenated_cu_seq_lens'.

Also reduce exp 1 batch_size from 8 to 4 (grad_accum 4→8) to avoid OOM
in reference cache building, which uses 4x the training batch_size.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Retry packing experiment now that the concatenated_cu_seq_lens KeyError
is fixed in dpo_utils.py.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Packing collator doesn't produce attention_mask (all tokens are real).
Fall back to summing input_ids tensor sizes when attention_mask is absent.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Previous attempt with bs=4 OOM'd during backward pass (compiled model
used 55 GiB + 6 GiB allocation failed). Reducing to bs=2 to fit within
80 GiB H100 memory with compile + activation checkpointing.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Restoring seq_len=16384, bs=1 baseline to confirm code fixes don't break
the non-packing path before attempting further optimizations.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add max_seq_length parameter to TensorDataCollatorWithFlattening.
After packing sequences, pad to max_seq_length so tensor shapes are
static (required for torch.compile). Flash attention uses cu_seq_lens
to ignore the padding.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
With DPO packing to max_seq_length, both chosen and rejected get
padded to 16k, using more memory than variable-length batches.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
16k was OOMing because DPO concatenates chosen+rejected for forward
pass (32k total tokens). 8k should fit within memory.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When packing with max_seq_length padding, the concatenated tensor
includes padding that extends beyond the actual data boundaries
tracked by cu_seq_lens. Truncate to actual content length before
splitting.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Packing + DPO has multiple issues:
1. Reference logprobs caching doesn't support packing (batch["index"]
   handling assumes non-packed batches)
2. DPO concatenates chosen+rejected, doubling memory when padded to
   max_seq_length
3. Dataset sequences may exceed the configured max_seq_length

The core infrastructure changes (max_seq_length padding in collator,
get_batch_logps truncation) are still in place for future work.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Changes:
- Add truncation logic to collator when packed sequences exceed max_seq_length
- Update cu_seq_lens to clamp at max_seq_length during truncation
- Fix MFU calculation in reference cache to use cu_seq_lens when packing
- Remove redundant truncation from get_batch_logps (now handled by collator)
- Update multi_node.sh: batch_size=4, packing=true, activation_memory_budget=0.3

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…cking

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When packing + padding to max_seq_length, the forward pass produces
logits for the full padded length, but cu_seq_lens only covers
the actual content. This truncation ensures the split_with_sizes
call works correctly.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Replace .tolist() and split-based loop with scatter_add for
torch.compile compatibility. Remove dynamic truncation to avoid
graph breaks.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Use cu_seq_lens to count actual tokens instead of numel() when packing
is enabled. This fixes:
- global_num_tokens_in_batch in data_loader.py
- batch_tokens in reference cache building

Without this fix, token counts would be inflated by padding.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Binary search found 0.1 works, 0.15 OOMs.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Keep same global batch size by reducing grad_accum from 2 to 1.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Testing memory limits with doubled batch size.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
finbarrtimbers and others added 5 commits February 12, 2026 11:41
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Avoids EADDRINUSE port conflicts on shared nodes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Use get_num_tokens helper in train_batch instead of chosen_attention_mask
  which doesn't exist when packing is enabled.
- Add missing ReduceType import from olmo_core.train.common.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Also update run-dpo-experiments skill to wait for single GPU success
before launching multi-node.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
finbarrtimbers and others added 9 commits February 12, 2026 13:23
…tream

When packing is enabled, greedily pack as many complete examples as fit
within max_seq_length rather than concatenating all examples then truncating.
This prevents data loss from cutting sequences mid-stream. Also set
cache_batch_size to 1 example per rank when packing to ensure no examples
are dropped during reference logprobs caching.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… safety

- Use clamp(max=...) instead of in-place clamp_max_() in _fit_to_max_length
- Use startswith/removeprefix instead of substring in/replace in _filter_feature_dicts
- Add is-not-None guards in _collect_flattened_features for type checker
- Deduplicate token counting in _get_batch_stats via padding_free_collator.get_num_tokens

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The _prefilter_features greedy packing ensures only complete sequences
are included, so sequences_dropped is always 0. Remove the metrics,
_align_dpo_seq_counts, and _num_valid_seqs/_sequences_dropped tracking.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Pre-filter guarantees packed length <= max_seq_length, so truncation
never fires. Replace _fit_to_max_length with _pad_to_max_length and
remove PackedTensors, logger, and cast imports.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Remove 3 redundant tests (test_indices_truncated_when_exceeding_max_length,
test_simulate_reference_cache_no_data_loss, test_asymmetric_lengths_pack_correctly)
that are subsumed by existing parameterized tests. Extract _collate and
_collate_and_get_logps helpers to reduce boilerplate. Add asymmetric case
to test_logps_count_matches_indices for coverage.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add single_example_collator to replace lambda x: x[0] everywhere,
  ensuring batch["index"] is always a tensor
- Track overflow in HFDataLoader._iter_batches so features filtered by
  the DPO collator are carried to the next batch instead of being
  silently dropped
- Overflow persists across epoch reshuffles
- Remove unnecessary "index" in features[0] guards in DPO collators

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Remove _pad_to_max_length function, inline padding directly in
  TensorDataCollatorWithFlattening.__call__
- Extract _prefilter_features as free function count_features_within_token_budget
  with docstring
- Remove stale comments in data_loader.py and dpo.py

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Collaborator

@hamishivi hamishivi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally seems good but some more minor comments. I feel like somehow all this prepro feels too complicated but thats a problem for another day.

Comment thread open_instruct/data_loader.py
micro_batch["chosen_attention_mask"].sum() + micro_batch["rejected_attention_mask"].sum()
).float()
micro_tokens = torch.tensor(
padding_free_collator.get_num_tokens(micro_batch), dtype=torch.float32, device=device
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: token counts can probably be longs.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you're right. The only issue here is that we have to do a manual all-reduce for all of the metrics that we want to calculate the weighted mean for, and that happens in float32 as a bunch of those metrics are float32s:

local_sums_list = [torch.tensor(total_tokens, dtype=torch.float32, device=device)] + [

We could make them all doubles. Idk do you think it's worth refactoring?

Comment thread open_instruct/padding_free_collator.py Outdated
ret["labels"].append(separator)
ret["labels"].append(label_source[1:])

if return_flash_attn_kwargs and cu_seq_lens is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the logic around return_flash_attn_kwargs is weird to me, I don't like mutable returns. Shouldnt we just always return the cu_seq_lens and let the caller decide if they want to keep or not? Does it add much runtime?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, you're right. I'll do an unconditional return.

finbarrtimbers and others added 3 commits February 14, 2026 10:19
- Drain overflow after epoch loop in HFDataLoader._iter_batches
- Keep token counts as plain Python ints until torch.stack point
- Always compute cu_seq_lens/pos_ids/seq_idx in _collect_flattened_features

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The overflow mechanism in HFDataLoader ensures no examples are dropped
regardless of batch size, making the batch_size=1 workaround unnecessary.
Also remove duplicate overflow drain loop in data_loader.py.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@finbarrtimbers finbarrtimbers added this pull request to the merge queue Feb 17, 2026
Merged via the queue into main with commit ee0c460 Feb 17, 2026
6 of 7 checks passed
@finbarrtimbers finbarrtimbers deleted the finbarr/profile-dpo branch February 17, 2026 18:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants