Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions fast_llm/data/sample/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,21 +98,41 @@ def __init__(
chosen_spans: RangeBatch | None = None,
rejected_spans: RangeBatch | None = None,
image_patches: PatchBatch | None = None,
valid_tokens: int | None = None,
):
self.tokens = tokens
self.loss_masking_spans = loss_masking_spans
self.chosen_spans = chosen_spans
self.rejected_spans = rejected_spans
self.image_patches = image_patches
self.valid_tokens = valid_tokens

@classmethod
def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.Self:
samples = list(samples)
token_batch = TokenBatch.from_samples([sample.tokens for sample in samples])
loss_masking_spans = _merge_optional(
RangeBatch.from_samples, [sample.loss_masking_spans for sample in samples]
)

# Calculate valid tokens for this batch (used for gradient accumulation weighting)
valid_tokens = None
if loss_masking_spans is not None:
batch_size, sequence_length = token_batch.tokens.shape
# Start with all tokens
valid_tokens = batch_size * sequence_length
# Subtract masked tokens
for sample_ranges in loss_masking_spans.ranges:
for begin, end in sample_ranges:
valid_tokens -= end - begin

return cls(
TokenBatch.from_samples([sample.tokens for sample in samples]),
_merge_optional(RangeBatch.from_samples, [sample.loss_masking_spans for sample in samples]),
token_batch,
loss_masking_spans,
_merge_optional(RangeBatch.from_samples, [sample.chosen_spans for sample in samples]),
_merge_optional(RangeBatch.from_samples, [sample.rejected_spans for sample in samples]),
_merge_optional(PatchBatch.from_samples, [sample.image_patches for sample in samples]),
valid_tokens,
)

def crop(self, begin: int, end: int) -> typing.Self:
Expand All @@ -122,6 +142,7 @@ def crop(self, begin: int, end: int) -> typing.Self:
_crop_optional(self.chosen_spans, begin, end),
_crop_optional(self.rejected_spans, begin, end),
_crop_optional(self.image_patches, begin, end),
valid_tokens=None, # Cropped batches don't have valid token counts
)

def to_device_(self, device: "torch.device | str"):
Expand Down
5 changes: 4 additions & 1 deletion fast_llm/data/sample/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@ def __init__(self, ranges: list[tuple[int, int]], sample_size: int):

@classmethod
def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self:
"""
Used to merge ranges from multiple documents, i.e. when multiple docuemnts are packed together.
"""
document: RangeSample
ranges = []
sample_size = 0
for document in documents:
for begin, end in document.ranges:
ranges.extend((begin + sample_size, end + sample_size))
ranges.append((begin + sample_size, end + sample_size))
sample_size += document.sample_size
return cls(ranges, sample_size)

Expand Down
1 change: 1 addition & 0 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def preprocess_batch(
phase: PhaseType,
iteration: int,
metrics: dict | None = None,
total_valid_tokens: int | None = None,
) -> list[tuple[torch.Tensor, dict]]:
# TODO Move batch splitting elsewhere, align interface with LayerBase
pass
Expand Down
6 changes: 6 additions & 0 deletions fast_llm/engine/multi_stage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ class StageConfig(Config):
hint=FieldHint.logging,
valid=check_field(Assert.geq, 0),
)
debug_losses: int = Field(
default=0,
desc="Log loss values after reduction.",
hint=FieldHint.logging,
valid=check_field(Assert.geq, 0),
)
debug_param_update: int = Field(
default=0,
desc="Log the parameters after update.",
Expand Down
47 changes: 43 additions & 4 deletions fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from fast_llm.config import Configurable
from fast_llm.core.distributed import all_reduce, recv, safe_barrier, send
from fast_llm.data.sample.language_model import LanguageModelBatch
from fast_llm.engine.config_utils.run import get_run, log_pipeline_parallel_main_rank
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.distributed.distributed import Distributed
Expand All @@ -18,7 +19,8 @@
from fast_llm.engine.optimizer.optimizer import Optimizer
from fast_llm.engine.schedule.config import EventType, ScheduleConfig, StepType, StreamType
from fast_llm.engine.schedule.schedule import Schedule, Step
from fast_llm.logging import log_memory_usage
from fast_llm.logging import log_memory_usage, log_tensor
from fast_llm.models.gpt.config import GPTBatchConfig
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -295,6 +297,10 @@ def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]:
else:
reduced_loss = 0.0
reduced_losses[name] = reduced_loss
if isinstance(reduced_loss, torch.Tensor) and self._multi_stage.config.multi_stage.debug_losses:
log_tensor(
f"loss: {name}", reduced_loss, level=self._multi_stage.config.multi_stage.debug_losses, log_fn=None
)
return {
name: reduced_loss.item() if isinstance(reduced_loss, torch.Tensor) else reduced_loss
for name, reduced_loss in reduced_losses.items()
Expand All @@ -319,19 +325,52 @@ def _train_step(self, context: BatchContext, step: Step) -> None:
def _preprocess_data(
self, context: BatchContext, data_iterator: typing.Iterator, preprocessed: bool
) -> typing.Generator[None, None, None]:
batch_config = context.schedule.batch_config
grad_output = (1 if self._optimizer is None else self._optimizer.grad_scale) / batch_config.num_inputs
from fast_llm.layers.language_model.config import LanguageModelKwargs

batch_config: GPTBatchConfig = context.schedule.batch_config
default_grad_output = (1 if self._optimizer is None else self._optimizer.grad_scale) / batch_config.num_inputs

# We need additional pass to compute total valid tokens, which is needed to correctly set grad weights when using loss masks + grad accumulation
# TODO: add conditions? This must not be used always
all_micro_batches = []
total_valid_tokens = None
for micro_batch in range(batch_config.sequential_micro_batches):
micro_batch_data = next(data_iterator)
micro_batch_data: LanguageModelBatch = next(data_iterator)
all_micro_batches.append(micro_batch_data)

# Sum valid tokens across all microbatches (if loss masking is used)
if (
not preprocessed
and hasattr(micro_batch_data, "valid_tokens")
and micro_batch_data.valid_tokens is not None
):
if total_valid_tokens is None:
total_valid_tokens = 0
total_valid_tokens += micro_batch_data.valid_tokens

# Second pass: Preprocess and yield each microbatch with correct gradient weighting
for micro_batch, micro_batch_data in enumerate(all_micro_batches):
if not preprocessed:
micro_batch_data = self._multi_stage.base_model.preprocess_batch(
micro_batch_data,
context.schedule.preprocessed_meta,
phase=context.phase,
iteration=context.iteration,
metrics=context.metrics,
total_valid_tokens=total_valid_tokens,
)
for micro_batch_split, (input_, kwargs) in enumerate(micro_batch_data):
# Compute grad_output based on valid tokens when loss masking is used
if LanguageModelKwargs.loss_mask in kwargs and total_valid_tokens is not None:
loss_mask = kwargs[LanguageModelKwargs.loss_mask]
valid_tokens = loss_mask.sum().item()
# Weight this micro-batch by its proportion of valid tokens. This is required to correctly scale the gradients when different microbatches have different number of valid tokens
grad_output = (1 if self._optimizer is None else self._optimizer.grad_scale) * (
valid_tokens / total_valid_tokens
)
else:
grad_output = default_grad_output

kwargs.update(
grad_output=grad_output,
micro_batch=micro_batch,
Expand Down
83 changes: 46 additions & 37 deletions fast_llm/functional/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,10 @@ def _torch_cross_entropy_forward_backward(
logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target
)
else:
loss = (
torch.nn.functional.cross_entropy(
logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none"
)
* loss_mask
).mean()
per_sample_loss = torch.nn.functional.cross_entropy(
logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none"
)
loss = (per_sample_loss * loss_mask).sum() / loss_mask.sum()
if grad_output is None:
grad = None
else:
Expand Down Expand Up @@ -129,7 +127,8 @@ def _fused_cross_entropy_forward_backward(
else:
grad_base = exp_logits - sum_exp_logits * target

grad = grad_base.mul((grad_output / logits.size(0)) / sum_exp_logits)
normalizer = loss_mask.sum() if loss_mask is not None else logits.size(0)
grad = grad_base.mul((grad_output / normalizer) / sum_exp_logits)
if logits_scale_factor != 1.0:
grad *= logits_scale_factor
if loss_mask is not None:
Expand All @@ -155,7 +154,8 @@ def _fused_cross_entropy_forward_backward(
if loss_mask is not None:
per_sample_loss = per_sample_loss * loss_mask

loss = per_sample_loss.mean()
valid_tokens = loss_mask.sum() if loss_mask is not None else logits.size(0)
loss = per_sample_loss.sum() / valid_tokens
if target_format != TargetFormat.labels and group is not None:
all_reduce(loss, op=ReduceOp.AVG, group=group)

Expand Down Expand Up @@ -227,7 +227,7 @@ def distributed_log_softmax(
return logits_norm - sum_exp_logits.log() # log_softmax


def _torch_reverse_kl_forward_backward(
def _reverse_kl_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
loss_mask: torch.Tensor | None,
Expand Down Expand Up @@ -261,36 +261,45 @@ def _torch_reverse_kl_forward_backward(

# Compute log probabilities
teacher_log_probs = distributed_log_softmax(target.float(), group=group)
# batch_size = logits.shape[0]
with torch.enable_grad():
logits_ = logits.float().detach().requires_grad_(grad_output is not None)
student_log_probs = distributed_log_softmax(logits_, group=group)

# Reverse KL: input=teacher_log_probs, target=student_probs
loss_terms = torch.nn.functional.kl_div(
teacher_log_probs, # input = log(p)
student_log_probs, # target = log(q)
reduction="none",
log_target=True,
).sum(dim=-1)
if loss_mask is not None:
# loss mask is the same on all ranks for TP over vocab.
valid = loss_mask.to(loss_terms.dtype)
loss_terms = loss_terms * valid
valid_tokens = torch.tensor(valid.sum(), device=loss_terms.device, dtype=loss_terms.dtype)
else:
valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype))
loss = loss_terms.sum() # sums over batch and seq. len.
student_log_probs = distributed_log_softmax(logits, group=group)

# Reverse KL: input=teacher_log_probs, target=student_probs
loss_terms = torch.nn.functional.kl_div(
teacher_log_probs, # input = log(p)
student_log_probs, # target = log(q)
reduction="none",
log_target=True,
).sum(dim=-1)
if loss_mask is not None:
# loss mask is the same on all ranks for TP over vocab.
valid = loss_mask.to(loss_terms.dtype)
loss_terms = loss_terms * valid
valid_tokens = valid.sum()
else:
valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype))
loss = loss_terms.sum() # sums over batch and seq. len.

if group is not None:
all_reduce(loss, op=ReduceOp.SUM, group=group)
loss /= valid_tokens

if grad_output is not None:
# need to calculate gradient manually, backprop through all reduce can be problematic, see https://github.com/pytorch/pytorch/issues/58005
log_ratio = student_log_probs - teacher_log_probs
expected = torch.sum(torch.exp(student_log_probs) * log_ratio, dim=-1, keepdim=True)
# expected E_q(log s - log t) -- this is actually dependent on the full vocab!
if group is not None:
all_reduce(loss, op=ReduceOp.SUM, group=group)
loss /= valid_tokens
all_reduce(expected, op=ReduceOp.SUM, group=group)
grad_base = torch.exp(student_log_probs) * (log_ratio - expected)

if grad_output is not None:
loss.backward(torch.full_like(loss, grad_output))
grad = logits_.grad.to(logits.dtype)
else:
grad = None
if loss_mask is not None:
valid = loss_mask.to(logits.dtype).unsqueeze(-1)
grad_base = grad_base * valid

grad = grad_base.mul(grad_output / valid_tokens)
grad = grad.to(logits.dtype)
else:
grad = None

return loss.detach_(), grad

Expand Down Expand Up @@ -339,7 +348,7 @@ def reverse_kl_forward_backward(
Assert.eq(loss_mask.shape, logits.shape[:-1])

# TODO: implement fused?
distillation_loss, distillation_grad = _torch_reverse_kl_forward_backward(
distillation_loss, distillation_grad = _reverse_kl_forward_backward(
logits=logits,
target=target,
loss_mask=loss_mask,
Expand Down
15 changes: 12 additions & 3 deletions fast_llm/functional/triton/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,22 @@ def triton_cross_entropy_forward_backward(
losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
# TODO: Safe to do inplace?
grad_logits = None if grad_output is None else torch.empty_like(logits)

# Compute valid token count for loss masking
if target_format == TargetFormat.labels:
# For labels format, masking is done via negative labels
valid_count = (target >= 0).sum().item() # Convert to Python scalar
else:
# For logits/probabilities format, masking is done via loss_mask
valid_count = loss_mask.sum().item() if loss_mask is not None else n_rows

if target_format == TargetFormat.labels:
triton_cross_entropy_forward_backward_kernel[(n_rows,)](
logits,
target,
grad_logits,
losses,
None if grad_output is None else grad_output / n_rows,
None if grad_output is None else grad_output / valid_count,
n_cols,
logits.stride(0),
None if grad_output is None else grad_logits.stride(0),
Expand All @@ -167,7 +176,7 @@ def triton_cross_entropy_forward_backward(
loss_mask,
grad_logits,
losses,
None if grad_output is None else grad_output / n_rows,
None if grad_output is None else grad_output / valid_count,
n_cols,
logits.stride(0),
target.stride(0),
Expand All @@ -177,4 +186,4 @@ def triton_cross_entropy_forward_backward(
num_warps=num_warps,
from_logits=target_format == TargetFormat.logits,
)
return losses.mean(), grad_logits
return losses.sum() / valid_count, grad_logits
1 change: 1 addition & 0 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class LanguageModelKwargs(BlockKwargs):
chosen_spans = "chosen_spans"
rejected_spans = "rejected_spans"
loss_mask = "loss_mask"
total_valid_tokens = "total_valid_tokens"
mask_inputs = "mask_inputs"


Expand Down
18 changes: 16 additions & 2 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,21 @@ def _logits_cross_entropy_forward_backward(
lm_loss, lm_grad = None, None

if distillation_target is not None and self._config.distillation_loss_factor > 0.0:
# We need to scale the loss by (valid_tokens * num_micro_batches) / total_valid_tokens to correctly average the loss over micro-batches.
# The runner averages losses by dividing by num_micro_batches, so we need to account for that.
# Note: for grads this scaling is already in the 'grad_output'
total_valid_tokens = kwargs.get(
LanguageModelKwargs.total_valid_tokens
) # number of not masked tokens across all micro-batches.
num_micro_batches = kwargs.get("num_micro_batches", 1)

if loss_mask is None or total_valid_tokens is None:
loss_scalor_df = 1
else:
valid_tokens = loss_mask.sum()
# Scale by (valid_tokens * num_micro_batches) / total_valid_tokens
# This accounts for the runner dividing by num_micro_batches
loss_scalor_df = (valid_tokens * num_micro_batches) / total_valid_tokens
if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl:
distillation_loss, distillation_grad = reverse_kl_forward_backward(
logits.flatten(0, -2),
Expand Down Expand Up @@ -405,13 +420,12 @@ def _logits_cross_entropy_forward_backward(
raise ValueError(
f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}"
)
distillation_loss = distillation_loss * self._config.distillation_loss_factor
distillation_loss = distillation_loss * self._config.distillation_loss_factor * loss_scalor_df
else:
distillation_loss, distillation_grad = None, None

# TODO: de-allocate earlier.
del logits

# TODO: Accumulate grads in-place to reduce memory and compute overhead.
grad = _add_tensors(dpo_grad, lm_grad, distillation_grad)

Expand Down
Loading