Skip to content

Lazy micro-batch input preparation in the schedule runner #510

@jlamypoirier

Description

@jlamypoirier

Problem

ScheduleRunner._preprocess_data (fast_llm/engine/schedule/runner.py:313-355) eagerly pulls every micro-batch from the data iterator at the start of each step, runs a global share_batch_data allreduce across all of them, and preprocesses every one (device transfer, attention masks, RoPE freqs, reference-model forward) — all before the first forward step runs. The existing TODO at trainer.py:221-222 flags it.

Two motivations to defer:

  • Data-loader pressure with high gradient accumulation.
  • RL on-policy staleness: with a streaming dataset, later micro-batches were generated by the rollout server earlier than necessary.

_get_forward_input (runner.py:423-433) already pulls lazily when needed; the generator just needs to do one MB per yield instead of all on the first next().

Constraint

share_batch_data (data/document/language_model.py, token.py) allreduces num_labels / num_documents across all MBs × DP, and the result is used as a divisor inside the loss kernel (loss/loss.py:124, loss/grpo.py:50). In functional/cross_entropy.py:174 the divisor scales both the loss value and the upstream gradient, so we can't drop it from the kernel without a story for the gradient scale.

num_labels depends on the loaded sample (loss masks, document boundaries, completion lengths) — we can't get it without pulling the data. So fully lazy pulls require restructuring how the divisor flows.

Option A — Mathematically correct

Defer the global division to end-of-batch. Use a cheap, locally-known scalar at kernel time to keep gradients in 16-bit range.

  • Kernel divisor: total tokens in the MB (size of the tokens tensor) — known immediately at load time, no extra computation. It approximates num_labels to within a few percent in practice, so the kernel-time gradient stays in its familiar scale.
  • Loss return: (local_loss, local_divisor)local_loss = sum_token_loss / total_tokens, local_divisor = local num_labels (the actual count we want to normalize by globally).
  • End-of-step in _reduce_losses: allreduce local_divisorD_global (replaces eager share_batch_data); allreduce local_lossloss_sum; reported loss = loss_sum × total_tokens_total / D_global (exact); apply total_tokens_total / D_global (≈1) correction to grad shards before optimizer step.
  • Quantities used only for metrics (e.g. num_documents_in_batch in GRPO's new_logprobs_mean / num_documents_in_batch) don't need any pre-division in the kernel — compute them post-hoc at reduction time.

No per-step state, no bootstrap. Touches loss kernels, LossDef.reduce, and the runner.

Performance note: today's share_batch_data does a single up-front allreduce, and the result is reused for every loss and metric that needs it. With Option A, each loss/metric naturally allreduces its own divisor and value at end of step, so several small allreduces per step instead of one. Probably fine, but if it matters, the optimization is to pack everything that needs reducing into one tensor and do one bundled allreduce, with losses sharing a divisor (e.g. all consumers of num_labels_in_batch) sharing a slot in the bundle.

Option B — Simple, opt-in via flag

Add per_micro_batch_normalization: bool = False on ScheduleConfig. When off: current behavior, unchanged. When on:

  • _preprocess_data becomes per-MB lazy: pull one set of model_inputs, set num_labels_in_batch ← allreduce(num_labels, group=DP) × num_micro_batches (one cheap scalar allreduce per MB across DP only — no cross-MB aggregation, no blocking on data not yet pulled), set num_documents_in_batch similarly, run preprocess_batch, yield.
  • No changes to loss kernels — they keep dividing by num_labels_in_batch.

Math: DP is always exact. Across MBs, each MB is divided by DP_sum_for_this_MB × num_micro_batches instead of the true global sum. Identical to today when MB sizes are uniform across MBs (the typical case); per-MB weighted otherwise. User opts in and acknowledges.

Touches runner.py and schedule/config.py only.

Comparison

Option A Option B
Loss/grad math vs. today Identical Identical for MB-uniform sizes; per-MB weighted otherwise
Files touched Runner + loss kernels + LossDef.reduce Runner + config
Bootstrap state None None

Open questions

  1. Is per-MB weighting acceptable as an opt-in for RL? Most RL setups weight per trajectory anyway.
  2. Combine — Option B as opt-in fast path, Option A as eventual default — or land just one?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions