Problem
ScheduleRunner._preprocess_data (fast_llm/engine/schedule/runner.py:313-355) eagerly pulls every micro-batch from the data iterator at the start of each step, runs a global share_batch_data allreduce across all of them, and preprocesses every one (device transfer, attention masks, RoPE freqs, reference-model forward) — all before the first forward step runs. The existing TODO at trainer.py:221-222 flags it.
Two motivations to defer:
- Data-loader pressure with high gradient accumulation.
- RL on-policy staleness: with a streaming dataset, later micro-batches were generated by the rollout server earlier than necessary.
_get_forward_input (runner.py:423-433) already pulls lazily when needed; the generator just needs to do one MB per yield instead of all on the first next().
Constraint
share_batch_data (data/document/language_model.py, token.py) allreduces num_labels / num_documents across all MBs × DP, and the result is used as a divisor inside the loss kernel (loss/loss.py:124, loss/grpo.py:50). In functional/cross_entropy.py:174 the divisor scales both the loss value and the upstream gradient, so we can't drop it from the kernel without a story for the gradient scale.
num_labels depends on the loaded sample (loss masks, document boundaries, completion lengths) — we can't get it without pulling the data. So fully lazy pulls require restructuring how the divisor flows.
Option A — Mathematically correct
Defer the global division to end-of-batch. Use a cheap, locally-known scalar at kernel time to keep gradients in 16-bit range.
- Kernel divisor: total tokens in the MB (size of the tokens tensor) — known immediately at load time, no extra computation. It approximates
num_labels to within a few percent in practice, so the kernel-time gradient stays in its familiar scale.
- Loss return:
(local_loss, local_divisor) — local_loss = sum_token_loss / total_tokens, local_divisor = local num_labels (the actual count we want to normalize by globally).
- End-of-step in
_reduce_losses: allreduce local_divisor → D_global (replaces eager share_batch_data); allreduce local_loss → loss_sum; reported loss = loss_sum × total_tokens_total / D_global (exact); apply total_tokens_total / D_global (≈1) correction to grad shards before optimizer step.
- Quantities used only for metrics (e.g.
num_documents_in_batch in GRPO's new_logprobs_mean / num_documents_in_batch) don't need any pre-division in the kernel — compute them post-hoc at reduction time.
No per-step state, no bootstrap. Touches loss kernels, LossDef.reduce, and the runner.
Performance note: today's share_batch_data does a single up-front allreduce, and the result is reused for every loss and metric that needs it. With Option A, each loss/metric naturally allreduces its own divisor and value at end of step, so several small allreduces per step instead of one. Probably fine, but if it matters, the optimization is to pack everything that needs reducing into one tensor and do one bundled allreduce, with losses sharing a divisor (e.g. all consumers of num_labels_in_batch) sharing a slot in the bundle.
Option B — Simple, opt-in via flag
Add per_micro_batch_normalization: bool = False on ScheduleConfig. When off: current behavior, unchanged. When on:
_preprocess_data becomes per-MB lazy: pull one set of model_inputs, set num_labels_in_batch ← allreduce(num_labels, group=DP) × num_micro_batches (one cheap scalar allreduce per MB across DP only — no cross-MB aggregation, no blocking on data not yet pulled), set num_documents_in_batch similarly, run preprocess_batch, yield.
- No changes to loss kernels — they keep dividing by
num_labels_in_batch.
Math: DP is always exact. Across MBs, each MB is divided by DP_sum_for_this_MB × num_micro_batches instead of the true global sum. Identical to today when MB sizes are uniform across MBs (the typical case); per-MB weighted otherwise. User opts in and acknowledges.
Touches runner.py and schedule/config.py only.
Comparison
|
Option A |
Option B |
| Loss/grad math vs. today |
Identical |
Identical for MB-uniform sizes; per-MB weighted otherwise |
| Files touched |
Runner + loss kernels + LossDef.reduce |
Runner + config |
| Bootstrap state |
None |
None |
Open questions
- Is per-MB weighting acceptable as an opt-in for RL? Most RL setups weight per trajectory anyway.
- Combine — Option B as opt-in fast path, Option A as eventual default — or land just one?
Problem
ScheduleRunner._preprocess_data(fast_llm/engine/schedule/runner.py:313-355) eagerly pulls every micro-batch from the data iterator at the start of each step, runs a globalshare_batch_dataallreduce across all of them, and preprocesses every one (device transfer, attention masks, RoPE freqs, reference-model forward) — all before the first forward step runs. The existing TODO attrainer.py:221-222flags it.Two motivations to defer:
_get_forward_input(runner.py:423-433) already pulls lazily when needed; the generator just needs to do one MB peryieldinstead of all on the firstnext().Constraint
share_batch_data(data/document/language_model.py,token.py) allreducesnum_labels/num_documentsacross all MBs × DP, and the result is used as a divisor inside the loss kernel (loss/loss.py:124,loss/grpo.py:50). Infunctional/cross_entropy.py:174the divisor scales both the loss value and the upstream gradient, so we can't drop it from the kernel without a story for the gradient scale.num_labelsdepends on the loaded sample (loss masks, document boundaries, completion lengths) — we can't get it without pulling the data. So fully lazy pulls require restructuring how the divisor flows.Option A — Mathematically correct
Defer the global division to end-of-batch. Use a cheap, locally-known scalar at kernel time to keep gradients in 16-bit range.
num_labelsto within a few percent in practice, so the kernel-time gradient stays in its familiar scale.(local_loss, local_divisor)—local_loss = sum_token_loss / total_tokens,local_divisor = local num_labels(the actual count we want to normalize by globally)._reduce_losses: allreducelocal_divisor→D_global(replaces eagershare_batch_data); allreducelocal_loss→loss_sum; reported loss =loss_sum × total_tokens_total / D_global(exact); applytotal_tokens_total / D_global(≈1) correction to grad shards before optimizer step.num_documents_in_batchin GRPO'snew_logprobs_mean / num_documents_in_batch) don't need any pre-division in the kernel — compute them post-hoc at reduction time.No per-step state, no bootstrap. Touches loss kernels,
LossDef.reduce, and the runner.Performance note: today's
share_batch_datadoes a single up-front allreduce, and the result is reused for every loss and metric that needs it. With Option A, each loss/metric naturally allreduces its own divisor and value at end of step, so several small allreduces per step instead of one. Probably fine, but if it matters, the optimization is to pack everything that needs reducing into one tensor and do one bundled allreduce, with losses sharing a divisor (e.g. all consumers ofnum_labels_in_batch) sharing a slot in the bundle.Option B — Simple, opt-in via flag
Add
per_micro_batch_normalization: bool = FalseonScheduleConfig. When off: current behavior, unchanged. When on:_preprocess_databecomes per-MB lazy: pull one set ofmodel_inputs, setnum_labels_in_batch ← allreduce(num_labels, group=DP) × num_micro_batches(one cheap scalar allreduce per MB across DP only — no cross-MB aggregation, no blocking on data not yet pulled), setnum_documents_in_batchsimilarly, runpreprocess_batch, yield.num_labels_in_batch.Math: DP is always exact. Across MBs, each MB is divided by
DP_sum_for_this_MB × num_micro_batchesinstead of the true global sum. Identical to today when MB sizes are uniform across MBs (the typical case); per-MB weighted otherwise. User opts in and acknowledges.Touches
runner.pyandschedule/config.pyonly.Comparison
LossDef.reduceOpen questions