Enable packing + compile for DPO with wasted tokens metric#1466
Enable packing + compile for DPO with wasted tokens metric#1466finbarrtimbers merged 101 commits intomainfrom
Conversation
The prefill_flops() method was counting LM head FLOPs once per sequence, but olmo-core counts them for each token position. This caused a ~1.6x discrepancy between perf/mfu_avg and throughput/device/MFU metrics. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add new metrics to PerfCallback: - perf/total_tokens: Cumulative tokens processed - perf/data_loading_seconds: Time spent loading batch data - perf/data_loading_pct: Data loading as % of wall clock time - perf/wall_clock_per_step: Total time from one step start to next - perf/step_overhead_pct: % of wall clock time on non-compute Remove SpeedMonitorCallback from DPO since PerfCallback now provides equivalent functionality with more detailed metrics. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…fCallback Profiles what % of DPO training cycle is the actual step vs overhead. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
TP and CP are always 1 at runtime, so remove the dead parameters, config creation, and mesh application logic. Simplify parallelism_factor to only use pipeline_parallel_degree. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Breaks down the inter-step overhead to isolate where time is spent between post_step callbacks and the next training step beginning. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Testing whether reducing GPU→CPU metric sync frequency improves step overhead. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…Monitor Adds torch.cuda.synchronize() before timing to distinguish host-side vs GPU-side step duration. Reports cuda_sync_ms and host_step_time_ms to identify how much time the GPU is still working after the host returns from train_batch/optim_step. Also re-enables SpeedMonitorCallback for olmo-core throughput metrics. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Targeting higher MFU by reducing padding waste. With avg seq length ~258 tokens, 16384 max_seq_length wastes 98.4% compute on padding. Reducing to 2048 and increasing batch_size to 8 keeps memory footprint similar (batch*seq_len constant) while attention saves memory (O(seq_len^2)). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Test if eliminating 98.4% padding waste via packing outperforms torch.compile speedup. Packing concatenates sequences with cu_seq_lens boundaries for flash attention, avoiding padding entirely. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Fix KeyError in concatenated_forward_olmo packing path: the padding-free collator produces 'concatenated_cu_seq_lens_k' but the code accessed 'concatenated_cu_seq_lens'. Also reduce exp 1 batch_size from 8 to 4 (grad_accum 4→8) to avoid OOM in reference cache building, which uses 4x the training batch_size. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Retry packing experiment now that the concatenated_cu_seq_lens KeyError is fixed in dpo_utils.py. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Packing collator doesn't produce attention_mask (all tokens are real). Fall back to summing input_ids tensor sizes when attention_mask is absent. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Previous attempt with bs=4 OOM'd during backward pass (compiled model used 55 GiB + 6 GiB allocation failed). Reducing to bs=2 to fit within 80 GiB H100 memory with compile + activation checkpointing. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Restoring seq_len=16384, bs=1 baseline to confirm code fixes don't break the non-packing path before attempting further optimizations. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add max_seq_length parameter to TensorDataCollatorWithFlattening. After packing sequences, pad to max_seq_length so tensor shapes are static (required for torch.compile). Flash attention uses cu_seq_lens to ignore the padding. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
With DPO packing to max_seq_length, both chosen and rejected get padded to 16k, using more memory than variable-length batches. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
16k was OOMing because DPO concatenates chosen+rejected for forward pass (32k total tokens). 8k should fit within memory. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When packing with max_seq_length padding, the concatenated tensor includes padding that extends beyond the actual data boundaries tracked by cu_seq_lens. Truncate to actual content length before splitting. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Packing + DPO has multiple issues: 1. Reference logprobs caching doesn't support packing (batch["index"] handling assumes non-packed batches) 2. DPO concatenates chosen+rejected, doubling memory when padded to max_seq_length 3. Dataset sequences may exceed the configured max_seq_length The core infrastructure changes (max_seq_length padding in collator, get_batch_logps truncation) are still in place for future work. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Changes: - Add truncation logic to collator when packed sequences exceed max_seq_length - Update cu_seq_lens to clamp at max_seq_length during truncation - Fix MFU calculation in reference cache to use cu_seq_lens when packing - Remove redundant truncation from get_batch_logps (now handled by collator) - Update multi_node.sh: batch_size=4, packing=true, activation_memory_budget=0.3 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…cking Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When packing + padding to max_seq_length, the forward pass produces logits for the full padded length, but cu_seq_lens only covers the actual content. This truncation ensures the split_with_sizes call works correctly. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Replace .tolist() and split-based loop with scatter_add for torch.compile compatibility. Remove dynamic truncation to avoid graph breaks. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Use cu_seq_lens to count actual tokens instead of numel() when packing is enabled. This fixes: - global_num_tokens_in_batch in data_loader.py - batch_tokens in reference cache building Without this fix, token counts would be inflated by padding. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Binary search found 0.1 works, 0.15 OOMs. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Keep same global batch size by reducing grad_accum from 2 to 1. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Testing memory limits with doubled batch size. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Avoids EADDRINUSE port conflicts on shared nodes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Use get_num_tokens helper in train_batch instead of chosen_attention_mask which doesn't exist when packing is enabled. - Add missing ReduceType import from olmo_core.train.common. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Also update run-dpo-experiments skill to wait for single GPU success before launching multi-node. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…tream When packing is enabled, greedily pack as many complete examples as fit within max_seq_length rather than concatenating all examples then truncating. This prevents data loss from cutting sequences mid-stream. Also set cache_batch_size to 1 example per rank when packing to ensure no examples are dropped during reference logprobs caching. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… safety - Use clamp(max=...) instead of in-place clamp_max_() in _fit_to_max_length - Use startswith/removeprefix instead of substring in/replace in _filter_feature_dicts - Add is-not-None guards in _collect_flattened_features for type checker - Deduplicate token counting in _get_batch_stats via padding_free_collator.get_num_tokens Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The _prefilter_features greedy packing ensures only complete sequences are included, so sequences_dropped is always 0. Remove the metrics, _align_dpo_seq_counts, and _num_valid_seqs/_sequences_dropped tracking. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Pre-filter guarantees packed length <= max_seq_length, so truncation never fires. Replace _fit_to_max_length with _pad_to_max_length and remove PackedTensors, logger, and cast imports. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Remove 3 redundant tests (test_indices_truncated_when_exceeding_max_length, test_simulate_reference_cache_no_data_loss, test_asymmetric_lengths_pack_correctly) that are subsumed by existing parameterized tests. Extract _collate and _collate_and_get_logps helpers to reduce boilerplate. Add asymmetric case to test_logps_count_matches_indices for coverage. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add single_example_collator to replace lambda x: x[0] everywhere, ensuring batch["index"] is always a tensor - Track overflow in HFDataLoader._iter_batches so features filtered by the DPO collator are carried to the next batch instead of being silently dropped - Overflow persists across epoch reshuffles - Remove unnecessary "index" in features[0] guards in DPO collators Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Remove _pad_to_max_length function, inline padding directly in TensorDataCollatorWithFlattening.__call__ - Extract _prefilter_features as free function count_features_within_token_budget with docstring - Remove stale comments in data_loader.py and dpo.py Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
hamishivi
left a comment
There was a problem hiding this comment.
Generally seems good but some more minor comments. I feel like somehow all this prepro feels too complicated but thats a problem for another day.
| micro_batch["chosen_attention_mask"].sum() + micro_batch["rejected_attention_mask"].sum() | ||
| ).float() | ||
| micro_tokens = torch.tensor( | ||
| padding_free_collator.get_num_tokens(micro_batch), dtype=torch.float32, device=device |
There was a problem hiding this comment.
nit: token counts can probably be longs.
There was a problem hiding this comment.
Yeah, you're right. The only issue here is that we have to do a manual all-reduce for all of the metrics that we want to calculate the weighted mean for, and that happens in float32 as a bunch of those metrics are float32s:
We could make them all doubles. Idk do you think it's worth refactoring?
| ret["labels"].append(separator) | ||
| ret["labels"].append(label_source[1:]) | ||
|
|
||
| if return_flash_attn_kwargs and cu_seq_lens is not None: |
There was a problem hiding this comment.
the logic around return_flash_attn_kwargs is weird to me, I don't like mutable returns. Shouldnt we just always return the cu_seq_lens and let the caller decide if they want to keep or not? Does it add much runtime?
There was a problem hiding this comment.
no, you're right. I'll do an unconditional return.
- Drain overflow after epoch loop in HFDataLoader._iter_batches - Keep token counts as plain Python ints until torch.stack point - Always compute cu_seq_lens/pos_ids/seq_idx in _collect_flattened_features Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The overflow mechanism in HFDataLoader ensures no examples are dropped regardless of batch size, making the batch_size=1 workaround unnecessary. Also remove duplicate overflow drain loop in data_loader.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
tokens_per_second_per_gpumetric to wandbDPO Packing Performance Results (post MFU-fix)
All runs on
finbarr/profile-dpobranch, 2 nodes x 8 H100 GPUs, OLMo2-7B, 16k seq len,activation_memory_budget=0.1.Batch size comparison (no shard_degree, 16-way FSDP)
shard_degree=8 vs no shard_degree (bs=8, full epoch, 57 steps after warmup)
Result: shard_degree=8 and no shard_degree produce identical performance. Over a full epoch (57 steps after warmup), all metrics are within noise. For DPO with 2 nodes, shard_degree has no effect — the default 16-way FSDP is fine.
Binary search: activation_memory_budget with bs=4
Tested both with shard_degree=8 and without shard_degree (16-way FSDP). All budgets > 0.1 OOM on the same 12.25 GiB logits allocation from DPO's concatenated forward pass.
With shard_degree=8 (8-way FSDP):
Without shard_degree (16-way FSDP):
Conclusion: The OOM bottleneck is the logits tensor in DPO's concatenated forward pass (~12.25 GiB = 4 seqs × 2 × 16384 tokens × 152064 vocab × 2 bytes). This allocation is independent of activation checkpointing, so increasing
activation_memory_budgetbeyond 0.1 always fails for bs=4 at 16k seq len. The budget only controls how much recomputation vs caching is used for intermediate activations — it cannot reduce the logits tensor size.Analysis
token_countwas computed from config (theoretical max) rather than actual tokens processed.Bug fixes:
Fixed cu_seq_lens offset bug (
padding_free_collator.py): When concatenating padded chosen/rejected sequences, the offset for rejected cu_seq_lens now correctly uses the padded length instead of actual sequence length.Fixed MFU calculation (
olmo_core_callbacks.py): Count actual sequences from batch cu_seq_lens instead of deriving from config, which overcounted when packing truncates sequences.Pre-filter DPO packing (no more mid-stream truncation)
When DPO packing is enabled, the collator now greedily packs as many complete examples as fit within
max_seq_length, rather than packing all examples then truncating mid-stream. This prevents data loss from cutting sequences partway through.Also fixed the reference cache data loader to use
batch_size=1per rank when packing, ensuring no examples are dropped during cache building.Runs: