Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
98836e6
grpo: add policy-gradient metrics behind compute_extra_metrics flag
bigximik Apr 27, 2026
b856e39
grpo: align metric names with DeepSpeed path
bigximik Apr 27, 2026
b07b999
gspo: add sequence-level IS-ratio clipping loss
bigximik Apr 28, 2026
fecc978
schedule: add rollouts_per_step to auto-set depth_first_micro_batches
bigximik Apr 28, 2026
7d8ec0c
grpo: dynamic docs_per_step accumulation and normalize_by_documents
bigximik Apr 28, 2026
014ba59
grpo: temperature scaling for IS ratio parity with actor sampling
bigximik Apr 29, 2026
d8cb9ef
head: fp32_lm_head flag to match vLLM bf16_last_layer_fp32 precision
bigximik May 4, 2026
0f90f20
head: fix fp32_lm_head gradient flow via detach + manual weight grad …
bigximik May 4, 2026
557a3c4
grpo: decouple loss/gradient divisors and fix SDP loss double-counting
bigximik May 5, 2026
a1b3f32
Merge branch 'main' into grpo-metrics
jlamypoirier May 5, 2026
9c07626
Merge branch 'grpo-metrics' into gspo
jlamypoirier May 5, 2026
d360a46
grpo: address review feedback on metrics
jlamypoirier May 5, 2026
bb6315c
grpo: address review follow-ups
jlamypoirier May 5, 2026
89ed062
grpo: round-3 review fixes
jlamypoirier May 5, 2026
b0852fd
grpo: GRPOMetrics as NamedTuple
jlamypoirier May 5, 2026
61ad4f7
grpo: fix entropy under tensor-parallel + minor review fixes
jlamypoirier May 5, 2026
15ae8d6
Merge remote-tracking branch 'origin/grpo-metrics' into gspo
jlamypoirier May 5, 2026
dfd2ce3
Merge remote-tracking branch 'origin/main' into gspo
jlamypoirier May 6, 2026
2fc2dfe
Merge remote-tracking branch 'origin/main' into gspo
jlamypoirier May 6, 2026
8547a56
gspo: address coarse-review easy items (#7, #13, #14)
jlamypoirier May 6, 2026
fc96d07
gspo: register as a sibling loss type instead of policy_loss switch
jlamypoirier May 7, 2026
d2c051a
gspo: collapse loss subclasses, dispatch kernel via self._call_kernel
jlamypoirier May 7, 2026
0d0185c
gspo: dispatch via self._forward = <kernel> instead of wrapper method
jlamypoirier May 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fast_llm/data/document/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class LanguageModelBatchPreprocessingConfig(TokenPreprocessingConfig):
use_preference_spans: bool = Field(default=False)
use_grpo_data: bool = Field(default=False)
return_label_counts: bool = Field(default=False)
return_document_index: bool = Field(default=False)

def _validate(self) -> None:
super()._validate()
Expand Down
13 changes: 10 additions & 3 deletions fast_llm/data/document/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class LanguageModelTargetInput(ModelInput):
advantages: torch.Tensor | None = None
old_log_probabilities: torch.Tensor | None = None
label_counts: torch.Tensor | None = None
document_index: torch.Tensor | None = None
num_labels: int | None = None
num_labels_in_batch: int | None = None

Expand Down Expand Up @@ -84,6 +85,7 @@ def to_kwargs(self) -> dict[str, typing.Any]:
LanguageModelKwargs.advantages: [target.advantages for target in self.targets],
LanguageModelKwargs.old_log_probabilities: [target.old_log_probabilities for target in self.targets],
LanguageModelKwargs.label_counts: [target.label_counts for target in self.targets],
LanguageModelKwargs.document_index: [target.document_index for target in self.targets],
LanguageModelKwargs.num_labels_in_batch: [target.num_labels_in_batch for target in self.targets],
}
if self.image_patches is not None:
Expand Down Expand Up @@ -177,7 +179,11 @@ def _set_target_inputs(
document_begin += length

mask = labels >= 0
label_counts = self._get_label_counts(mask) if config.return_label_counts else None
label_counts, document_index = (
self._get_label_counts(mask)
if config.return_label_counts or config.return_document_index
else (None, None)
)

for input_index, model_input in enumerate(model_inputs):
label_end = model_input.sequence_k_dim.size + prediction_distance
Expand All @@ -188,6 +194,7 @@ def _set_target_inputs(
tokens=labels[label_begin:label_end].clone(),
mask=mask[label_begin:label_end] if config.return_prediction_mask else None,
label_counts=label_counts[label_begin:label_end] if config.return_label_counts else None,
document_index=document_index[label_begin:label_end] if config.return_document_index else None,
# Set value for the first input only so `share_batch_data` generated the correct sum.
# TODO: ====== Make optional?
num_labels=(
Expand All @@ -202,7 +209,7 @@ def _set_target_inputs(

model_input.targets.append(target_input)

def _get_label_counts(self, mask: torch.Tensor):
def _get_label_counts(self, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# Count the number of non-masked labels in each document through cumulative sums.
mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)])
length_cumsum = torch.tensor([0] + self.lengths, device=self.device).cumsum(0)
Expand All @@ -214,4 +221,4 @@ def _get_label_counts(self, mask: torch.Tensor):
document_index = torch.searchsorted(
length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right"
)
return labels_per_document[document_index]
return labels_per_document[document_index], document_index
10 changes: 10 additions & 0 deletions fast_llm/engine/schedule/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ class ScheduleConfig(Config):
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
docs_per_step: int = Field(
default=0,
desc="Target number of documents (rollouts) per optimizer step, globally across all data-parallel ranks. "
"When >0, each training step dynamically accumulates microbatches until the globally all-reduced "
"document count reaches this value, then triggers the optimizer step. "
"depth_first_micro_batches is ignored when this is set. "
"0 = use depth_first_micro_batches as-is (fixed microbatch count per step).",
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
breadth_first_micro_batches: int = Field(
default=1,
desc="Number of micro-batches processed breadth-first, i.e., interleaved across model stages.",
Expand Down
5 changes: 3 additions & 2 deletions fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ def _preprocess_data(
if context.schedule.phase.is_training
else None
)
model_inputs = [next(data_iterator) for _ in range(self._config.sequential_micro_batches)]
n_micro_batches = context.schedule._eff_sequential_micro_batches
model_inputs = [next(data_iterator) for _ in range(n_micro_batches)]
model_inputs[0][0].share_batch_data(
[model_input for model_inputs_ in model_inputs for model_input in model_inputs_], self._distributed
)
Expand All @@ -336,7 +337,7 @@ def _preprocess_data(
extra_kwargs={
"grad_output": grad_output,
"micro_batch": micro_batch,
"num_micro_batches": self._config.sequential_micro_batches,
"num_micro_batches": n_micro_batches,
"micro_batch_splits": self._config.micro_batch_splits,
},
)
Expand Down
38 changes: 25 additions & 13 deletions fast_llm/engine/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,17 @@ def __init__(
batch_meta: list[ModelInput],
distributed_config: DistributedConfig,
phase: PhaseType,
_depth_first_override: int | None = None,
):
super().__init__(config)
self._depth_first_override = _depth_first_override
self._multi_stage = multi_stage
self._distributed_config = distributed_config
self._num_stages = len(self._multi_stage.stages)
self._phase = phase
self._is_training = self._phase.is_training

if self._config.num_inputs < self._distributed_config.pipeline_parallel:
if self._eff_num_inputs < self._distributed_config.pipeline_parallel:
warnings.warn("Not enough input to achieve true pipeline parallelism.")

# Setup the activation metas.
Expand Down Expand Up @@ -155,9 +157,25 @@ def __init__(
def phase(self) -> PhaseType:
return self._phase

@property
def _eff_depth_first(self) -> int:
return (
self._depth_first_override
if self._depth_first_override is not None
else self._config.depth_first_micro_batches
)

@property
def _eff_sequential_micro_batches(self) -> int:
return self._eff_depth_first * self._config.breadth_first_micro_batches

@property
def _eff_num_inputs(self) -> int:
return self._eff_sequential_micro_batches * self._config.micro_batch_splits

@property
def samples_per_batch(self) -> int:
return self._config.sequential_micro_batches * self._distributed_config.batch_data_parallel
return self._eff_sequential_micro_batches * self._distributed_config.batch_data_parallel

def iterate(self, pipeline_rank: int | None = None) -> typing.Iterator[Step]:
return iter(self._steps if pipeline_rank is None else self._device_steps[pipeline_rank])
Expand Down Expand Up @@ -189,7 +207,7 @@ def _create_index(self) -> None:
Assert.in_range(
step.index,
0,
self._config.num_inputs,
self._eff_num_inputs,
)
Assert.incl(step.type_, (StepType.forward, StepType.backward))
step.global_index = i
Expand All @@ -205,7 +223,7 @@ def _create_index(self) -> None:
Assert.custom(all, self._device_steps)
# Consistency checks
step_map = self._step_map.copy()
for data_index in range(self._config.num_inputs):
for data_index in range(self._eff_num_inputs):
for type_ in (StepType.forward, StepType.backward):
for stage in range(0 if type_ == StepType.forward else self._first_grad_stage, self._num_stages):
assert (
Expand Down Expand Up @@ -470,14 +488,11 @@ def _create_steps(self) -> tuple[list[Step], int]:
first_grad_stage += 1
else:
first_grad_stage = self._num_stages
for depth_first_micro_batch in range(self._config.depth_first_micro_batches):
for depth_first_micro_batch in range(self._eff_depth_first):
for stage in range(self._num_stages):
for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches):
for micro_batch_split in range(self._config.micro_batch_splits):
micro_batch = (
breadth_first_micro_batch * self._config.depth_first_micro_batches
+ depth_first_micro_batch
)
micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch
steps.append(
Step(
stage=stage,
Expand All @@ -492,10 +507,7 @@ def _create_steps(self) -> tuple[list[Step], int]:
for stage in reversed(range(first_grad_stage, self._num_stages)):
for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches):
for micro_batch_split in reversed(range(self._config.micro_batch_splits)):
micro_batch = (
breadth_first_micro_batch * self._config.depth_first_micro_batches
+ depth_first_micro_batch
)
micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch
steps.append(
Step(
stage=stage,
Expand Down
61 changes: 54 additions & 7 deletions fast_llm/engine/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,12 @@ def setup(self, distributed: Distributed, run: Run) -> None:
preprocessing_config = self._multi_stage.get_preprocessing_config(
PhaseType.training, self._config.schedule.micro_batch_splits
)
self._single_mb_meta = preprocessing_config.get_input_meta(self._data.config.micro_batch_size)
self._schedule_cache: dict[int, Schedule] = {}
self._schedule = Schedule(
config=self._config.schedule,
multi_stage=self._multi_stage,
batch_meta=preprocessing_config.get_input_meta(self._data.config.micro_batch_size),
batch_meta=self._single_mb_meta,
distributed_config=self._config.model.distributed,
phase=PhaseType.training,
)
Expand All @@ -140,6 +142,41 @@ def setup(self, distributed: Distributed, run: Run) -> None:

self._is_setup = True

def _get_or_build_schedule(self, n_microbatches: int) -> Schedule:
if n_microbatches not in self._schedule_cache:
bfmb = self._config.schedule.breadth_first_micro_batches
depth_first = n_microbatches // bfmb
self._schedule_cache[n_microbatches] = Schedule(
config=self._config.schedule,
multi_stage=self._multi_stage,
batch_meta=self._single_mb_meta,
distributed_config=self._config.model.distributed,
phase=PhaseType.training,
_depth_first_override=depth_first,
)
return self._schedule_cache[n_microbatches]

def _prefetch_to_doc_target(self, data_iterator) -> list:
target = self._config.schedule.docs_per_step
bfmb = self._config.schedule.breadth_first_micro_batches
buffer = []
total_docs = 0
while total_docs < target:
mb = next(data_iterator)
mb[0].share_batch_data(mb, self._distributed)
total_docs += mb[0].num_documents_in_batch
buffer.append(mb)
Assert.eq(
len(buffer) % bfmb,
0,
msg=f"Fetched {len(buffer)} microbatches not divisible by breadth_first_micro_batches={bfmb}",
)
# Reset num_documents_in_batch to the step total on all microbatches
for mb in buffer:
for mi in mb:
mi.num_documents_in_batch = total_docs
return buffer

@abc.abstractmethod
def _get_data(self) -> Data:
pass
Expand Down Expand Up @@ -220,12 +257,22 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:

# TODO: Data loader hates getting all micro-batches at once.
# (Also preprocessing adds overhead)
reduced_losses, update_successful, train_metrics = self._runner.run_step(
train_iterator,
self._schedule,
iteration=self._completed_steps,
return_metrics=is_logging,
)
if self._config.schedule.docs_per_step > 0:
buffer = self._prefetch_to_doc_target(train_iterator)
step_schedule = self._get_or_build_schedule(len(buffer))
reduced_losses, update_successful, train_metrics = self._runner.run_step(
iter(buffer),
step_schedule,
iteration=self._completed_steps,
return_metrics=is_logging,
)
else:
reduced_losses, update_successful, train_metrics = self._runner.run_step(
train_iterator,
self._schedule,
iteration=self._completed_steps,
return_metrics=is_logging,
)

# Advanced, skipped, and Nan iterations.
if update_successful:
Expand Down
5 changes: 4 additions & 1 deletion fast_llm/functional/triton/grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def triton_grpo_loss_forward_backward(
logits_scale_factor: float = 1.0,
num_labels_in_seq: torch.Tensor | None = None,
divisor: float | None = None,
grad_divisor: float | None = None, # Optional separate divisor for the gradient (defaults to divisor)
block_size: int | None = None,
num_warps: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
Expand All @@ -148,6 +149,8 @@ def triton_grpo_loss_forward_backward(
n_cols = logits.size(-1)
if divisor is None:
divisor = n_rows
if grad_divisor is None:
grad_divisor = divisor
if block_size is None:
block_size = min(triton.next_power_of_2(n_cols), 32768)
if num_warps is None:
Expand All @@ -171,7 +174,7 @@ def triton_grpo_loss_forward_backward(
grad_logits = torch.empty_like(logits) if grad_logits is None else grad_logits
backward_kwargs = {
"grad_logits_ptr": grad_logits,
"grad_losses": grad_output / divisor,
"grad_losses": grad_output / grad_divisor,
"grad_logits_stride_0": grad_logits.stride(-2),
"accumulate": accumulate,
}
Expand Down
8 changes: 8 additions & 0 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class LanguageModelKwargs(LanguageModelLossKwargs):
sample_map = "sample_map"
embedding_map = "embedding_map"
num_documents_in_batch = "num_documents_in_batch"
document_index = "document_index"
# TODO: These are generic
phase = "phase"
loss_mask = "loss_mask"
Expand Down Expand Up @@ -119,6 +120,13 @@ class LanguageModelHeadConfig(BlockConfig):
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
fp32_lm_head: bool = Field(
default=False,
desc="Upcast input and weight to float32 before the lm_head linear. "
"Matches vLLM's bf16_last_layer_fp32 quantization so new_logprobs and old_logprobs "
"are computed at the same numerical precision, keeping the IS ratio near 1 at init.",
hint=FieldHint.feature,
)
prediction_heads: int = Field(
default=1,
desc="Prediction heads.",
Expand Down
32 changes: 27 additions & 5 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig
from fast_llm.layers.language_model.loss.loss import LanguageModelLoss
from fast_llm.tensor import TensorMeta
from fast_llm.tensor import TensorMeta, accumulate_gradient
from fast_llm.utils import Assert, safe_merge_dicts

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -242,9 +242,17 @@ def _logits_loss_forward_backward_partial(
split_index: int = 0,
return_logits: bool = False,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if self._config.fp32_lm_head:
input_dtype = input_.dtype
input_ = input_.to(torch.float32)
# detach → requires_grad=False → output_parallel_linear_backward skips weight grad
weight = self.output_weights.detach().to(torch.float32)
else:
weight = self.output_weights

logits, context = output_parallel_linear_forward(
input_=input_,
weight=self.output_weights,
weight=weight,
bias=None,
group=self._parallel_dim.group if self._vocab_parallel else None,
sequence_parallel=self._sequence_parallel and self._vocab_parallel,
Expand Down Expand Up @@ -273,9 +281,23 @@ def _logits_loss_forward_backward_partial(
if loss_value is not None:
losses_.append(loss_value.detach())

return sum(losses_) if losses_ else None, (
output_parallel_linear_backward(grad, context) if self.training else None
)
if not self.training or grad is None:
return sum(losses_) if losses_ else None, None

input_grad = output_parallel_linear_backward(grad, context)
if self._config.fp32_lm_head:
# Weight grad was skipped because weight.requires_grad=False; accumulate manually.
# context: (input_, weight, bias, group, sequence_parallel, ...)
saved_input = context[0]
if context[4]: # sequence_parallel
from fast_llm.core.ops import gather_op

saved_input = gather_op(saved_input, context[3], dim=0)
grad_weight = grad.flatten(0, -2).t().mm(saved_input.flatten(0, -2))
accumulate_gradient(self.output_weights, grad_weight.to(self.output_weights.dtype))
input_grad = input_grad.to(input_dtype)

return sum(losses_) if losses_ else None, input_grad

def get_loss_definitions(self) -> list[LossDef]:
return [
Expand Down
Loading
Loading